1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "arith-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation Conversion
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 
29 /// Converts composite arith.constant operation to spv.Constant.
30 struct ConstantCompositeOpPattern final
31     : public OpConversionPattern<arith::ConstantOp> {
32   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
33 
34   LogicalResult
35   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
36                   ConversionPatternRewriter &rewriter) const override;
37 };
38 
39 /// Converts scalar arith.constant operation to spv.Constant.
40 struct ConstantScalarOpPattern final
41     : public OpConversionPattern<arith::ConstantOp> {
42   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
43 
44   LogicalResult
45   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
46                   ConversionPatternRewriter &rewriter) const override;
47 };
48 
49 /// Converts arith.remsi to GLSL SPIR-V ops.
50 ///
51 /// This cannot be merged into the template unary/binary pattern due to Vulkan
52 /// restrictions over spv.SRem and spv.SMod.
53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
54   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
58                   ConversionPatternRewriter &rewriter) const override;
59 };
60 
61 /// Converts arith.remsi to OpenCL SPIR-V ops.
62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
63   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
64 
65   LogicalResult
66   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override;
68 };
69 
70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
71 /// other than the BinaryOpPatternPattern because if the operands are boolean
72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
75 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
76   using OpConversionPattern<Op>::OpConversionPattern;
77 
78   LogicalResult
79   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override;
81 };
82 
83 /// Converts arith.xori to SPIR-V operations.
84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
85   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
86 
87   LogicalResult
88   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
89                   ConversionPatternRewriter &rewriter) const override;
90 };
91 
92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
93 /// vector of i1.
94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
95   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
103 /// i1.
104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
105   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override;
110 };
111 
112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
113 /// i1.
114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
115   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
116 
117   LogicalResult
118   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
119                   ConversionPatternRewriter &rewriter) const override;
120 };
121 
122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
123 /// i1.
124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
125   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override;
130 };
131 
132 /// Converts type-casting standard operations to SPIR-V operations.
133 template <typename Op, typename SPIRVOp>
134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
135   using OpConversionPattern<Op>::OpConversionPattern;
136 
137   LogicalResult
138   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override;
140 };
141 
142 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
144 public:
145   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const override;
150 };
151 
152 /// Converts integer compare operation to SPIR-V ops.
153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
154 public:
155   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override;
160 };
161 
162 /// Converts floating-point comparison operations to SPIR-V ops.
163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
164 public:
165   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
166 
167   LogicalResult
168   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override;
170 };
171 
172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
173 /// Kernel capability.
174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
175 public:
176   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
177 
178   LogicalResult
179   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
180                   ConversionPatternRewriter &rewriter) const override;
181 };
182 
183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
184 /// require additional capability.
185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
186 public:
187   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
191                   ConversionPatternRewriter &rewriter) const override;
192 };
193 
194 /// Converts arith.select to spv.Select.
195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
196 public:
197   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
198   LogicalResult
199   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 
203 } // namespace
204 
205 //===----------------------------------------------------------------------===//
206 // Conversion Helpers
207 //===----------------------------------------------------------------------===//
208 
209 /// Converts the given `srcAttr` into a boolean attribute if it holds an
210 /// integral value. Returns null attribute if conversion fails.
211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
212   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
213     return boolAttr;
214   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
215     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
216   return BoolAttr();
217 }
218 
219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
220 /// Returns null attribute if conversion fails.
221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
222                                       Builder builder) {
223   // If the source number uses less active bits than the target bitwidth, then
224   // it should be safe to convert.
225   if (srcAttr.getValue().isIntN(dstType.getWidth()))
226     return builder.getIntegerAttr(dstType, srcAttr.getInt());
227 
228   // XXX: Try again by interpreting the source number as a signed value.
229   // Although integers in the standard dialect are signless, they can represent
230   // a signed number. It's the operation decides how to interpret. This is
231   // dangerous, but it seems there is no good way of handling this if we still
232   // want to change the bitwidth. Emit a message at least.
233   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
234     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
235     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
236                             << dstAttr << "' for type '" << dstType << "'\n");
237     return dstAttr;
238   }
239 
240   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
241                           << "' illegal: cannot fit into target type '"
242                           << dstType << "'\n");
243   return IntegerAttr();
244 }
245 
246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
249                                   Builder builder) {
250   // Only support converting to float for now.
251   if (!dstType.isF32())
252     return FloatAttr();
253 
254   // Try to convert the source floating-point number to single precision.
255   APFloat dstVal = srcAttr.getValue();
256   bool losesInfo = false;
257   APFloat::opStatus status =
258       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
259   if (status != APFloat::opOK || losesInfo) {
260     LLVM_DEBUG(llvm::dbgs()
261                << srcAttr << " illegal: cannot fit into converted type '"
262                << dstType << "'\n");
263     return FloatAttr();
264   }
265 
266   return builder.getF32FloatAttr(dstVal.convertToFloat());
267 }
268 
269 /// Returns true if the given `type` is a boolean scalar or vector type.
270 static bool isBoolScalarOrVector(Type type) {
271   if (type.isInteger(1))
272     return true;
273   if (auto vecType = type.dyn_cast<VectorType>())
274     return vecType.getElementType().isInteger(1);
275   return false;
276 }
277 
278 /// Returns true if scalar/vector type `a` and `b` have the same number of
279 /// bitwidth.
280 static bool hasSameBitwidth(Type a, Type b) {
281   auto getNumBitwidth = [](Type type) {
282     unsigned bw = 0;
283     if (type.isIntOrFloat())
284       bw = type.getIntOrFloatBitWidth();
285     else if (auto vecType = type.dyn_cast<VectorType>())
286       bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
287     return bw;
288   };
289   unsigned aBW = getNumBitwidth(a);
290   unsigned bBW = getNumBitwidth(b);
291   return aBW != 0 && bBW != 0 && aBW == bBW;
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // ConstantOp with composite type
296 //===----------------------------------------------------------------------===//
297 
298 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
299     arith::ConstantOp constOp, OpAdaptor adaptor,
300     ConversionPatternRewriter &rewriter) const {
301   auto srcType = constOp.getType().dyn_cast<ShapedType>();
302   if (!srcType || srcType.getNumElements() == 1)
303     return failure();
304 
305   // arith.constant should only have vector or tenor types.
306   assert((srcType.isa<VectorType, RankedTensorType>()));
307 
308   auto dstType = getTypeConverter()->convertType(srcType);
309   if (!dstType)
310     return failure();
311 
312   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
313   if (!dstElementsAttr)
314     return failure();
315 
316   ShapedType dstAttrType = dstElementsAttr.getType();
317 
318   // If the composite type has more than one dimensions, perform linearization.
319   if (srcType.getRank() > 1) {
320     if (srcType.isa<RankedTensorType>()) {
321       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
322                                           srcType.getElementType());
323       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
324     } else {
325       // TODO: add support for large vectors.
326       return failure();
327     }
328   }
329 
330   Type srcElemType = srcType.getElementType();
331   Type dstElemType;
332   // Tensor types are converted to SPIR-V array types; vector types are
333   // converted to SPIR-V vector/array types.
334   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
335     dstElemType = arrayType.getElementType();
336   else
337     dstElemType = dstType.cast<VectorType>().getElementType();
338 
339   // If the source and destination element types are different, perform
340   // attribute conversion.
341   if (srcElemType != dstElemType) {
342     SmallVector<Attribute, 8> elements;
343     if (srcElemType.isa<FloatType>()) {
344       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
345         FloatAttr dstAttr =
346             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
347         if (!dstAttr)
348           return failure();
349         elements.push_back(dstAttr);
350       }
351     } else if (srcElemType.isInteger(1)) {
352       return failure();
353     } else {
354       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
355         IntegerAttr dstAttr = convertIntegerAttr(
356             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
357         if (!dstAttr)
358           return failure();
359         elements.push_back(dstAttr);
360       }
361     }
362 
363     // Unfortunately, we cannot use dialect-specific types for element
364     // attributes; element attributes only works with builtin types. So we need
365     // to prepare another converted builtin types for the destination elements
366     // attribute.
367     if (dstAttrType.isa<RankedTensorType>())
368       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
369     else
370       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
371 
372     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
373   }
374 
375   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
376                                                  dstElementsAttr);
377   return success();
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // ConstantOp with scalar type
382 //===----------------------------------------------------------------------===//
383 
384 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
385     arith::ConstantOp constOp, OpAdaptor adaptor,
386     ConversionPatternRewriter &rewriter) const {
387   Type srcType = constOp.getType();
388   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
389     if (shapedType.getNumElements() != 1)
390       return failure();
391     srcType = shapedType.getElementType();
392   }
393   if (!srcType.isIntOrIndexOrFloat())
394     return failure();
395 
396   Attribute cstAttr = constOp.getValue();
397   if (cstAttr.getType().isa<ShapedType>())
398     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
399 
400   Type dstType = getTypeConverter()->convertType(srcType);
401   if (!dstType)
402     return failure();
403 
404   // Floating-point types.
405   if (srcType.isa<FloatType>()) {
406     auto srcAttr = cstAttr.cast<FloatAttr>();
407     auto dstAttr = srcAttr;
408 
409     // Floating-point types not supported in the target environment are all
410     // converted to float type.
411     if (srcType != dstType) {
412       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
413       if (!dstAttr)
414         return failure();
415     }
416 
417     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
418     return success();
419   }
420 
421   // Bool type.
422   if (srcType.isInteger(1)) {
423     // arith.constant can use 0/1 instead of true/false for i1 values. We need
424     // to handle that here.
425     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
426     if (!dstAttr)
427       return failure();
428     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
429     return success();
430   }
431 
432   // IndexType or IntegerType. Index values are converted to 32-bit integer
433   // values when converting to SPIR-V.
434   auto srcAttr = cstAttr.cast<IntegerAttr>();
435   auto dstAttr =
436       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
437   if (!dstAttr)
438     return failure();
439   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
440   return success();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // RemSIOpGLSLPattern
445 //===----------------------------------------------------------------------===//
446 
447 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
448 /// the sign of `signOperand`.
449 ///
450 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
451 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
452 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
453 /// if either operand can be negative. Emulate it via spv.UMod.
454 template <typename SignedAbsOp>
455 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
456                                     Value signOperand, OpBuilder &builder) {
457   assert(lhs.getType() == rhs.getType());
458   assert(lhs == signOperand || rhs == signOperand);
459 
460   Type type = lhs.getType();
461 
462   // Calculate the remainder with spv.UMod.
463   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
464   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
465   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
466 
467   // Fix the sign.
468   Value isPositive;
469   if (lhs == signOperand)
470     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
471   else
472     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
473   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
474   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
475 }
476 
477 LogicalResult
478 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
479                                     ConversionPatternRewriter &rewriter) const {
480   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
481       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
482       adaptor.getOperands()[0], rewriter);
483   rewriter.replaceOp(op, result);
484 
485   return success();
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // RemSIOpOCLPattern
490 //===----------------------------------------------------------------------===//
491 
492 LogicalResult
493 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
494                                    ConversionPatternRewriter &rewriter) const {
495   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
496       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
497       adaptor.getOperands()[0], rewriter);
498   rewriter.replaceOp(op, result);
499 
500   return success();
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // BitwiseOpPattern
505 //===----------------------------------------------------------------------===//
506 
507 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
508 LogicalResult
509 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
510     Op op, typename Op::Adaptor adaptor,
511     ConversionPatternRewriter &rewriter) const {
512   assert(adaptor.getOperands().size() == 2);
513   auto dstType =
514       this->getTypeConverter()->convertType(op.getResult().getType());
515   if (!dstType)
516     return failure();
517   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
518     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
519                                                          adaptor.getOperands());
520   } else {
521     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
522                                                          adaptor.getOperands());
523   }
524   return success();
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // XOrIOpLogicalPattern
529 //===----------------------------------------------------------------------===//
530 
531 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
532     arith::XOrIOp op, OpAdaptor adaptor,
533     ConversionPatternRewriter &rewriter) const {
534   assert(adaptor.getOperands().size() == 2);
535 
536   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
537     return failure();
538 
539   auto dstType = getTypeConverter()->convertType(op.getType());
540   if (!dstType)
541     return failure();
542   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
543                                                    adaptor.getOperands());
544 
545   return success();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // XOrIOpBooleanPattern
550 //===----------------------------------------------------------------------===//
551 
552 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
553     arith::XOrIOp op, OpAdaptor adaptor,
554     ConversionPatternRewriter &rewriter) const {
555   assert(adaptor.getOperands().size() == 2);
556 
557   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
558     return failure();
559 
560   auto dstType = getTypeConverter()->convertType(op.getType());
561   if (!dstType)
562     return failure();
563   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
564                                                         adaptor.getOperands());
565   return success();
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // UIToFPI1Pattern
570 //===----------------------------------------------------------------------===//
571 
572 LogicalResult
573 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
574                                  ConversionPatternRewriter &rewriter) const {
575   auto srcType = adaptor.getOperands().front().getType();
576   if (!isBoolScalarOrVector(srcType))
577     return failure();
578 
579   auto dstType =
580       this->getTypeConverter()->convertType(op.getResult().getType());
581   Location loc = op.getLoc();
582   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
583   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
584   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
585       op, dstType, adaptor.getOperands().front(), one, zero);
586   return success();
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // ExtUII1Pattern
591 //===----------------------------------------------------------------------===//
592 
593 LogicalResult
594 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
595                                 ConversionPatternRewriter &rewriter) const {
596   auto srcType = adaptor.getOperands().front().getType();
597   if (!isBoolScalarOrVector(srcType))
598     return failure();
599 
600   auto dstType =
601       this->getTypeConverter()->convertType(op.getResult().getType());
602   Location loc = op.getLoc();
603   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
604   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
605   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
606       op, dstType, adaptor.getOperands().front(), one, zero);
607   return success();
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // TruncII1Pattern
612 //===----------------------------------------------------------------------===//
613 
614 LogicalResult
615 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
616                                  ConversionPatternRewriter &rewriter) const {
617   auto dstType =
618       this->getTypeConverter()->convertType(op.getResult().getType());
619   if (!isBoolScalarOrVector(dstType))
620     return failure();
621 
622   Location loc = op.getLoc();
623   auto srcType = adaptor.getOperands().front().getType();
624   // Check if (x & 1) == 1.
625   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
626   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
627       loc, srcType, adaptor.getOperands()[0], mask);
628   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
629 
630   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
631   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
632   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
633   return success();
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // TypeCastingOpPattern
638 //===----------------------------------------------------------------------===//
639 
640 template <typename Op, typename SPIRVOp>
641 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
642     Op op, typename Op::Adaptor adaptor,
643     ConversionPatternRewriter &rewriter) const {
644   assert(adaptor.getOperands().size() == 1);
645   auto srcType = adaptor.getOperands().front().getType();
646   auto dstType =
647       this->getTypeConverter()->convertType(op.getResult().getType());
648   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
649     return failure();
650   if (dstType == srcType) {
651     // Due to type conversion, we are seeing the same source and target type.
652     // Then we can just erase this operation by forwarding its operand.
653     rewriter.replaceOp(op, adaptor.getOperands().front());
654   } else {
655     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
656                                                   adaptor.getOperands());
657   }
658   return success();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // CmpIOpBooleanPattern
663 //===----------------------------------------------------------------------===//
664 
665 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
666     arith::CmpIOp op, OpAdaptor adaptor,
667     ConversionPatternRewriter &rewriter) const {
668   Type operandType = op.getLhs().getType();
669   if (!isBoolScalarOrVector(operandType))
670     return failure();
671 
672   switch (op.getPredicate()) {
673 #define DISPATCH(cmpPredicate, spirvOp)                                        \
674   case cmpPredicate: {                                                         \
675     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
676                                          adaptor.getRhs());                    \
677     return success();                                                          \
678   }
679 
680     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
681     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
682 
683 #undef DISPATCH
684   default:;
685   }
686   return failure();
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // CmpIOpPattern
691 //===----------------------------------------------------------------------===//
692 
693 LogicalResult
694 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
695                                ConversionPatternRewriter &rewriter) const {
696   Type srcType = op.getLhs().getType();
697   if (isBoolScalarOrVector(srcType))
698     return failure();
699   Type dstType = getTypeConverter()->convertType(srcType);
700   if (!dstType)
701     return failure();
702 
703   switch (op.getPredicate()) {
704 #define DISPATCH(cmpPredicate, spirvOp)                                        \
705   case cmpPredicate:                                                           \
706     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
707         srcType != dstType && !hasSameBitwidth(srcType, dstType)) {            \
708       return op.emitError(                                                     \
709           "bitwidth emulation is not implemented yet on unsigned op");         \
710     }                                                                          \
711     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
712                                          adaptor.getRhs());                    \
713     return success();
714 
715     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
716     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
717     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
718     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
719     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
720     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
721     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
722     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
723     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
724     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
725 
726 #undef DISPATCH
727   }
728   return failure();
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // CmpFOpPattern
733 //===----------------------------------------------------------------------===//
734 
735 LogicalResult
736 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
737                                ConversionPatternRewriter &rewriter) const {
738   switch (op.getPredicate()) {
739 #define DISPATCH(cmpPredicate, spirvOp)                                        \
740   case cmpPredicate:                                                           \
741     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
742                                          adaptor.getRhs());                    \
743     return success();
744 
745     // Ordered.
746     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
747     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
748     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
749     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
750     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
751     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
752     // Unordered.
753     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
754     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
755     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
756     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
757     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
758     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
759 
760 #undef DISPATCH
761 
762   default:
763     break;
764   }
765   return failure();
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // CmpFOpNanKernelPattern
770 //===----------------------------------------------------------------------===//
771 
772 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
773     arith::CmpFOp op, OpAdaptor adaptor,
774     ConversionPatternRewriter &rewriter) const {
775   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
776     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
777                                                   adaptor.getRhs());
778     return success();
779   }
780 
781   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
782     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
783                                                     adaptor.getRhs());
784     return success();
785   }
786 
787   return failure();
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // CmpFOpNanNonePattern
792 //===----------------------------------------------------------------------===//
793 
794 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
795     arith::CmpFOp op, OpAdaptor adaptor,
796     ConversionPatternRewriter &rewriter) const {
797   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
798       op.getPredicate() != arith::CmpFPredicate::UNO)
799     return failure();
800 
801   Location loc = op.getLoc();
802 
803   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
804   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
805 
806   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
807   if (op.getPredicate() == arith::CmpFPredicate::ORD)
808     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
809 
810   rewriter.replaceOp(op, replace);
811   return success();
812 }
813 
814 //===----------------------------------------------------------------------===//
815 // SelectOpPattern
816 //===----------------------------------------------------------------------===//
817 
818 LogicalResult
819 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
820                                  ConversionPatternRewriter &rewriter) const {
821   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
822                                                adaptor.getTrueValue(),
823                                                adaptor.getFalseValue());
824   return success();
825 }
826 
827 //===----------------------------------------------------------------------===//
828 // Pattern Population
829 //===----------------------------------------------------------------------===//
830 
831 void mlir::arith::populateArithmeticToSPIRVPatterns(
832     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
833   // clang-format off
834   patterns.add<
835     ConstantCompositeOpPattern,
836     ConstantScalarOpPattern,
837     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
838     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
839     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
840     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
841     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
842     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
843     RemSIOpGLSLPattern, RemSIOpOCLPattern,
844     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
845     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
846     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
847     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
848     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
849     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
850     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
851     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
852     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
853     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
854     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
855     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
856     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
857     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
858     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
859     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
860     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
861     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
862     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
863     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
864     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
865     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
866     CmpIOpBooleanPattern, CmpIOpPattern,
867     CmpFOpNanNonePattern, CmpFOpPattern,
868     SelectOpPattern,
869 
870     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
871     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
872     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
873     spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
874     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
875     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>
876   >(typeConverter, patterns.getContext());
877   // clang-format on
878 
879   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
880   // capability is available.
881   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
882                                        /*benefit=*/2);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // Pass Definition
887 //===----------------------------------------------------------------------===//
888 
889 namespace {
890 struct ConvertArithmeticToSPIRVPass
891     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
892   void runOnOperation() override {
893     auto module = getOperation();
894     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
895     auto target = SPIRVConversionTarget::get(targetAttr);
896 
897     SPIRVTypeConverter::Options options;
898     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
899     SPIRVTypeConverter typeConverter(targetAttr, options);
900 
901     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
902     // in patterns for other dialects.
903     auto addUnrealizedCast = [](OpBuilder &builder, Type type,
904                                 ValueRange inputs, Location loc) {
905       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
906       return Optional<Value>(cast.getResult(0));
907     };
908     typeConverter.addSourceMaterialization(addUnrealizedCast);
909     typeConverter.addTargetMaterialization(addUnrealizedCast);
910     target->addLegalOp<UnrealizedConversionCastOp>();
911 
912     RewritePatternSet patterns(&getContext());
913     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
914 
915     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
916       signalPassFailure();
917   }
918 };
919 } // namespace
920 
921 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
922   return std::make_unique<ConvertArithmeticToSPIRVPass>();
923 }
924