1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "arith-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Operation Conversion 25 //===----------------------------------------------------------------------===// 26 27 namespace { 28 29 /// Converts composite arith.constant operation to spv.Constant. 30 struct ConstantCompositeOpPattern final 31 : public OpConversionPattern<arith::ConstantOp> { 32 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 33 34 LogicalResult 35 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 36 ConversionPatternRewriter &rewriter) const override; 37 }; 38 39 /// Converts scalar arith.constant operation to spv.Constant. 40 struct ConstantScalarOpPattern final 41 : public OpConversionPattern<arith::ConstantOp> { 42 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 43 44 LogicalResult 45 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 46 ConversionPatternRewriter &rewriter) const override; 47 }; 48 49 /// Converts arith.remsi to GLSL SPIR-V ops. 50 /// 51 /// This cannot be merged into the template unary/binary pattern due to Vulkan 52 /// restrictions over spv.SRem and spv.SMod. 53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 54 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 58 ConversionPatternRewriter &rewriter) const override; 59 }; 60 61 /// Converts arith.remsi to OpenCL SPIR-V ops. 62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 63 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 64 65 LogicalResult 66 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override; 68 }; 69 70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 71 /// other than the BinaryOpPatternPattern because if the operands are boolean 72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 75 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 76 using OpConversionPattern<Op>::OpConversionPattern; 77 78 LogicalResult 79 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override; 81 }; 82 83 /// Converts arith.xori to SPIR-V operations. 84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 85 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 86 87 LogicalResult 88 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 89 ConversionPatternRewriter &rewriter) const override; 90 }; 91 92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 93 /// vector of i1. 94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 95 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 103 /// i1. 104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 105 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 113 /// i1. 114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 115 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 116 117 LogicalResult 118 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const override; 120 }; 121 122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 123 /// i1. 124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 125 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override; 130 }; 131 132 /// Converts type-casting standard operations to SPIR-V operations. 133 template <typename Op, typename SPIRVOp> 134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 135 using OpConversionPattern<Op>::OpConversionPattern; 136 137 LogicalResult 138 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override; 140 }; 141 142 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 144 public: 145 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const override; 150 }; 151 152 /// Converts integer compare operation to SPIR-V ops. 153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 154 public: 155 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override; 160 }; 161 162 /// Converts floating-point comparison operations to SPIR-V ops. 163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 164 public: 165 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 166 167 LogicalResult 168 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 169 ConversionPatternRewriter &rewriter) const override; 170 }; 171 172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 173 /// Kernel capability. 174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 175 public: 176 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 177 178 LogicalResult 179 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override; 181 }; 182 183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 184 /// require additional capability. 185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 186 public: 187 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 191 ConversionPatternRewriter &rewriter) const override; 192 }; 193 194 /// Converts arith.select to spv.Select. 195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 196 public: 197 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 198 LogicalResult 199 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const override; 201 }; 202 203 } // namespace 204 205 //===----------------------------------------------------------------------===// 206 // Conversion Helpers 207 //===----------------------------------------------------------------------===// 208 209 /// Converts the given `srcAttr` into a boolean attribute if it holds an 210 /// integral value. Returns null attribute if conversion fails. 211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 212 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 213 return boolAttr; 214 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 215 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 216 return BoolAttr(); 217 } 218 219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 220 /// Returns null attribute if conversion fails. 221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 222 Builder builder) { 223 // If the source number uses less active bits than the target bitwidth, then 224 // it should be safe to convert. 225 if (srcAttr.getValue().isIntN(dstType.getWidth())) 226 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 227 228 // XXX: Try again by interpreting the source number as a signed value. 229 // Although integers in the standard dialect are signless, they can represent 230 // a signed number. It's the operation decides how to interpret. This is 231 // dangerous, but it seems there is no good way of handling this if we still 232 // want to change the bitwidth. Emit a message at least. 233 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 234 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 235 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 236 << dstAttr << "' for type '" << dstType << "'\n"); 237 return dstAttr; 238 } 239 240 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 241 << "' illegal: cannot fit into target type '" 242 << dstType << "'\n"); 243 return IntegerAttr(); 244 } 245 246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 249 Builder builder) { 250 // Only support converting to float for now. 251 if (!dstType.isF32()) 252 return FloatAttr(); 253 254 // Try to convert the source floating-point number to single precision. 255 APFloat dstVal = srcAttr.getValue(); 256 bool losesInfo = false; 257 APFloat::opStatus status = 258 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 259 if (status != APFloat::opOK || losesInfo) { 260 LLVM_DEBUG(llvm::dbgs() 261 << srcAttr << " illegal: cannot fit into converted type '" 262 << dstType << "'\n"); 263 return FloatAttr(); 264 } 265 266 return builder.getF32FloatAttr(dstVal.convertToFloat()); 267 } 268 269 /// Returns true if the given `type` is a boolean scalar or vector type. 270 static bool isBoolScalarOrVector(Type type) { 271 if (type.isInteger(1)) 272 return true; 273 if (auto vecType = type.dyn_cast<VectorType>()) 274 return vecType.getElementType().isInteger(1); 275 return false; 276 } 277 278 //===----------------------------------------------------------------------===// 279 // ConstantOp with composite type 280 //===----------------------------------------------------------------------===// 281 282 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 283 arith::ConstantOp constOp, OpAdaptor adaptor, 284 ConversionPatternRewriter &rewriter) const { 285 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 286 if (!srcType || srcType.getNumElements() == 1) 287 return failure(); 288 289 // arith.constant should only have vector or tenor types. 290 assert((srcType.isa<VectorType, RankedTensorType>())); 291 292 auto dstType = getTypeConverter()->convertType(srcType); 293 if (!dstType) 294 return failure(); 295 296 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 297 ShapedType dstAttrType = dstElementsAttr.getType(); 298 if (!dstElementsAttr) 299 return failure(); 300 301 // If the composite type has more than one dimensions, perform linearization. 302 if (srcType.getRank() > 1) { 303 if (srcType.isa<RankedTensorType>()) { 304 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 305 srcType.getElementType()); 306 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 307 } else { 308 // TODO: add support for large vectors. 309 return failure(); 310 } 311 } 312 313 Type srcElemType = srcType.getElementType(); 314 Type dstElemType; 315 // Tensor types are converted to SPIR-V array types; vector types are 316 // converted to SPIR-V vector/array types. 317 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 318 dstElemType = arrayType.getElementType(); 319 else 320 dstElemType = dstType.cast<VectorType>().getElementType(); 321 322 // If the source and destination element types are different, perform 323 // attribute conversion. 324 if (srcElemType != dstElemType) { 325 SmallVector<Attribute, 8> elements; 326 if (srcElemType.isa<FloatType>()) { 327 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 328 FloatAttr dstAttr = 329 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 330 if (!dstAttr) 331 return failure(); 332 elements.push_back(dstAttr); 333 } 334 } else if (srcElemType.isInteger(1)) { 335 return failure(); 336 } else { 337 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 338 IntegerAttr dstAttr = convertIntegerAttr( 339 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 340 if (!dstAttr) 341 return failure(); 342 elements.push_back(dstAttr); 343 } 344 } 345 346 // Unfortunately, we cannot use dialect-specific types for element 347 // attributes; element attributes only works with builtin types. So we need 348 // to prepare another converted builtin types for the destination elements 349 // attribute. 350 if (dstAttrType.isa<RankedTensorType>()) 351 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 352 else 353 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 354 355 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 356 } 357 358 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 359 dstElementsAttr); 360 return success(); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // ConstantOp with scalar type 365 //===----------------------------------------------------------------------===// 366 367 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 368 arith::ConstantOp constOp, OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const { 370 Type srcType = constOp.getType(); 371 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 372 if (shapedType.getNumElements() != 1) 373 return failure(); 374 srcType = shapedType.getElementType(); 375 } 376 if (!srcType.isIntOrIndexOrFloat()) 377 return failure(); 378 379 Attribute cstAttr = constOp.getValue(); 380 if (cstAttr.getType().isa<ShapedType>()) 381 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 382 383 Type dstType = getTypeConverter()->convertType(srcType); 384 if (!dstType) 385 return failure(); 386 387 // Floating-point types. 388 if (srcType.isa<FloatType>()) { 389 auto srcAttr = cstAttr.cast<FloatAttr>(); 390 auto dstAttr = srcAttr; 391 392 // Floating-point types not supported in the target environment are all 393 // converted to float type. 394 if (srcType != dstType) { 395 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 396 if (!dstAttr) 397 return failure(); 398 } 399 400 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 401 return success(); 402 } 403 404 // Bool type. 405 if (srcType.isInteger(1)) { 406 // arith.constant can use 0/1 instead of true/false for i1 values. We need 407 // to handle that here. 408 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 409 if (!dstAttr) 410 return failure(); 411 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 412 return success(); 413 } 414 415 // IndexType or IntegerType. Index values are converted to 32-bit integer 416 // values when converting to SPIR-V. 417 auto srcAttr = cstAttr.cast<IntegerAttr>(); 418 auto dstAttr = 419 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 420 if (!dstAttr) 421 return failure(); 422 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 423 return success(); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // RemSIOpGLSLPattern 428 //===----------------------------------------------------------------------===// 429 430 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 431 /// the sign of `signOperand`. 432 /// 433 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 434 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 435 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 436 /// if either operand can be negative. Emulate it via spv.UMod. 437 template <typename SignedAbsOp> 438 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 439 Value signOperand, OpBuilder &builder) { 440 assert(lhs.getType() == rhs.getType()); 441 assert(lhs == signOperand || rhs == signOperand); 442 443 Type type = lhs.getType(); 444 445 // Calculate the remainder with spv.UMod. 446 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 447 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 448 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 449 450 // Fix the sign. 451 Value isPositive; 452 if (lhs == signOperand) 453 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 454 else 455 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 456 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 457 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 458 } 459 460 LogicalResult 461 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 462 ConversionPatternRewriter &rewriter) const { 463 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 464 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 465 adaptor.getOperands()[0], rewriter); 466 rewriter.replaceOp(op, result); 467 468 return success(); 469 } 470 471 //===----------------------------------------------------------------------===// 472 // RemSIOpOCLPattern 473 //===----------------------------------------------------------------------===// 474 475 LogicalResult 476 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 477 ConversionPatternRewriter &rewriter) const { 478 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 479 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 480 adaptor.getOperands()[0], rewriter); 481 rewriter.replaceOp(op, result); 482 483 return success(); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // BitwiseOpPattern 488 //===----------------------------------------------------------------------===// 489 490 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 491 LogicalResult 492 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 493 Op op, typename Op::Adaptor adaptor, 494 ConversionPatternRewriter &rewriter) const { 495 assert(adaptor.getOperands().size() == 2); 496 auto dstType = 497 this->getTypeConverter()->convertType(op.getResult().getType()); 498 if (!dstType) 499 return failure(); 500 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 501 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 502 adaptor.getOperands()); 503 } else { 504 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 505 adaptor.getOperands()); 506 } 507 return success(); 508 } 509 510 //===----------------------------------------------------------------------===// 511 // XOrIOpLogicalPattern 512 //===----------------------------------------------------------------------===// 513 514 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 515 arith::XOrIOp op, OpAdaptor adaptor, 516 ConversionPatternRewriter &rewriter) const { 517 assert(adaptor.getOperands().size() == 2); 518 519 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 520 return failure(); 521 522 auto dstType = getTypeConverter()->convertType(op.getType()); 523 if (!dstType) 524 return failure(); 525 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 526 adaptor.getOperands()); 527 528 return success(); 529 } 530 531 //===----------------------------------------------------------------------===// 532 // XOrIOpBooleanPattern 533 //===----------------------------------------------------------------------===// 534 535 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 536 arith::XOrIOp op, OpAdaptor adaptor, 537 ConversionPatternRewriter &rewriter) const { 538 assert(adaptor.getOperands().size() == 2); 539 540 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 541 return failure(); 542 543 auto dstType = getTypeConverter()->convertType(op.getType()); 544 if (!dstType) 545 return failure(); 546 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 547 adaptor.getOperands()); 548 return success(); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // UIToFPI1Pattern 553 //===----------------------------------------------------------------------===// 554 555 LogicalResult 556 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 557 ConversionPatternRewriter &rewriter) const { 558 auto srcType = adaptor.getOperands().front().getType(); 559 if (!isBoolScalarOrVector(srcType)) 560 return failure(); 561 562 auto dstType = 563 this->getTypeConverter()->convertType(op.getResult().getType()); 564 Location loc = op.getLoc(); 565 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 566 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 567 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 568 op, dstType, adaptor.getOperands().front(), one, zero); 569 return success(); 570 } 571 572 //===----------------------------------------------------------------------===// 573 // ExtUII1Pattern 574 //===----------------------------------------------------------------------===// 575 576 LogicalResult 577 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 578 ConversionPatternRewriter &rewriter) const { 579 auto srcType = adaptor.getOperands().front().getType(); 580 if (!isBoolScalarOrVector(srcType)) 581 return failure(); 582 583 auto dstType = 584 this->getTypeConverter()->convertType(op.getResult().getType()); 585 Location loc = op.getLoc(); 586 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 587 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 588 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 589 op, dstType, adaptor.getOperands().front(), one, zero); 590 return success(); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // TruncII1Pattern 595 //===----------------------------------------------------------------------===// 596 597 LogicalResult 598 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 599 ConversionPatternRewriter &rewriter) const { 600 auto dstType = 601 this->getTypeConverter()->convertType(op.getResult().getType()); 602 if (!isBoolScalarOrVector(dstType)) 603 return failure(); 604 605 Location loc = op.getLoc(); 606 auto srcType = adaptor.getOperands().front().getType(); 607 // Check if (x & 1) == 1. 608 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 609 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 610 loc, srcType, adaptor.getOperands()[0], mask); 611 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 612 613 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 614 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 615 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 616 return success(); 617 } 618 619 //===----------------------------------------------------------------------===// 620 // TypeCastingOpPattern 621 //===----------------------------------------------------------------------===// 622 623 template <typename Op, typename SPIRVOp> 624 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 625 Op op, typename Op::Adaptor adaptor, 626 ConversionPatternRewriter &rewriter) const { 627 assert(adaptor.getOperands().size() == 1); 628 auto srcType = adaptor.getOperands().front().getType(); 629 auto dstType = 630 this->getTypeConverter()->convertType(op.getResult().getType()); 631 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 632 return failure(); 633 if (dstType == srcType) { 634 // Due to type conversion, we are seeing the same source and target type. 635 // Then we can just erase this operation by forwarding its operand. 636 rewriter.replaceOp(op, adaptor.getOperands().front()); 637 } else { 638 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 639 adaptor.getOperands()); 640 } 641 return success(); 642 } 643 644 //===----------------------------------------------------------------------===// 645 // CmpIOpBooleanPattern 646 //===----------------------------------------------------------------------===// 647 648 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 649 arith::CmpIOp op, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const { 651 Type operandType = op.getLhs().getType(); 652 if (!isBoolScalarOrVector(operandType)) 653 return failure(); 654 655 switch (op.getPredicate()) { 656 #define DISPATCH(cmpPredicate, spirvOp) \ 657 case cmpPredicate: \ 658 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 659 adaptor.getLhs(), adaptor.getRhs()); \ 660 return success(); 661 662 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 663 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 664 665 #undef DISPATCH 666 default:; 667 } 668 return failure(); 669 } 670 671 //===----------------------------------------------------------------------===// 672 // CmpIOpPattern 673 //===----------------------------------------------------------------------===// 674 675 LogicalResult 676 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 677 ConversionPatternRewriter &rewriter) const { 678 Type operandType = op.getLhs().getType(); 679 if (isBoolScalarOrVector(operandType)) 680 return failure(); 681 682 switch (op.getPredicate()) { 683 #define DISPATCH(cmpPredicate, spirvOp) \ 684 case cmpPredicate: \ 685 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 686 operandType != this->getTypeConverter()->convertType(operandType)) { \ 687 return op.emitError( \ 688 "bitwidth emulation is not implemented yet on unsigned op"); \ 689 } \ 690 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 691 adaptor.getLhs(), adaptor.getRhs()); \ 692 return success(); 693 694 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 695 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 696 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 697 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 698 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 699 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 700 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 701 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 702 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 703 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 704 705 #undef DISPATCH 706 } 707 return failure(); 708 } 709 710 //===----------------------------------------------------------------------===// 711 // CmpFOpPattern 712 //===----------------------------------------------------------------------===// 713 714 LogicalResult 715 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 716 ConversionPatternRewriter &rewriter) const { 717 switch (op.getPredicate()) { 718 #define DISPATCH(cmpPredicate, spirvOp) \ 719 case cmpPredicate: \ 720 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 721 adaptor.getLhs(), adaptor.getRhs()); \ 722 return success(); 723 724 // Ordered. 725 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 726 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 727 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 728 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 729 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 730 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 731 // Unordered. 732 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 733 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 734 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 735 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 736 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 737 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 738 739 #undef DISPATCH 740 741 default: 742 break; 743 } 744 return failure(); 745 } 746 747 //===----------------------------------------------------------------------===// 748 // CmpFOpNanKernelPattern 749 //===----------------------------------------------------------------------===// 750 751 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 752 arith::CmpFOp op, OpAdaptor adaptor, 753 ConversionPatternRewriter &rewriter) const { 754 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 755 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 756 adaptor.getRhs()); 757 return success(); 758 } 759 760 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 761 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 762 adaptor.getRhs()); 763 return success(); 764 } 765 766 return failure(); 767 } 768 769 //===----------------------------------------------------------------------===// 770 // CmpFOpNanNonePattern 771 //===----------------------------------------------------------------------===// 772 773 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 774 arith::CmpFOp op, OpAdaptor adaptor, 775 ConversionPatternRewriter &rewriter) const { 776 if (op.getPredicate() != arith::CmpFPredicate::ORD && 777 op.getPredicate() != arith::CmpFPredicate::UNO) 778 return failure(); 779 780 Location loc = op.getLoc(); 781 782 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 783 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 784 785 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 786 if (op.getPredicate() == arith::CmpFPredicate::ORD) 787 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 788 789 rewriter.replaceOp(op, replace); 790 return success(); 791 } 792 793 //===----------------------------------------------------------------------===// 794 // SelectOpPattern 795 //===----------------------------------------------------------------------===// 796 797 LogicalResult 798 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 799 ConversionPatternRewriter &rewriter) const { 800 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 801 adaptor.getTrueValue(), 802 adaptor.getFalseValue()); 803 return success(); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // Pattern Population 808 //===----------------------------------------------------------------------===// 809 810 void mlir::arith::populateArithmeticToSPIRVPatterns( 811 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 812 // clang-format off 813 patterns.add< 814 ConstantCompositeOpPattern, 815 ConstantScalarOpPattern, 816 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 817 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 818 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 819 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 820 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 821 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 822 RemSIOpGLSLPattern, RemSIOpOCLPattern, 823 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 824 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 825 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 826 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 827 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 828 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 829 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 830 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 831 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 832 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 833 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 834 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 835 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 836 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 837 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 838 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 839 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 840 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 841 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 842 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 843 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 844 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 845 CmpIOpBooleanPattern, CmpIOpPattern, 846 CmpFOpNanNonePattern, CmpFOpPattern, 847 SelectOpPattern, 848 849 spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>, 850 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>, 851 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>, 852 spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>, 853 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>, 854 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp> 855 >(typeConverter, patterns.getContext()); 856 // clang-format on 857 858 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 859 // capability is available. 860 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 861 /*benefit=*/2); 862 } 863 864 //===----------------------------------------------------------------------===// 865 // Pass Definition 866 //===----------------------------------------------------------------------===// 867 868 namespace { 869 struct ConvertArithmeticToSPIRVPass 870 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 871 void runOnOperation() override { 872 auto module = getOperation(); 873 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 874 auto target = SPIRVConversionTarget::get(targetAttr); 875 876 SPIRVTypeConverter::Options options; 877 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 878 SPIRVTypeConverter typeConverter(targetAttr, options); 879 880 RewritePatternSet patterns(&getContext()); 881 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 882 populateFuncToSPIRVPatterns(typeConverter, patterns); 883 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 884 885 if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 886 signalPassFailure(); 887 } 888 }; 889 } // namespace 890 891 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 892 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 893 } 894