1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "arith-to-spirv-pattern" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Operation Conversion 24 //===----------------------------------------------------------------------===// 25 26 namespace { 27 28 /// Converts composite arith.constant operation to spv.Constant. 29 struct ConstantCompositeOpPattern final 30 : public OpConversionPattern<arith::ConstantOp> { 31 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 32 33 LogicalResult 34 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override; 36 }; 37 38 /// Converts scalar arith.constant operation to spv.Constant. 39 struct ConstantScalarOpPattern final 40 : public OpConversionPattern<arith::ConstantOp> { 41 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 42 43 LogicalResult 44 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override; 46 }; 47 48 /// Converts arith.remsi to GLSL SPIR-V ops. 49 /// 50 /// This cannot be merged into the template unary/binary pattern due to Vulkan 51 /// restrictions over spv.SRem and spv.SMod. 52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 53 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 57 ConversionPatternRewriter &rewriter) const override; 58 }; 59 60 /// Converts arith.remsi to OpenCL SPIR-V ops. 61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 62 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override; 67 }; 68 69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 70 /// other than the BinaryOpPatternPattern because if the operands are boolean 71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 74 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 75 using OpConversionPattern<Op>::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override; 80 }; 81 82 /// Converts arith.xori to SPIR-V operations. 83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 84 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override; 89 }; 90 91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 92 /// vector of i1. 93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 94 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override; 99 }; 100 101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 102 /// i1. 103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 104 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 112 /// i1. 113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 114 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 115 116 LogicalResult 117 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 118 ConversionPatternRewriter &rewriter) const override; 119 }; 120 121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 122 /// i1. 123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 124 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 125 126 LogicalResult 127 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 128 ConversionPatternRewriter &rewriter) const override; 129 }; 130 131 /// Converts type-casting standard operations to SPIR-V operations. 132 template <typename Op, typename SPIRVOp> 133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 134 using OpConversionPattern<Op>::OpConversionPattern; 135 136 LogicalResult 137 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 138 ConversionPatternRewriter &rewriter) const override; 139 }; 140 141 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 143 public: 144 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 145 146 LogicalResult 147 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override; 149 }; 150 151 /// Converts integer compare operation to SPIR-V ops. 152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 153 public: 154 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override; 159 }; 160 161 /// Converts floating-point comparison operations to SPIR-V ops. 162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 163 public: 164 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 165 166 LogicalResult 167 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const override; 169 }; 170 171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 172 /// Kernel capability. 173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 174 public: 175 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 176 177 LogicalResult 178 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override; 180 }; 181 182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 183 /// require additional capability. 184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 185 public: 186 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 187 188 LogicalResult 189 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 190 ConversionPatternRewriter &rewriter) const override; 191 }; 192 193 /// Converts arith.select to spv.Select. 194 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 195 public: 196 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 197 LogicalResult 198 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override; 200 }; 201 202 } // namespace 203 204 //===----------------------------------------------------------------------===// 205 // Conversion Helpers 206 //===----------------------------------------------------------------------===// 207 208 /// Converts the given `srcAttr` into a boolean attribute if it holds an 209 /// integral value. Returns null attribute if conversion fails. 210 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 211 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 212 return boolAttr; 213 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 214 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 215 return BoolAttr(); 216 } 217 218 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 219 /// Returns null attribute if conversion fails. 220 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 221 Builder builder) { 222 // If the source number uses less active bits than the target bitwidth, then 223 // it should be safe to convert. 224 if (srcAttr.getValue().isIntN(dstType.getWidth())) 225 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 226 227 // XXX: Try again by interpreting the source number as a signed value. 228 // Although integers in the standard dialect are signless, they can represent 229 // a signed number. It's the operation decides how to interpret. This is 230 // dangerous, but it seems there is no good way of handling this if we still 231 // want to change the bitwidth. Emit a message at least. 232 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 233 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 234 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 235 << dstAttr << "' for type '" << dstType << "'\n"); 236 return dstAttr; 237 } 238 239 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 240 << "' illegal: cannot fit into target type '" 241 << dstType << "'\n"); 242 return IntegerAttr(); 243 } 244 245 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 246 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 247 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 248 Builder builder) { 249 // Only support converting to float for now. 250 if (!dstType.isF32()) 251 return FloatAttr(); 252 253 // Try to convert the source floating-point number to single precision. 254 APFloat dstVal = srcAttr.getValue(); 255 bool losesInfo = false; 256 APFloat::opStatus status = 257 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 258 if (status != APFloat::opOK || losesInfo) { 259 LLVM_DEBUG(llvm::dbgs() 260 << srcAttr << " illegal: cannot fit into converted type '" 261 << dstType << "'\n"); 262 return FloatAttr(); 263 } 264 265 return builder.getF32FloatAttr(dstVal.convertToFloat()); 266 } 267 268 /// Returns true if the given `type` is a boolean scalar or vector type. 269 static bool isBoolScalarOrVector(Type type) { 270 if (type.isInteger(1)) 271 return true; 272 if (auto vecType = type.dyn_cast<VectorType>()) 273 return vecType.getElementType().isInteger(1); 274 return false; 275 } 276 277 //===----------------------------------------------------------------------===// 278 // ConstantOp with composite type 279 //===----------------------------------------------------------------------===// 280 281 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 282 arith::ConstantOp constOp, OpAdaptor adaptor, 283 ConversionPatternRewriter &rewriter) const { 284 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 285 if (!srcType || srcType.getNumElements() == 1) 286 return failure(); 287 288 // arith.constant should only have vector or tenor types. 289 assert((srcType.isa<VectorType, RankedTensorType>())); 290 291 auto dstType = getTypeConverter()->convertType(srcType); 292 if (!dstType) 293 return failure(); 294 295 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 296 ShapedType dstAttrType = dstElementsAttr.getType(); 297 if (!dstElementsAttr) 298 return failure(); 299 300 // If the composite type has more than one dimensions, perform linearization. 301 if (srcType.getRank() > 1) { 302 if (srcType.isa<RankedTensorType>()) { 303 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 304 srcType.getElementType()); 305 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 306 } else { 307 // TODO: add support for large vectors. 308 return failure(); 309 } 310 } 311 312 Type srcElemType = srcType.getElementType(); 313 Type dstElemType; 314 // Tensor types are converted to SPIR-V array types; vector types are 315 // converted to SPIR-V vector/array types. 316 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 317 dstElemType = arrayType.getElementType(); 318 else 319 dstElemType = dstType.cast<VectorType>().getElementType(); 320 321 // If the source and destination element types are different, perform 322 // attribute conversion. 323 if (srcElemType != dstElemType) { 324 SmallVector<Attribute, 8> elements; 325 if (srcElemType.isa<FloatType>()) { 326 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 327 FloatAttr dstAttr = 328 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 329 if (!dstAttr) 330 return failure(); 331 elements.push_back(dstAttr); 332 } 333 } else if (srcElemType.isInteger(1)) { 334 return failure(); 335 } else { 336 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 337 IntegerAttr dstAttr = convertIntegerAttr( 338 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 339 if (!dstAttr) 340 return failure(); 341 elements.push_back(dstAttr); 342 } 343 } 344 345 // Unfortunately, we cannot use dialect-specific types for element 346 // attributes; element attributes only works with builtin types. So we need 347 // to prepare another converted builtin types for the destination elements 348 // attribute. 349 if (dstAttrType.isa<RankedTensorType>()) 350 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 351 else 352 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 353 354 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 355 } 356 357 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 358 dstElementsAttr); 359 return success(); 360 } 361 362 //===----------------------------------------------------------------------===// 363 // ConstantOp with scalar type 364 //===----------------------------------------------------------------------===// 365 366 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 367 arith::ConstantOp constOp, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const { 369 Type srcType = constOp.getType(); 370 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 371 if (shapedType.getNumElements() != 1) 372 return failure(); 373 srcType = shapedType.getElementType(); 374 } 375 if (!srcType.isIntOrIndexOrFloat()) 376 return failure(); 377 378 Attribute cstAttr = constOp.getValue(); 379 if (cstAttr.getType().isa<ShapedType>()) 380 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 381 382 Type dstType = getTypeConverter()->convertType(srcType); 383 if (!dstType) 384 return failure(); 385 386 // Floating-point types. 387 if (srcType.isa<FloatType>()) { 388 auto srcAttr = cstAttr.cast<FloatAttr>(); 389 auto dstAttr = srcAttr; 390 391 // Floating-point types not supported in the target environment are all 392 // converted to float type. 393 if (srcType != dstType) { 394 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 395 if (!dstAttr) 396 return failure(); 397 } 398 399 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 400 return success(); 401 } 402 403 // Bool type. 404 if (srcType.isInteger(1)) { 405 // arith.constant can use 0/1 instead of true/false for i1 values. We need 406 // to handle that here. 407 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 408 if (!dstAttr) 409 return failure(); 410 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 411 return success(); 412 } 413 414 // IndexType or IntegerType. Index values are converted to 32-bit integer 415 // values when converting to SPIR-V. 416 auto srcAttr = cstAttr.cast<IntegerAttr>(); 417 auto dstAttr = 418 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 419 if (!dstAttr) 420 return failure(); 421 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 422 return success(); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // RemSIOpGLSLPattern 427 //===----------------------------------------------------------------------===// 428 429 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 430 /// the sign of `signOperand`. 431 /// 432 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 433 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 434 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 435 /// if either operand can be negative. Emulate it via spv.UMod. 436 template <typename SignedAbsOp> 437 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 438 Value signOperand, OpBuilder &builder) { 439 assert(lhs.getType() == rhs.getType()); 440 assert(lhs == signOperand || rhs == signOperand); 441 442 Type type = lhs.getType(); 443 444 // Calculate the remainder with spv.UMod. 445 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 446 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 447 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 448 449 // Fix the sign. 450 Value isPositive; 451 if (lhs == signOperand) 452 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 453 else 454 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 455 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 456 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 457 } 458 459 LogicalResult 460 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 461 ConversionPatternRewriter &rewriter) const { 462 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 463 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 464 adaptor.getOperands()[0], rewriter); 465 rewriter.replaceOp(op, result); 466 467 return success(); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // RemSIOpOCLPattern 472 //===----------------------------------------------------------------------===// 473 474 LogicalResult 475 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 476 ConversionPatternRewriter &rewriter) const { 477 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 478 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 479 adaptor.getOperands()[0], rewriter); 480 rewriter.replaceOp(op, result); 481 482 return success(); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // BitwiseOpPattern 487 //===----------------------------------------------------------------------===// 488 489 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 490 LogicalResult 491 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 492 Op op, typename Op::Adaptor adaptor, 493 ConversionPatternRewriter &rewriter) const { 494 assert(adaptor.getOperands().size() == 2); 495 auto dstType = 496 this->getTypeConverter()->convertType(op.getResult().getType()); 497 if (!dstType) 498 return failure(); 499 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 500 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 501 adaptor.getOperands()); 502 } else { 503 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 504 adaptor.getOperands()); 505 } 506 return success(); 507 } 508 509 //===----------------------------------------------------------------------===// 510 // XOrIOpLogicalPattern 511 //===----------------------------------------------------------------------===// 512 513 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 514 arith::XOrIOp op, OpAdaptor adaptor, 515 ConversionPatternRewriter &rewriter) const { 516 assert(adaptor.getOperands().size() == 2); 517 518 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 519 return failure(); 520 521 auto dstType = getTypeConverter()->convertType(op.getType()); 522 if (!dstType) 523 return failure(); 524 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 525 adaptor.getOperands()); 526 527 return success(); 528 } 529 530 //===----------------------------------------------------------------------===// 531 // XOrIOpBooleanPattern 532 //===----------------------------------------------------------------------===// 533 534 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 535 arith::XOrIOp op, OpAdaptor adaptor, 536 ConversionPatternRewriter &rewriter) const { 537 assert(adaptor.getOperands().size() == 2); 538 539 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 540 return failure(); 541 542 auto dstType = getTypeConverter()->convertType(op.getType()); 543 if (!dstType) 544 return failure(); 545 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 546 adaptor.getOperands()); 547 return success(); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // UIToFPI1Pattern 552 //===----------------------------------------------------------------------===// 553 554 LogicalResult 555 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 556 ConversionPatternRewriter &rewriter) const { 557 auto srcType = adaptor.getOperands().front().getType(); 558 if (!isBoolScalarOrVector(srcType)) 559 return failure(); 560 561 auto dstType = 562 this->getTypeConverter()->convertType(op.getResult().getType()); 563 Location loc = op.getLoc(); 564 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 565 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 566 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 567 op, dstType, adaptor.getOperands().front(), one, zero); 568 return success(); 569 } 570 571 //===----------------------------------------------------------------------===// 572 // ExtUII1Pattern 573 //===----------------------------------------------------------------------===// 574 575 LogicalResult 576 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 577 ConversionPatternRewriter &rewriter) const { 578 auto srcType = adaptor.getOperands().front().getType(); 579 if (!isBoolScalarOrVector(srcType)) 580 return failure(); 581 582 auto dstType = 583 this->getTypeConverter()->convertType(op.getResult().getType()); 584 Location loc = op.getLoc(); 585 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 586 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 587 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 588 op, dstType, adaptor.getOperands().front(), one, zero); 589 return success(); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // TruncII1Pattern 594 //===----------------------------------------------------------------------===// 595 596 LogicalResult 597 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 598 ConversionPatternRewriter &rewriter) const { 599 auto dstType = 600 this->getTypeConverter()->convertType(op.getResult().getType()); 601 if (!isBoolScalarOrVector(dstType)) 602 return failure(); 603 604 Location loc = op.getLoc(); 605 auto srcType = adaptor.getOperands().front().getType(); 606 // Check if (x & 1) == 1. 607 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 608 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 609 loc, srcType, adaptor.getOperands()[0], mask); 610 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 611 612 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 613 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 614 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 615 return success(); 616 } 617 618 //===----------------------------------------------------------------------===// 619 // TypeCastingOpPattern 620 //===----------------------------------------------------------------------===// 621 622 template <typename Op, typename SPIRVOp> 623 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 624 Op op, typename Op::Adaptor adaptor, 625 ConversionPatternRewriter &rewriter) const { 626 assert(adaptor.getOperands().size() == 1); 627 auto srcType = adaptor.getOperands().front().getType(); 628 auto dstType = 629 this->getTypeConverter()->convertType(op.getResult().getType()); 630 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 631 return failure(); 632 if (dstType == srcType) { 633 // Due to type conversion, we are seeing the same source and target type. 634 // Then we can just erase this operation by forwarding its operand. 635 rewriter.replaceOp(op, adaptor.getOperands().front()); 636 } else { 637 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 638 adaptor.getOperands()); 639 } 640 return success(); 641 } 642 643 //===----------------------------------------------------------------------===// 644 // CmpIOpBooleanPattern 645 //===----------------------------------------------------------------------===// 646 647 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 648 arith::CmpIOp op, OpAdaptor adaptor, 649 ConversionPatternRewriter &rewriter) const { 650 Type operandType = op.getLhs().getType(); 651 if (!isBoolScalarOrVector(operandType)) 652 return failure(); 653 654 switch (op.getPredicate()) { 655 #define DISPATCH(cmpPredicate, spirvOp) \ 656 case cmpPredicate: \ 657 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 658 adaptor.getLhs(), adaptor.getRhs()); \ 659 return success(); 660 661 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 662 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 663 664 #undef DISPATCH 665 default:; 666 } 667 return failure(); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // CmpIOpPattern 672 //===----------------------------------------------------------------------===// 673 674 LogicalResult 675 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 676 ConversionPatternRewriter &rewriter) const { 677 Type operandType = op.getLhs().getType(); 678 if (isBoolScalarOrVector(operandType)) 679 return failure(); 680 681 switch (op.getPredicate()) { 682 #define DISPATCH(cmpPredicate, spirvOp) \ 683 case cmpPredicate: \ 684 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 685 operandType != this->getTypeConverter()->convertType(operandType)) { \ 686 return op.emitError( \ 687 "bitwidth emulation is not implemented yet on unsigned op"); \ 688 } \ 689 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 690 adaptor.getLhs(), adaptor.getRhs()); \ 691 return success(); 692 693 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 694 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 695 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 696 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 697 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 698 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 699 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 700 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 701 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 702 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 703 704 #undef DISPATCH 705 } 706 return failure(); 707 } 708 709 //===----------------------------------------------------------------------===// 710 // CmpFOpPattern 711 //===----------------------------------------------------------------------===// 712 713 LogicalResult 714 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 715 ConversionPatternRewriter &rewriter) const { 716 switch (op.getPredicate()) { 717 #define DISPATCH(cmpPredicate, spirvOp) \ 718 case cmpPredicate: \ 719 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 720 adaptor.getLhs(), adaptor.getRhs()); \ 721 return success(); 722 723 // Ordered. 724 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 725 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 726 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 727 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 728 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 729 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 730 // Unordered. 731 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 732 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 733 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 734 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 735 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 736 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 737 738 #undef DISPATCH 739 740 default: 741 break; 742 } 743 return failure(); 744 } 745 746 //===----------------------------------------------------------------------===// 747 // CmpFOpNanKernelPattern 748 //===----------------------------------------------------------------------===// 749 750 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 751 arith::CmpFOp op, OpAdaptor adaptor, 752 ConversionPatternRewriter &rewriter) const { 753 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 754 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 755 adaptor.getRhs()); 756 return success(); 757 } 758 759 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 760 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 761 adaptor.getRhs()); 762 return success(); 763 } 764 765 return failure(); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // CmpFOpNanNonePattern 770 //===----------------------------------------------------------------------===// 771 772 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 773 arith::CmpFOp op, OpAdaptor adaptor, 774 ConversionPatternRewriter &rewriter) const { 775 if (op.getPredicate() != arith::CmpFPredicate::ORD && 776 op.getPredicate() != arith::CmpFPredicate::UNO) 777 return failure(); 778 779 Location loc = op.getLoc(); 780 781 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 782 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 783 784 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 785 if (op.getPredicate() == arith::CmpFPredicate::ORD) 786 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 787 788 rewriter.replaceOp(op, replace); 789 return success(); 790 } 791 792 //===----------------------------------------------------------------------===// 793 // SelectOpPattern 794 //===----------------------------------------------------------------------===// 795 796 LogicalResult 797 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 798 ConversionPatternRewriter &rewriter) const { 799 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 800 adaptor.getTrueValue(), 801 adaptor.getFalseValue()); 802 return success(); 803 } 804 805 //===----------------------------------------------------------------------===// 806 // Pattern Population 807 //===----------------------------------------------------------------------===// 808 809 void mlir::arith::populateArithmeticToSPIRVPatterns( 810 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 811 // clang-format off 812 patterns.add< 813 ConstantCompositeOpPattern, 814 ConstantScalarOpPattern, 815 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 816 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 817 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 818 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 819 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 820 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 821 RemSIOpGLSLPattern, RemSIOpOCLPattern, 822 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 823 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 824 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 825 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 826 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 827 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 828 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 829 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 830 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 831 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 832 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 833 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 834 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 835 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 836 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 837 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 838 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 839 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 840 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 841 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 842 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 843 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 844 CmpIOpBooleanPattern, CmpIOpPattern, 845 CmpFOpNanNonePattern, CmpFOpPattern, 846 SelectOpPattern 847 >(typeConverter, patterns.getContext()); 848 // clang-format on 849 850 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 851 // capability is available. 852 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 853 /*benefit=*/2); 854 } 855 856 //===----------------------------------------------------------------------===// 857 // Pass Definition 858 //===----------------------------------------------------------------------===// 859 860 namespace { 861 struct ConvertArithmeticToSPIRVPass 862 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 863 void runOnOperation() override { 864 auto module = getOperation()->getParentOfType<ModuleOp>(); 865 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 866 auto target = SPIRVConversionTarget::get(targetAttr); 867 868 SPIRVTypeConverter::Options options; 869 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 870 SPIRVTypeConverter typeConverter(targetAttr, options); 871 872 RewritePatternSet patterns(&getContext()); 873 mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 874 875 if (failed(applyPartialConversion(getOperation(), *target, 876 std::move(patterns)))) 877 signalPassFailure(); 878 } 879 }; 880 } // namespace 881 882 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 883 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 884 } 885