1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "arith-to-spirv-pattern"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Operation Conversion
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 
28 /// Converts composite arith.constant operation to spv.Constant.
29 struct ConstantCompositeOpPattern final
30     : public OpConversionPattern<arith::ConstantOp> {
31   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
32 
33   LogicalResult
34   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 /// Converts scalar arith.constant operation to spv.Constant.
39 struct ConstantScalarOpPattern final
40     : public OpConversionPattern<arith::ConstantOp> {
41   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override;
46 };
47 
48 /// Converts arith.remsi to GLSL SPIR-V ops.
49 ///
50 /// This cannot be merged into the template unary/binary pattern due to Vulkan
51 /// restrictions over spv.SRem and spv.SMod.
52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
53   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override;
58 };
59 
60 /// Converts arith.remsi to OpenCL SPIR-V ops.
61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
62   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override;
67 };
68 
69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
70 /// other than the BinaryOpPatternPattern because if the operands are boolean
71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
74 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
75   using OpConversionPattern<Op>::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override;
80 };
81 
82 /// Converts arith.xori to SPIR-V operations.
83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
84   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
92 /// vector of i1.
93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
94   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
95 
96   LogicalResult
97   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override;
99 };
100 
101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
102 /// i1.
103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
104   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
112 /// i1.
113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
114   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
115 
116   LogicalResult
117   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
118                   ConversionPatternRewriter &rewriter) const override;
119 };
120 
121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
122 /// i1.
123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
124   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
128                   ConversionPatternRewriter &rewriter) const override;
129 };
130 
131 /// Converts type-casting standard operations to SPIR-V operations.
132 template <typename Op, typename SPIRVOp>
133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
134   using OpConversionPattern<Op>::OpConversionPattern;
135 
136   LogicalResult
137   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
138                   ConversionPatternRewriter &rewriter) const override;
139 };
140 
141 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
143 public:
144   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
145 
146   LogicalResult
147   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override;
149 };
150 
151 /// Converts integer compare operation to SPIR-V ops.
152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
153 public:
154   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
158                   ConversionPatternRewriter &rewriter) const override;
159 };
160 
161 /// Converts floating-point comparison operations to SPIR-V ops.
162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
163 public:
164   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
165 
166   LogicalResult
167   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override;
169 };
170 
171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
172 /// Kernel capability.
173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
174 public:
175   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 
182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
183 /// require additional capability.
184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
185 public:
186   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
187 
188   LogicalResult
189   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
190                   ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // Conversion Helpers
197 //===----------------------------------------------------------------------===//
198 
199 /// Converts the given `srcAttr` into a boolean attribute if it holds an
200 /// integral value. Returns null attribute if conversion fails.
201 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
202   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
203     return boolAttr;
204   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
205     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
206   return BoolAttr();
207 }
208 
209 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
210 /// Returns null attribute if conversion fails.
211 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
212                                       Builder builder) {
213   // If the source number uses less active bits than the target bitwidth, then
214   // it should be safe to convert.
215   if (srcAttr.getValue().isIntN(dstType.getWidth()))
216     return builder.getIntegerAttr(dstType, srcAttr.getInt());
217 
218   // XXX: Try again by interpreting the source number as a signed value.
219   // Although integers in the standard dialect are signless, they can represent
220   // a signed number. It's the operation decides how to interpret. This is
221   // dangerous, but it seems there is no good way of handling this if we still
222   // want to change the bitwidth. Emit a message at least.
223   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
224     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
225     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
226                             << dstAttr << "' for type '" << dstType << "'\n");
227     return dstAttr;
228   }
229 
230   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
231                           << "' illegal: cannot fit into target type '"
232                           << dstType << "'\n");
233   return IntegerAttr();
234 }
235 
236 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
237 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
238 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
239                                   Builder builder) {
240   // Only support converting to float for now.
241   if (!dstType.isF32())
242     return FloatAttr();
243 
244   // Try to convert the source floating-point number to single precision.
245   APFloat dstVal = srcAttr.getValue();
246   bool losesInfo = false;
247   APFloat::opStatus status =
248       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
249   if (status != APFloat::opOK || losesInfo) {
250     LLVM_DEBUG(llvm::dbgs()
251                << srcAttr << " illegal: cannot fit into converted type '"
252                << dstType << "'\n");
253     return FloatAttr();
254   }
255 
256   return builder.getF32FloatAttr(dstVal.convertToFloat());
257 }
258 
259 /// Returns true if the given `type` is a boolean scalar or vector type.
260 static bool isBoolScalarOrVector(Type type) {
261   if (type.isInteger(1))
262     return true;
263   if (auto vecType = type.dyn_cast<VectorType>())
264     return vecType.getElementType().isInteger(1);
265   return false;
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // ConstantOp with composite type
270 //===----------------------------------------------------------------------===//
271 
272 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273     arith::ConstantOp constOp, OpAdaptor adaptor,
274     ConversionPatternRewriter &rewriter) const {
275   auto srcType = constOp.getType().dyn_cast<ShapedType>();
276   if (!srcType)
277     return failure();
278 
279   // arith.constant should only have vector or tenor types.
280   assert((srcType.isa<VectorType, RankedTensorType>()));
281 
282   auto dstType = getTypeConverter()->convertType(srcType);
283   if (!dstType)
284     return failure();
285 
286   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
287   ShapedType dstAttrType = dstElementsAttr.getType();
288   if (!dstElementsAttr)
289     return failure();
290 
291   // If the composite type has more than one dimensions, perform linearization.
292   if (srcType.getRank() > 1) {
293     if (srcType.isa<RankedTensorType>()) {
294       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
295                                           srcType.getElementType());
296       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
297     } else {
298       // TODO: add support for large vectors.
299       return failure();
300     }
301   }
302 
303   Type srcElemType = srcType.getElementType();
304   Type dstElemType;
305   // Tensor types are converted to SPIR-V array types; vector types are
306   // converted to SPIR-V vector/array types.
307   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
308     dstElemType = arrayType.getElementType();
309   else
310     dstElemType = dstType.cast<VectorType>().getElementType();
311 
312   // If the source and destination element types are different, perform
313   // attribute conversion.
314   if (srcElemType != dstElemType) {
315     SmallVector<Attribute, 8> elements;
316     if (srcElemType.isa<FloatType>()) {
317       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
318         FloatAttr dstAttr =
319             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
320         if (!dstAttr)
321           return failure();
322         elements.push_back(dstAttr);
323       }
324     } else if (srcElemType.isInteger(1)) {
325       return failure();
326     } else {
327       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
328         IntegerAttr dstAttr = convertIntegerAttr(
329             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
330         if (!dstAttr)
331           return failure();
332         elements.push_back(dstAttr);
333       }
334     }
335 
336     // Unfortunately, we cannot use dialect-specific types for element
337     // attributes; element attributes only works with builtin types. So we need
338     // to prepare another converted builtin types for the destination elements
339     // attribute.
340     if (dstAttrType.isa<RankedTensorType>())
341       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
342     else
343       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
344 
345     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
346   }
347 
348   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
349                                                  dstElementsAttr);
350   return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // ConstantOp with scalar type
355 //===----------------------------------------------------------------------===//
356 
357 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358     arith::ConstantOp constOp, OpAdaptor adaptor,
359     ConversionPatternRewriter &rewriter) const {
360   Type srcType = constOp.getType();
361   if (!srcType.isIntOrIndexOrFloat())
362     return failure();
363 
364   Type dstType = getTypeConverter()->convertType(srcType);
365   if (!dstType)
366     return failure();
367 
368   // Floating-point types.
369   if (srcType.isa<FloatType>()) {
370     auto srcAttr = constOp.getValue().cast<FloatAttr>();
371     auto dstAttr = srcAttr;
372 
373     // Floating-point types not supported in the target environment are all
374     // converted to float type.
375     if (srcType != dstType) {
376       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
377       if (!dstAttr)
378         return failure();
379     }
380 
381     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
382     return success();
383   }
384 
385   // Bool type.
386   if (srcType.isInteger(1)) {
387     // arith.constant can use 0/1 instead of true/false for i1 values. We need
388     // to handle that here.
389     auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter);
390     if (!dstAttr)
391       return failure();
392     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
393     return success();
394   }
395 
396   // IndexType or IntegerType. Index values are converted to 32-bit integer
397   // values when converting to SPIR-V.
398   auto srcAttr = constOp.getValue().cast<IntegerAttr>();
399   auto dstAttr =
400       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
401   if (!dstAttr)
402     return failure();
403   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
404   return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // RemSIOpGLSLPattern
409 //===----------------------------------------------------------------------===//
410 
411 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
412 /// the sign of `signOperand`.
413 ///
414 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
415 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
416 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
417 /// if either operand can be negative. Emulate it via spv.UMod.
418 template <typename SignedAbsOp>
419 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
420                                     Value signOperand, OpBuilder &builder) {
421   assert(lhs.getType() == rhs.getType());
422   assert(lhs == signOperand || rhs == signOperand);
423 
424   Type type = lhs.getType();
425 
426   // Calculate the remainder with spv.UMod.
427   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
428   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
429   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
430 
431   // Fix the sign.
432   Value isPositive;
433   if (lhs == signOperand)
434     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
435   else
436     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
437   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
438   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
439 }
440 
441 LogicalResult
442 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
443                                     ConversionPatternRewriter &rewriter) const {
444   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
445       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
446       adaptor.getOperands()[0], rewriter);
447   rewriter.replaceOp(op, result);
448 
449   return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // RemSIOpOCLPattern
454 //===----------------------------------------------------------------------===//
455 
456 LogicalResult
457 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
458                                    ConversionPatternRewriter &rewriter) const {
459   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
460       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
461       adaptor.getOperands()[0], rewriter);
462   rewriter.replaceOp(op, result);
463 
464   return success();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // BitwiseOpPattern
469 //===----------------------------------------------------------------------===//
470 
471 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
472 LogicalResult
473 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
474     Op op, typename Op::Adaptor adaptor,
475     ConversionPatternRewriter &rewriter) const {
476   assert(adaptor.getOperands().size() == 2);
477   auto dstType =
478       this->getTypeConverter()->convertType(op.getResult().getType());
479   if (!dstType)
480     return failure();
481   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
482     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
483                                                          adaptor.getOperands());
484   } else {
485     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
486                                                          adaptor.getOperands());
487   }
488   return success();
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // XOrIOpLogicalPattern
493 //===----------------------------------------------------------------------===//
494 
495 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
496     arith::XOrIOp op, OpAdaptor adaptor,
497     ConversionPatternRewriter &rewriter) const {
498   assert(adaptor.getOperands().size() == 2);
499 
500   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
501     return failure();
502 
503   auto dstType = getTypeConverter()->convertType(op.getType());
504   if (!dstType)
505     return failure();
506   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
507                                                    adaptor.getOperands());
508 
509   return success();
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // XOrIOpBooleanPattern
514 //===----------------------------------------------------------------------===//
515 
516 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
517     arith::XOrIOp op, OpAdaptor adaptor,
518     ConversionPatternRewriter &rewriter) const {
519   assert(adaptor.getOperands().size() == 2);
520 
521   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
522     return failure();
523 
524   auto dstType = getTypeConverter()->convertType(op.getType());
525   if (!dstType)
526     return failure();
527   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
528                                                         adaptor.getOperands());
529   return success();
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // UIToFPI1Pattern
534 //===----------------------------------------------------------------------===//
535 
536 LogicalResult
537 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
538                                  ConversionPatternRewriter &rewriter) const {
539   auto srcType = adaptor.getOperands().front().getType();
540   if (!isBoolScalarOrVector(srcType))
541     return failure();
542 
543   auto dstType =
544       this->getTypeConverter()->convertType(op.getResult().getType());
545   Location loc = op.getLoc();
546   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
547   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
548   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
549       op, dstType, adaptor.getOperands().front(), one, zero);
550   return success();
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // ExtUII1Pattern
555 //===----------------------------------------------------------------------===//
556 
557 LogicalResult
558 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
559                                 ConversionPatternRewriter &rewriter) const {
560   auto srcType = adaptor.getOperands().front().getType();
561   if (!isBoolScalarOrVector(srcType))
562     return failure();
563 
564   auto dstType =
565       this->getTypeConverter()->convertType(op.getResult().getType());
566   Location loc = op.getLoc();
567   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
568   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
569   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
570       op, dstType, adaptor.getOperands().front(), one, zero);
571   return success();
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // TruncII1Pattern
576 //===----------------------------------------------------------------------===//
577 
578 LogicalResult
579 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
580                                  ConversionPatternRewriter &rewriter) const {
581   auto dstType =
582       this->getTypeConverter()->convertType(op.getResult().getType());
583   if (!isBoolScalarOrVector(dstType))
584     return failure();
585 
586   Location loc = op.getLoc();
587   auto srcType = adaptor.getOperands().front().getType();
588   // Check if (x & 1) == 1.
589   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
590   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
591       loc, srcType, adaptor.getOperands()[0], mask);
592   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
593 
594   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
595   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
596   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
597   return success();
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // TypeCastingOpPattern
602 //===----------------------------------------------------------------------===//
603 
604 template <typename Op, typename SPIRVOp>
605 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
606     Op op, typename Op::Adaptor adaptor,
607     ConversionPatternRewriter &rewriter) const {
608   assert(adaptor.getOperands().size() == 1);
609   auto srcType = adaptor.getOperands().front().getType();
610   auto dstType =
611       this->getTypeConverter()->convertType(op.getResult().getType());
612   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
613     return failure();
614   if (dstType == srcType) {
615     // Due to type conversion, we are seeing the same source and target type.
616     // Then we can just erase this operation by forwarding its operand.
617     rewriter.replaceOp(op, adaptor.getOperands().front());
618   } else {
619     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
620                                                   adaptor.getOperands());
621   }
622   return success();
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // CmpIOpBooleanPattern
627 //===----------------------------------------------------------------------===//
628 
629 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
630     arith::CmpIOp op, OpAdaptor adaptor,
631     ConversionPatternRewriter &rewriter) const {
632   Type operandType = op.getLhs().getType();
633   if (!isBoolScalarOrVector(operandType))
634     return failure();
635 
636   switch (op.getPredicate()) {
637 #define DISPATCH(cmpPredicate, spirvOp)                                        \
638   case cmpPredicate:                                                           \
639     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
640                                          adaptor.getLhs(), adaptor.getRhs());  \
641     return success();
642 
643     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
644     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
645 
646 #undef DISPATCH
647   default:;
648   }
649   return failure();
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // CmpIOpPattern
654 //===----------------------------------------------------------------------===//
655 
656 LogicalResult
657 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
658                                ConversionPatternRewriter &rewriter) const {
659   Type operandType = op.getLhs().getType();
660   if (isBoolScalarOrVector(operandType))
661     return failure();
662 
663   switch (op.getPredicate()) {
664 #define DISPATCH(cmpPredicate, spirvOp)                                        \
665   case cmpPredicate:                                                           \
666     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
667         operandType != this->getTypeConverter()->convertType(operandType)) {   \
668       return op.emitError(                                                     \
669           "bitwidth emulation is not implemented yet on unsigned op");         \
670     }                                                                          \
671     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
672                                          adaptor.getLhs(), adaptor.getRhs());  \
673     return success();
674 
675     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
676     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
677     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
678     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
679     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
680     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
681     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
682     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
683     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
684     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
685 
686 #undef DISPATCH
687   }
688   return failure();
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // CmpFOpPattern
693 //===----------------------------------------------------------------------===//
694 
695 LogicalResult
696 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
697                                ConversionPatternRewriter &rewriter) const {
698   switch (op.getPredicate()) {
699 #define DISPATCH(cmpPredicate, spirvOp)                                        \
700   case cmpPredicate:                                                           \
701     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
702                                          adaptor.getLhs(), adaptor.getRhs());  \
703     return success();
704 
705     // Ordered.
706     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
707     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
708     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
709     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
710     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
711     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
712     // Unordered.
713     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
714     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
715     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
716     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
717     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
718     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
719 
720 #undef DISPATCH
721 
722   default:
723     break;
724   }
725   return failure();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // CmpFOpNanKernelPattern
730 //===----------------------------------------------------------------------===//
731 
732 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
733     arith::CmpFOp op, OpAdaptor adaptor,
734     ConversionPatternRewriter &rewriter) const {
735   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
736     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
737                                                   adaptor.getRhs());
738     return success();
739   }
740 
741   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
742     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
743                                                     adaptor.getRhs());
744     return success();
745   }
746 
747   return failure();
748 }
749 
750 //===----------------------------------------------------------------------===//
751 // CmpFOpNanNonePattern
752 //===----------------------------------------------------------------------===//
753 
754 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
755     arith::CmpFOp op, OpAdaptor adaptor,
756     ConversionPatternRewriter &rewriter) const {
757   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
758       op.getPredicate() != arith::CmpFPredicate::UNO)
759     return failure();
760 
761   Location loc = op.getLoc();
762 
763   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
764   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
765 
766   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
767   if (op.getPredicate() == arith::CmpFPredicate::ORD)
768     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
769 
770   rewriter.replaceOp(op, replace);
771   return success();
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // Pattern Population
776 //===----------------------------------------------------------------------===//
777 
778 void mlir::arith::populateArithmeticToSPIRVPatterns(
779     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
780   // clang-format off
781   patterns.add<
782     ConstantCompositeOpPattern,
783     ConstantScalarOpPattern,
784     spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
785     spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
786     spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
787     spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
788     spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
789     spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
790     RemSIOpGLSLPattern, RemSIOpOCLPattern,
791     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
792     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
793     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
794     spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
795     spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
796     spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
797     spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
798     spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
799     spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
800     spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
801     spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
802     spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
803     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
804     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
805     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
806     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
807     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
808     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
809     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
810     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
811     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
812     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
813     CmpIOpBooleanPattern, CmpIOpPattern,
814     CmpFOpNanNonePattern, CmpFOpPattern
815   >(typeConverter, patterns.getContext());
816   // clang-format on
817 
818   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
819   // capability is available.
820   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
821                                        /*benefit=*/2);
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Pass Definition
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 struct ConvertArithmeticToSPIRVPass
830     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
831   void runOnFunction() override {
832     auto module = getOperation()->getParentOfType<ModuleOp>();
833     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
834     auto target = SPIRVConversionTarget::get(targetAttr);
835 
836     SPIRVTypeConverter::Options options;
837     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
838     SPIRVTypeConverter typeConverter(targetAttr, options);
839 
840     RewritePatternSet patterns(&getContext());
841     mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
842 
843     if (failed(applyPartialConversion(getOperation(), *target,
844                                       std::move(patterns))))
845       signalPassFailure();
846   }
847 };
848 } // namespace
849 
850 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
851   return std::make_unique<ConvertArithmeticToSPIRVPass>();
852 }
853