1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 10 #include "../PassDetail.h" 11 #include "../SPIRVCommon/Pattern.h" 12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "arith-to-spirv-pattern" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Operation Conversion 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 30 /// Converts composite arith.constant operation to spv.Constant. 31 struct ConstantCompositeOpPattern final 32 : public OpConversionPattern<arith::ConstantOp> { 33 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 34 35 LogicalResult 36 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 37 ConversionPatternRewriter &rewriter) const override; 38 }; 39 40 /// Converts scalar arith.constant operation to spv.Constant. 41 struct ConstantScalarOpPattern final 42 : public OpConversionPattern<arith::ConstantOp> { 43 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; 44 45 LogicalResult 46 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 47 ConversionPatternRewriter &rewriter) const override; 48 }; 49 50 /// Converts arith.remsi to GLSL SPIR-V ops. 51 /// 52 /// This cannot be merged into the template unary/binary pattern due to Vulkan 53 /// restrictions over spv.SRem and spv.SMod. 54 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { 55 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 56 57 LogicalResult 58 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 59 ConversionPatternRewriter &rewriter) const override; 60 }; 61 62 /// Converts arith.remsi to OpenCL SPIR-V ops. 63 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { 64 using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; 65 66 LogicalResult 67 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override; 69 }; 70 71 /// Converts bitwise operations to SPIR-V operations. This is a special pattern 72 /// other than the BinaryOpPatternPattern because if the operands are boolean 73 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For 74 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. 75 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 76 struct BitwiseOpPattern final : public OpConversionPattern<Op> { 77 using OpConversionPattern<Op>::OpConversionPattern; 78 79 LogicalResult 80 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 81 ConversionPatternRewriter &rewriter) const override; 82 }; 83 84 /// Converts arith.xori to SPIR-V operations. 85 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { 86 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 87 88 LogicalResult 89 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 90 ConversionPatternRewriter &rewriter) const override; 91 }; 92 93 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or 94 /// vector of i1. 95 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { 96 using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; 97 98 LogicalResult 99 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, 100 ConversionPatternRewriter &rewriter) const override; 101 }; 102 103 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of 104 /// i1. 105 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { 106 using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; 107 108 LogicalResult 109 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const override; 111 }; 112 113 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of 114 /// i1. 115 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { 116 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; 117 118 LogicalResult 119 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 120 ConversionPatternRewriter &rewriter) const override; 121 }; 122 123 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of 124 /// i1. 125 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { 126 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; 127 128 LogicalResult 129 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 130 ConversionPatternRewriter &rewriter) const override; 131 }; 132 133 /// Converts type-casting standard operations to SPIR-V operations. 134 template <typename Op, typename SPIRVOp> 135 struct TypeCastingOpPattern final : public OpConversionPattern<Op> { 136 using OpConversionPattern<Op>::OpConversionPattern; 137 138 LogicalResult 139 matchAndRewrite(Op op, typename Op::Adaptor adaptor, 140 ConversionPatternRewriter &rewriter) const override; 141 }; 142 143 /// Converts integer compare operation on i1 type operands to SPIR-V ops. 144 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { 145 public: 146 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 147 148 LogicalResult 149 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 150 ConversionPatternRewriter &rewriter) const override; 151 }; 152 153 /// Converts integer compare operation to SPIR-V ops. 154 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { 155 public: 156 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; 157 158 LogicalResult 159 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 160 ConversionPatternRewriter &rewriter) const override; 161 }; 162 163 /// Converts floating-point comparison operations to SPIR-V ops. 164 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { 165 public: 166 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 170 ConversionPatternRewriter &rewriter) const override; 171 }; 172 173 /// Converts floating point NaN check to SPIR-V ops. This pattern requires 174 /// Kernel capability. 175 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { 176 public: 177 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 178 179 LogicalResult 180 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 181 ConversionPatternRewriter &rewriter) const override; 182 }; 183 184 /// Converts floating point NaN check to SPIR-V ops. This pattern does not 185 /// require additional capability. 186 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { 187 public: 188 using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; 189 190 LogicalResult 191 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 192 ConversionPatternRewriter &rewriter) const override; 193 }; 194 195 /// Converts arith.select to spv.Select. 196 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { 197 public: 198 using OpConversionPattern<arith::SelectOp>::OpConversionPattern; 199 LogicalResult 200 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 201 ConversionPatternRewriter &rewriter) const override; 202 }; 203 204 } // namespace 205 206 //===----------------------------------------------------------------------===// 207 // Conversion Helpers 208 //===----------------------------------------------------------------------===// 209 210 /// Converts the given `srcAttr` into a boolean attribute if it holds an 211 /// integral value. Returns null attribute if conversion fails. 212 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { 213 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) 214 return boolAttr; 215 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) 216 return builder.getBoolAttr(intAttr.getValue().getBoolValue()); 217 return BoolAttr(); 218 } 219 220 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 221 /// Returns null attribute if conversion fails. 222 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, 223 Builder builder) { 224 // If the source number uses less active bits than the target bitwidth, then 225 // it should be safe to convert. 226 if (srcAttr.getValue().isIntN(dstType.getWidth())) 227 return builder.getIntegerAttr(dstType, srcAttr.getInt()); 228 229 // XXX: Try again by interpreting the source number as a signed value. 230 // Although integers in the standard dialect are signless, they can represent 231 // a signed number. It's the operation decides how to interpret. This is 232 // dangerous, but it seems there is no good way of handling this if we still 233 // want to change the bitwidth. Emit a message at least. 234 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { 235 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); 236 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" 237 << dstAttr << "' for type '" << dstType << "'\n"); 238 return dstAttr; 239 } 240 241 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr 242 << "' illegal: cannot fit into target type '" 243 << dstType << "'\n"); 244 return IntegerAttr(); 245 } 246 247 /// Converts the given `srcAttr` to a new attribute of the given `dstType`. 248 /// Returns null attribute if `dstType` is not 32-bit or conversion fails. 249 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, 250 Builder builder) { 251 // Only support converting to float for now. 252 if (!dstType.isF32()) 253 return FloatAttr(); 254 255 // Try to convert the source floating-point number to single precision. 256 APFloat dstVal = srcAttr.getValue(); 257 bool losesInfo = false; 258 APFloat::opStatus status = 259 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); 260 if (status != APFloat::opOK || losesInfo) { 261 LLVM_DEBUG(llvm::dbgs() 262 << srcAttr << " illegal: cannot fit into converted type '" 263 << dstType << "'\n"); 264 return FloatAttr(); 265 } 266 267 return builder.getF32FloatAttr(dstVal.convertToFloat()); 268 } 269 270 /// Returns true if the given `type` is a boolean scalar or vector type. 271 static bool isBoolScalarOrVector(Type type) { 272 if (type.isInteger(1)) 273 return true; 274 if (auto vecType = type.dyn_cast<VectorType>()) 275 return vecType.getElementType().isInteger(1); 276 return false; 277 } 278 279 /// Returns true if scalar/vector type `a` and `b` have the same number of 280 /// bitwidth. 281 static bool hasSameBitwidth(Type a, Type b) { 282 auto getNumBitwidth = [](Type type) { 283 unsigned bw = 0; 284 if (type.isIntOrFloat()) 285 bw = type.getIntOrFloatBitWidth(); 286 else if (auto vecType = type.dyn_cast<VectorType>()) 287 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); 288 return bw; 289 }; 290 unsigned aBW = getNumBitwidth(a); 291 unsigned bBW = getNumBitwidth(b); 292 return aBW != 0 && bBW != 0 && aBW == bBW; 293 } 294 295 //===----------------------------------------------------------------------===// 296 // ConstantOp with composite type 297 //===----------------------------------------------------------------------===// 298 299 LogicalResult ConstantCompositeOpPattern::matchAndRewrite( 300 arith::ConstantOp constOp, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const { 302 auto srcType = constOp.getType().dyn_cast<ShapedType>(); 303 if (!srcType || srcType.getNumElements() == 1) 304 return failure(); 305 306 // arith.constant should only have vector or tenor types. 307 assert((srcType.isa<VectorType, RankedTensorType>())); 308 309 auto dstType = getTypeConverter()->convertType(srcType); 310 if (!dstType) 311 return failure(); 312 313 auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); 314 if (!dstElementsAttr) 315 return failure(); 316 317 ShapedType dstAttrType = dstElementsAttr.getType(); 318 319 // If the composite type has more than one dimensions, perform linearization. 320 if (srcType.getRank() > 1) { 321 if (srcType.isa<RankedTensorType>()) { 322 dstAttrType = RankedTensorType::get(srcType.getNumElements(), 323 srcType.getElementType()); 324 dstElementsAttr = dstElementsAttr.reshape(dstAttrType); 325 } else { 326 // TODO: add support for large vectors. 327 return failure(); 328 } 329 } 330 331 Type srcElemType = srcType.getElementType(); 332 Type dstElemType; 333 // Tensor types are converted to SPIR-V array types; vector types are 334 // converted to SPIR-V vector/array types. 335 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) 336 dstElemType = arrayType.getElementType(); 337 else 338 dstElemType = dstType.cast<VectorType>().getElementType(); 339 340 // If the source and destination element types are different, perform 341 // attribute conversion. 342 if (srcElemType != dstElemType) { 343 SmallVector<Attribute, 8> elements; 344 if (srcElemType.isa<FloatType>()) { 345 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { 346 FloatAttr dstAttr = 347 convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); 348 if (!dstAttr) 349 return failure(); 350 elements.push_back(dstAttr); 351 } 352 } else if (srcElemType.isInteger(1)) { 353 return failure(); 354 } else { 355 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { 356 IntegerAttr dstAttr = convertIntegerAttr( 357 srcAttr, dstElemType.cast<IntegerType>(), rewriter); 358 if (!dstAttr) 359 return failure(); 360 elements.push_back(dstAttr); 361 } 362 } 363 364 // Unfortunately, we cannot use dialect-specific types for element 365 // attributes; element attributes only works with builtin types. So we need 366 // to prepare another converted builtin types for the destination elements 367 // attribute. 368 if (dstAttrType.isa<RankedTensorType>()) 369 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); 370 else 371 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); 372 373 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); 374 } 375 376 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, 377 dstElementsAttr); 378 return success(); 379 } 380 381 //===----------------------------------------------------------------------===// 382 // ConstantOp with scalar type 383 //===----------------------------------------------------------------------===// 384 385 LogicalResult ConstantScalarOpPattern::matchAndRewrite( 386 arith::ConstantOp constOp, OpAdaptor adaptor, 387 ConversionPatternRewriter &rewriter) const { 388 Type srcType = constOp.getType(); 389 if (auto shapedType = srcType.dyn_cast<ShapedType>()) { 390 if (shapedType.getNumElements() != 1) 391 return failure(); 392 srcType = shapedType.getElementType(); 393 } 394 if (!srcType.isIntOrIndexOrFloat()) 395 return failure(); 396 397 Attribute cstAttr = constOp.getValue(); 398 if (cstAttr.getType().isa<ShapedType>()) 399 cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>(); 400 401 Type dstType = getTypeConverter()->convertType(srcType); 402 if (!dstType) 403 return failure(); 404 405 // Floating-point types. 406 if (srcType.isa<FloatType>()) { 407 auto srcAttr = cstAttr.cast<FloatAttr>(); 408 auto dstAttr = srcAttr; 409 410 // Floating-point types not supported in the target environment are all 411 // converted to float type. 412 if (srcType != dstType) { 413 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); 414 if (!dstAttr) 415 return failure(); 416 } 417 418 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 419 return success(); 420 } 421 422 // Bool type. 423 if (srcType.isInteger(1)) { 424 // arith.constant can use 0/1 instead of true/false for i1 values. We need 425 // to handle that here. 426 auto dstAttr = convertBoolAttr(cstAttr, rewriter); 427 if (!dstAttr) 428 return failure(); 429 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 430 return success(); 431 } 432 433 // IndexType or IntegerType. Index values are converted to 32-bit integer 434 // values when converting to SPIR-V. 435 auto srcAttr = cstAttr.cast<IntegerAttr>(); 436 auto dstAttr = 437 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); 438 if (!dstAttr) 439 return failure(); 440 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); 441 return success(); 442 } 443 444 //===----------------------------------------------------------------------===// 445 // RemSIOpGLSLPattern 446 //===----------------------------------------------------------------------===// 447 448 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow 449 /// the sign of `signOperand`. 450 /// 451 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment 452 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative 453 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod 454 /// if either operand can be negative. Emulate it via spv.UMod. 455 template <typename SignedAbsOp> 456 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, 457 Value signOperand, OpBuilder &builder) { 458 assert(lhs.getType() == rhs.getType()); 459 assert(lhs == signOperand || rhs == signOperand); 460 461 Type type = lhs.getType(); 462 463 // Calculate the remainder with spv.UMod. 464 Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); 465 Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); 466 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); 467 468 // Fix the sign. 469 Value isPositive; 470 if (lhs == signOperand) 471 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); 472 else 473 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); 474 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); 475 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); 476 } 477 478 LogicalResult 479 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 480 ConversionPatternRewriter &rewriter) const { 481 Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( 482 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 483 adaptor.getOperands()[0], rewriter); 484 rewriter.replaceOp(op, result); 485 486 return success(); 487 } 488 489 //===----------------------------------------------------------------------===// 490 // RemSIOpOCLPattern 491 //===----------------------------------------------------------------------===// 492 493 LogicalResult 494 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, 495 ConversionPatternRewriter &rewriter) const { 496 Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( 497 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], 498 adaptor.getOperands()[0], rewriter); 499 rewriter.replaceOp(op, result); 500 501 return success(); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // BitwiseOpPattern 506 //===----------------------------------------------------------------------===// 507 508 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> 509 LogicalResult 510 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( 511 Op op, typename Op::Adaptor adaptor, 512 ConversionPatternRewriter &rewriter) const { 513 assert(adaptor.getOperands().size() == 2); 514 auto dstType = 515 this->getTypeConverter()->convertType(op.getResult().getType()); 516 if (!dstType) 517 return failure(); 518 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { 519 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, 520 adaptor.getOperands()); 521 } else { 522 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, 523 adaptor.getOperands()); 524 } 525 return success(); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // XOrIOpLogicalPattern 530 //===----------------------------------------------------------------------===// 531 532 LogicalResult XOrIOpLogicalPattern::matchAndRewrite( 533 arith::XOrIOp op, OpAdaptor adaptor, 534 ConversionPatternRewriter &rewriter) const { 535 assert(adaptor.getOperands().size() == 2); 536 537 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) 538 return failure(); 539 540 auto dstType = getTypeConverter()->convertType(op.getType()); 541 if (!dstType) 542 return failure(); 543 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, 544 adaptor.getOperands()); 545 546 return success(); 547 } 548 549 //===----------------------------------------------------------------------===// 550 // XOrIOpBooleanPattern 551 //===----------------------------------------------------------------------===// 552 553 LogicalResult XOrIOpBooleanPattern::matchAndRewrite( 554 arith::XOrIOp op, OpAdaptor adaptor, 555 ConversionPatternRewriter &rewriter) const { 556 assert(adaptor.getOperands().size() == 2); 557 558 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) 559 return failure(); 560 561 auto dstType = getTypeConverter()->convertType(op.getType()); 562 if (!dstType) 563 return failure(); 564 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, 565 adaptor.getOperands()); 566 return success(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // UIToFPI1Pattern 571 //===----------------------------------------------------------------------===// 572 573 LogicalResult 574 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, 575 ConversionPatternRewriter &rewriter) const { 576 auto srcType = adaptor.getOperands().front().getType(); 577 if (!isBoolScalarOrVector(srcType)) 578 return failure(); 579 580 auto dstType = 581 this->getTypeConverter()->convertType(op.getResult().getType()); 582 Location loc = op.getLoc(); 583 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 584 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 585 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 586 op, dstType, adaptor.getOperands().front(), one, zero); 587 return success(); 588 } 589 590 //===----------------------------------------------------------------------===// 591 // ExtUII1Pattern 592 //===----------------------------------------------------------------------===// 593 594 LogicalResult 595 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, 596 ConversionPatternRewriter &rewriter) const { 597 auto srcType = adaptor.getOperands().front().getType(); 598 if (!isBoolScalarOrVector(srcType)) 599 return failure(); 600 601 auto dstType = 602 this->getTypeConverter()->convertType(op.getResult().getType()); 603 Location loc = op.getLoc(); 604 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 605 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 606 rewriter.template replaceOpWithNewOp<spirv::SelectOp>( 607 op, dstType, adaptor.getOperands().front(), one, zero); 608 return success(); 609 } 610 611 //===----------------------------------------------------------------------===// 612 // TruncII1Pattern 613 //===----------------------------------------------------------------------===// 614 615 LogicalResult 616 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, 617 ConversionPatternRewriter &rewriter) const { 618 auto dstType = 619 this->getTypeConverter()->convertType(op.getResult().getType()); 620 if (!isBoolScalarOrVector(dstType)) 621 return failure(); 622 623 Location loc = op.getLoc(); 624 auto srcType = adaptor.getOperands().front().getType(); 625 // Check if (x & 1) == 1. 626 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); 627 Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( 628 loc, srcType, adaptor.getOperands()[0], mask); 629 Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); 630 631 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); 632 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); 633 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); 634 return success(); 635 } 636 637 //===----------------------------------------------------------------------===// 638 // TypeCastingOpPattern 639 //===----------------------------------------------------------------------===// 640 641 template <typename Op, typename SPIRVOp> 642 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( 643 Op op, typename Op::Adaptor adaptor, 644 ConversionPatternRewriter &rewriter) const { 645 assert(adaptor.getOperands().size() == 1); 646 auto srcType = adaptor.getOperands().front().getType(); 647 auto dstType = 648 this->getTypeConverter()->convertType(op.getResult().getType()); 649 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) 650 return failure(); 651 if (dstType == srcType) { 652 // Due to type conversion, we are seeing the same source and target type. 653 // Then we can just erase this operation by forwarding its operand. 654 rewriter.replaceOp(op, adaptor.getOperands().front()); 655 } else { 656 rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, 657 adaptor.getOperands()); 658 } 659 return success(); 660 } 661 662 //===----------------------------------------------------------------------===// 663 // CmpIOpBooleanPattern 664 //===----------------------------------------------------------------------===// 665 666 LogicalResult CmpIOpBooleanPattern::matchAndRewrite( 667 arith::CmpIOp op, OpAdaptor adaptor, 668 ConversionPatternRewriter &rewriter) const { 669 Type srcType = op.getLhs().getType(); 670 if (!isBoolScalarOrVector(srcType)) 671 return failure(); 672 Type dstType = getTypeConverter()->convertType(srcType); 673 if (!dstType) 674 return failure(); 675 676 switch (op.getPredicate()) { 677 case arith::CmpIPredicate::eq: { 678 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(), 679 adaptor.getRhs()); 680 return success(); 681 } 682 case arith::CmpIPredicate::ne: { 683 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(), 684 adaptor.getRhs()); 685 return success(); 686 } 687 case arith::CmpIPredicate::uge: 688 case arith::CmpIPredicate::ugt: 689 case arith::CmpIPredicate::ule: 690 case arith::CmpIPredicate::ult: { 691 // There are no direct corresponding instructions in SPIR-V for such cases. 692 // Extend them to 32-bit and do comparision then. 693 Type type = rewriter.getI32Type(); 694 if (auto vectorType = dstType.dyn_cast<VectorType>()) 695 type = VectorType::get(vectorType.getShape(), type); 696 auto extLhs = 697 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); 698 auto extRhs = 699 rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); 700 701 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, 702 extRhs); 703 return success(); 704 } 705 default: 706 break; 707 } 708 return failure(); 709 } 710 711 //===----------------------------------------------------------------------===// 712 // CmpIOpPattern 713 //===----------------------------------------------------------------------===// 714 715 LogicalResult 716 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const { 718 Type srcType = op.getLhs().getType(); 719 if (isBoolScalarOrVector(srcType)) 720 return failure(); 721 Type dstType = getTypeConverter()->convertType(srcType); 722 if (!dstType) 723 return failure(); 724 725 switch (op.getPredicate()) { 726 #define DISPATCH(cmpPredicate, spirvOp) \ 727 case cmpPredicate: \ 728 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ 729 srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \ 730 return op.emitError( \ 731 "bitwidth emulation is not implemented yet on unsigned op"); \ 732 } \ 733 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 734 adaptor.getRhs()); \ 735 return success(); 736 737 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); 738 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); 739 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); 740 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); 741 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); 742 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); 743 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); 744 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); 745 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); 746 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); 747 748 #undef DISPATCH 749 } 750 return failure(); 751 } 752 753 //===----------------------------------------------------------------------===// 754 // CmpFOpPattern 755 //===----------------------------------------------------------------------===// 756 757 LogicalResult 758 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 759 ConversionPatternRewriter &rewriter) const { 760 switch (op.getPredicate()) { 761 #define DISPATCH(cmpPredicate, spirvOp) \ 762 case cmpPredicate: \ 763 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ 764 adaptor.getRhs()); \ 765 return success(); 766 767 // Ordered. 768 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); 769 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); 770 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); 771 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); 772 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); 773 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); 774 // Unordered. 775 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); 776 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); 777 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); 778 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); 779 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); 780 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); 781 782 #undef DISPATCH 783 784 default: 785 break; 786 } 787 return failure(); 788 } 789 790 //===----------------------------------------------------------------------===// 791 // CmpFOpNanKernelPattern 792 //===----------------------------------------------------------------------===// 793 794 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( 795 arith::CmpFOp op, OpAdaptor adaptor, 796 ConversionPatternRewriter &rewriter) const { 797 if (op.getPredicate() == arith::CmpFPredicate::ORD) { 798 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), 799 adaptor.getRhs()); 800 return success(); 801 } 802 803 if (op.getPredicate() == arith::CmpFPredicate::UNO) { 804 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), 805 adaptor.getRhs()); 806 return success(); 807 } 808 809 return failure(); 810 } 811 812 //===----------------------------------------------------------------------===// 813 // CmpFOpNanNonePattern 814 //===----------------------------------------------------------------------===// 815 816 LogicalResult CmpFOpNanNonePattern::matchAndRewrite( 817 arith::CmpFOp op, OpAdaptor adaptor, 818 ConversionPatternRewriter &rewriter) const { 819 if (op.getPredicate() != arith::CmpFPredicate::ORD && 820 op.getPredicate() != arith::CmpFPredicate::UNO) 821 return failure(); 822 823 Location loc = op.getLoc(); 824 825 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); 826 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); 827 828 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); 829 if (op.getPredicate() == arith::CmpFPredicate::ORD) 830 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); 831 832 rewriter.replaceOp(op, replace); 833 return success(); 834 } 835 836 //===----------------------------------------------------------------------===// 837 // SelectOpPattern 838 //===----------------------------------------------------------------------===// 839 840 LogicalResult 841 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, 842 ConversionPatternRewriter &rewriter) const { 843 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), 844 adaptor.getTrueValue(), 845 adaptor.getFalseValue()); 846 return success(); 847 } 848 849 //===----------------------------------------------------------------------===// 850 // Pattern Population 851 //===----------------------------------------------------------------------===// 852 853 void mlir::arith::populateArithmeticToSPIRVPatterns( 854 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 855 // clang-format off 856 patterns.add< 857 ConstantCompositeOpPattern, 858 ConstantScalarOpPattern, 859 spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>, 860 spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>, 861 spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>, 862 spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, 863 spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, 864 spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, 865 RemSIOpGLSLPattern, RemSIOpOCLPattern, 866 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, 867 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, 868 XOrIOpLogicalPattern, XOrIOpBooleanPattern, 869 spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, 870 spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, 871 spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, 872 spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, 873 spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, 874 spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, 875 spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, 876 spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, 877 spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, 878 TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, 879 TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, 880 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, 881 TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, 882 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, 883 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, 884 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, 885 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, 886 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, 887 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, 888 CmpIOpBooleanPattern, CmpIOpPattern, 889 CmpFOpNanNonePattern, CmpFOpPattern, 890 SelectOpPattern, 891 892 spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>, 893 spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>, 894 spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>, 895 spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>, 896 spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>, 897 spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp> 898 >(typeConverter, patterns.getContext()); 899 // clang-format on 900 901 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel 902 // capability is available. 903 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), 904 /*benefit=*/2); 905 } 906 907 //===----------------------------------------------------------------------===// 908 // Pass Definition 909 //===----------------------------------------------------------------------===// 910 911 namespace { 912 struct ConvertArithmeticToSPIRVPass 913 : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { 914 void runOnOperation() override { 915 auto module = getOperation(); 916 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 917 auto target = SPIRVConversionTarget::get(targetAttr); 918 919 SPIRVTypeConverter::Options options; 920 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 921 SPIRVTypeConverter typeConverter(targetAttr, options); 922 923 // Use UnrealizedConversionCast as the bridge so that we don't need to pull 924 // in patterns for other dialects. 925 auto addUnrealizedCast = [](OpBuilder &builder, Type type, 926 ValueRange inputs, Location loc) { 927 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 928 return Optional<Value>(cast.getResult(0)); 929 }; 930 typeConverter.addSourceMaterialization(addUnrealizedCast); 931 typeConverter.addTargetMaterialization(addUnrealizedCast); 932 target->addLegalOp<UnrealizedConversionCastOp>(); 933 934 RewritePatternSet patterns(&getContext()); 935 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 936 937 if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 938 signalPassFailure(); 939 } 940 }; 941 } // namespace 942 943 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { 944 return std::make_unique<ConvertArithmeticToSPIRVPass>(); 945 } 946