1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "arith-to-spirv-pattern" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation Conversion 24 //===----------------------------------------------------------------------===// 25 26 namespace { 27 28 /// Converts composite arith.constant operation to spv.Constant. 29 struct ConstantCompositeOpPattern final 30 : public OpConversionPattern<arith::ConstantOp> { 31 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 32 33 LogicalResult 34 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override; 36 }; 37 38 /// Converts scalar arith.constant operation to spv.Constant. 39 struct ConstantScalarOpPattern final 40 : public OpConversionPattern<arith::ConstantOp> { 41 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override; 46 }; 47 48 /// Converts arith.remsi to GLSL SPIR-V ops. 49 /// 50 /// This cannot be merged into the template unary/binary pattern due to Vulkan 51 /// restrictions over spv.SRem and spv.SMod. 52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 53 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 57 ConversionPatternRewriter &rewriter) const override; 58 }; 59 60 /// Converts arith.remsi to OpenCL SPIR-V ops. 61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 62 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override; 67 }; 68 69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 70 /// other than the BinaryOpPatternPattern because if the operands are boolean 71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 74 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 75 using OpConversionPattern<Op>::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override; 80 }; 81 82 /// Converts arith.xori to SPIR-V operations. 83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 84 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override; 89 }; 90 91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 92 /// vector of i1. 93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 94 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override; 99 }; 100 101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 102 /// i1. 103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 104 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 112 /// i1. 113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 114 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 115 116 LogicalResult 117 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 118 ConversionPatternRewriter &rewriter) const override; 119 }; 120 121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 122 /// i1. 123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 124 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 128 ConversionPatternRewriter &rewriter) const override; 129 }; 130 131 /// Converts type-casting standard operations to SPIR-V operations. 132 template <typename Op, typename SPIRVOp> 133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 134 using OpConversionPattern<Op>::OpConversionPattern; 135 136 LogicalResult 137 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 138 ConversionPatternRewriter &rewriter) const override; 139 }; 140 141 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 143 public: 144 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 145 146 LogicalResult 147 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override; 149 }; 150 151 /// Converts integer compare operation to SPIR-V ops. 152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 153 public: 154 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override; 159 }; 160 161 /// Converts floating-point comparison operations to SPIR-V ops. 162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 163 public: 164 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 165 166 LogicalResult 167 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override; 169 }; 170 171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 172 /// Kernel capability. 173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 174 public: 175 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 176 177 LogicalResult 178 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override; 180 }; 181 182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 183 /// require additional capability. 184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 185 public: 186 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 187 188 LogicalResult 189 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 190 ConversionPatternRewriter &rewriter) const override; 191 }; 192 193 } // namespace 194 195 //===----------------------------------------------------------------------===// 196 // Conversion Helpers 197 //===----------------------------------------------------------------------===// 198 199 /// Converts the given `srcAttr` into a boolean attribute if it holds an 200 /// integral value. Returns null attribute if conversion fails. 201 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 202 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 203 return boolAttr; 204 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 205 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 206 return BoolAttr(); 207 } 208 209 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 210 /// Returns null attribute if conversion fails. 211 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 212 Builder builder) { 213 // If the source number uses less active bits than the target bitwidth, then 214 // it should be safe to convert. 215 if (srcAttr.getValue().isIntN(dstType.getWidth())) 216 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 217 218 // XXX: Try again by interpreting the source number as a signed value. 219 // Although integers in the standard dialect are signless, they can represent 220 // a signed number. It's the operation decides how to interpret. This is 221 // dangerous, but it seems there is no good way of handling this if we still 222 // want to change the bitwidth. Emit a message at least. 223 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 224 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 225 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 226 << dstAttr << "' for type '" << dstType << "'\n"); 227 return dstAttr; 228 } 229 230 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 231 << "' illegal: cannot fit into target type '" 232 << dstType << "'\n"); 233 return IntegerAttr(); 234 } 235 236 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 237 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 238 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 239 Builder builder) { 240 // Only support converting to float for now. 241 if (!dstType.isF32()) 242 return FloatAttr(); 243 244 // Try to convert the source floating-point number to single precision. 245 APFloat dstVal = srcAttr.getValue(); 246 bool losesInfo = false; 247 APFloat::opStatus status = 248 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 249 if (status != APFloat::opOK || losesInfo) { 250 LLVM_DEBUG(llvm::dbgs() 251 << srcAttr << " illegal: cannot fit into converted type '" 252 << dstType << "'\n"); 253 return FloatAttr(); 254 } 255 256 return builder.getF32FloatAttr(dstVal.convertToFloat()); 257 } 258 259 /// Returns true if the given `type` is a boolean scalar or vector type. 260 static bool isBoolScalarOrVector(Type type) { 261 if (type.isInteger(1)) 262 return true; 263 if (auto vecType = type.dyn_cast<VectorType>()) 264 return vecType.getElementType().isInteger(1); 265 return false; 266 } 267 268 //===----------------------------------------------------------------------===// 269 // ConstantOp with composite type 270 //===----------------------------------------------------------------------===// 271 272 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 273 arith::ConstantOp constOp, OpAdaptor adaptor, 274 ConversionPatternRewriter &rewriter) const { 275 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 276 if (!srcType) 277 return failure(); 278 279 // arith.constant should only have vector or tenor types. 280 assert((srcType.isa<VectorType, RankedTensorType>())); 281 282 auto dstType = getTypeConverter()->convertType(srcType); 283 if (!dstType) 284 return failure(); 285 286 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 287 ShapedType dstAttrType = dstElementsAttr.getType(); 288 if (!dstElementsAttr) 289 return failure(); 290 291 // If the composite type has more than one dimensions, perform linearization. 292 if (srcType.getRank() > 1) { 293 if (srcType.isa<RankedTensorType>()) { 294 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 295 srcType.getElementType()); 296 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 297 } else { 298 // TODO: add support for large vectors. 299 return failure(); 300 } 301 } 302 303 Type srcElemType = srcType.getElementType(); 304 Type dstElemType; 305 // Tensor types are converted to SPIR-V array types; vector types are 306 // converted to SPIR-V vector/array types. 307 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 308 dstElemType = arrayType.getElementType(); 309 else 310 dstElemType = dstType.cast<VectorType>().getElementType(); 311 312 // If the source and destination element types are different, perform 313 // attribute conversion. 314 if (srcElemType != dstElemType) { 315 SmallVector<Attribute, 8> elements; 316 if (srcElemType.isa<FloatType>()) { 317 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 318 FloatAttr dstAttr = 319 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 320 if (!dstAttr) 321 return failure(); 322 elements.push_back(dstAttr); 323 } 324 } else if (srcElemType.isInteger(1)) { 325 return failure(); 326 } else { 327 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 328 IntegerAttr dstAttr = convertIntegerAttr( 329 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 330 if (!dstAttr) 331 return failure(); 332 elements.push_back(dstAttr); 333 } 334 } 335 336 // Unfortunately, we cannot use dialect-specific types for element 337 // attributes; element attributes only works with builtin types. So we need 338 // to prepare another converted builtin types for the destination elements 339 // attribute. 340 if (dstAttrType.isa<RankedTensorType>()) 341 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 342 else 343 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 344 345 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 346 } 347 348 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 349 dstElementsAttr); 350 return success(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // ConstantOp with scalar type 355 //===----------------------------------------------------------------------===// 356 357 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 358 arith::ConstantOp constOp, OpAdaptor adaptor, 359 ConversionPatternRewriter &rewriter) const { 360 Type srcType = constOp.getType(); 361 if (!srcType.isIntOrIndexOrFloat()) 362 return failure(); 363 364 Type dstType = getTypeConverter()->convertType(srcType); 365 if (!dstType) 366 return failure(); 367 368 // Floating-point types. 369 if (srcType.isa<FloatType>()) { 370 auto srcAttr = constOp.getValue().cast<FloatAttr>(); 371 auto dstAttr = srcAttr; 372 373 // Floating-point types not supported in the target environment are all 374 // converted to float type. 375 if (srcType != dstType) { 376 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 377 if (!dstAttr) 378 return failure(); 379 } 380 381 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 382 return success(); 383 } 384 385 // Bool type. 386 if (srcType.isInteger(1)) { 387 // arith.constant can use 0/1 instead of true/false for i1 values. We need 388 // to handle that here. 389 auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter); 390 if (!dstAttr) 391 return failure(); 392 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 393 return success(); 394 } 395 396 // IndexType or IntegerType. Index values are converted to 32-bit integer 397 // values when converting to SPIR-V. 398 auto srcAttr = constOp.getValue().cast<IntegerAttr>(); 399 auto dstAttr = 400 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 401 if (!dstAttr) 402 return failure(); 403 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 404 return success(); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // RemSIOpGLSLPattern 409 //===----------------------------------------------------------------------===// 410 411 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 412 /// the sign of `signOperand`. 413 /// 414 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 415 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 416 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 417 /// if either operand can be negative. Emulate it via spv.UMod. 418 template <typename SignedAbsOp> 419 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 420 Value signOperand, OpBuilder &builder) { 421 assert(lhs.getType() == rhs.getType()); 422 assert(lhs == signOperand || rhs == signOperand); 423 424 Type type = lhs.getType(); 425 426 // Calculate the remainder with spv.UMod. 427 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 428 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 429 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 430 431 // Fix the sign. 432 Value isPositive; 433 if (lhs == signOperand) 434 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 435 else 436 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 437 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 438 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 439 } 440 441 LogicalResult 442 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 443 ConversionPatternRewriter &rewriter) const { 444 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 445 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 446 adaptor.getOperands()[0], rewriter); 447 rewriter.replaceOp(op, result); 448 449 return success(); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // RemSIOpOCLPattern 454 //===----------------------------------------------------------------------===// 455 456 LogicalResult 457 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 458 ConversionPatternRewriter &rewriter) const { 459 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 460 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 461 adaptor.getOperands()[0], rewriter); 462 rewriter.replaceOp(op, result); 463 464 return success(); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // BitwiseOpPattern 469 //===----------------------------------------------------------------------===// 470 471 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 472 LogicalResult 473 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 474 Op op, typename Op::Adaptor adaptor, 475 ConversionPatternRewriter &rewriter) const { 476 assert(adaptor.getOperands().size() == 2); 477 auto dstType = 478 this->getTypeConverter()->convertType(op.getResult().getType()); 479 if (!dstType) 480 return failure(); 481 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 482 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 483 adaptor.getOperands()); 484 } else { 485 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 486 adaptor.getOperands()); 487 } 488 return success(); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // XOrIOpLogicalPattern 493 //===----------------------------------------------------------------------===// 494 495 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 496 arith::XOrIOp op, OpAdaptor adaptor, 497 ConversionPatternRewriter &rewriter) const { 498 assert(adaptor.getOperands().size() == 2); 499 500 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 501 return failure(); 502 503 auto dstType = getTypeConverter()->convertType(op.getType()); 504 if (!dstType) 505 return failure(); 506 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 507 adaptor.getOperands()); 508 509 return success(); 510 } 511 512 //===----------------------------------------------------------------------===// 513 // XOrIOpBooleanPattern 514 //===----------------------------------------------------------------------===// 515 516 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 517 arith::XOrIOp op, OpAdaptor adaptor, 518 ConversionPatternRewriter &rewriter) const { 519 assert(adaptor.getOperands().size() == 2); 520 521 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 522 return failure(); 523 524 auto dstType = getTypeConverter()->convertType(op.getType()); 525 if (!dstType) 526 return failure(); 527 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 528 adaptor.getOperands()); 529 return success(); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // UIToFPI1Pattern 534 //===----------------------------------------------------------------------===// 535 536 LogicalResult 537 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const { 539 auto srcType = adaptor.getOperands().front().getType(); 540 if (!isBoolScalarOrVector(srcType)) 541 return failure(); 542 543 auto dstType = 544 this->getTypeConverter()->convertType(op.getResult().getType()); 545 Location loc = op.getLoc(); 546 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 547 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 548 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 549 op, dstType, adaptor.getOperands().front(), one, zero); 550 return success(); 551 } 552 553 //===----------------------------------------------------------------------===// 554 // ExtUII1Pattern 555 //===----------------------------------------------------------------------===// 556 557 LogicalResult 558 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 559 ConversionPatternRewriter &rewriter) const { 560 auto srcType = adaptor.getOperands().front().getType(); 561 if (!isBoolScalarOrVector(srcType)) 562 return failure(); 563 564 auto dstType = 565 this->getTypeConverter()->convertType(op.getResult().getType()); 566 Location loc = op.getLoc(); 567 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 568 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 569 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 570 op, dstType, adaptor.getOperands().front(), one, zero); 571 return success(); 572 } 573 574 //===----------------------------------------------------------------------===// 575 // TruncII1Pattern 576 //===----------------------------------------------------------------------===// 577 578 LogicalResult 579 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 580 ConversionPatternRewriter &rewriter) const { 581 auto dstType = 582 this->getTypeConverter()->convertType(op.getResult().getType()); 583 if (!isBoolScalarOrVector(dstType)) 584 return failure(); 585 586 Location loc = op.getLoc(); 587 auto srcType = adaptor.getOperands().front().getType(); 588 // Check if (x & 1) == 1. 589 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 590 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 591 loc, srcType, adaptor.getOperands()[0], mask); 592 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 593 594 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 595 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 596 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 597 return success(); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // TypeCastingOpPattern 602 //===----------------------------------------------------------------------===// 603 604 template <typename Op, typename SPIRVOp> 605 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 606 Op op, typename Op::Adaptor adaptor, 607 ConversionPatternRewriter &rewriter) const { 608 assert(adaptor.getOperands().size() == 1); 609 auto srcType = adaptor.getOperands().front().getType(); 610 auto dstType = 611 this->getTypeConverter()->convertType(op.getResult().getType()); 612 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 613 return failure(); 614 if (dstType == srcType) { 615 // Due to type conversion, we are seeing the same source and target type. 616 // Then we can just erase this operation by forwarding its operand. 617 rewriter.replaceOp(op, adaptor.getOperands().front()); 618 } else { 619 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 620 adaptor.getOperands()); 621 } 622 return success(); 623 } 624 625 //===----------------------------------------------------------------------===// 626 // CmpIOpBooleanPattern 627 //===----------------------------------------------------------------------===// 628 629 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 630 arith::CmpIOp op, OpAdaptor adaptor, 631 ConversionPatternRewriter &rewriter) const { 632 Type operandType = op.getLhs().getType(); 633 if (!isBoolScalarOrVector(operandType)) 634 return failure(); 635 636 switch (op.getPredicate()) { 637 #define DISPATCH(cmpPredicate, spirvOp) \ 638 case cmpPredicate: \ 639 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 640 adaptor.getLhs(), adaptor.getRhs()); \ 641 return success(); 642 643 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 644 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 645 646 #undef DISPATCH 647 default:; 648 } 649 return failure(); 650 } 651 652 //===----------------------------------------------------------------------===// 653 // CmpIOpPattern 654 //===----------------------------------------------------------------------===// 655 656 LogicalResult 657 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 658 ConversionPatternRewriter &rewriter) const { 659 Type operandType = op.getLhs().getType(); 660 if (isBoolScalarOrVector(operandType)) 661 return failure(); 662 663 switch (op.getPredicate()) { 664 #define DISPATCH(cmpPredicate, spirvOp) \ 665 case cmpPredicate: \ 666 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 667 operandType != this->getTypeConverter()->convertType(operandType)) { \ 668 return op.emitError( \ 669 "bitwidth emulation is not implemented yet on unsigned op"); \ 670 } \ 671 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 672 adaptor.getLhs(), adaptor.getRhs()); \ 673 return success(); 674 675 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 676 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 677 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 678 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 679 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 680 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 681 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 682 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 683 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 684 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 685 686 #undef DISPATCH 687 } 688 return failure(); 689 } 690 691 //===----------------------------------------------------------------------===// 692 // CmpFOpPattern 693 //===----------------------------------------------------------------------===// 694 695 LogicalResult 696 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 697 ConversionPatternRewriter &rewriter) const { 698 switch (op.getPredicate()) { 699 #define DISPATCH(cmpPredicate, spirvOp) \ 700 case cmpPredicate: \ 701 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 702 adaptor.getLhs(), adaptor.getRhs()); \ 703 return success(); 704 705 // Ordered. 706 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 707 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 708 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 709 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 710 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 711 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 712 // Unordered. 713 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 714 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 715 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 716 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 717 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 718 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 719 720 #undef DISPATCH 721 722 default: 723 break; 724 } 725 return failure(); 726 } 727 728 //===----------------------------------------------------------------------===// 729 // CmpFOpNanKernelPattern 730 //===----------------------------------------------------------------------===// 731 732 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 733 arith::CmpFOp op, OpAdaptor adaptor, 734 ConversionPatternRewriter &rewriter) const { 735 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 736 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 737 adaptor.getRhs()); 738 return success(); 739 } 740 741 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 742 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 743 adaptor.getRhs()); 744 return success(); 745 } 746 747 return failure(); 748 } 749 750 //===----------------------------------------------------------------------===// 751 // CmpFOpNanNonePattern 752 //===----------------------------------------------------------------------===// 753 754 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 755 arith::CmpFOp op, OpAdaptor adaptor, 756 ConversionPatternRewriter &rewriter) const { 757 if (op.getPredicate() != arith::CmpFPredicate::ORD && 758 op.getPredicate() != arith::CmpFPredicate::UNO) 759 return failure(); 760 761 Location loc = op.getLoc(); 762 763 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 764 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 765 766 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 767 if (op.getPredicate() == arith::CmpFPredicate::ORD) 768 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 769 770 rewriter.replaceOp(op, replace); 771 return success(); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // Pattern Population 776 //===----------------------------------------------------------------------===// 777 778 void mlir::arith::populateArithmeticToSPIRVPatterns( 779 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 780 // clang-format off 781 patterns.add< 782 ConstantCompositeOpPattern, 783 ConstantScalarOpPattern, 784 spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>, 785 spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>, 786 spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>, 787 spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>, 788 spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>, 789 spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>, 790 RemSIOpGLSLPattern, RemSIOpOCLPattern, 791 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 792 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 793 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 794 spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 795 spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 796 spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 797 spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>, 798 spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>, 799 spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>, 800 spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>, 801 spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>, 802 spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>, 803 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 804 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 805 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 806 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 807 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 808 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 809 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 810 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 811 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 812 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 813 CmpIOpBooleanPattern, CmpIOpPattern, 814 CmpFOpNanNonePattern, CmpFOpPattern 815 >(typeConverter, patterns.getContext()); 816 // clang-format on 817 818 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 819 // capability is available. 820 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 821 /*benefit=*/2); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Pass Definition 826 //===----------------------------------------------------------------------===// 827 828 namespace { 829 struct ConvertArithmeticToSPIRVPass 830 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 831 void runOnFunction() override { 832 auto module = getOperation()->getParentOfType<ModuleOp>(); 833 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 834 auto target = SPIRVConversionTarget::get(targetAttr); 835 836 SPIRVTypeConverter::Options options; 837 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 838 SPIRVTypeConverter typeConverter(targetAttr, options); 839 840 RewritePatternSet patterns(&getContext()); 841 mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 842 843 if (failed(applyPartialConversion(getOperation(), *target, 844 std::move(patterns)))) 845 signalPassFailure(); 846 } 847 }; 848 } // namespace 849 850 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 851 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 852 } 853