1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "arith-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Operation Conversion 25 //===----------------------------------------------------------------------===// 26 27 namespace { 28 29 /// Converts composite arith.constant operation to spv.Constant. 30 struct ConstantCompositeOpPattern final 31 : public OpConversionPattern<arith::ConstantOp> { 32 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 33 34 LogicalResult 35 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 36 ConversionPatternRewriter &rewriter) const override; 37 }; 38 39 /// Converts scalar arith.constant operation to spv.Constant. 40 struct ConstantScalarOpPattern final 41 : public OpConversionPattern<arith::ConstantOp> { 42 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 43 44 LogicalResult 45 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 46 ConversionPatternRewriter &rewriter) const override; 47 }; 48 49 /// Converts arith.remsi to GLSL SPIR-V ops. 50 /// 51 /// This cannot be merged into the template unary/binary pattern due to Vulkan 52 /// restrictions over spv.SRem and spv.SMod. 53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 54 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 58 ConversionPatternRewriter &rewriter) const override; 59 }; 60 61 /// Converts arith.remsi to OpenCL SPIR-V ops. 62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 63 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 64 65 LogicalResult 66 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override; 68 }; 69 70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 71 /// other than the BinaryOpPatternPattern because if the operands are boolean 72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 75 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 76 using OpConversionPattern<Op>::OpConversionPattern; 77 78 LogicalResult 79 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override; 81 }; 82 83 /// Converts arith.xori to SPIR-V operations. 84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 85 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 86 87 LogicalResult 88 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 89 ConversionPatternRewriter &rewriter) const override; 90 }; 91 92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 93 /// vector of i1. 94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 95 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 103 /// i1. 104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 105 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 113 /// i1. 114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 115 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 116 117 LogicalResult 118 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const override; 120 }; 121 122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 123 /// i1. 124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 125 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override; 130 }; 131 132 /// Converts type-casting standard operations to SPIR-V operations. 133 template <typename Op, typename SPIRVOp> 134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 135 using OpConversionPattern<Op>::OpConversionPattern; 136 137 LogicalResult 138 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override; 140 }; 141 142 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 144 public: 145 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const override; 150 }; 151 152 /// Converts integer compare operation to SPIR-V ops. 153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 154 public: 155 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override; 160 }; 161 162 /// Converts floating-point comparison operations to SPIR-V ops. 163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 164 public: 165 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 166 167 LogicalResult 168 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 169 ConversionPatternRewriter &rewriter) const override; 170 }; 171 172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 173 /// Kernel capability. 174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 175 public: 176 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 177 178 LogicalResult 179 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override; 181 }; 182 183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 184 /// require additional capability. 185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 186 public: 187 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 191 ConversionPatternRewriter &rewriter) const override; 192 }; 193 194 /// Converts arith.select to spv.Select. 195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 196 public: 197 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 198 LogicalResult 199 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const override; 201 }; 202 203 } // namespace 204 205 //===----------------------------------------------------------------------===// 206 // Conversion Helpers 207 //===----------------------------------------------------------------------===// 208 209 /// Converts the given `srcAttr` into a boolean attribute if it holds an 210 /// integral value. Returns null attribute if conversion fails. 211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 212 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 213 return boolAttr; 214 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 215 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 216 return BoolAttr(); 217 } 218 219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 220 /// Returns null attribute if conversion fails. 221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 222 Builder builder) { 223 // If the source number uses less active bits than the target bitwidth, then 224 // it should be safe to convert. 225 if (srcAttr.getValue().isIntN(dstType.getWidth())) 226 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 227 228 // XXX: Try again by interpreting the source number as a signed value. 229 // Although integers in the standard dialect are signless, they can represent 230 // a signed number. It's the operation decides how to interpret. This is 231 // dangerous, but it seems there is no good way of handling this if we still 232 // want to change the bitwidth. Emit a message at least. 233 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 234 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 235 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 236 << dstAttr << "' for type '" << dstType << "'\n"); 237 return dstAttr; 238 } 239 240 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 241 << "' illegal: cannot fit into target type '" 242 << dstType << "'\n"); 243 return IntegerAttr(); 244 } 245 246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 249 Builder builder) { 250 // Only support converting to float for now. 251 if (!dstType.isF32()) 252 return FloatAttr(); 253 254 // Try to convert the source floating-point number to single precision. 255 APFloat dstVal = srcAttr.getValue(); 256 bool losesInfo = false; 257 APFloat::opStatus status = 258 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 259 if (status != APFloat::opOK || losesInfo) { 260 LLVM_DEBUG(llvm::dbgs() 261 << srcAttr << " illegal: cannot fit into converted type '" 262 << dstType << "'\n"); 263 return FloatAttr(); 264 } 265 266 return builder.getF32FloatAttr(dstVal.convertToFloat()); 267 } 268 269 /// Returns true if the given `type` is a boolean scalar or vector type. 270 static bool isBoolScalarOrVector(Type type) { 271 if (type.isInteger(1)) 272 return true; 273 if (auto vecType = type.dyn_cast<VectorType>()) 274 return vecType.getElementType().isInteger(1); 275 return false; 276 } 277 278 //===----------------------------------------------------------------------===// 279 // ConstantOp with composite type 280 //===----------------------------------------------------------------------===// 281 282 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 283 arith::ConstantOp constOp, OpAdaptor adaptor, 284 ConversionPatternRewriter &rewriter) const { 285 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 286 if (!srcType || srcType.getNumElements() == 1) 287 return failure(); 288 289 // arith.constant should only have vector or tenor types. 290 assert((srcType.isa<VectorType, RankedTensorType>())); 291 292 auto dstType = getTypeConverter()->convertType(srcType); 293 if (!dstType) 294 return failure(); 295 296 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 297 if (!dstElementsAttr) 298 return failure(); 299 300 ShapedType dstAttrType = dstElementsAttr.getType(); 301 302 // If the composite type has more than one dimensions, perform linearization. 303 if (srcType.getRank() > 1) { 304 if (srcType.isa<RankedTensorType>()) { 305 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 306 srcType.getElementType()); 307 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 308 } else { 309 // TODO: add support for large vectors. 310 return failure(); 311 } 312 } 313 314 Type srcElemType = srcType.getElementType(); 315 Type dstElemType; 316 // Tensor types are converted to SPIR-V array types; vector types are 317 // converted to SPIR-V vector/array types. 318 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 319 dstElemType = arrayType.getElementType(); 320 else 321 dstElemType = dstType.cast<VectorType>().getElementType(); 322 323 // If the source and destination element types are different, perform 324 // attribute conversion. 325 if (srcElemType != dstElemType) { 326 SmallVector<Attribute, 8> elements; 327 if (srcElemType.isa<FloatType>()) { 328 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 329 FloatAttr dstAttr = 330 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 331 if (!dstAttr) 332 return failure(); 333 elements.push_back(dstAttr); 334 } 335 } else if (srcElemType.isInteger(1)) { 336 return failure(); 337 } else { 338 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 339 IntegerAttr dstAttr = convertIntegerAttr( 340 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 341 if (!dstAttr) 342 return failure(); 343 elements.push_back(dstAttr); 344 } 345 } 346 347 // Unfortunately, we cannot use dialect-specific types for element 348 // attributes; element attributes only works with builtin types. So we need 349 // to prepare another converted builtin types for the destination elements 350 // attribute. 351 if (dstAttrType.isa<RankedTensorType>()) 352 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 353 else 354 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 355 356 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 357 } 358 359 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 360 dstElementsAttr); 361 return success(); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // ConstantOp with scalar type 366 //===----------------------------------------------------------------------===// 367 368 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 369 arith::ConstantOp constOp, OpAdaptor adaptor, 370 ConversionPatternRewriter &rewriter) const { 371 Type srcType = constOp.getType(); 372 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 373 if (shapedType.getNumElements() != 1) 374 return failure(); 375 srcType = shapedType.getElementType(); 376 } 377 if (!srcType.isIntOrIndexOrFloat()) 378 return failure(); 379 380 Attribute cstAttr = constOp.getValue(); 381 if (cstAttr.getType().isa<ShapedType>()) 382 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 383 384 Type dstType = getTypeConverter()->convertType(srcType); 385 if (!dstType) 386 return failure(); 387 388 // Floating-point types. 389 if (srcType.isa<FloatType>()) { 390 auto srcAttr = cstAttr.cast<FloatAttr>(); 391 auto dstAttr = srcAttr; 392 393 // Floating-point types not supported in the target environment are all 394 // converted to float type. 395 if (srcType != dstType) { 396 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 397 if (!dstAttr) 398 return failure(); 399 } 400 401 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 402 return success(); 403 } 404 405 // Bool type. 406 if (srcType.isInteger(1)) { 407 // arith.constant can use 0/1 instead of true/false for i1 values. We need 408 // to handle that here. 409 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 410 if (!dstAttr) 411 return failure(); 412 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 413 return success(); 414 } 415 416 // IndexType or IntegerType. Index values are converted to 32-bit integer 417 // values when converting to SPIR-V. 418 auto srcAttr = cstAttr.cast<IntegerAttr>(); 419 auto dstAttr = 420 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 421 if (!dstAttr) 422 return failure(); 423 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 424 return success(); 425 } 426 427 //===----------------------------------------------------------------------===// 428 // RemSIOpGLSLPattern 429 //===----------------------------------------------------------------------===// 430 431 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 432 /// the sign of `signOperand`. 433 /// 434 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 435 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 436 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 437 /// if either operand can be negative. Emulate it via spv.UMod. 438 template <typename SignedAbsOp> 439 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 440 Value signOperand, OpBuilder &builder) { 441 assert(lhs.getType() == rhs.getType()); 442 assert(lhs == signOperand || rhs == signOperand); 443 444 Type type = lhs.getType(); 445 446 // Calculate the remainder with spv.UMod. 447 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 448 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 449 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 450 451 // Fix the sign. 452 Value isPositive; 453 if (lhs == signOperand) 454 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 455 else 456 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 457 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 458 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 459 } 460 461 LogicalResult 462 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 463 ConversionPatternRewriter &rewriter) const { 464 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 465 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 466 adaptor.getOperands()[0], rewriter); 467 rewriter.replaceOp(op, result); 468 469 return success(); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // RemSIOpOCLPattern 474 //===----------------------------------------------------------------------===// 475 476 LogicalResult 477 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 478 ConversionPatternRewriter &rewriter) const { 479 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 480 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 481 adaptor.getOperands()[0], rewriter); 482 rewriter.replaceOp(op, result); 483 484 return success(); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // BitwiseOpPattern 489 //===----------------------------------------------------------------------===// 490 491 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 492 LogicalResult 493 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 494 Op op, typename Op::Adaptor adaptor, 495 ConversionPatternRewriter &rewriter) const { 496 assert(adaptor.getOperands().size() == 2); 497 auto dstType = 498 this->getTypeConverter()->convertType(op.getResult().getType()); 499 if (!dstType) 500 return failure(); 501 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 502 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 503 adaptor.getOperands()); 504 } else { 505 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 506 adaptor.getOperands()); 507 } 508 return success(); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // XOrIOpLogicalPattern 513 //===----------------------------------------------------------------------===// 514 515 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 516 arith::XOrIOp op, OpAdaptor adaptor, 517 ConversionPatternRewriter &rewriter) const { 518 assert(adaptor.getOperands().size() == 2); 519 520 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 521 return failure(); 522 523 auto dstType = getTypeConverter()->convertType(op.getType()); 524 if (!dstType) 525 return failure(); 526 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 527 adaptor.getOperands()); 528 529 return success(); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // XOrIOpBooleanPattern 534 //===----------------------------------------------------------------------===// 535 536 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 537 arith::XOrIOp op, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const { 539 assert(adaptor.getOperands().size() == 2); 540 541 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 542 return failure(); 543 544 auto dstType = getTypeConverter()->convertType(op.getType()); 545 if (!dstType) 546 return failure(); 547 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 548 adaptor.getOperands()); 549 return success(); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // UIToFPI1Pattern 554 //===----------------------------------------------------------------------===// 555 556 LogicalResult 557 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 558 ConversionPatternRewriter &rewriter) const { 559 auto srcType = adaptor.getOperands().front().getType(); 560 if (!isBoolScalarOrVector(srcType)) 561 return failure(); 562 563 auto dstType = 564 this->getTypeConverter()->convertType(op.getResult().getType()); 565 Location loc = op.getLoc(); 566 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 567 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 568 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 569 op, dstType, adaptor.getOperands().front(), one, zero); 570 return success(); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // ExtUII1Pattern 575 //===----------------------------------------------------------------------===// 576 577 LogicalResult 578 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 579 ConversionPatternRewriter &rewriter) const { 580 auto srcType = adaptor.getOperands().front().getType(); 581 if (!isBoolScalarOrVector(srcType)) 582 return failure(); 583 584 auto dstType = 585 this->getTypeConverter()->convertType(op.getResult().getType()); 586 Location loc = op.getLoc(); 587 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 588 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 589 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 590 op, dstType, adaptor.getOperands().front(), one, zero); 591 return success(); 592 } 593 594 //===----------------------------------------------------------------------===// 595 // TruncII1Pattern 596 //===----------------------------------------------------------------------===// 597 598 LogicalResult 599 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 600 ConversionPatternRewriter &rewriter) const { 601 auto dstType = 602 this->getTypeConverter()->convertType(op.getResult().getType()); 603 if (!isBoolScalarOrVector(dstType)) 604 return failure(); 605 606 Location loc = op.getLoc(); 607 auto srcType = adaptor.getOperands().front().getType(); 608 // Check if (x & 1) == 1. 609 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 610 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 611 loc, srcType, adaptor.getOperands()[0], mask); 612 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 613 614 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 615 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 616 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 617 return success(); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // TypeCastingOpPattern 622 //===----------------------------------------------------------------------===// 623 624 template <typename Op, typename SPIRVOp> 625 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 626 Op op, typename Op::Adaptor adaptor, 627 ConversionPatternRewriter &rewriter) const { 628 assert(adaptor.getOperands().size() == 1); 629 auto srcType = adaptor.getOperands().front().getType(); 630 auto dstType = 631 this->getTypeConverter()->convertType(op.getResult().getType()); 632 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 633 return failure(); 634 if (dstType == srcType) { 635 // Due to type conversion, we are seeing the same source and target type. 636 // Then we can just erase this operation by forwarding its operand. 637 rewriter.replaceOp(op, adaptor.getOperands().front()); 638 } else { 639 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 640 adaptor.getOperands()); 641 } 642 return success(); 643 } 644 645 //===----------------------------------------------------------------------===// 646 // CmpIOpBooleanPattern 647 //===----------------------------------------------------------------------===// 648 649 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 650 arith::CmpIOp op, OpAdaptor adaptor, 651 ConversionPatternRewriter &rewriter) const { 652 Type operandType = op.getLhs().getType(); 653 if (!isBoolScalarOrVector(operandType)) 654 return failure(); 655 656 switch (op.getPredicate()) { 657 #define DISPATCH(cmpPredicate, spirvOp) \ 658 case cmpPredicate: \ 659 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 660 adaptor.getLhs(), adaptor.getRhs()); \ 661 return success(); 662 663 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 664 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 665 666 #undef DISPATCH 667 default:; 668 } 669 return failure(); 670 } 671 672 //===----------------------------------------------------------------------===// 673 // CmpIOpPattern 674 //===----------------------------------------------------------------------===// 675 676 LogicalResult 677 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 678 ConversionPatternRewriter &rewriter) const { 679 Type operandType = op.getLhs().getType(); 680 if (isBoolScalarOrVector(operandType)) 681 return failure(); 682 683 switch (op.getPredicate()) { 684 #define DISPATCH(cmpPredicate, spirvOp) \ 685 case cmpPredicate: \ 686 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 687 operandType != this->getTypeConverter()->convertType(operandType)) { \ 688 return op.emitError( \ 689 "bitwidth emulation is not implemented yet on unsigned op"); \ 690 } \ 691 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 692 adaptor.getLhs(), adaptor.getRhs()); \ 693 return success(); 694 695 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 696 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 697 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 698 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 699 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 700 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 701 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 702 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 703 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 704 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 705 706 #undef DISPATCH 707 } 708 return failure(); 709 } 710 711 //===----------------------------------------------------------------------===// 712 // CmpFOpPattern 713 //===----------------------------------------------------------------------===// 714 715 LogicalResult 716 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const { 718 switch (op.getPredicate()) { 719 #define DISPATCH(cmpPredicate, spirvOp) \ 720 case cmpPredicate: \ 721 rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ 722 adaptor.getLhs(), adaptor.getRhs()); \ 723 return success(); 724 725 // Ordered. 726 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 727 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 728 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 729 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 730 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 731 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 732 // Unordered. 733 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 734 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 735 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 736 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 737 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 738 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 739 740 #undef DISPATCH 741 742 default: 743 break; 744 } 745 return failure(); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // CmpFOpNanKernelPattern 750 //===----------------------------------------------------------------------===// 751 752 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 753 arith::CmpFOp op, OpAdaptor adaptor, 754 ConversionPatternRewriter &rewriter) const { 755 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 756 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 757 adaptor.getRhs()); 758 return success(); 759 } 760 761 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 762 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 763 adaptor.getRhs()); 764 return success(); 765 } 766 767 return failure(); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // CmpFOpNanNonePattern 772 //===----------------------------------------------------------------------===// 773 774 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 775 arith::CmpFOp op, OpAdaptor adaptor, 776 ConversionPatternRewriter &rewriter) const { 777 if (op.getPredicate() != arith::CmpFPredicate::ORD && 778 op.getPredicate() != arith::CmpFPredicate::UNO) 779 return failure(); 780 781 Location loc = op.getLoc(); 782 783 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 784 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 785 786 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 787 if (op.getPredicate() == arith::CmpFPredicate::ORD) 788 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 789 790 rewriter.replaceOp(op, replace); 791 return success(); 792 } 793 794 //===----------------------------------------------------------------------===// 795 // SelectOpPattern 796 //===----------------------------------------------------------------------===// 797 798 LogicalResult 799 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 800 ConversionPatternRewriter &rewriter) const { 801 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 802 adaptor.getTrueValue(), 803 adaptor.getFalseValue()); 804 return success(); 805 } 806 807 //===----------------------------------------------------------------------===// 808 // Pattern Population 809 //===----------------------------------------------------------------------===// 810 811 void mlir::arith::populateArithmeticToSPIRVPatterns( 812 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 813 // clang-format off 814 patterns.add< 815 ConstantCompositeOpPattern, 816 ConstantScalarOpPattern, 817 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 818 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 819 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 820 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 821 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 822 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 823 RemSIOpGLSLPattern, RemSIOpOCLPattern, 824 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 825 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 826 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 827 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 828 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 829 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 830 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 831 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 832 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 833 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 834 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 835 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 836 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 837 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 838 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 839 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 840 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 841 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 842 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 843 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 844 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 845 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 846 CmpIOpBooleanPattern, CmpIOpPattern, 847 CmpFOpNanNonePattern, CmpFOpPattern, 848 SelectOpPattern, 849 850 spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>, 851 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>, 852 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>, 853 spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>, 854 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>, 855 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp> 856 >(typeConverter, patterns.getContext()); 857 // clang-format on 858 859 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 860 // capability is available. 861 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 862 /*benefit=*/2); 863 } 864 865 //===----------------------------------------------------------------------===// 866 // Pass Definition 867 //===----------------------------------------------------------------------===// 868 869 namespace { 870 struct ConvertArithmeticToSPIRVPass 871 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 872 void runOnOperation() override { 873 auto module = getOperation(); 874 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 875 auto target = SPIRVConversionTarget::get(targetAttr); 876 877 SPIRVTypeConverter::Options options; 878 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 879 SPIRVTypeConverter typeConverter(targetAttr, options); 880 881 RewritePatternSet patterns(&getContext()); 882 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 883 populateFuncToSPIRVPatterns(typeConverter, patterns); 884 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 885 886 if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 887 signalPassFailure(); 888 } 889 }; 890 } // namespace 891 892 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 893 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 894 } 895