1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/Support/Debug.h"
19
20 #define DEBUG_TYPE "arith-to-spirv-pattern"
21
22 using namespace mlir;
23
24 //===----------------------------------------------------------------------===//
25 // Operation Conversion
26 //===----------------------------------------------------------------------===//
27
28 namespace {
29
30 /// Converts composite arith.constant operation to spv.Constant.
31 struct ConstantCompositeOpPattern final
32 : public OpConversionPattern<arith::ConstantOp> {
33 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
34
35 LogicalResult
36 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override;
38 };
39
40 /// Converts scalar arith.constant operation to spv.Constant.
41 struct ConstantScalarOpPattern final
42 : public OpConversionPattern<arith::ConstantOp> {
43 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
44
45 LogicalResult
46 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
47 ConversionPatternRewriter &rewriter) const override;
48 };
49
50 /// Converts arith.remsi to GLSL SPIR-V ops.
51 ///
52 /// This cannot be merged into the template unary/binary pattern due to Vulkan
53 /// restrictions over spv.SRem and spv.SMod.
54 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
55 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
56
57 LogicalResult
58 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override;
60 };
61
62 /// Converts arith.remsi to OpenCL SPIR-V ops.
63 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
64 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
65
66 LogicalResult
67 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override;
69 };
70
71 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
72 /// other than the BinaryOpPatternPattern because if the operands are boolean
73 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
74 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
75 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
76 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
77 using OpConversionPattern<Op>::OpConversionPattern;
78
79 LogicalResult
80 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
81 ConversionPatternRewriter &rewriter) const override;
82 };
83
84 /// Converts arith.xori to SPIR-V operations.
85 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
86 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
87
88 LogicalResult
89 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
90 ConversionPatternRewriter &rewriter) const override;
91 };
92
93 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
94 /// vector of i1.
95 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
96 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
97
98 LogicalResult
99 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override;
101 };
102
103 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
104 /// i1.
105 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
106 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
107
108 LogicalResult
109 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override;
111 };
112
113 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
114 /// i1.
115 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
116 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
117
118 LogicalResult
119 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override;
121 };
122
123 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
124 /// i1.
125 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
126 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
127
128 LogicalResult
129 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override;
131 };
132
133 /// Converts type-casting standard operations to SPIR-V operations.
134 template <typename Op, typename SPIRVOp>
135 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
136 using OpConversionPattern<Op>::OpConversionPattern;
137
138 LogicalResult
139 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
140 ConversionPatternRewriter &rewriter) const override;
141 };
142
143 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
144 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
145 public:
146 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
147
148 LogicalResult
149 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override;
151 };
152
153 /// Converts integer compare operation to SPIR-V ops.
154 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
155 public:
156 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
157
158 LogicalResult
159 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
160 ConversionPatternRewriter &rewriter) const override;
161 };
162
163 /// Converts floating-point comparison operations to SPIR-V ops.
164 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
165 public:
166 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
167
168 LogicalResult
169 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
170 ConversionPatternRewriter &rewriter) const override;
171 };
172
173 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
174 /// Kernel capability.
175 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
176 public:
177 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
178
179 LogicalResult
180 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override;
182 };
183
184 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
185 /// require additional capability.
186 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
187 public:
188 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
189
190 LogicalResult
191 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override;
193 };
194
195 /// Converts arith.select to spv.Select.
196 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
197 public:
198 using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
199 LogicalResult
200 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override;
202 };
203
204 } // namespace
205
206 //===----------------------------------------------------------------------===//
207 // Conversion Helpers
208 //===----------------------------------------------------------------------===//
209
210 /// Converts the given `srcAttr` into a boolean attribute if it holds an
211 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)212 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
213 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
214 return boolAttr;
215 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
216 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
217 return BoolAttr();
218 }
219
220 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
221 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)222 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
223 Builder builder) {
224 // If the source number uses less active bits than the target bitwidth, then
225 // it should be safe to convert.
226 if (srcAttr.getValue().isIntN(dstType.getWidth()))
227 return builder.getIntegerAttr(dstType, srcAttr.getInt());
228
229 // XXX: Try again by interpreting the source number as a signed value.
230 // Although integers in the standard dialect are signless, they can represent
231 // a signed number. It's the operation decides how to interpret. This is
232 // dangerous, but it seems there is no good way of handling this if we still
233 // want to change the bitwidth. Emit a message at least.
234 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
235 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
236 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
237 << dstAttr << "' for type '" << dstType << "'\n");
238 return dstAttr;
239 }
240
241 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
242 << "' illegal: cannot fit into target type '"
243 << dstType << "'\n");
244 return IntegerAttr();
245 }
246
247 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
248 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)249 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
250 Builder builder) {
251 // Only support converting to float for now.
252 if (!dstType.isF32())
253 return FloatAttr();
254
255 // Try to convert the source floating-point number to single precision.
256 APFloat dstVal = srcAttr.getValue();
257 bool losesInfo = false;
258 APFloat::opStatus status =
259 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
260 if (status != APFloat::opOK || losesInfo) {
261 LLVM_DEBUG(llvm::dbgs()
262 << srcAttr << " illegal: cannot fit into converted type '"
263 << dstType << "'\n");
264 return FloatAttr();
265 }
266
267 return builder.getF32FloatAttr(dstVal.convertToFloat());
268 }
269
270 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)271 static bool isBoolScalarOrVector(Type type) {
272 if (type.isInteger(1))
273 return true;
274 if (auto vecType = type.dyn_cast<VectorType>())
275 return vecType.getElementType().isInteger(1);
276 return false;
277 }
278
279 /// Returns true if scalar/vector type `a` and `b` have the same number of
280 /// bitwidth.
hasSameBitwidth(Type a,Type b)281 static bool hasSameBitwidth(Type a, Type b) {
282 auto getNumBitwidth = [](Type type) {
283 unsigned bw = 0;
284 if (type.isIntOrFloat())
285 bw = type.getIntOrFloatBitWidth();
286 else if (auto vecType = type.dyn_cast<VectorType>())
287 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
288 return bw;
289 };
290 unsigned aBW = getNumBitwidth(a);
291 unsigned bBW = getNumBitwidth(b);
292 return aBW != 0 && bBW != 0 && aBW == bBW;
293 }
294
295 //===----------------------------------------------------------------------===//
296 // ConstantOp with composite type
297 //===----------------------------------------------------------------------===//
298
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const299 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
300 arith::ConstantOp constOp, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const {
302 auto srcType = constOp.getType().dyn_cast<ShapedType>();
303 if (!srcType || srcType.getNumElements() == 1)
304 return failure();
305
306 // arith.constant should only have vector or tenor types.
307 assert((srcType.isa<VectorType, RankedTensorType>()));
308
309 auto dstType = getTypeConverter()->convertType(srcType);
310 if (!dstType)
311 return failure();
312
313 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
314 if (!dstElementsAttr)
315 return failure();
316
317 ShapedType dstAttrType = dstElementsAttr.getType();
318
319 // If the composite type has more than one dimensions, perform linearization.
320 if (srcType.getRank() > 1) {
321 if (srcType.isa<RankedTensorType>()) {
322 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
323 srcType.getElementType());
324 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
325 } else {
326 // TODO: add support for large vectors.
327 return failure();
328 }
329 }
330
331 Type srcElemType = srcType.getElementType();
332 Type dstElemType;
333 // Tensor types are converted to SPIR-V array types; vector types are
334 // converted to SPIR-V vector/array types.
335 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
336 dstElemType = arrayType.getElementType();
337 else
338 dstElemType = dstType.cast<VectorType>().getElementType();
339
340 // If the source and destination element types are different, perform
341 // attribute conversion.
342 if (srcElemType != dstElemType) {
343 SmallVector<Attribute, 8> elements;
344 if (srcElemType.isa<FloatType>()) {
345 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
346 FloatAttr dstAttr =
347 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
348 if (!dstAttr)
349 return failure();
350 elements.push_back(dstAttr);
351 }
352 } else if (srcElemType.isInteger(1)) {
353 return failure();
354 } else {
355 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
356 IntegerAttr dstAttr = convertIntegerAttr(
357 srcAttr, dstElemType.cast<IntegerType>(), rewriter);
358 if (!dstAttr)
359 return failure();
360 elements.push_back(dstAttr);
361 }
362 }
363
364 // Unfortunately, we cannot use dialect-specific types for element
365 // attributes; element attributes only works with builtin types. So we need
366 // to prepare another converted builtin types for the destination elements
367 // attribute.
368 if (dstAttrType.isa<RankedTensorType>())
369 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
370 else
371 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
372
373 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
374 }
375
376 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
377 dstElementsAttr);
378 return success();
379 }
380
381 //===----------------------------------------------------------------------===//
382 // ConstantOp with scalar type
383 //===----------------------------------------------------------------------===//
384
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const385 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386 arith::ConstantOp constOp, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const {
388 Type srcType = constOp.getType();
389 if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
390 if (shapedType.getNumElements() != 1)
391 return failure();
392 srcType = shapedType.getElementType();
393 }
394 if (!srcType.isIntOrIndexOrFloat())
395 return failure();
396
397 Attribute cstAttr = constOp.getValue();
398 if (cstAttr.getType().isa<ShapedType>())
399 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
400
401 Type dstType = getTypeConverter()->convertType(srcType);
402 if (!dstType)
403 return failure();
404
405 // Floating-point types.
406 if (srcType.isa<FloatType>()) {
407 auto srcAttr = cstAttr.cast<FloatAttr>();
408 auto dstAttr = srcAttr;
409
410 // Floating-point types not supported in the target environment are all
411 // converted to float type.
412 if (srcType != dstType) {
413 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
414 if (!dstAttr)
415 return failure();
416 }
417
418 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419 return success();
420 }
421
422 // Bool type.
423 if (srcType.isInteger(1)) {
424 // arith.constant can use 0/1 instead of true/false for i1 values. We need
425 // to handle that here.
426 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
427 if (!dstAttr)
428 return failure();
429 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
430 return success();
431 }
432
433 // IndexType or IntegerType. Index values are converted to 32-bit integer
434 // values when converting to SPIR-V.
435 auto srcAttr = cstAttr.cast<IntegerAttr>();
436 auto dstAttr =
437 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
438 if (!dstAttr)
439 return failure();
440 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
441 return success();
442 }
443
444 //===----------------------------------------------------------------------===//
445 // RemSIOpGLPattern
446 //===----------------------------------------------------------------------===//
447
448 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
449 /// the sign of `signOperand`.
450 ///
451 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
452 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
453 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
454 /// if either operand can be negative. Emulate it via spv.UMod.
455 template <typename SignedAbsOp>
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)456 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
457 Value signOperand, OpBuilder &builder) {
458 assert(lhs.getType() == rhs.getType());
459 assert(lhs == signOperand || rhs == signOperand);
460
461 Type type = lhs.getType();
462
463 // Calculate the remainder with spv.UMod.
464 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
465 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
466 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
467
468 // Fix the sign.
469 Value isPositive;
470 if (lhs == signOperand)
471 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
472 else
473 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
474 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
475 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
476 }
477
478 LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const479 RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
480 ConversionPatternRewriter &rewriter) const {
481 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
482 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
483 adaptor.getOperands()[0], rewriter);
484 rewriter.replaceOp(op, result);
485
486 return success();
487 }
488
489 //===----------------------------------------------------------------------===//
490 // RemSIOpCLPattern
491 //===----------------------------------------------------------------------===//
492
493 LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const494 RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter) const {
496 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
497 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
498 adaptor.getOperands()[0], rewriter);
499 rewriter.replaceOp(op, result);
500
501 return success();
502 }
503
504 //===----------------------------------------------------------------------===//
505 // BitwiseOpPattern
506 //===----------------------------------------------------------------------===//
507
508 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
509 LogicalResult
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const510 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
511 Op op, typename Op::Adaptor adaptor,
512 ConversionPatternRewriter &rewriter) const {
513 assert(adaptor.getOperands().size() == 2);
514 auto dstType =
515 this->getTypeConverter()->convertType(op.getResult().getType());
516 if (!dstType)
517 return failure();
518 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
519 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
520 adaptor.getOperands());
521 } else {
522 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
523 adaptor.getOperands());
524 }
525 return success();
526 }
527
528 //===----------------------------------------------------------------------===//
529 // XOrIOpLogicalPattern
530 //===----------------------------------------------------------------------===//
531
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const532 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
533 arith::XOrIOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter) const {
535 assert(adaptor.getOperands().size() == 2);
536
537 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
538 return failure();
539
540 auto dstType = getTypeConverter()->convertType(op.getType());
541 if (!dstType)
542 return failure();
543 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
544 adaptor.getOperands());
545
546 return success();
547 }
548
549 //===----------------------------------------------------------------------===//
550 // XOrIOpBooleanPattern
551 //===----------------------------------------------------------------------===//
552
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const553 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
554 arith::XOrIOp op, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const {
556 assert(adaptor.getOperands().size() == 2);
557
558 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
559 return failure();
560
561 auto dstType = getTypeConverter()->convertType(op.getType());
562 if (!dstType)
563 return failure();
564 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
565 adaptor.getOperands());
566 return success();
567 }
568
569 //===----------------------------------------------------------------------===//
570 // UIToFPI1Pattern
571 //===----------------------------------------------------------------------===//
572
573 LogicalResult
matchAndRewrite(arith::UIToFPOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const574 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter) const {
576 auto srcType = adaptor.getOperands().front().getType();
577 if (!isBoolScalarOrVector(srcType))
578 return failure();
579
580 auto dstType =
581 this->getTypeConverter()->convertType(op.getResult().getType());
582 Location loc = op.getLoc();
583 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
584 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
585 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
586 op, dstType, adaptor.getOperands().front(), one, zero);
587 return success();
588 }
589
590 //===----------------------------------------------------------------------===//
591 // ExtUII1Pattern
592 //===----------------------------------------------------------------------===//
593
594 LogicalResult
matchAndRewrite(arith::ExtUIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const595 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
596 ConversionPatternRewriter &rewriter) const {
597 auto srcType = adaptor.getOperands().front().getType();
598 if (!isBoolScalarOrVector(srcType))
599 return failure();
600
601 auto dstType =
602 this->getTypeConverter()->convertType(op.getResult().getType());
603 Location loc = op.getLoc();
604 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
605 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
606 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
607 op, dstType, adaptor.getOperands().front(), one, zero);
608 return success();
609 }
610
611 //===----------------------------------------------------------------------===//
612 // TruncII1Pattern
613 //===----------------------------------------------------------------------===//
614
615 LogicalResult
matchAndRewrite(arith::TruncIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const616 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
617 ConversionPatternRewriter &rewriter) const {
618 auto dstType =
619 this->getTypeConverter()->convertType(op.getResult().getType());
620 if (!isBoolScalarOrVector(dstType))
621 return failure();
622
623 Location loc = op.getLoc();
624 auto srcType = adaptor.getOperands().front().getType();
625 // Check if (x & 1) == 1.
626 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
627 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
628 loc, srcType, adaptor.getOperands()[0], mask);
629 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
630
631 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
632 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
633 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
634 return success();
635 }
636
637 //===----------------------------------------------------------------------===//
638 // TypeCastingOpPattern
639 //===----------------------------------------------------------------------===//
640
641 template <typename Op, typename SPIRVOp>
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const642 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
643 Op op, typename Op::Adaptor adaptor,
644 ConversionPatternRewriter &rewriter) const {
645 assert(adaptor.getOperands().size() == 1);
646 auto srcType = adaptor.getOperands().front().getType();
647 auto dstType =
648 this->getTypeConverter()->convertType(op.getResult().getType());
649 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
650 return failure();
651 if (dstType == srcType) {
652 // Due to type conversion, we are seeing the same source and target type.
653 // Then we can just erase this operation by forwarding its operand.
654 rewriter.replaceOp(op, adaptor.getOperands().front());
655 } else {
656 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
657 adaptor.getOperands());
658 }
659 return success();
660 }
661
662 //===----------------------------------------------------------------------===//
663 // CmpIOpBooleanPattern
664 //===----------------------------------------------------------------------===//
665
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const666 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
667 arith::CmpIOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const {
669 Type srcType = op.getLhs().getType();
670 if (!isBoolScalarOrVector(srcType))
671 return failure();
672 Type dstType = getTypeConverter()->convertType(srcType);
673 if (!dstType)
674 return failure();
675
676 switch (op.getPredicate()) {
677 case arith::CmpIPredicate::eq: {
678 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
679 adaptor.getRhs());
680 return success();
681 }
682 case arith::CmpIPredicate::ne: {
683 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
684 adaptor.getRhs());
685 return success();
686 }
687 case arith::CmpIPredicate::uge:
688 case arith::CmpIPredicate::ugt:
689 case arith::CmpIPredicate::ule:
690 case arith::CmpIPredicate::ult: {
691 // There are no direct corresponding instructions in SPIR-V for such cases.
692 // Extend them to 32-bit and do comparision then.
693 Type type = rewriter.getI32Type();
694 if (auto vectorType = dstType.dyn_cast<VectorType>())
695 type = VectorType::get(vectorType.getShape(), type);
696 auto extLhs =
697 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
698 auto extRhs =
699 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
700
701 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
702 extRhs);
703 return success();
704 }
705 default:
706 break;
707 }
708 return failure();
709 }
710
711 //===----------------------------------------------------------------------===//
712 // CmpIOpPattern
713 //===----------------------------------------------------------------------===//
714
715 LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const716 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const {
718 Type srcType = op.getLhs().getType();
719 if (isBoolScalarOrVector(srcType))
720 return failure();
721 Type dstType = getTypeConverter()->convertType(srcType);
722 if (!dstType)
723 return failure();
724
725 switch (op.getPredicate()) {
726 #define DISPATCH(cmpPredicate, spirvOp) \
727 case cmpPredicate: \
728 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
729 srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \
730 return op.emitError( \
731 "bitwidth emulation is not implemented yet on unsigned op"); \
732 } \
733 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
734 adaptor.getRhs()); \
735 return success();
736
737 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
738 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
739 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
740 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
741 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
742 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
743 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
744 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
745 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
746 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
747
748 #undef DISPATCH
749 }
750 return failure();
751 }
752
753 //===----------------------------------------------------------------------===//
754 // CmpFOpPattern
755 //===----------------------------------------------------------------------===//
756
757 LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const758 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const {
760 switch (op.getPredicate()) {
761 #define DISPATCH(cmpPredicate, spirvOp) \
762 case cmpPredicate: \
763 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
764 adaptor.getRhs()); \
765 return success();
766
767 // Ordered.
768 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
769 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
770 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
771 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
772 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
773 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
774 // Unordered.
775 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
776 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
777 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
778 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
779 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
780 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
781
782 #undef DISPATCH
783
784 default:
785 break;
786 }
787 return failure();
788 }
789
790 //===----------------------------------------------------------------------===//
791 // CmpFOpNanKernelPattern
792 //===----------------------------------------------------------------------===//
793
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const794 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
795 arith::CmpFOp op, OpAdaptor adaptor,
796 ConversionPatternRewriter &rewriter) const {
797 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
798 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
799 adaptor.getRhs());
800 return success();
801 }
802
803 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
804 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
805 adaptor.getRhs());
806 return success();
807 }
808
809 return failure();
810 }
811
812 //===----------------------------------------------------------------------===//
813 // CmpFOpNanNonePattern
814 //===----------------------------------------------------------------------===//
815
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const816 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
817 arith::CmpFOp op, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
820 op.getPredicate() != arith::CmpFPredicate::UNO)
821 return failure();
822
823 Location loc = op.getLoc();
824
825 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
826 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
827
828 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
829 if (op.getPredicate() == arith::CmpFPredicate::ORD)
830 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
831
832 rewriter.replaceOp(op, replace);
833 return success();
834 }
835
836 //===----------------------------------------------------------------------===//
837 // SelectOpPattern
838 //===----------------------------------------------------------------------===//
839
840 LogicalResult
matchAndRewrite(arith::SelectOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const841 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
842 ConversionPatternRewriter &rewriter) const {
843 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
844 adaptor.getTrueValue(),
845 adaptor.getFalseValue());
846 return success();
847 }
848
849 //===----------------------------------------------------------------------===//
850 // Pattern Population
851 //===----------------------------------------------------------------------===//
852
populateArithmeticToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)853 void mlir::arith::populateArithmeticToSPIRVPatterns(
854 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
855 // clang-format off
856 patterns.add<
857 ConstantCompositeOpPattern,
858 ConstantScalarOpPattern,
859 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
860 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
861 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
862 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
863 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
864 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
865 RemSIOpGLPattern, RemSIOpCLPattern,
866 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
867 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
868 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
869 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
870 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
871 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
872 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
873 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
874 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
875 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
876 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
877 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
878 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
879 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
880 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
881 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
882 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
883 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
884 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
885 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
886 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
887 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
888 CmpIOpBooleanPattern, CmpIOpPattern,
889 CmpFOpNanNonePattern, CmpFOpPattern,
890 SelectOpPattern,
891
892 spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
893 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
894 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
895 spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLFMinOp>,
896 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
897 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
898 >(typeConverter, patterns.getContext());
899 // clang-format on
900
901 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
902 // capability is available.
903 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
904 /*benefit=*/2);
905 }
906
907 //===----------------------------------------------------------------------===//
908 // Pass Definition
909 //===----------------------------------------------------------------------===//
910
911 namespace {
912 struct ConvertArithmeticToSPIRVPass
913 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
runOnOperation__anon52142abf0311::ConvertArithmeticToSPIRVPass914 void runOnOperation() override {
915 auto module = getOperation();
916 auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
917 auto target = SPIRVConversionTarget::get(targetAttr);
918
919 SPIRVTypeConverter::Options options;
920 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
921 SPIRVTypeConverter typeConverter(targetAttr, options);
922
923 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
924 // in patterns for other dialects.
925 auto addUnrealizedCast = [](OpBuilder &builder, Type type,
926 ValueRange inputs, Location loc) {
927 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
928 return Optional<Value>(cast.getResult(0));
929 };
930 typeConverter.addSourceMaterialization(addUnrealizedCast);
931 typeConverter.addTargetMaterialization(addUnrealizedCast);
932 target->addLegalOp<UnrealizedConversionCastOp>();
933
934 RewritePatternSet patterns(&getContext());
935 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
936
937 if (failed(applyPartialConversion(module, *target, std::move(patterns))))
938 signalPassFailure();
939 }
940 };
941 } // namespace
942
createConvertArithmeticToSPIRVPass()943 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
944 return std::make_unique<ConvertArithmeticToSPIRVPass>();
945 }
946