1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "arith-to-spirv-pattern" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation Conversion 24 //===----------------------------------------------------------------------===// 25 26 namespace { 27 28 /// Converts composite arith.constant operation to spv.Constant. 29 struct ConstantCompositeOpPattern final 30 : public OpConversionPattern<arith::ConstantOp> { 31 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 32 33 LogicalResult 34 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override; 36 }; 37 38 /// Converts scalar arith.constant operation to spv.Constant. 39 struct ConstantScalarOpPattern final 40 : public OpConversionPattern<arith::ConstantOp> { 41 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override; 46 }; 47 48 /// Converts arith.remsi to GLSL SPIR-V ops. 49 /// 50 /// This cannot be merged into the template unary/binary pattern due to Vulkan 51 /// restrictions over spv.SRem and spv.SMod. 52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 53 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 57 ConversionPatternRewriter &rewriter) const override; 58 }; 59 60 /// Converts arith.remsi to OpenCL SPIR-V ops. 61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 62 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override; 67 }; 68 69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 70 /// other than the BinaryOpPatternPattern because if the operands are boolean 71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 74 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 75 using OpConversionPattern<Op>::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override; 80 }; 81 82 /// Converts arith.xori to SPIR-V operations. 83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 84 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override; 89 }; 90 91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 92 /// vector of i1. 93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 94 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override; 99 }; 100 101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 102 /// i1. 103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 104 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 112 /// i1. 113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 114 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 115 116 LogicalResult 117 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 118 ConversionPatternRewriter &rewriter) const override; 119 }; 120 121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 122 /// i1. 123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 124 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 128 ConversionPatternRewriter &rewriter) const override; 129 }; 130 131 /// Converts type-casting standard operations to SPIR-V operations. 132 template <typename Op, typename SPIRVOp> 133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 134 using OpConversionPattern<Op>::OpConversionPattern; 135 136 LogicalResult 137 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 138 ConversionPatternRewriter &rewriter) const override; 139 }; 140 141 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 143 public: 144 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 145 146 LogicalResult 147 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override; 149 }; 150 151 /// Converts integer compare operation to SPIR-V ops. 152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 153 public: 154 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override; 159 }; 160 161 /// Converts floating-point comparison operations to SPIR-V ops. 162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 163 public: 164 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 165 166 LogicalResult 167 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override; 169 }; 170 171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 172 /// Kernel capability. 173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 174 public: 175 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 176 177 LogicalResult 178 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override; 180 }; 181 182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 183 /// require additional capability. 184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 185 public: 186 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 187 188 LogicalResult 189 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 190 ConversionPatternRewriter &rewriter) const override; 191 }; 192 193 } // namespace 194 195 //===----------------------------------------------------------------------===// 196 // Conversion Helpers 197 //===----------------------------------------------------------------------===// 198 199 /// Converts the given `srcAttr` into a boolean attribute if it holds an 200 /// integral value. Returns null attribute if conversion fails. 201 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 202 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 203 return boolAttr; 204 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 205 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 206 return BoolAttr(); 207 } 208 209 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 210 /// Returns null attribute if conversion fails. 211 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 212 Builder builder) { 213 // If the source number uses less active bits than the target bitwidth, then 214 // it should be safe to convert. 215 if (srcAttr.getValue().isIntN(dstType.getWidth())) 216 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 217 218 // XXX: Try again by interpreting the source number as a signed value. 219 // Although integers in the standard dialect are signless, they can represent 220 // a signed number. It's the operation decides how to interpret. This is 221 // dangerous, but it seems there is no good way of handling this if we still 222 // want to change the bitwidth. Emit a message at least. 223 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 224 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 225 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 226 << dstAttr << "' for type '" << dstType << "'\n"); 227 return dstAttr; 228 } 229 230 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 231 << "' illegal: cannot fit into target type '" 232 << dstType << "'\n"); 233 return IntegerAttr(); 234 } 235 236 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 237 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 238 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 239 Builder builder) { 240 // Only support converting to float for now. 241 if (!dstType.isF32()) 242 return FloatAttr(); 243 244 // Try to convert the source floating-point number to single precision. 245 APFloat dstVal = srcAttr.getValue(); 246 bool losesInfo = false; 247 APFloat::opStatus status = 248 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 249 if (status != APFloat::opOK || losesInfo) { 250 LLVM_DEBUG(llvm::dbgs() 251 << srcAttr << " illegal: cannot fit into converted type '" 252 << dstType << "'\n"); 253 return FloatAttr(); 254 } 255 256 return builder.getF32FloatAttr(dstVal.convertToFloat()); 257 } 258 259 /// Returns true if the given `type` is a boolean scalar or vector type. 260 static bool isBoolScalarOrVector(Type type) { 261 if (type.isInteger(1)) 262 return true; 263 if (auto vecType = type.dyn_cast<VectorType>()) 264 return vecType.getElementType().isInteger(1); 265 return false; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // ConstantOp with composite type 270 //===----------------------------------------------------------------------===// 271 272 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 273 arith::ConstantOp constOp, OpAdaptor adaptor, 274 ConversionPatternRewriter &rewriter) const { 275 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 276 if (!srcType || srcType.getNumElements() == 1) 277 return failure(); 278 279 // arith.constant should only have vector or tenor types. 280 assert((srcType.isa<VectorType, RankedTensorType>())); 281 282 auto dstType = getTypeConverter()->convertType(srcType); 283 if (!dstType) 284 return failure(); 285 286 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 287 ShapedType dstAttrType = dstElementsAttr.getType(); 288 if (!dstElementsAttr) 289 return failure(); 290 291 // If the composite type has more than one dimensions, perform linearization. 292 if (srcType.getRank() > 1) { 293 if (srcType.isa<RankedTensorType>()) { 294 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 295 srcType.getElementType()); 296 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 297 } else { 298 // TODO: add support for large vectors. 299 return failure(); 300 } 301 } 302 303 Type srcElemType = srcType.getElementType(); 304 Type dstElemType; 305 // Tensor types are converted to SPIR-V array types; vector types are 306 // converted to SPIR-V vector/array types. 307 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 308 dstElemType = arrayType.getElementType(); 309 else 310 dstElemType = dstType.cast<VectorType>().getElementType(); 311 312 // If the source and destination element types are different, perform 313 // attribute conversion. 314 if (srcElemType != dstElemType) { 315 SmallVector<Attribute, 8> elements; 316 if (srcElemType.isa<FloatType>()) { 317 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 318 FloatAttr dstAttr = 319 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 320 if (!dstAttr) 321 return failure(); 322 elements.push_back(dstAttr); 323 } 324 } else if (srcElemType.isInteger(1)) { 325 return failure(); 326 } else { 327 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 328 IntegerAttr dstAttr = convertIntegerAttr( 329 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 330 if (!dstAttr) 331 return failure(); 332 elements.push_back(dstAttr); 333 } 334 } 335 336 // Unfortunately, we cannot use dialect-specific types for element 337 // attributes; element attributes only works with builtin types. So we need 338 // to prepare another converted builtin types for the destination elements 339 // attribute. 340 if (dstAttrType.isa<RankedTensorType>()) 341 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 342 else 343 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 344 345 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 346 } 347 348 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 349 dstElementsAttr); 350 return success(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // ConstantOp with scalar type 355 //===----------------------------------------------------------------------===// 356 357 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 358 arith::ConstantOp constOp, OpAdaptor adaptor, 359 ConversionPatternRewriter &rewriter) const { 360 Type srcType = constOp.getType(); 361 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 362 if (shapedType.getNumElements() != 1) 363 return failure(); 364 srcType = shapedType.getElementType(); 365 } 366 if (!srcType.isIntOrIndexOrFloat()) 367 return failure(); 368 369 Attribute cstAttr = constOp.getValue(); 370 if (cstAttr.getType().isa<ShapedType>()) 371 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 372 373 Type dstType = getTypeConverter()->convertType(srcType); 374 if (!dstType) 375 return failure(); 376 377 // Floating-point types. 378 if (srcType.isa<FloatType>()) { 379 auto srcAttr = cstAttr.cast<FloatAttr>(); 380 auto dstAttr = srcAttr; 381 382 // Floating-point types not supported in the target environment are all 383 // converted to float type. 384 if (srcType != dstType) { 385 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 386 if (!dstAttr) 387 return failure(); 388 } 389 390 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 391 return success(); 392 } 393 394 // Bool type. 395 if (srcType.isInteger(1)) { 396 // arith.constant can use 0/1 instead of true/false for i1 values. We need 397 // to handle that here. 398 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 399 if (!dstAttr) 400 return failure(); 401 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 402 return success(); 403 } 404 405 // IndexType or IntegerType. Index values are converted to 32-bit integer 406 // values when converting to SPIR-V. 407 auto srcAttr = cstAttr.cast<IntegerAttr>(); 408 auto dstAttr = 409 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 410 if (!dstAttr) 411 return failure(); 412 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 413 return success(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // RemSIOpGLSLPattern 418 //===----------------------------------------------------------------------===// 419 420 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 421 /// the sign of `signOperand`. 422 /// 423 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 424 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 425 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 426 /// if either operand can be negative. Emulate it via spv.UMod. 427 template <typename SignedAbsOp> 428 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 429 Value signOperand, OpBuilder &builder) { 430 assert(lhs.getType() == rhs.getType()); 431 assert(lhs == signOperand || rhs == signOperand); 432 433 Type type = lhs.getType(); 434 435 // Calculate the remainder with spv.UMod. 436 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 437 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 438 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 439 440 // Fix the sign. 441 Value isPositive; 442 if (lhs == signOperand) 443 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 444 else 445 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 446 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 447 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 448 } 449 450 LogicalResult 451 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 452 ConversionPatternRewriter &rewriter) const { 453 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 454 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 455 adaptor.getOperands()[0], rewriter); 456 rewriter.replaceOp(op, result); 457 458 return success(); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // RemSIOpOCLPattern 463 //===----------------------------------------------------------------------===// 464 465 LogicalResult 466 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 467 ConversionPatternRewriter &rewriter) const { 468 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 469 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 470 adaptor.getOperands()[0], rewriter); 471 rewriter.replaceOp(op, result); 472 473 return success(); 474 } 475 476 //===----------------------------------------------------------------------===// 477 // BitwiseOpPattern 478 //===----------------------------------------------------------------------===// 479 480 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 481 LogicalResult 482 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 483 Op op, typename Op::Adaptor adaptor, 484 ConversionPatternRewriter &rewriter) const { 485 assert(adaptor.getOperands().size() == 2); 486 auto dstType = 487 this->getTypeConverter()->convertType(op.getResult().getType()); 488 if (!dstType) 489 return failure(); 490 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 491 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 492 adaptor.getOperands()); 493 } else { 494 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 495 adaptor.getOperands()); 496 } 497 return success(); 498 } 499 500 //===----------------------------------------------------------------------===// 501 // XOrIOpLogicalPattern 502 //===----------------------------------------------------------------------===// 503 504 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 505 arith::XOrIOp op, OpAdaptor adaptor, 506 ConversionPatternRewriter &rewriter) const { 507 assert(adaptor.getOperands().size() == 2); 508 509 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 510 return failure(); 511 512 auto dstType = getTypeConverter()->convertType(op.getType()); 513 if (!dstType) 514 return failure(); 515 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 516 adaptor.getOperands()); 517 518 return success(); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // XOrIOpBooleanPattern 523 //===----------------------------------------------------------------------===// 524 525 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 526 arith::XOrIOp op, OpAdaptor adaptor, 527 ConversionPatternRewriter &rewriter) const { 528 assert(adaptor.getOperands().size() == 2); 529 530 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 531 return failure(); 532 533 auto dstType = getTypeConverter()->convertType(op.getType()); 534 if (!dstType) 535 return failure(); 536 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 537 adaptor.getOperands()); 538 return success(); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // UIToFPI1Pattern 543 //===----------------------------------------------------------------------===// 544 545 LogicalResult 546 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 547 ConversionPatternRewriter &rewriter) const { 548 auto srcType = adaptor.getOperands().front().getType(); 549 if (!isBoolScalarOrVector(srcType)) 550 return failure(); 551 552 auto dstType = 553 this->getTypeConverter()->convertType(op.getResult().getType()); 554 Location loc = op.getLoc(); 555 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 556 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 557 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 558 op, dstType, adaptor.getOperands().front(), one, zero); 559 return success(); 560 } 561 562 //===----------------------------------------------------------------------===// 563 // ExtUII1Pattern 564 //===----------------------------------------------------------------------===// 565 566 LogicalResult 567 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 568 ConversionPatternRewriter &rewriter) const { 569 auto srcType = adaptor.getOperands().front().getType(); 570 if (!isBoolScalarOrVector(srcType)) 571 return failure(); 572 573 auto dstType = 574 this->getTypeConverter()->convertType(op.getResult().getType()); 575 Location loc = op.getLoc(); 576 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 577 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 578 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 579 op, dstType, adaptor.getOperands().front(), one, zero); 580 return success(); 581 } 582 583 //===----------------------------------------------------------------------===// 584 // TruncII1Pattern 585 //===----------------------------------------------------------------------===// 586 587 LogicalResult 588 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 589 ConversionPatternRewriter &rewriter) const { 590 auto dstType = 591 this->getTypeConverter()->convertType(op.getResult().getType()); 592 if (!isBoolScalarOrVector(dstType)) 593 return failure(); 594 595 Location loc = op.getLoc(); 596 auto srcType = adaptor.getOperands().front().getType(); 597 // Check if (x & 1) == 1. 598 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 599 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 600 loc, srcType, adaptor.getOperands()[0], mask); 601 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 602 603 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 604 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 605 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 606 return success(); 607 } 608 609 //===----------------------------------------------------------------------===// 610 // TypeCastingOpPattern 611 //===----------------------------------------------------------------------===// 612 613 template <typename Op, typename SPIRVOp> 614 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 615 Op op, typename Op::Adaptor adaptor, 616 ConversionPatternRewriter &rewriter) const { 617 assert(adaptor.getOperands().size() == 1); 618 auto srcType = adaptor.getOperands().front().getType(); 619 auto dstType = 620 this->getTypeConverter()->convertType(op.getResult().getType()); 621 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 622 return failure(); 623 if (dstType == srcType) { 624 // Due to type conversion, we are seeing the same source and target type. 625 // Then we can just erase this operation by forwarding its operand. 626 rewriter.replaceOp(op, adaptor.getOperands().front()); 627 } else { 628 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 629 adaptor.getOperands()); 630 } 631 return success(); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // CmpIOpBooleanPattern 636 //===----------------------------------------------------------------------===// 637 638 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 639 arith::CmpIOp op, OpAdaptor adaptor, 640 ConversionPatternRewriter &rewriter) const { 641 Type operandType = op.getLhs().getType(); 642 if (!isBoolScalarOrVector(operandType)) 643 return failure(); 644 645 switch (op.getPredicate()) { 646 #define DISPATCH(cmpPredicate, spirvOp) \ 647 case cmpPredicate: \ 648 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 649 adaptor.getLhs(), adaptor.getRhs()); \ 650 return success(); 651 652 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 653 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 654 655 #undef DISPATCH 656 default:; 657 } 658 return failure(); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // CmpIOpPattern 663 //===----------------------------------------------------------------------===// 664 665 LogicalResult 666 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 667 ConversionPatternRewriter &rewriter) const { 668 Type operandType = op.getLhs().getType(); 669 if (isBoolScalarOrVector(operandType)) 670 return failure(); 671 672 switch (op.getPredicate()) { 673 #define DISPATCH(cmpPredicate, spirvOp) \ 674 case cmpPredicate: \ 675 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 676 operandType != this->getTypeConverter()->convertType(operandType)) { \ 677 return op.emitError( \ 678 "bitwidth emulation is not implemented yet on unsigned op"); \ 679 } \ 680 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 681 adaptor.getLhs(), adaptor.getRhs()); \ 682 return success(); 683 684 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 685 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 686 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 687 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 688 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 689 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 690 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 691 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 692 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 693 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 694 695 #undef DISPATCH 696 } 697 return failure(); 698 } 699 700 //===----------------------------------------------------------------------===// 701 // CmpFOpPattern 702 //===----------------------------------------------------------------------===// 703 704 LogicalResult 705 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 706 ConversionPatternRewriter &rewriter) const { 707 switch (op.getPredicate()) { 708 #define DISPATCH(cmpPredicate, spirvOp) \ 709 case cmpPredicate: \ 710 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 711 adaptor.getLhs(), adaptor.getRhs()); \ 712 return success(); 713 714 // Ordered. 715 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 716 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 717 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 718 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 719 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 720 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 721 // Unordered. 722 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 723 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 724 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 725 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 726 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 727 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 728 729 #undef DISPATCH 730 731 default: 732 break; 733 } 734 return failure(); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // CmpFOpNanKernelPattern 739 //===----------------------------------------------------------------------===// 740 741 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 742 arith::CmpFOp op, OpAdaptor adaptor, 743 ConversionPatternRewriter &rewriter) const { 744 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 745 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 746 adaptor.getRhs()); 747 return success(); 748 } 749 750 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 751 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 752 adaptor.getRhs()); 753 return success(); 754 } 755 756 return failure(); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // CmpFOpNanNonePattern 761 //===----------------------------------------------------------------------===// 762 763 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 764 arith::CmpFOp op, OpAdaptor adaptor, 765 ConversionPatternRewriter &rewriter) const { 766 if (op.getPredicate() != arith::CmpFPredicate::ORD && 767 op.getPredicate() != arith::CmpFPredicate::UNO) 768 return failure(); 769 770 Location loc = op.getLoc(); 771 772 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 773 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 774 775 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 776 if (op.getPredicate() == arith::CmpFPredicate::ORD) 777 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 778 779 rewriter.replaceOp(op, replace); 780 return success(); 781 } 782 783 //===----------------------------------------------------------------------===// 784 // Pattern Population 785 //===----------------------------------------------------------------------===// 786 787 void mlir::arith::populateArithmeticToSPIRVPatterns( 788 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 789 // clang-format off 790 patterns.add< 791 ConstantCompositeOpPattern, 792 ConstantScalarOpPattern, 793 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 794 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 795 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 796 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 797 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 798 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 799 RemSIOpGLSLPattern, RemSIOpOCLPattern, 800 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 801 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 802 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 803 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 804 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 805 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 806 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 807 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 808 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 809 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 810 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 811 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 812 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 813 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 814 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 815 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 816 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 817 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 818 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 819 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 820 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 821 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 822 CmpIOpBooleanPattern, CmpIOpPattern, 823 CmpFOpNanNonePattern, CmpFOpPattern 824 >(typeConverter, patterns.getContext()); 825 // clang-format on 826 827 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 828 // capability is available. 829 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 830 /*benefit=*/2); 831 } 832 833 //===----------------------------------------------------------------------===// 834 // Pass Definition 835 //===----------------------------------------------------------------------===// 836 837 namespace { 838 struct ConvertArithmeticToSPIRVPass 839 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 840 void runOnOperation() override { 841 auto module = getOperation()->getParentOfType<ModuleOp>(); 842 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 843 auto target = SPIRVConversionTarget::get(targetAttr); 844 845 SPIRVTypeConverter::Options options; 846 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 847 SPIRVTypeConverter typeConverter(targetAttr, options); 848 849 RewritePatternSet patterns(&getContext()); 850 mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 851 852 if (failed(applyPartialConversion(getOperation(), *target, 853 std::move(patterns)))) 854 signalPassFailure(); 855 } 856 }; 857 } // namespace 858 859 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 860 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 861 } 862