1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "arith-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Operation Conversion 25 //===----------------------------------------------------------------------===// 26 27 namespace { 28 29 /// Converts composite arith.constant operation to spv.Constant. 30 struct ConstantCompositeOpPattern final 31 : public OpConversionPattern<arith::ConstantOp> { 32 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 33 34 LogicalResult 35 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 36 ConversionPatternRewriter &rewriter) const override; 37 }; 38 39 /// Converts scalar arith.constant operation to spv.Constant. 40 struct ConstantScalarOpPattern final 41 : public OpConversionPattern<arith::ConstantOp> { 42 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 43 44 LogicalResult 45 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 46 ConversionPatternRewriter &rewriter) const override; 47 }; 48 49 /// Converts arith.remsi to GLSL SPIR-V ops. 50 /// 51 /// This cannot be merged into the template unary/binary pattern due to Vulkan 52 /// restrictions over spv.SRem and spv.SMod. 53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 54 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 58 ConversionPatternRewriter &rewriter) const override; 59 }; 60 61 /// Converts arith.remsi to OpenCL SPIR-V ops. 62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 63 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 64 65 LogicalResult 66 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override; 68 }; 69 70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 71 /// other than the BinaryOpPatternPattern because if the operands are boolean 72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 75 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 76 using OpConversionPattern<Op>::OpConversionPattern; 77 78 LogicalResult 79 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override; 81 }; 82 83 /// Converts arith.xori to SPIR-V operations. 84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 85 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 86 87 LogicalResult 88 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 89 ConversionPatternRewriter &rewriter) const override; 90 }; 91 92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 93 /// vector of i1. 94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 95 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 103 /// i1. 104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 105 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const override; 110 }; 111 112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 113 /// i1. 114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 115 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 116 117 LogicalResult 118 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const override; 120 }; 121 122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 123 /// i1. 124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 125 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override; 130 }; 131 132 /// Converts type-casting standard operations to SPIR-V operations. 133 template <typename Op, typename SPIRVOp> 134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 135 using OpConversionPattern<Op>::OpConversionPattern; 136 137 LogicalResult 138 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 139 ConversionPatternRewriter &rewriter) const override; 140 }; 141 142 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 144 public: 145 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const override; 150 }; 151 152 /// Converts integer compare operation to SPIR-V ops. 153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 154 public: 155 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override; 160 }; 161 162 /// Converts floating-point comparison operations to SPIR-V ops. 163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 164 public: 165 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 166 167 LogicalResult 168 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 169 ConversionPatternRewriter &rewriter) const override; 170 }; 171 172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 173 /// Kernel capability. 174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 175 public: 176 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 177 178 LogicalResult 179 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 180 ConversionPatternRewriter &rewriter) const override; 181 }; 182 183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 184 /// require additional capability. 185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 186 public: 187 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 191 ConversionPatternRewriter &rewriter) const override; 192 }; 193 194 /// Converts arith.select to spv.Select. 195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 196 public: 197 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 198 LogicalResult 199 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const override; 201 }; 202 203 } // namespace 204 205 //===----------------------------------------------------------------------===// 206 // Conversion Helpers 207 //===----------------------------------------------------------------------===// 208 209 /// Converts the given `srcAttr` into a boolean attribute if it holds an 210 /// integral value. Returns null attribute if conversion fails. 211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 212 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 213 return boolAttr; 214 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 215 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 216 return BoolAttr(); 217 } 218 219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 220 /// Returns null attribute if conversion fails. 221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 222 Builder builder) { 223 // If the source number uses less active bits than the target bitwidth, then 224 // it should be safe to convert. 225 if (srcAttr.getValue().isIntN(dstType.getWidth())) 226 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 227 228 // XXX: Try again by interpreting the source number as a signed value. 229 // Although integers in the standard dialect are signless, they can represent 230 // a signed number. It's the operation decides how to interpret. This is 231 // dangerous, but it seems there is no good way of handling this if we still 232 // want to change the bitwidth. Emit a message at least. 233 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 234 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 235 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 236 << dstAttr << "' for type '" << dstType << "'\n"); 237 return dstAttr; 238 } 239 240 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 241 << "' illegal: cannot fit into target type '" 242 << dstType << "'\n"); 243 return IntegerAttr(); 244 } 245 246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 249 Builder builder) { 250 // Only support converting to float for now. 251 if (!dstType.isF32()) 252 return FloatAttr(); 253 254 // Try to convert the source floating-point number to single precision. 255 APFloat dstVal = srcAttr.getValue(); 256 bool losesInfo = false; 257 APFloat::opStatus status = 258 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 259 if (status != APFloat::opOK || losesInfo) { 260 LLVM_DEBUG(llvm::dbgs() 261 << srcAttr << " illegal: cannot fit into converted type '" 262 << dstType << "'\n"); 263 return FloatAttr(); 264 } 265 266 return builder.getF32FloatAttr(dstVal.convertToFloat()); 267 } 268 269 /// Returns true if the given `type` is a boolean scalar or vector type. 270 static bool isBoolScalarOrVector(Type type) { 271 if (type.isInteger(1)) 272 return true; 273 if (auto vecType = type.dyn_cast<VectorType>()) 274 return vecType.getElementType().isInteger(1); 275 return false; 276 } 277 278 /// Returns true if scalar/vector type `a` and `b` have the same number of 279 /// bitwidth. 280 static bool hasSameBitwidth(Type a, Type b) { 281 auto getNumBitwidth = [](Type type) { 282 unsigned bw = 0; 283 if (type.isIntOrFloat()) 284 bw = type.getIntOrFloatBitWidth(); 285 else if (auto vecType = type.dyn_cast<VectorType>()) 286 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); 287 return bw; 288 }; 289 unsigned aBW = getNumBitwidth(a); 290 unsigned bBW = getNumBitwidth(b); 291 return aBW != 0 && bBW != 0 && aBW == bBW; 292 } 293 294 //===----------------------------------------------------------------------===// 295 // ConstantOp with composite type 296 //===----------------------------------------------------------------------===// 297 298 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 299 arith::ConstantOp constOp, OpAdaptor adaptor, 300 ConversionPatternRewriter &rewriter) const { 301 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 302 if (!srcType || srcType.getNumElements() == 1) 303 return failure(); 304 305 // arith.constant should only have vector or tenor types. 306 assert((srcType.isa<VectorType, RankedTensorType>())); 307 308 auto dstType = getTypeConverter()->convertType(srcType); 309 if (!dstType) 310 return failure(); 311 312 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 313 if (!dstElementsAttr) 314 return failure(); 315 316 ShapedType dstAttrType = dstElementsAttr.getType(); 317 318 // If the composite type has more than one dimensions, perform linearization. 319 if (srcType.getRank() > 1) { 320 if (srcType.isa<RankedTensorType>()) { 321 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 322 srcType.getElementType()); 323 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 324 } else { 325 // TODO: add support for large vectors. 326 return failure(); 327 } 328 } 329 330 Type srcElemType = srcType.getElementType(); 331 Type dstElemType; 332 // Tensor types are converted to SPIR-V array types; vector types are 333 // converted to SPIR-V vector/array types. 334 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 335 dstElemType = arrayType.getElementType(); 336 else 337 dstElemType = dstType.cast<VectorType>().getElementType(); 338 339 // If the source and destination element types are different, perform 340 // attribute conversion. 341 if (srcElemType != dstElemType) { 342 SmallVector<Attribute, 8> elements; 343 if (srcElemType.isa<FloatType>()) { 344 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 345 FloatAttr dstAttr = 346 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 347 if (!dstAttr) 348 return failure(); 349 elements.push_back(dstAttr); 350 } 351 } else if (srcElemType.isInteger(1)) { 352 return failure(); 353 } else { 354 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 355 IntegerAttr dstAttr = convertIntegerAttr( 356 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 357 if (!dstAttr) 358 return failure(); 359 elements.push_back(dstAttr); 360 } 361 } 362 363 // Unfortunately, we cannot use dialect-specific types for element 364 // attributes; element attributes only works with builtin types. So we need 365 // to prepare another converted builtin types for the destination elements 366 // attribute. 367 if (dstAttrType.isa<RankedTensorType>()) 368 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 369 else 370 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 371 372 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 373 } 374 375 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 376 dstElementsAttr); 377 return success(); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // ConstantOp with scalar type 382 //===----------------------------------------------------------------------===// 383 384 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 385 arith::ConstantOp constOp, OpAdaptor adaptor, 386 ConversionPatternRewriter &rewriter) const { 387 Type srcType = constOp.getType(); 388 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 389 if (shapedType.getNumElements() != 1) 390 return failure(); 391 srcType = shapedType.getElementType(); 392 } 393 if (!srcType.isIntOrIndexOrFloat()) 394 return failure(); 395 396 Attribute cstAttr = constOp.getValue(); 397 if (cstAttr.getType().isa<ShapedType>()) 398 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 399 400 Type dstType = getTypeConverter()->convertType(srcType); 401 if (!dstType) 402 return failure(); 403 404 // Floating-point types. 405 if (srcType.isa<FloatType>()) { 406 auto srcAttr = cstAttr.cast<FloatAttr>(); 407 auto dstAttr = srcAttr; 408 409 // Floating-point types not supported in the target environment are all 410 // converted to float type. 411 if (srcType != dstType) { 412 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 413 if (!dstAttr) 414 return failure(); 415 } 416 417 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 418 return success(); 419 } 420 421 // Bool type. 422 if (srcType.isInteger(1)) { 423 // arith.constant can use 0/1 instead of true/false for i1 values. We need 424 // to handle that here. 425 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 426 if (!dstAttr) 427 return failure(); 428 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 429 return success(); 430 } 431 432 // IndexType or IntegerType. Index values are converted to 32-bit integer 433 // values when converting to SPIR-V. 434 auto srcAttr = cstAttr.cast<IntegerAttr>(); 435 auto dstAttr = 436 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 437 if (!dstAttr) 438 return failure(); 439 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 440 return success(); 441 } 442 443 //===----------------------------------------------------------------------===// 444 // RemSIOpGLSLPattern 445 //===----------------------------------------------------------------------===// 446 447 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 448 /// the sign of `signOperand`. 449 /// 450 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 451 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 452 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 453 /// if either operand can be negative. Emulate it via spv.UMod. 454 template <typename SignedAbsOp> 455 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 456 Value signOperand, OpBuilder &builder) { 457 assert(lhs.getType() == rhs.getType()); 458 assert(lhs == signOperand || rhs == signOperand); 459 460 Type type = lhs.getType(); 461 462 // Calculate the remainder with spv.UMod. 463 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 464 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 465 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 466 467 // Fix the sign. 468 Value isPositive; 469 if (lhs == signOperand) 470 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 471 else 472 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 473 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 474 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 475 } 476 477 LogicalResult 478 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 479 ConversionPatternRewriter &rewriter) const { 480 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 481 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 482 adaptor.getOperands()[0], rewriter); 483 rewriter.replaceOp(op, result); 484 485 return success(); 486 } 487 488 //===----------------------------------------------------------------------===// 489 // RemSIOpOCLPattern 490 //===----------------------------------------------------------------------===// 491 492 LogicalResult 493 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 494 ConversionPatternRewriter &rewriter) const { 495 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 496 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 497 adaptor.getOperands()[0], rewriter); 498 rewriter.replaceOp(op, result); 499 500 return success(); 501 } 502 503 //===----------------------------------------------------------------------===// 504 // BitwiseOpPattern 505 //===----------------------------------------------------------------------===// 506 507 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 508 LogicalResult 509 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 510 Op op, typename Op::Adaptor adaptor, 511 ConversionPatternRewriter &rewriter) const { 512 assert(adaptor.getOperands().size() == 2); 513 auto dstType = 514 this->getTypeConverter()->convertType(op.getResult().getType()); 515 if (!dstType) 516 return failure(); 517 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 518 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 519 adaptor.getOperands()); 520 } else { 521 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 522 adaptor.getOperands()); 523 } 524 return success(); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // XOrIOpLogicalPattern 529 //===----------------------------------------------------------------------===// 530 531 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 532 arith::XOrIOp op, OpAdaptor adaptor, 533 ConversionPatternRewriter &rewriter) const { 534 assert(adaptor.getOperands().size() == 2); 535 536 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 537 return failure(); 538 539 auto dstType = getTypeConverter()->convertType(op.getType()); 540 if (!dstType) 541 return failure(); 542 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 543 adaptor.getOperands()); 544 545 return success(); 546 } 547 548 //===----------------------------------------------------------------------===// 549 // XOrIOpBooleanPattern 550 //===----------------------------------------------------------------------===// 551 552 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 553 arith::XOrIOp op, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const { 555 assert(adaptor.getOperands().size() == 2); 556 557 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 558 return failure(); 559 560 auto dstType = getTypeConverter()->convertType(op.getType()); 561 if (!dstType) 562 return failure(); 563 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 564 adaptor.getOperands()); 565 return success(); 566 } 567 568 //===----------------------------------------------------------------------===// 569 // UIToFPI1Pattern 570 //===----------------------------------------------------------------------===// 571 572 LogicalResult 573 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 574 ConversionPatternRewriter &rewriter) const { 575 auto srcType = adaptor.getOperands().front().getType(); 576 if (!isBoolScalarOrVector(srcType)) 577 return failure(); 578 579 auto dstType = 580 this->getTypeConverter()->convertType(op.getResult().getType()); 581 Location loc = op.getLoc(); 582 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 583 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 584 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 585 op, dstType, adaptor.getOperands().front(), one, zero); 586 return success(); 587 } 588 589 //===----------------------------------------------------------------------===// 590 // ExtUII1Pattern 591 //===----------------------------------------------------------------------===// 592 593 LogicalResult 594 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 595 ConversionPatternRewriter &rewriter) const { 596 auto srcType = adaptor.getOperands().front().getType(); 597 if (!isBoolScalarOrVector(srcType)) 598 return failure(); 599 600 auto dstType = 601 this->getTypeConverter()->convertType(op.getResult().getType()); 602 Location loc = op.getLoc(); 603 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 604 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 605 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 606 op, dstType, adaptor.getOperands().front(), one, zero); 607 return success(); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // TruncII1Pattern 612 //===----------------------------------------------------------------------===// 613 614 LogicalResult 615 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 616 ConversionPatternRewriter &rewriter) const { 617 auto dstType = 618 this->getTypeConverter()->convertType(op.getResult().getType()); 619 if (!isBoolScalarOrVector(dstType)) 620 return failure(); 621 622 Location loc = op.getLoc(); 623 auto srcType = adaptor.getOperands().front().getType(); 624 // Check if (x & 1) == 1. 625 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 626 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 627 loc, srcType, adaptor.getOperands()[0], mask); 628 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 629 630 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 631 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 632 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 633 return success(); 634 } 635 636 //===----------------------------------------------------------------------===// 637 // TypeCastingOpPattern 638 //===----------------------------------------------------------------------===// 639 640 template <typename Op, typename SPIRVOp> 641 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 642 Op op, typename Op::Adaptor adaptor, 643 ConversionPatternRewriter &rewriter) const { 644 assert(adaptor.getOperands().size() == 1); 645 auto srcType = adaptor.getOperands().front().getType(); 646 auto dstType = 647 this->getTypeConverter()->convertType(op.getResult().getType()); 648 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 649 return failure(); 650 if (dstType == srcType) { 651 // Due to type conversion, we are seeing the same source and target type. 652 // Then we can just erase this operation by forwarding its operand. 653 rewriter.replaceOp(op, adaptor.getOperands().front()); 654 } else { 655 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 656 adaptor.getOperands()); 657 } 658 return success(); 659 } 660 661 //===----------------------------------------------------------------------===// 662 // CmpIOpBooleanPattern 663 //===----------------------------------------------------------------------===// 664 665 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 666 arith::CmpIOp op, OpAdaptor adaptor, 667 ConversionPatternRewriter &rewriter) const { 668 Type operandType = op.getLhs().getType(); 669 if (!isBoolScalarOrVector(operandType)) 670 return failure(); 671 672 switch (op.getPredicate()) { 673 #define DISPATCH(cmpPredicate, spirvOp) \ 674 case cmpPredicate: { \ 675 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 676 adaptor.getRhs()); \ 677 return success(); \ 678 } 679 680 DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); 681 DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); 682 683 #undef DISPATCH 684 default:; 685 } 686 return failure(); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // CmpIOpPattern 691 //===----------------------------------------------------------------------===// 692 693 LogicalResult 694 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 695 ConversionPatternRewriter &rewriter) const { 696 Type srcType = op.getLhs().getType(); 697 if (isBoolScalarOrVector(srcType)) 698 return failure(); 699 Type dstType = getTypeConverter()->convertType(srcType); 700 if (!dstType) 701 return failure(); 702 703 switch (op.getPredicate()) { 704 #define DISPATCH(cmpPredicate, spirvOp) \ 705 case cmpPredicate: \ 706 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 707 srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \ 708 return op.emitError( \ 709 "bitwidth emulation is not implemented yet on unsigned op"); \ 710 } \ 711 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 712 adaptor.getRhs()); \ 713 return success(); 714 715 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 716 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 717 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 718 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 719 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 720 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 721 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 722 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 723 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 724 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 725 726 #undef DISPATCH 727 } 728 return failure(); 729 } 730 731 //===----------------------------------------------------------------------===// 732 // CmpFOpPattern 733 //===----------------------------------------------------------------------===// 734 735 LogicalResult 736 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 737 ConversionPatternRewriter &rewriter) const { 738 switch (op.getPredicate()) { 739 #define DISPATCH(cmpPredicate, spirvOp) \ 740 case cmpPredicate: \ 741 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 742 adaptor.getRhs()); \ 743 return success(); 744 745 // Ordered. 746 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 747 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 748 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 749 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 750 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 751 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 752 // Unordered. 753 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 754 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 755 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 756 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 757 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 758 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 759 760 #undef DISPATCH 761 762 default: 763 break; 764 } 765 return failure(); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // CmpFOpNanKernelPattern 770 //===----------------------------------------------------------------------===// 771 772 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 773 arith::CmpFOp op, OpAdaptor adaptor, 774 ConversionPatternRewriter &rewriter) const { 775 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 776 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 777 adaptor.getRhs()); 778 return success(); 779 } 780 781 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 782 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 783 adaptor.getRhs()); 784 return success(); 785 } 786 787 return failure(); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // CmpFOpNanNonePattern 792 //===----------------------------------------------------------------------===// 793 794 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 795 arith::CmpFOp op, OpAdaptor adaptor, 796 ConversionPatternRewriter &rewriter) const { 797 if (op.getPredicate() != arith::CmpFPredicate::ORD && 798 op.getPredicate() != arith::CmpFPredicate::UNO) 799 return failure(); 800 801 Location loc = op.getLoc(); 802 803 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 804 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 805 806 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 807 if (op.getPredicate() == arith::CmpFPredicate::ORD) 808 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 809 810 rewriter.replaceOp(op, replace); 811 return success(); 812 } 813 814 //===----------------------------------------------------------------------===// 815 // SelectOpPattern 816 //===----------------------------------------------------------------------===// 817 818 LogicalResult 819 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 820 ConversionPatternRewriter &rewriter) const { 821 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 822 adaptor.getTrueValue(), 823 adaptor.getFalseValue()); 824 return success(); 825 } 826 827 //===----------------------------------------------------------------------===// 828 // Pattern Population 829 //===----------------------------------------------------------------------===// 830 831 void mlir::arith::populateArithmeticToSPIRVPatterns( 832 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 833 // clang-format off 834 patterns.add< 835 ConstantCompositeOpPattern, 836 ConstantScalarOpPattern, 837 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 838 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 839 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 840 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 841 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 842 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 843 RemSIOpGLSLPattern, RemSIOpOCLPattern, 844 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 845 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 846 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 847 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 848 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 849 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 850 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 851 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 852 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 853 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 854 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 855 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 856 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 857 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 858 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 859 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 860 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 861 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 862 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 863 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 864 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 865 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 866 CmpIOpBooleanPattern, CmpIOpPattern, 867 CmpFOpNanNonePattern, CmpFOpPattern, 868 SelectOpPattern, 869 870 spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>, 871 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>, 872 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>, 873 spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>, 874 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>, 875 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp> 876 >(typeConverter, patterns.getContext()); 877 // clang-format on 878 879 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 880 // capability is available. 881 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 882 /*benefit=*/2); 883 } 884 885 //===----------------------------------------------------------------------===// 886 // Pass Definition 887 //===----------------------------------------------------------------------===// 888 889 namespace { 890 struct ConvertArithmeticToSPIRVPass 891 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 892 void runOnOperation() override { 893 auto module = getOperation(); 894 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 895 auto target = SPIRVConversionTarget::get(targetAttr); 896 897 SPIRVTypeConverter::Options options; 898 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 899 SPIRVTypeConverter typeConverter(targetAttr, options); 900 901 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 902 // in patterns for other dialects. 903 auto addUnrealizedCast = [](OpBuilder &builder, Type type, 904 ValueRange inputs, Location loc) { 905 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 906 return Optional<Value>(cast.getResult(0)); 907 }; 908 typeConverter.addSourceMaterialization(addUnrealizedCast); 909 typeConverter.addTargetMaterialization(addUnrealizedCast); 910 target->addLegalOp<UnrealizedConversionCastOp>(); 911 912 RewritePatternSet patterns(&getContext()); 913 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 914 915 if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 916 signalPassFailure(); 917 } 918 }; 919 } // namespace 920 921 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 922 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 923 } 924