1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "arith-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation Conversion
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 
30 /// Converts composite arith.constant operation to spv.Constant.
31 struct ConstantCompositeOpPattern final
32     : public OpConversionPattern<arith::ConstantOp> {
33   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
34 
35   LogicalResult
36   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
37                   ConversionPatternRewriter &rewriter) const override;
38 };
39 
40 /// Converts scalar arith.constant operation to spv.Constant.
41 struct ConstantScalarOpPattern final
42     : public OpConversionPattern<arith::ConstantOp> {
43   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
44 
45   LogicalResult
46   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
47                   ConversionPatternRewriter &rewriter) const override;
48 };
49 
50 /// Converts arith.remsi to GLSL SPIR-V ops.
51 ///
52 /// This cannot be merged into the template unary/binary pattern due to Vulkan
53 /// restrictions over spv.SRem and spv.SMod.
54 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
55   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
56 
57   LogicalResult
58   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
59                   ConversionPatternRewriter &rewriter) const override;
60 };
61 
62 /// Converts arith.remsi to OpenCL SPIR-V ops.
63 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
64   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
65 
66   LogicalResult
67   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
68                   ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
72 /// other than the BinaryOpPatternPattern because if the operands are boolean
73 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
74 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
75 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
76 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
77   using OpConversionPattern<Op>::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
81                   ConversionPatternRewriter &rewriter) const override;
82 };
83 
84 /// Converts arith.xori to SPIR-V operations.
85 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
86   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
87 
88   LogicalResult
89   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
90                   ConversionPatternRewriter &rewriter) const override;
91 };
92 
93 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
94 /// vector of i1.
95 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
96   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
97 
98   LogicalResult
99   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
100                   ConversionPatternRewriter &rewriter) const override;
101 };
102 
103 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
104 /// i1.
105 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
106   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
107 
108   LogicalResult
109   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const override;
111 };
112 
113 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
114 /// i1.
115 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
116   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
117 
118   LogicalResult
119   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
120                   ConversionPatternRewriter &rewriter) const override;
121 };
122 
123 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
124 /// i1.
125 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
126   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
127 
128   LogicalResult
129   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const override;
131 };
132 
133 /// Converts type-casting standard operations to SPIR-V operations.
134 template <typename Op, typename SPIRVOp>
135 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
136   using OpConversionPattern<Op>::OpConversionPattern;
137 
138   LogicalResult
139   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
140                   ConversionPatternRewriter &rewriter) const override;
141 };
142 
143 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
144 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
145 public:
146   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
147 
148   LogicalResult
149   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
150                   ConversionPatternRewriter &rewriter) const override;
151 };
152 
153 /// Converts integer compare operation to SPIR-V ops.
154 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
155 public:
156   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
157 
158   LogicalResult
159   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
160                   ConversionPatternRewriter &rewriter) const override;
161 };
162 
163 /// Converts floating-point comparison operations to SPIR-V ops.
164 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
165 public:
166   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
167 
168   LogicalResult
169   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
170                   ConversionPatternRewriter &rewriter) const override;
171 };
172 
173 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
174 /// Kernel capability.
175 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
176 public:
177   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
178 
179   LogicalResult
180   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
181                   ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
185 /// require additional capability.
186 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
187 public:
188   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
189 
190   LogicalResult
191   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
192                   ConversionPatternRewriter &rewriter) const override;
193 };
194 
195 /// Converts arith.select to spv.Select.
196 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
197 public:
198   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
199   LogicalResult
200   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
201                   ConversionPatternRewriter &rewriter) const override;
202 };
203 
204 } // namespace
205 
206 //===----------------------------------------------------------------------===//
207 // Conversion Helpers
208 //===----------------------------------------------------------------------===//
209 
210 /// Converts the given `srcAttr` into a boolean attribute if it holds an
211 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)212 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
213   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
214     return boolAttr;
215   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
216     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
217   return BoolAttr();
218 }
219 
220 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
221 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)222 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
223                                       Builder builder) {
224   // If the source number uses less active bits than the target bitwidth, then
225   // it should be safe to convert.
226   if (srcAttr.getValue().isIntN(dstType.getWidth()))
227     return builder.getIntegerAttr(dstType, srcAttr.getInt());
228 
229   // XXX: Try again by interpreting the source number as a signed value.
230   // Although integers in the standard dialect are signless, they can represent
231   // a signed number. It's the operation decides how to interpret. This is
232   // dangerous, but it seems there is no good way of handling this if we still
233   // want to change the bitwidth. Emit a message at least.
234   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
235     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
236     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
237                             << dstAttr << "' for type '" << dstType << "'\n");
238     return dstAttr;
239   }
240 
241   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
242                           << "' illegal: cannot fit into target type '"
243                           << dstType << "'\n");
244   return IntegerAttr();
245 }
246 
247 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
248 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)249 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
250                                   Builder builder) {
251   // Only support converting to float for now.
252   if (!dstType.isF32())
253     return FloatAttr();
254 
255   // Try to convert the source floating-point number to single precision.
256   APFloat dstVal = srcAttr.getValue();
257   bool losesInfo = false;
258   APFloat::opStatus status =
259       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
260   if (status != APFloat::opOK || losesInfo) {
261     LLVM_DEBUG(llvm::dbgs()
262                << srcAttr << " illegal: cannot fit into converted type '"
263                << dstType << "'\n");
264     return FloatAttr();
265   }
266 
267   return builder.getF32FloatAttr(dstVal.convertToFloat());
268 }
269 
270 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)271 static bool isBoolScalarOrVector(Type type) {
272   if (type.isInteger(1))
273     return true;
274   if (auto vecType = type.dyn_cast<VectorType>())
275     return vecType.getElementType().isInteger(1);
276   return false;
277 }
278 
279 /// Returns true if scalar/vector type `a` and `b` have the same number of
280 /// bitwidth.
hasSameBitwidth(Type a,Type b)281 static bool hasSameBitwidth(Type a, Type b) {
282   auto getNumBitwidth = [](Type type) {
283     unsigned bw = 0;
284     if (type.isIntOrFloat())
285       bw = type.getIntOrFloatBitWidth();
286     else if (auto vecType = type.dyn_cast<VectorType>())
287       bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
288     return bw;
289   };
290   unsigned aBW = getNumBitwidth(a);
291   unsigned bBW = getNumBitwidth(b);
292   return aBW != 0 && bBW != 0 && aBW == bBW;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // ConstantOp with composite type
297 //===----------------------------------------------------------------------===//
298 
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const299 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
300     arith::ConstantOp constOp, OpAdaptor adaptor,
301     ConversionPatternRewriter &rewriter) const {
302   auto srcType = constOp.getType().dyn_cast<ShapedType>();
303   if (!srcType || srcType.getNumElements() == 1)
304     return failure();
305 
306   // arith.constant should only have vector or tenor types.
307   assert((srcType.isa<VectorType, RankedTensorType>()));
308 
309   auto dstType = getTypeConverter()->convertType(srcType);
310   if (!dstType)
311     return failure();
312 
313   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
314   if (!dstElementsAttr)
315     return failure();
316 
317   ShapedType dstAttrType = dstElementsAttr.getType();
318 
319   // If the composite type has more than one dimensions, perform linearization.
320   if (srcType.getRank() > 1) {
321     if (srcType.isa<RankedTensorType>()) {
322       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
323                                           srcType.getElementType());
324       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
325     } else {
326       // TODO: add support for large vectors.
327       return failure();
328     }
329   }
330 
331   Type srcElemType = srcType.getElementType();
332   Type dstElemType;
333   // Tensor types are converted to SPIR-V array types; vector types are
334   // converted to SPIR-V vector/array types.
335   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
336     dstElemType = arrayType.getElementType();
337   else
338     dstElemType = dstType.cast<VectorType>().getElementType();
339 
340   // If the source and destination element types are different, perform
341   // attribute conversion.
342   if (srcElemType != dstElemType) {
343     SmallVector<Attribute, 8> elements;
344     if (srcElemType.isa<FloatType>()) {
345       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
346         FloatAttr dstAttr =
347             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
348         if (!dstAttr)
349           return failure();
350         elements.push_back(dstAttr);
351       }
352     } else if (srcElemType.isInteger(1)) {
353       return failure();
354     } else {
355       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
356         IntegerAttr dstAttr = convertIntegerAttr(
357             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
358         if (!dstAttr)
359           return failure();
360         elements.push_back(dstAttr);
361       }
362     }
363 
364     // Unfortunately, we cannot use dialect-specific types for element
365     // attributes; element attributes only works with builtin types. So we need
366     // to prepare another converted builtin types for the destination elements
367     // attribute.
368     if (dstAttrType.isa<RankedTensorType>())
369       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
370     else
371       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
372 
373     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
374   }
375 
376   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
377                                                  dstElementsAttr);
378   return success();
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // ConstantOp with scalar type
383 //===----------------------------------------------------------------------===//
384 
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const385 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386     arith::ConstantOp constOp, OpAdaptor adaptor,
387     ConversionPatternRewriter &rewriter) const {
388   Type srcType = constOp.getType();
389   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
390     if (shapedType.getNumElements() != 1)
391       return failure();
392     srcType = shapedType.getElementType();
393   }
394   if (!srcType.isIntOrIndexOrFloat())
395     return failure();
396 
397   Attribute cstAttr = constOp.getValue();
398   if (cstAttr.getType().isa<ShapedType>())
399     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
400 
401   Type dstType = getTypeConverter()->convertType(srcType);
402   if (!dstType)
403     return failure();
404 
405   // Floating-point types.
406   if (srcType.isa<FloatType>()) {
407     auto srcAttr = cstAttr.cast<FloatAttr>();
408     auto dstAttr = srcAttr;
409 
410     // Floating-point types not supported in the target environment are all
411     // converted to float type.
412     if (srcType != dstType) {
413       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
414       if (!dstAttr)
415         return failure();
416     }
417 
418     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419     return success();
420   }
421 
422   // Bool type.
423   if (srcType.isInteger(1)) {
424     // arith.constant can use 0/1 instead of true/false for i1 values. We need
425     // to handle that here.
426     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
427     if (!dstAttr)
428       return failure();
429     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
430     return success();
431   }
432 
433   // IndexType or IntegerType. Index values are converted to 32-bit integer
434   // values when converting to SPIR-V.
435   auto srcAttr = cstAttr.cast<IntegerAttr>();
436   auto dstAttr =
437       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
438   if (!dstAttr)
439     return failure();
440   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
441   return success();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // RemSIOpGLPattern
446 //===----------------------------------------------------------------------===//
447 
448 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
449 /// the sign of `signOperand`.
450 ///
451 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
452 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
453 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
454 /// if either operand can be negative. Emulate it via spv.UMod.
455 template <typename SignedAbsOp>
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)456 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
457                                     Value signOperand, OpBuilder &builder) {
458   assert(lhs.getType() == rhs.getType());
459   assert(lhs == signOperand || rhs == signOperand);
460 
461   Type type = lhs.getType();
462 
463   // Calculate the remainder with spv.UMod.
464   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
465   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
466   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
467 
468   // Fix the sign.
469   Value isPositive;
470   if (lhs == signOperand)
471     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
472   else
473     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
474   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
475   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
476 }
477 
478 LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const479 RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
480                                   ConversionPatternRewriter &rewriter) const {
481   Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
482       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
483       adaptor.getOperands()[0], rewriter);
484   rewriter.replaceOp(op, result);
485 
486   return success();
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // RemSIOpCLPattern
491 //===----------------------------------------------------------------------===//
492 
493 LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const494 RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
495                                   ConversionPatternRewriter &rewriter) const {
496   Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
497       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
498       adaptor.getOperands()[0], rewriter);
499   rewriter.replaceOp(op, result);
500 
501   return success();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // BitwiseOpPattern
506 //===----------------------------------------------------------------------===//
507 
508 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
509 LogicalResult
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const510 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
511     Op op, typename Op::Adaptor adaptor,
512     ConversionPatternRewriter &rewriter) const {
513   assert(adaptor.getOperands().size() == 2);
514   auto dstType =
515       this->getTypeConverter()->convertType(op.getResult().getType());
516   if (!dstType)
517     return failure();
518   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
519     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
520                                                          adaptor.getOperands());
521   } else {
522     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
523                                                          adaptor.getOperands());
524   }
525   return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // XOrIOpLogicalPattern
530 //===----------------------------------------------------------------------===//
531 
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const532 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
533     arith::XOrIOp op, OpAdaptor adaptor,
534     ConversionPatternRewriter &rewriter) const {
535   assert(adaptor.getOperands().size() == 2);
536 
537   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
538     return failure();
539 
540   auto dstType = getTypeConverter()->convertType(op.getType());
541   if (!dstType)
542     return failure();
543   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
544                                                    adaptor.getOperands());
545 
546   return success();
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOpBooleanPattern
551 //===----------------------------------------------------------------------===//
552 
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const553 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
554     arith::XOrIOp op, OpAdaptor adaptor,
555     ConversionPatternRewriter &rewriter) const {
556   assert(adaptor.getOperands().size() == 2);
557 
558   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
559     return failure();
560 
561   auto dstType = getTypeConverter()->convertType(op.getType());
562   if (!dstType)
563     return failure();
564   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
565                                                         adaptor.getOperands());
566   return success();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // UIToFPI1Pattern
571 //===----------------------------------------------------------------------===//
572 
573 LogicalResult
matchAndRewrite(arith::UIToFPOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const574 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
575                                  ConversionPatternRewriter &rewriter) const {
576   auto srcType = adaptor.getOperands().front().getType();
577   if (!isBoolScalarOrVector(srcType))
578     return failure();
579 
580   auto dstType =
581       this->getTypeConverter()->convertType(op.getResult().getType());
582   Location loc = op.getLoc();
583   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
584   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
585   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
586       op, dstType, adaptor.getOperands().front(), one, zero);
587   return success();
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // ExtUII1Pattern
592 //===----------------------------------------------------------------------===//
593 
594 LogicalResult
matchAndRewrite(arith::ExtUIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const595 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
596                                 ConversionPatternRewriter &rewriter) const {
597   auto srcType = adaptor.getOperands().front().getType();
598   if (!isBoolScalarOrVector(srcType))
599     return failure();
600 
601   auto dstType =
602       this->getTypeConverter()->convertType(op.getResult().getType());
603   Location loc = op.getLoc();
604   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
605   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
606   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
607       op, dstType, adaptor.getOperands().front(), one, zero);
608   return success();
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // TruncII1Pattern
613 //===----------------------------------------------------------------------===//
614 
615 LogicalResult
matchAndRewrite(arith::TruncIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const616 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
617                                  ConversionPatternRewriter &rewriter) const {
618   auto dstType =
619       this->getTypeConverter()->convertType(op.getResult().getType());
620   if (!isBoolScalarOrVector(dstType))
621     return failure();
622 
623   Location loc = op.getLoc();
624   auto srcType = adaptor.getOperands().front().getType();
625   // Check if (x & 1) == 1.
626   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
627   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
628       loc, srcType, adaptor.getOperands()[0], mask);
629   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
630 
631   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
632   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
633   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
634   return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // TypeCastingOpPattern
639 //===----------------------------------------------------------------------===//
640 
641 template <typename Op, typename SPIRVOp>
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const642 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
643     Op op, typename Op::Adaptor adaptor,
644     ConversionPatternRewriter &rewriter) const {
645   assert(adaptor.getOperands().size() == 1);
646   auto srcType = adaptor.getOperands().front().getType();
647   auto dstType =
648       this->getTypeConverter()->convertType(op.getResult().getType());
649   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
650     return failure();
651   if (dstType == srcType) {
652     // Due to type conversion, we are seeing the same source and target type.
653     // Then we can just erase this operation by forwarding its operand.
654     rewriter.replaceOp(op, adaptor.getOperands().front());
655   } else {
656     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
657                                                   adaptor.getOperands());
658   }
659   return success();
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // CmpIOpBooleanPattern
664 //===----------------------------------------------------------------------===//
665 
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const666 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
667     arith::CmpIOp op, OpAdaptor adaptor,
668     ConversionPatternRewriter &rewriter) const {
669   Type srcType = op.getLhs().getType();
670   if (!isBoolScalarOrVector(srcType))
671     return failure();
672   Type dstType = getTypeConverter()->convertType(srcType);
673   if (!dstType)
674     return failure();
675 
676   switch (op.getPredicate()) {
677   case arith::CmpIPredicate::eq: {
678     rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
679                                                        adaptor.getRhs());
680     return success();
681   }
682   case arith::CmpIPredicate::ne: {
683     rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
684                                                           adaptor.getRhs());
685     return success();
686   }
687   case arith::CmpIPredicate::uge:
688   case arith::CmpIPredicate::ugt:
689   case arith::CmpIPredicate::ule:
690   case arith::CmpIPredicate::ult: {
691     // There are no direct corresponding instructions in SPIR-V for such cases.
692     // Extend them to 32-bit and do comparision then.
693     Type type = rewriter.getI32Type();
694     if (auto vectorType = dstType.dyn_cast<VectorType>())
695       type = VectorType::get(vectorType.getShape(), type);
696     auto extLhs =
697         rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
698     auto extRhs =
699         rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
700 
701     rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
702                                                extRhs);
703     return success();
704   }
705   default:
706     break;
707   }
708   return failure();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // CmpIOpPattern
713 //===----------------------------------------------------------------------===//
714 
715 LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const716 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
717                                ConversionPatternRewriter &rewriter) const {
718   Type srcType = op.getLhs().getType();
719   if (isBoolScalarOrVector(srcType))
720     return failure();
721   Type dstType = getTypeConverter()->convertType(srcType);
722   if (!dstType)
723     return failure();
724 
725   switch (op.getPredicate()) {
726 #define DISPATCH(cmpPredicate, spirvOp)                                        \
727   case cmpPredicate:                                                           \
728     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
729         srcType != dstType && !hasSameBitwidth(srcType, dstType)) {            \
730       return op.emitError(                                                     \
731           "bitwidth emulation is not implemented yet on unsigned op");         \
732     }                                                                          \
733     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
734                                          adaptor.getRhs());                    \
735     return success();
736 
737     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
738     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
739     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
740     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
741     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
742     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
743     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
744     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
745     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
746     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
747 
748 #undef DISPATCH
749   }
750   return failure();
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // CmpFOpPattern
755 //===----------------------------------------------------------------------===//
756 
757 LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const758 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
759                                ConversionPatternRewriter &rewriter) const {
760   switch (op.getPredicate()) {
761 #define DISPATCH(cmpPredicate, spirvOp)                                        \
762   case cmpPredicate:                                                           \
763     rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(),                 \
764                                          adaptor.getRhs());                    \
765     return success();
766 
767     // Ordered.
768     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
769     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
770     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
771     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
772     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
773     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
774     // Unordered.
775     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
776     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
777     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
778     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
779     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
780     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
781 
782 #undef DISPATCH
783 
784   default:
785     break;
786   }
787   return failure();
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // CmpFOpNanKernelPattern
792 //===----------------------------------------------------------------------===//
793 
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const794 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
795     arith::CmpFOp op, OpAdaptor adaptor,
796     ConversionPatternRewriter &rewriter) const {
797   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
798     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
799                                                   adaptor.getRhs());
800     return success();
801   }
802 
803   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
804     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
805                                                     adaptor.getRhs());
806     return success();
807   }
808 
809   return failure();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // CmpFOpNanNonePattern
814 //===----------------------------------------------------------------------===//
815 
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const816 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
817     arith::CmpFOp op, OpAdaptor adaptor,
818     ConversionPatternRewriter &rewriter) const {
819   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
820       op.getPredicate() != arith::CmpFPredicate::UNO)
821     return failure();
822 
823   Location loc = op.getLoc();
824 
825   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
826   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
827 
828   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
829   if (op.getPredicate() == arith::CmpFPredicate::ORD)
830     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
831 
832   rewriter.replaceOp(op, replace);
833   return success();
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // SelectOpPattern
838 //===----------------------------------------------------------------------===//
839 
840 LogicalResult
matchAndRewrite(arith::SelectOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const841 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
842                                  ConversionPatternRewriter &rewriter) const {
843   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
844                                                adaptor.getTrueValue(),
845                                                adaptor.getFalseValue());
846   return success();
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // Pattern Population
851 //===----------------------------------------------------------------------===//
852 
populateArithmeticToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)853 void mlir::arith::populateArithmeticToSPIRVPatterns(
854     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
855   // clang-format off
856   patterns.add<
857     ConstantCompositeOpPattern,
858     ConstantScalarOpPattern,
859     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
860     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
861     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
862     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
863     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
864     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
865     RemSIOpGLPattern, RemSIOpCLPattern,
866     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
867     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
868     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
869     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
870     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
871     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
872     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
873     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
874     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
875     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
876     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
877     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
878     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
879     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
880     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
881     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
882     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
883     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
884     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
885     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
886     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
887     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
888     CmpIOpBooleanPattern, CmpIOpPattern,
889     CmpFOpNanNonePattern, CmpFOpPattern,
890     SelectOpPattern,
891 
892     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
893     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
894     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
895     spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLFMinOp>,
896     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
897     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
898   >(typeConverter, patterns.getContext());
899   // clang-format on
900 
901   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
902   // capability is available.
903   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
904                                        /*benefit=*/2);
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // Pass Definition
909 //===----------------------------------------------------------------------===//
910 
911 namespace {
912 struct ConvertArithmeticToSPIRVPass
913     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
runOnOperation__anon52142abf0311::ConvertArithmeticToSPIRVPass914   void runOnOperation() override {
915     auto module = getOperation();
916     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
917     auto target = SPIRVConversionTarget::get(targetAttr);
918 
919     SPIRVTypeConverter::Options options;
920     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
921     SPIRVTypeConverter typeConverter(targetAttr, options);
922 
923     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
924     // in patterns for other dialects.
925     auto addUnrealizedCast = [](OpBuilder &builder, Type type,
926                                 ValueRange inputs, Location loc) {
927       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
928       return Optional<Value>(cast.getResult(0));
929     };
930     typeConverter.addSourceMaterialization(addUnrealizedCast);
931     typeConverter.addTargetMaterialization(addUnrealizedCast);
932     target->addLegalOp<UnrealizedConversionCastOp>();
933 
934     RewritePatternSet patterns(&getContext());
935     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
936 
937     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
938       signalPassFailure();
939   }
940 };
941 } // namespace
942 
createConvertArithmeticToSPIRVPass()943 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
944   return std::make_unique<ConvertArithmeticToSPIRVPass>();
945 }
946