1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "arith-to-spirv-pattern" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation Conversion 24 //===----------------------------------------------------------------------===// 25 26 namespace { 27 28 /// Converts composite arith.constant operation to spv.Constant. 29 struct ConstantCompositeOpPattern final 30 : public OpConversionPattern<arith::ConstantOp> { 31 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 32 33 LogicalResult 34 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override; 36 }; 37 38 /// Converts scalar arith.constant operation to spv.Constant. 39 struct ConstantScalarOpPattern final 40 : public OpConversionPattern<arith::ConstantOp> { 41 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override; 46 }; 47 48 /// Converts arith.remsi to SPIR-V ops. 49 /// 50 /// This cannot be merged into the template unary/binary pattern due to Vulkan 51 /// restrictions over spv.SRem and spv.SMod. 52 struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> { 53 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 57 ConversionPatternRewriter &rewriter) const override; 58 }; 59 60 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 61 /// other than the BinaryOpPatternPattern because if the operands are boolean 62 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 63 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 64 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 65 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 66 using OpConversionPattern<Op>::OpConversionPattern; 67 68 LogicalResult 69 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 70 ConversionPatternRewriter &rewriter) const override; 71 }; 72 73 /// Converts arith.xori to SPIR-V operations. 74 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 75 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override; 80 }; 81 82 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 83 /// vector of i1. 84 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 85 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 86 87 LogicalResult 88 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 89 ConversionPatternRewriter &rewriter) const override; 90 }; 91 92 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 93 /// i1. 94 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 95 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 103 /// i1. 104 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 105 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 113 /// i1. 114 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 115 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 116 117 LogicalResult 118 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const override; 120 }; 121 122 /// Converts type-casting standard operations to SPIR-V operations. 123 template <typename Op, typename SPIRVOp> 124 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 125 using OpConversionPattern<Op>::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override; 130 }; 131 132 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 133 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 134 public: 135 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 136 137 LogicalResult 138 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override; 140 }; 141 142 /// Converts integer compare operation to SPIR-V ops. 143 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 144 public: 145 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const override; 150 }; 151 152 /// Converts floating-point comparison operations to SPIR-V ops. 153 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 154 public: 155 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override; 160 }; 161 162 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 163 /// Kernel capability. 164 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 165 public: 166 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 170 ConversionPatternRewriter &rewriter) const override; 171 }; 172 173 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 174 /// require additional capability. 175 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 176 public: 177 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 178 179 LogicalResult 180 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 181 ConversionPatternRewriter &rewriter) const override; 182 }; 183 184 } // end anonymous namespace 185 186 //===----------------------------------------------------------------------===// 187 // Conversion Helpers 188 //===----------------------------------------------------------------------===// 189 190 /// Converts the given `srcAttr` into a boolean attribute if it holds an 191 /// integral value. Returns null attribute if conversion fails. 192 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 193 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 194 return boolAttr; 195 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 196 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 197 return BoolAttr(); 198 } 199 200 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 201 /// Returns null attribute if conversion fails. 202 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 203 Builder builder) { 204 // If the source number uses less active bits than the target bitwidth, then 205 // it should be safe to convert. 206 if (srcAttr.getValue().isIntN(dstType.getWidth())) 207 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 208 209 // XXX: Try again by interpreting the source number as a signed value. 210 // Although integers in the standard dialect are signless, they can represent 211 // a signed number. It's the operation decides how to interpret. This is 212 // dangerous, but it seems there is no good way of handling this if we still 213 // want to change the bitwidth. Emit a message at least. 214 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 215 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 216 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 217 << dstAttr << "' for type '" << dstType << "'\n"); 218 return dstAttr; 219 } 220 221 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 222 << "' illegal: cannot fit into target type '" 223 << dstType << "'\n"); 224 return IntegerAttr(); 225 } 226 227 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 228 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 229 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 230 Builder builder) { 231 // Only support converting to float for now. 232 if (!dstType.isF32()) 233 return FloatAttr(); 234 235 // Try to convert the source floating-point number to single precision. 236 APFloat dstVal = srcAttr.getValue(); 237 bool losesInfo = false; 238 APFloat::opStatus status = 239 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 240 if (status != APFloat::opOK || losesInfo) { 241 LLVM_DEBUG(llvm::dbgs() 242 << srcAttr << " illegal: cannot fit into converted type '" 243 << dstType << "'\n"); 244 return FloatAttr(); 245 } 246 247 return builder.getF32FloatAttr(dstVal.convertToFloat()); 248 } 249 250 /// Returns true if the given `type` is a boolean scalar or vector type. 251 static bool isBoolScalarOrVector(Type type) { 252 if (type.isInteger(1)) 253 return true; 254 if (auto vecType = type.dyn_cast<VectorType>()) 255 return vecType.getElementType().isInteger(1); 256 return false; 257 } 258 259 //===----------------------------------------------------------------------===// 260 // ConstantOp with composite type 261 //===----------------------------------------------------------------------===// 262 263 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 264 arith::ConstantOp constOp, OpAdaptor adaptor, 265 ConversionPatternRewriter &rewriter) const { 266 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 267 if (!srcType) 268 return failure(); 269 270 // arith.constant should only have vector or tenor types. 271 assert((srcType.isa<VectorType, RankedTensorType>())); 272 273 auto dstType = getTypeConverter()->convertType(srcType); 274 if (!dstType) 275 return failure(); 276 277 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 278 ShapedType dstAttrType = dstElementsAttr.getType(); 279 if (!dstElementsAttr) 280 return failure(); 281 282 // If the composite type has more than one dimensions, perform linearization. 283 if (srcType.getRank() > 1) { 284 if (srcType.isa<RankedTensorType>()) { 285 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 286 srcType.getElementType()); 287 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 288 } else { 289 // TODO: add support for large vectors. 290 return failure(); 291 } 292 } 293 294 Type srcElemType = srcType.getElementType(); 295 Type dstElemType; 296 // Tensor types are converted to SPIR-V array types; vector types are 297 // converted to SPIR-V vector/array types. 298 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 299 dstElemType = arrayType.getElementType(); 300 else 301 dstElemType = dstType.cast<VectorType>().getElementType(); 302 303 // If the source and destination element types are different, perform 304 // attribute conversion. 305 if (srcElemType != dstElemType) { 306 SmallVector<Attribute, 8> elements; 307 if (srcElemType.isa<FloatType>()) { 308 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 309 FloatAttr dstAttr = 310 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 311 if (!dstAttr) 312 return failure(); 313 elements.push_back(dstAttr); 314 } 315 } else if (srcElemType.isInteger(1)) { 316 return failure(); 317 } else { 318 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 319 IntegerAttr dstAttr = convertIntegerAttr( 320 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 321 if (!dstAttr) 322 return failure(); 323 elements.push_back(dstAttr); 324 } 325 } 326 327 // Unfortunately, we cannot use dialect-specific types for element 328 // attributes; element attributes only works with builtin types. So we need 329 // to prepare another converted builtin types for the destination elements 330 // attribute. 331 if (dstAttrType.isa<RankedTensorType>()) 332 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 333 else 334 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 335 336 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 337 } 338 339 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 340 dstElementsAttr); 341 return success(); 342 } 343 344 //===----------------------------------------------------------------------===// 345 // ConstantOp with scalar type 346 //===----------------------------------------------------------------------===// 347 348 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 349 arith::ConstantOp constOp, OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const { 351 Type srcType = constOp.getType(); 352 if (!srcType.isIntOrIndexOrFloat()) 353 return failure(); 354 355 Type dstType = getTypeConverter()->convertType(srcType); 356 if (!dstType) 357 return failure(); 358 359 // Floating-point types. 360 if (srcType.isa<FloatType>()) { 361 auto srcAttr = constOp.getValue().cast<FloatAttr>(); 362 auto dstAttr = srcAttr; 363 364 // Floating-point types not supported in the target environment are all 365 // converted to float type. 366 if (srcType != dstType) { 367 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 368 if (!dstAttr) 369 return failure(); 370 } 371 372 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 373 return success(); 374 } 375 376 // Bool type. 377 if (srcType.isInteger(1)) { 378 // arith.constant can use 0/1 instead of true/false for i1 values. We need 379 // to handle that here. 380 auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter); 381 if (!dstAttr) 382 return failure(); 383 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 384 return success(); 385 } 386 387 // IndexType or IntegerType. Index values are converted to 32-bit integer 388 // values when converting to SPIR-V. 389 auto srcAttr = constOp.getValue().cast<IntegerAttr>(); 390 auto dstAttr = 391 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 392 if (!dstAttr) 393 return failure(); 394 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 395 return success(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // RemSIOpPattern 400 //===----------------------------------------------------------------------===// 401 402 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 403 /// the sign of `signOperand`. 404 /// 405 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 406 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 407 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 408 /// if either operand can be negative. Emulate it via spv.UMod. 409 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 410 Value signOperand, OpBuilder &builder) { 411 assert(lhs.getType() == rhs.getType()); 412 assert(lhs == signOperand || rhs == signOperand); 413 414 Type type = lhs.getType(); 415 416 // Calculate the remainder with spv.UMod. 417 Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs); 418 Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs); 419 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 420 421 // Fix the sign. 422 Value isPositive; 423 if (lhs == signOperand) 424 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 425 else 426 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 427 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 428 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 429 } 430 431 LogicalResult 432 RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 433 ConversionPatternRewriter &rewriter) const { 434 Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0], 435 adaptor.getOperands()[1], 436 adaptor.getOperands()[0], rewriter); 437 rewriter.replaceOp(op, result); 438 439 return success(); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // BitwiseOpPattern 444 //===----------------------------------------------------------------------===// 445 446 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 447 LogicalResult 448 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 449 Op op, typename Op::Adaptor adaptor, 450 ConversionPatternRewriter &rewriter) const { 451 assert(adaptor.getOperands().size() == 2); 452 auto dstType = 453 this->getTypeConverter()->convertType(op.getResult().getType()); 454 if (!dstType) 455 return failure(); 456 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 457 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 458 adaptor.getOperands()); 459 } else { 460 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 461 adaptor.getOperands()); 462 } 463 return success(); 464 } 465 466 //===----------------------------------------------------------------------===// 467 // XOrIOpLogicalPattern 468 //===----------------------------------------------------------------------===// 469 470 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 471 arith::XOrIOp op, OpAdaptor adaptor, 472 ConversionPatternRewriter &rewriter) const { 473 assert(adaptor.getOperands().size() == 2); 474 475 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 476 return failure(); 477 478 auto dstType = getTypeConverter()->convertType(op.getType()); 479 if (!dstType) 480 return failure(); 481 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 482 adaptor.getOperands()); 483 484 return success(); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // XOrIOpBooleanPattern 489 //===----------------------------------------------------------------------===// 490 491 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 492 arith::XOrIOp op, OpAdaptor adaptor, 493 ConversionPatternRewriter &rewriter) const { 494 assert(adaptor.getOperands().size() == 2); 495 496 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 497 return failure(); 498 499 auto dstType = getTypeConverter()->convertType(op.getType()); 500 if (!dstType) 501 return failure(); 502 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 503 adaptor.getOperands()); 504 return success(); 505 } 506 507 //===----------------------------------------------------------------------===// 508 // UIToFPI1Pattern 509 //===----------------------------------------------------------------------===// 510 511 LogicalResult 512 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 513 ConversionPatternRewriter &rewriter) const { 514 auto srcType = adaptor.getOperands().front().getType(); 515 if (!isBoolScalarOrVector(srcType)) 516 return failure(); 517 518 auto dstType = 519 this->getTypeConverter()->convertType(op.getResult().getType()); 520 Location loc = op.getLoc(); 521 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 522 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 523 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 524 op, dstType, adaptor.getOperands().front(), one, zero); 525 return success(); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // ExtUII1Pattern 530 //===----------------------------------------------------------------------===// 531 532 LogicalResult 533 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 534 ConversionPatternRewriter &rewriter) const { 535 auto srcType = adaptor.getOperands().front().getType(); 536 if (!isBoolScalarOrVector(srcType)) 537 return failure(); 538 539 auto dstType = 540 this->getTypeConverter()->convertType(op.getResult().getType()); 541 Location loc = op.getLoc(); 542 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 543 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 544 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 545 op, dstType, adaptor.getOperands().front(), one, zero); 546 return success(); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // TruncII1Pattern 551 //===----------------------------------------------------------------------===// 552 553 LogicalResult 554 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const { 556 auto dstType = 557 this->getTypeConverter()->convertType(op.getResult().getType()); 558 if (!isBoolScalarOrVector(dstType)) 559 return failure(); 560 561 Location loc = op.getLoc(); 562 auto srcType = adaptor.getOperands().front().getType(); 563 // Check if (x & 1) == 1. 564 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 565 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 566 loc, srcType, adaptor.getOperands()[0], mask); 567 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 568 569 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 570 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 571 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 572 return success(); 573 } 574 575 //===----------------------------------------------------------------------===// 576 // TypeCastingOpPattern 577 //===----------------------------------------------------------------------===// 578 579 template <typename Op, typename SPIRVOp> 580 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 581 Op op, typename Op::Adaptor adaptor, 582 ConversionPatternRewriter &rewriter) const { 583 assert(adaptor.getOperands().size() == 1); 584 auto srcType = adaptor.getOperands().front().getType(); 585 auto dstType = 586 this->getTypeConverter()->convertType(op.getResult().getType()); 587 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 588 return failure(); 589 if (dstType == srcType) { 590 // Due to type conversion, we are seeing the same source and target type. 591 // Then we can just erase this operation by forwarding its operand. 592 rewriter.replaceOp(op, adaptor.getOperands().front()); 593 } else { 594 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 595 adaptor.getOperands()); 596 } 597 return success(); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // CmpIOpBooleanPattern 602 //===----------------------------------------------------------------------===// 603 604 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 605 arith::CmpIOp op, OpAdaptor adaptor, 606 ConversionPatternRewriter &rewriter) const { 607 Type operandType = op.getLhs().getType(); 608 if (!isBoolScalarOrVector(operandType)) 609 return failure(); 610 611 switch (op.getPredicate()) { 612 #define DISPATCH(cmpPredicate, spirvOp) \ 613 case cmpPredicate: \ 614 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 615 adaptor.lhs(), adaptor.rhs()); \ 616 return success(); 617 618 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 619 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 620 621 #undef DISPATCH 622 default:; 623 } 624 return failure(); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // CmpIOpPattern 629 //===----------------------------------------------------------------------===// 630 631 LogicalResult 632 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 633 ConversionPatternRewriter &rewriter) const { 634 Type operandType = op.getLhs().getType(); 635 if (isBoolScalarOrVector(operandType)) 636 return failure(); 637 638 switch (op.getPredicate()) { 639 #define DISPATCH(cmpPredicate, spirvOp) \ 640 case cmpPredicate: \ 641 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 642 operandType != this->getTypeConverter()->convertType(operandType)) { \ 643 return op.emitError( \ 644 "bitwidth emulation is not implemented yet on unsigned op"); \ 645 } \ 646 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 647 adaptor.lhs(), adaptor.rhs()); \ 648 return success(); 649 650 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 651 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 652 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 653 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 654 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 655 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 656 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 657 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 658 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 659 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 660 661 #undef DISPATCH 662 } 663 return failure(); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // CmpFOpPattern 668 //===----------------------------------------------------------------------===// 669 670 LogicalResult 671 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 672 ConversionPatternRewriter &rewriter) const { 673 switch (op.getPredicate()) { 674 #define DISPATCH(cmpPredicate, spirvOp) \ 675 case cmpPredicate: \ 676 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 677 adaptor.lhs(), adaptor.rhs()); \ 678 return success(); 679 680 // Ordered. 681 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 682 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 683 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 684 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 685 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 686 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 687 // Unordered. 688 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 689 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 690 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 691 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 692 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 693 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 694 695 #undef DISPATCH 696 697 default: 698 break; 699 } 700 return failure(); 701 } 702 703 //===----------------------------------------------------------------------===// 704 // CmpFOpNanKernelPattern 705 //===----------------------------------------------------------------------===// 706 707 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 708 arith::CmpFOp op, OpAdaptor adaptor, 709 ConversionPatternRewriter &rewriter) const { 710 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 711 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 712 adaptor.getRhs()); 713 return success(); 714 } 715 716 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 717 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 718 adaptor.getRhs()); 719 return success(); 720 } 721 722 return failure(); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // CmpFOpNanNonePattern 727 //===----------------------------------------------------------------------===// 728 729 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 730 arith::CmpFOp op, OpAdaptor adaptor, 731 ConversionPatternRewriter &rewriter) const { 732 if (op.getPredicate() != arith::CmpFPredicate::ORD && 733 op.getPredicate() != arith::CmpFPredicate::UNO) 734 return failure(); 735 736 Location loc = op.getLoc(); 737 738 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 739 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 740 741 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 742 if (op.getPredicate() == arith::CmpFPredicate::ORD) 743 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 744 745 rewriter.replaceOp(op, replace); 746 return success(); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // Pattern Population 751 //===----------------------------------------------------------------------===// 752 753 void mlir::arith::populateArithmeticToSPIRVPatterns( 754 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 755 // clang-format off 756 patterns.add< 757 ConstantCompositeOpPattern, 758 ConstantScalarOpPattern, 759 spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>, 760 spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>, 761 spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>, 762 spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>, 763 spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>, 764 spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>, 765 RemSIOpPattern, 766 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 767 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 768 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 769 spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 770 spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 771 spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 772 spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>, 773 spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>, 774 spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>, 775 spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>, 776 spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>, 777 spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>, 778 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 779 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 780 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 781 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 782 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 783 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 784 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 785 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 786 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 787 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 788 CmpIOpBooleanPattern, CmpIOpPattern, 789 CmpFOpNanNonePattern, CmpFOpPattern 790 >(typeConverter, patterns.getContext()); 791 // clang-format on 792 793 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 794 // capability is available. 795 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 796 /*benefit=*/2); 797 } 798 799 //===----------------------------------------------------------------------===// 800 // Pass Definition 801 //===----------------------------------------------------------------------===// 802 803 namespace { 804 struct ConvertArithmeticToSPIRVPass 805 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 806 void runOnFunction() override { 807 auto module = getOperation()->getParentOfType<ModuleOp>(); 808 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 809 auto target = SPIRVConversionTarget::get(targetAttr); 810 811 SPIRVTypeConverter::Options options; 812 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 813 SPIRVTypeConverter typeConverter(targetAttr, options); 814 815 RewritePatternSet patterns(&getContext()); 816 mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 817 818 if (failed(applyPartialConversion(getOperation(), *target, 819 std::move(patterns)))) 820 signalPassFailure(); 821 } 822 }; 823 } // end anonymous namespace 824 825 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 826 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 827 } 828