1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 #define DEBUG_TYPE "spirv-to-llvm-pattern" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Utility functions 34 //===----------------------------------------------------------------------===// 35 36 /// Returns true if the given type is a signed integer or vector type. 37 static bool isSignedIntegerOrVector(Type type) { 38 if (type.isSignedInteger()) 39 return true; 40 if (auto vecType = type.dyn_cast<VectorType>()) 41 return vecType.getElementType().isSignedInteger(); 42 return false; 43 } 44 45 /// Returns true if the given type is an unsigned integer or vector type 46 static bool isUnsignedIntegerOrVector(Type type) { 47 if (type.isUnsignedInteger()) 48 return true; 49 if (auto vecType = type.dyn_cast<VectorType>()) 50 return vecType.getElementType().isUnsignedInteger(); 51 return false; 52 } 53 54 /// Returns the bit width of integer, float or vector of float or integer values 55 static unsigned getBitWidth(Type type) { 56 assert((type.isIntOrFloat() || type.isa<VectorType>()) && 57 "bitwidth is not supported for this type"); 58 if (type.isIntOrFloat()) 59 return type.getIntOrFloatBitWidth(); 60 auto vecType = type.dyn_cast<VectorType>(); 61 auto elementType = vecType.getElementType(); 62 assert(elementType.isIntOrFloat() && 63 "only integers and floats have a bitwidth"); 64 return elementType.getIntOrFloatBitWidth(); 65 } 66 67 /// Returns the bit width of LLVMType integer or vector. 68 static unsigned getLLVMTypeBitWidth(Type type) { 69 return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) 70 : type) 71 .cast<IntegerType>() 72 .getWidth(); 73 } 74 75 /// Creates `IntegerAttribute` with all bits set for given type 76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 77 if (auto vecType = type.dyn_cast<VectorType>()) { 78 auto integerType = vecType.getElementType().cast<IntegerType>(); 79 return builder.getIntegerAttr(integerType, -1); 80 } 81 auto integerType = type.cast<IntegerType>(); 82 return builder.getIntegerAttr(integerType, -1); 83 } 84 85 /// Creates `llvm.mlir.constant` with all bits set for the given type. 86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 87 PatternRewriter &rewriter) { 88 if (srcType.isa<VectorType>()) { 89 return rewriter.create<LLVM::ConstantOp>( 90 loc, dstType, 91 SplatElementsAttr::get(srcType.cast<ShapedType>(), 92 minusOneIntegerAttribute(srcType, rewriter))); 93 } 94 return rewriter.create<LLVM::ConstantOp>( 95 loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 96 } 97 98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 99 static Value createFPConstant(Location loc, Type srcType, Type dstType, 100 PatternRewriter &rewriter, double value) { 101 if (auto vecType = srcType.dyn_cast<VectorType>()) { 102 auto floatType = vecType.getElementType().cast<FloatType>(); 103 return rewriter.create<LLVM::ConstantOp>( 104 loc, dstType, 105 SplatElementsAttr::get(vecType, 106 rewriter.getFloatAttr(floatType, value))); 107 } 108 auto floatType = srcType.cast<FloatType>(); 109 return rewriter.create<LLVM::ConstantOp>( 110 loc, dstType, rewriter.getFloatAttr(floatType, value)); 111 } 112 113 /// Utility function for bitfield ops: 114 /// - `BitFieldInsert` 115 /// - `BitFieldSExtract` 116 /// - `BitFieldUExtract` 117 /// Truncates or extends the value. If the bitwidth of the value is the same as 118 /// `llvmType` bitwidth, the value remains unchanged. 119 static Value optionallyTruncateOrExtend(Location loc, Value value, 120 Type llvmType, 121 PatternRewriter &rewriter) { 122 auto srcType = value.getType(); 123 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 124 unsigned valueBitWidth = LLVM::isCompatibleType(srcType) 125 ? getLLVMTypeBitWidth(srcType) 126 : getBitWidth(srcType); 127 128 if (valueBitWidth < targetBitWidth) 129 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 130 // If the bit widths of `Count` and `Offset` are greater than the bit width 131 // of the target type, they are truncated. Truncation is safe since `Count` 132 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 133 // both values can be expressed in 8 bits. 134 if (valueBitWidth > targetBitWidth) 135 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 136 return value; 137 } 138 139 /// Broadcasts the value to vector with `numElements` number of elements. 140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 141 LLVMTypeConverter &typeConverter, 142 ConversionPatternRewriter &rewriter) { 143 auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 144 auto llvmVectorType = typeConverter.convertType(vectorType); 145 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 146 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 147 for (unsigned i = 0; i < numElements; ++i) { 148 auto index = rewriter.create<LLVM::ConstantOp>( 149 loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 150 broadcasted = rewriter.create<LLVM::InsertElementOp>( 151 loc, llvmVectorType, broadcasted, toBroadcast, index); 152 } 153 return broadcasted; 154 } 155 156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 157 static Value optionallyBroadcast(Location loc, Value value, Type srcType, 158 LLVMTypeConverter &typeConverter, 159 ConversionPatternRewriter &rewriter) { 160 if (auto vectorType = srcType.dyn_cast<VectorType>()) { 161 unsigned numElements = vectorType.getNumElements(); 162 return broadcast(loc, value, numElements, typeConverter, rewriter); 163 } 164 return value; 165 } 166 167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 168 /// `BitFieldUExtract`. 169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 170 /// a vector type, construct a vector that has: 171 /// - same number of elements as `Base` 172 /// - each element has the type that is the same as the type of `Offset` or 173 /// `Count` 174 /// - each element has the same value as `Offset` or `Count` 175 /// Then cast `Offset` and `Count` if their bit width is different 176 /// from `Base` bit width. 177 static Value processCountOrOffset(Location loc, Value value, Type srcType, 178 Type dstType, LLVMTypeConverter &converter, 179 ConversionPatternRewriter &rewriter) { 180 Value broadcasted = 181 optionallyBroadcast(loc, value, srcType, converter, rewriter); 182 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 183 } 184 185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 186 /// offset to LLVM struct. Otherwise, the conversion is not supported. 187 static Optional<Type> 188 convertStructTypeWithOffset(spirv::StructType type, 189 LLVMTypeConverter &converter) { 190 if (type != VulkanLayoutUtils::decorateType(type)) 191 return llvm::None; 192 193 auto elementsVector = llvm::to_vector<8>( 194 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 195 return converter.convertType(elementType); 196 })); 197 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 198 /*isPacked=*/false); 199 } 200 201 /// Converts SPIR-V struct with no offset to packed LLVM struct. 202 static Type convertStructTypePacked(spirv::StructType type, 203 LLVMTypeConverter &converter) { 204 auto elementsVector = llvm::to_vector<8>( 205 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 206 return converter.convertType(elementType); 207 })); 208 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 209 /*isPacked=*/true); 210 } 211 212 /// Creates LLVM dialect constant with the given value. 213 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 214 unsigned value) { 215 return rewriter.create<LLVM::ConstantOp>( 216 loc, IntegerType::get(rewriter.getContext(), 32), 217 rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 218 } 219 220 /// Utility for `spv.Load` and `spv.Store` conversion. 221 static LogicalResult replaceWithLoadOrStore(Operation *op, 222 ConversionPatternRewriter &rewriter, 223 LLVMTypeConverter &typeConverter, 224 unsigned alignment, bool isVolatile, 225 bool isNonTemporal) { 226 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 227 auto dstType = typeConverter.convertType(loadOp.getType()); 228 if (!dstType) 229 return failure(); 230 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 231 loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal); 232 return success(); 233 } 234 auto storeOp = cast<spirv::StoreOp>(op); 235 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(), 236 storeOp.ptr(), alignment, 237 isVolatile, isNonTemporal); 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Type conversion 243 //===----------------------------------------------------------------------===// 244 245 /// Converts SPIR-V array type to LLVM array. Natural stride (according to 246 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 247 /// when converting ops that manipulate array types. 248 static Optional<Type> convertArrayType(spirv::ArrayType type, 249 TypeConverter &converter) { 250 unsigned stride = type.getArrayStride(); 251 Type elementType = type.getElementType(); 252 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes(); 253 if (stride != 0 && 254 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) 255 return llvm::None; 256 257 auto llvmElementType = converter.convertType(elementType); 258 unsigned numElements = type.getNumElements(); 259 return LLVM::LLVMArrayType::get(llvmElementType, numElements); 260 } 261 262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 263 /// modelled at the moment. 264 static Type convertPointerType(spirv::PointerType type, 265 TypeConverter &converter) { 266 auto pointeeType = converter.convertType(type.getPointeeType()); 267 return LLVM::LLVMPointerType::get(pointeeType); 268 } 269 270 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 271 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 272 /// no modelling of array stride at the moment. 273 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 274 TypeConverter &converter) { 275 if (type.getArrayStride() != 0) 276 return llvm::None; 277 auto elementType = converter.convertType(type.getElementType()); 278 return LLVM::LLVMArrayType::get(elementType, 0); 279 } 280 281 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 282 /// member decorations. Also, only natural offset is supported. 283 static Optional<Type> convertStructType(spirv::StructType type, 284 LLVMTypeConverter &converter) { 285 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 286 type.getMemberDecorations(memberDecorations); 287 if (!memberDecorations.empty()) 288 return llvm::None; 289 if (type.hasOffset()) 290 return convertStructTypeWithOffset(type, converter); 291 return convertStructTypePacked(type, converter); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // Operation conversion 296 //===----------------------------------------------------------------------===// 297 298 namespace { 299 300 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 301 public: 302 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 303 304 LogicalResult 305 matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands, 306 ConversionPatternRewriter &rewriter) const override { 307 auto dstType = typeConverter.convertType(op.component_ptr().getType()); 308 if (!dstType) 309 return failure(); 310 // To use GEP we need to add a first 0 index to go through the pointer. 311 auto indices = llvm::to_vector<4>(op.indices()); 312 Type indexType = op.indices().front().getType(); 313 auto llvmIndexType = typeConverter.convertType(indexType); 314 if (!llvmIndexType) 315 return failure(); 316 Value zero = rewriter.create<LLVM::ConstantOp>( 317 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 318 indices.insert(indices.begin(), zero); 319 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(), 320 indices); 321 return success(); 322 } 323 }; 324 325 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 326 public: 327 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 328 329 LogicalResult 330 matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands, 331 ConversionPatternRewriter &rewriter) const override { 332 auto dstType = typeConverter.convertType(op.pointer().getType()); 333 if (!dstType) 334 return failure(); 335 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable()); 336 return success(); 337 } 338 }; 339 340 class BitFieldInsertPattern 341 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 342 public: 343 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 344 345 LogicalResult 346 matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands, 347 ConversionPatternRewriter &rewriter) const override { 348 auto srcType = op.getType(); 349 auto dstType = typeConverter.convertType(srcType); 350 if (!dstType) 351 return failure(); 352 Location loc = op.getLoc(); 353 354 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 355 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 356 typeConverter, rewriter); 357 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 358 typeConverter, rewriter); 359 360 // Create a mask with bits set outside [Offset, Offset + Count - 1]. 361 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 362 Value maskShiftedByCount = 363 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 364 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 365 maskShiftedByCount, minusOne); 366 Value maskShiftedByCountAndOffset = 367 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 368 Value mask = rewriter.create<LLVM::XOrOp>( 369 loc, dstType, maskShiftedByCountAndOffset, minusOne); 370 371 // Extract unchanged bits from the `Base` that are outside of 372 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 373 Value baseAndMask = 374 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask); 375 Value insertShiftedByOffset = 376 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset); 377 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 378 insertShiftedByOffset); 379 return success(); 380 } 381 }; 382 383 /// Converts SPIR-V ConstantOp with scalar or vector type. 384 class ConstantScalarAndVectorPattern 385 : public SPIRVToLLVMConversion<spirv::ConstantOp> { 386 public: 387 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 388 389 LogicalResult 390 matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands, 391 ConversionPatternRewriter &rewriter) const override { 392 auto srcType = constOp.getType(); 393 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat()) 394 return failure(); 395 396 auto dstType = typeConverter.convertType(srcType); 397 if (!dstType) 398 return failure(); 399 400 // SPIR-V constant can be a signed/unsigned integer, which has to be 401 // casted to signless integer when converting to LLVM dialect. Removing the 402 // sign bit may have unexpected behaviour. However, it is better to handle 403 // it case-by-case, given that the purpose of the conversion is not to 404 // cover all possible corner cases. 405 if (isSignedIntegerOrVector(srcType) || 406 isUnsignedIntegerOrVector(srcType)) { 407 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 408 409 if (srcType.isa<VectorType>()) { 410 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); 411 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 412 constOp, dstType, 413 dstElementsAttr.mapValues( 414 signlessType, [&](const APInt &value) { return value; })); 415 return success(); 416 } 417 auto srcAttr = constOp.value().cast<IntegerAttr>(); 418 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 419 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 420 return success(); 421 } 422 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands, 423 constOp.getAttrs()); 424 return success(); 425 } 426 }; 427 428 class BitFieldSExtractPattern 429 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 430 public: 431 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 432 433 LogicalResult 434 matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands, 435 ConversionPatternRewriter &rewriter) const override { 436 auto srcType = op.getType(); 437 auto dstType = typeConverter.convertType(srcType); 438 if (!dstType) 439 return failure(); 440 Location loc = op.getLoc(); 441 442 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 443 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 444 typeConverter, rewriter); 445 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 446 typeConverter, rewriter); 447 448 // Create a constant that holds the size of the `Base`. 449 IntegerType integerType; 450 if (auto vecType = srcType.dyn_cast<VectorType>()) 451 integerType = vecType.getElementType().cast<IntegerType>(); 452 else 453 integerType = srcType.cast<IntegerType>(); 454 455 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 456 Value size = 457 srcType.isa<VectorType>() 458 ? rewriter.create<LLVM::ConstantOp>( 459 loc, dstType, 460 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize)) 461 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 462 463 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 464 // at Offset + Count - 1 is the most significant bit now. 465 Value countPlusOffset = 466 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 467 Value amountToShiftLeft = 468 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 469 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 470 loc, dstType, op.base(), amountToShiftLeft); 471 472 // Shift the result right, filling the bits with the sign bit. 473 Value amountToShiftRight = 474 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 475 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 476 amountToShiftRight); 477 return success(); 478 } 479 }; 480 481 class BitFieldUExtractPattern 482 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 483 public: 484 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 485 486 LogicalResult 487 matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands, 488 ConversionPatternRewriter &rewriter) const override { 489 auto srcType = op.getType(); 490 auto dstType = typeConverter.convertType(srcType); 491 if (!dstType) 492 return failure(); 493 Location loc = op.getLoc(); 494 495 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 496 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 497 typeConverter, rewriter); 498 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 499 typeConverter, rewriter); 500 501 // Create a mask with bits set at [0, Count - 1]. 502 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 503 Value maskShiftedByCount = 504 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 505 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 506 minusOne); 507 508 // Shift `Base` by `Offset` and apply the mask on it. 509 Value shiftedBase = 510 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset); 511 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 512 return success(); 513 } 514 }; 515 516 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 517 public: 518 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 519 520 LogicalResult 521 matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands, 522 ConversionPatternRewriter &rewriter) const override { 523 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands, 524 branchOp.getTarget()); 525 return success(); 526 } 527 }; 528 529 class BranchConditionalConversionPattern 530 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 531 public: 532 using SPIRVToLLVMConversion< 533 spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 534 535 LogicalResult 536 matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands, 537 ConversionPatternRewriter &rewriter) const override { 538 // If branch weights exist, map them to 32-bit integer vector. 539 ElementsAttr branchWeights = nullptr; 540 if (auto weights = op.branch_weights()) { 541 VectorType weightType = VectorType::get(2, rewriter.getI32Type()); 542 branchWeights = 543 DenseElementsAttr::get(weightType, weights.getValue().getValue()); 544 } 545 546 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 547 op, op.condition(), op.getTrueBlockArguments(), 548 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 549 op.getFalseBlock()); 550 return success(); 551 } 552 }; 553 554 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type 555 /// is an aggregate type (struct or array). Otherwise, converts to 556 /// `llvm.extractelement` that operates on vectors. 557 class CompositeExtractPattern 558 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 559 public: 560 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 561 562 LogicalResult 563 matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands, 564 ConversionPatternRewriter &rewriter) const override { 565 auto dstType = this->typeConverter.convertType(op.getType()); 566 if (!dstType) 567 return failure(); 568 569 Type containerType = op.composite().getType(); 570 if (containerType.isa<VectorType>()) { 571 Location loc = op.getLoc(); 572 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 573 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 574 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 575 op, dstType, op.composite(), index); 576 return success(); 577 } 578 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 579 op, dstType, op.composite(), op.indices()); 580 return success(); 581 } 582 }; 583 584 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type 585 /// is an aggregate type (struct or array). Otherwise, converts to 586 /// `llvm.insertelement` that operates on vectors. 587 class CompositeInsertPattern 588 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 589 public: 590 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 591 592 LogicalResult 593 matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands, 594 ConversionPatternRewriter &rewriter) const override { 595 auto dstType = this->typeConverter.convertType(op.getType()); 596 if (!dstType) 597 return failure(); 598 599 Type containerType = op.composite().getType(); 600 if (containerType.isa<VectorType>()) { 601 Location loc = op.getLoc(); 602 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 603 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 604 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 605 op, dstType, op.composite(), op.object(), index); 606 return success(); 607 } 608 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 609 op, dstType, op.composite(), op.object(), op.indices()); 610 return success(); 611 } 612 }; 613 614 /// Converts SPIR-V operations that have straightforward LLVM equivalent 615 /// into LLVM dialect operations. 616 template <typename SPIRVOp, typename LLVMOp> 617 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 618 public: 619 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 620 621 LogicalResult 622 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 623 ConversionPatternRewriter &rewriter) const override { 624 auto dstType = this->typeConverter.convertType(operation.getType()); 625 if (!dstType) 626 return failure(); 627 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands, 628 operation.getAttrs()); 629 return success(); 630 } 631 }; 632 633 /// Converts `spv.ExecutionMode` into a global struct constant that holds 634 /// execution mode information. 635 class ExecutionModePattern 636 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 637 public: 638 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 639 640 LogicalResult 641 matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands, 642 ConversionPatternRewriter &rewriter) const override { 643 // First, create the global struct's name that would be associated with 644 // this entry point's execution mode. We set it to be: 645 // __spv__{SPIR-V module name}_{function name}_execution_mode_info 646 ModuleOp module = op->getParentOfType<ModuleOp>(); 647 std::string moduleName; 648 if (module.getName().hasValue()) 649 moduleName = "_" + module.getName().getValue().str(); 650 else 651 moduleName = ""; 652 std::string executionModeInfoName = llvm::formatv( 653 "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str()); 654 655 MLIRContext *context = rewriter.getContext(); 656 OpBuilder::InsertionGuard guard(rewriter); 657 rewriter.setInsertionPointToStart(module.getBody()); 658 659 // Create a struct type, corresponding to the C struct below. 660 // struct { 661 // int32_t executionMode; 662 // int32_t values[]; // optional values 663 // }; 664 auto llvmI32Type = IntegerType::get(context, 32); 665 SmallVector<Type, 2> fields; 666 fields.push_back(llvmI32Type); 667 ArrayAttr values = op.values(); 668 if (!values.empty()) { 669 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 670 fields.push_back(arrayType); 671 } 672 auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 673 674 // Create `llvm.mlir.global` with initializer region containing one block. 675 auto global = rewriter.create<LLVM::GlobalOp>( 676 UnknownLoc::get(context), structType, /*isConstant=*/true, 677 LLVM::Linkage::External, executionModeInfoName, Attribute()); 678 Location loc = global.getLoc(); 679 Region ®ion = global.getInitializerRegion(); 680 Block *block = rewriter.createBlock(®ion); 681 682 // Initialize the struct and set the execution mode value. 683 rewriter.setInsertionPoint(block, block->begin()); 684 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 685 IntegerAttr executionModeAttr = op.execution_modeAttr(); 686 Value executionMode = 687 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr); 688 structValue = rewriter.create<LLVM::InsertValueOp>( 689 loc, structType, structValue, executionMode, 690 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, 691 context)); 692 693 // Insert extra operands if they exist into execution mode info struct. 694 for (unsigned i = 0, e = values.size(); i < e; ++i) { 695 auto attr = values.getValue()[i]; 696 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 697 structValue = rewriter.create<LLVM::InsertValueOp>( 698 loc, structType, structValue, entry, 699 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), 700 rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, 701 context)); 702 } 703 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 704 rewriter.eraseOp(op); 705 return success(); 706 } 707 }; 708 709 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global 710 /// returns a pointer, whereas in LLVM dialect the global holds an actual value. 711 /// This difference is handled by `spv.mlir.addressof` and 712 /// `llvm.mlir.addressof`ops that both return a pointer. 713 class GlobalVariablePattern 714 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 715 public: 716 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion; 717 718 LogicalResult 719 matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands, 720 ConversionPatternRewriter &rewriter) const override { 721 // Currently, there is no support of initialization with a constant value in 722 // SPIR-V dialect. Specialization constants are not considered as well. 723 if (op.initializer()) 724 return failure(); 725 726 auto srcType = op.type().cast<spirv::PointerType>(); 727 auto dstType = typeConverter.convertType(srcType.getPointeeType()); 728 if (!dstType) 729 return failure(); 730 731 // Limit conversion to the current invocation only or `StorageBuffer` 732 // required by SPIR-V runner. 733 // This is okay because multiple invocations are not supported yet. 734 auto storageClass = srcType.getStorageClass(); 735 if (storageClass != spirv::StorageClass::Input && 736 storageClass != spirv::StorageClass::Private && 737 storageClass != spirv::StorageClass::Output && 738 storageClass != spirv::StorageClass::StorageBuffer) { 739 return failure(); 740 } 741 742 // LLVM dialect spec: "If the global value is a constant, storing into it is 743 // not allowed.". This corresponds to SPIR-V 'Input' storage class that is 744 // read-only. 745 bool isConstant = storageClass == spirv::StorageClass::Input; 746 // SPIR-V spec: "By default, functions and global variables are private to a 747 // module and cannot be accessed by other modules. However, a module may be 748 // written to export or import functions and global (module scope) 749 // variables.". Therefore, map 'Private' storage class to private linkage, 750 // 'Input' and 'Output' to external linkage. 751 auto linkage = storageClass == spirv::StorageClass::Private 752 ? LLVM::Linkage::Private 753 : LLVM::Linkage::External; 754 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 755 op, dstType, isConstant, linkage, op.sym_name(), Attribute()); 756 return success(); 757 } 758 }; 759 760 /// Converts SPIR-V cast ops that do not have straightforward LLVM 761 /// equivalent in LLVM dialect. 762 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 763 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 764 public: 765 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 766 767 LogicalResult 768 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 769 ConversionPatternRewriter &rewriter) const override { 770 771 Type fromType = operation.operand().getType(); 772 Type toType = operation.getType(); 773 774 auto dstType = this->typeConverter.convertType(toType); 775 if (!dstType) 776 return failure(); 777 778 if (getBitWidth(fromType) < getBitWidth(toType)) { 779 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType, 780 operands); 781 return success(); 782 } 783 if (getBitWidth(fromType) > getBitWidth(toType)) { 784 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType, 785 operands); 786 return success(); 787 } 788 return failure(); 789 } 790 }; 791 792 class FunctionCallPattern 793 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 794 public: 795 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 796 797 LogicalResult 798 matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands, 799 ConversionPatternRewriter &rewriter) const override { 800 if (callOp.getNumResults() == 0) { 801 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands, 802 callOp.getAttrs()); 803 return success(); 804 } 805 806 // Function returns a single result. 807 auto dstType = typeConverter.convertType(callOp.getType(0)); 808 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands, 809 callOp.getAttrs()); 810 return success(); 811 } 812 }; 813 814 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 815 template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 816 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 817 public: 818 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 819 820 LogicalResult 821 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 822 ConversionPatternRewriter &rewriter) const override { 823 824 auto dstType = this->typeConverter.convertType(operation.getType()); 825 if (!dstType) 826 return failure(); 827 828 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 829 operation, dstType, 830 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), 831 operation.operand1(), operation.operand2(), 832 LLVM::FMFAttr::get({}, operation.getContext())); 833 return success(); 834 } 835 }; 836 837 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 838 template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 839 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 840 public: 841 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 842 843 LogicalResult 844 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 845 ConversionPatternRewriter &rewriter) const override { 846 847 auto dstType = this->typeConverter.convertType(operation.getType()); 848 if (!dstType) 849 return failure(); 850 851 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 852 operation, dstType, 853 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), 854 operation.operand1(), operation.operand2()); 855 return success(); 856 } 857 }; 858 859 class InverseSqrtPattern 860 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> { 861 public: 862 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion; 863 864 LogicalResult 865 matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands, 866 ConversionPatternRewriter &rewriter) const override { 867 auto srcType = op.getType(); 868 auto dstType = typeConverter.convertType(srcType); 869 if (!dstType) 870 return failure(); 871 872 Location loc = op.getLoc(); 873 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 874 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); 875 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 876 return success(); 877 } 878 }; 879 880 /// Converts `spv.Load` and `spv.Store` to LLVM dialect. 881 template <typename SPIRVop> 882 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> { 883 public: 884 using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion; 885 886 LogicalResult 887 matchAndRewrite(SPIRVop op, ArrayRef<Value> operands, 888 ConversionPatternRewriter &rewriter) const override { 889 890 if (!op.memory_access().hasValue()) { 891 replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0, 892 /*isVolatile=*/false, /*isNonTemporal=*/false); 893 return success(); 894 } 895 auto memoryAccess = op.memory_access().getValue(); 896 switch (memoryAccess) { 897 case spirv::MemoryAccess::Aligned: 898 case spirv::MemoryAccess::None: 899 case spirv::MemoryAccess::Nontemporal: 900 case spirv::MemoryAccess::Volatile: { 901 unsigned alignment = 902 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; 903 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 904 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 905 replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment, 906 isVolatile, isNonTemporal); 907 return success(); 908 } 909 default: 910 // There is no support of other memory access attributes. 911 return failure(); 912 } 913 } 914 }; 915 916 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. 917 template <typename SPIRVOp> 918 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 919 public: 920 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 921 922 LogicalResult 923 matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands, 924 ConversionPatternRewriter &rewriter) const override { 925 926 auto srcType = notOp.getType(); 927 auto dstType = this->typeConverter.convertType(srcType); 928 if (!dstType) 929 return failure(); 930 931 Location loc = notOp.getLoc(); 932 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 933 auto mask = srcType.template isa<VectorType>() 934 ? rewriter.create<LLVM::ConstantOp>( 935 loc, dstType, 936 SplatElementsAttr::get( 937 srcType.template cast<VectorType>(), minusOne)) 938 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 939 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 940 notOp.operand(), mask); 941 return success(); 942 } 943 }; 944 945 /// A template pattern that erases the given `SPIRVOp`. 946 template <typename SPIRVOp> 947 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 948 public: 949 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 950 951 LogicalResult 952 matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands, 953 ConversionPatternRewriter &rewriter) const override { 954 rewriter.eraseOp(op); 955 return success(); 956 } 957 }; 958 959 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 960 public: 961 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 962 963 LogicalResult 964 matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands, 965 ConversionPatternRewriter &rewriter) const override { 966 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 967 ArrayRef<Value>()); 968 return success(); 969 } 970 }; 971 972 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 973 public: 974 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 975 976 LogicalResult 977 matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands, 978 ConversionPatternRewriter &rewriter) const override { 979 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 980 operands); 981 return success(); 982 } 983 }; 984 985 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be 986 /// reachable for conversion to succeed. 987 /// The structure of the loop in LLVM dialect will be the following: 988 /// 989 /// +------------------------------------+ 990 /// | <code before spv.loop> | 991 /// | llvm.br ^header | 992 /// +------------------------------------+ 993 /// | 994 /// +----------------+ | 995 /// | | | 996 /// | V V 997 /// | +------------------------------------+ 998 /// | | ^header: | 999 /// | | <header code> | 1000 /// | | llvm.cond_br %cond, ^body, ^exit | 1001 /// | +------------------------------------+ 1002 /// | | 1003 /// | |----------------------+ 1004 /// | | | 1005 /// | V | 1006 /// | +------------------------------------+ | 1007 /// | | ^body: | | 1008 /// | | <body code> | | 1009 /// | | llvm.br ^continue | | 1010 /// | +------------------------------------+ | 1011 /// | | | 1012 /// | V | 1013 /// | +------------------------------------+ | 1014 /// | | ^continue: | | 1015 /// | | <continue code> | | 1016 /// | | llvm.br ^header | | 1017 /// | +------------------------------------+ | 1018 /// | | | 1019 /// +---------------+ +----------------------+ 1020 /// | 1021 /// V 1022 /// +------------------------------------+ 1023 /// | ^exit: | 1024 /// | llvm.br ^remaining | 1025 /// +------------------------------------+ 1026 /// | 1027 /// V 1028 /// +------------------------------------+ 1029 /// | ^remaining: | 1030 /// | <code after spv.loop> | 1031 /// +------------------------------------+ 1032 /// 1033 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1034 public: 1035 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1036 1037 LogicalResult 1038 matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands, 1039 ConversionPatternRewriter &rewriter) const override { 1040 // There is no support of loop control at the moment. 1041 if (loopOp.loop_control() != spirv::LoopControl::None) 1042 return failure(); 1043 1044 Location loc = loopOp.getLoc(); 1045 1046 // Split the current block after `spv.loop`. The remaining ops will be used 1047 // in `endBlock`. 1048 Block *currentBlock = rewriter.getBlock(); 1049 auto position = Block::iterator(loopOp); 1050 Block *endBlock = rewriter.splitBlock(currentBlock, position); 1051 1052 // Remove entry block and create a branch in the current block going to the 1053 // header block. 1054 Block *entryBlock = loopOp.getEntryBlock(); 1055 assert(entryBlock->getOperations().size() == 1); 1056 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1057 if (!brOp) 1058 return failure(); 1059 Block *headerBlock = loopOp.getHeaderBlock(); 1060 rewriter.setInsertionPointToEnd(currentBlock); 1061 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1062 rewriter.eraseBlock(entryBlock); 1063 1064 // Branch from merge block to end block. 1065 Block *mergeBlock = loopOp.getMergeBlock(); 1066 Operation *terminator = mergeBlock->getTerminator(); 1067 ValueRange terminatorOperands = terminator->getOperands(); 1068 rewriter.setInsertionPointToEnd(mergeBlock); 1069 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1070 1071 rewriter.inlineRegionBefore(loopOp.body(), endBlock); 1072 rewriter.replaceOp(loopOp, endBlock->getArguments()); 1073 return success(); 1074 } 1075 }; 1076 1077 /// Converts `spv.selection` with `spv.BranchConditional` in its header block. 1078 /// All blocks within selection should be reachable for conversion to succeed. 1079 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1080 public: 1081 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1082 1083 LogicalResult 1084 matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands, 1085 ConversionPatternRewriter &rewriter) const override { 1086 // There is no support for `Flatten` or `DontFlatten` selection control at 1087 // the moment. This are just compiler hints and can be performed during the 1088 // optimization passes. 1089 if (op.selection_control() != spirv::SelectionControl::None) 1090 return failure(); 1091 1092 // `spv.selection` should have at least two blocks: one selection header 1093 // block and one merge block. If no blocks are present, or control flow 1094 // branches straight to merge block (two blocks are present), the op is 1095 // redundant and it is erased. 1096 if (op.body().getBlocks().size() <= 2) { 1097 rewriter.eraseOp(op); 1098 return success(); 1099 } 1100 1101 Location loc = op.getLoc(); 1102 1103 // Split the current block after `spv.selection`. The remaining ops will be 1104 // used in `continueBlock`. 1105 auto *currentBlock = rewriter.getInsertionBlock(); 1106 rewriter.setInsertionPointAfter(op); 1107 auto position = rewriter.getInsertionPoint(); 1108 auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1109 1110 // Extract conditional branch information from the header block. By SPIR-V 1111 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` 1112 // op. Note that `spv.Switch op` is not supported at the moment in the 1113 // SPIR-V dialect. Remove this block when finished. 1114 auto *headerBlock = op.getHeaderBlock(); 1115 assert(headerBlock->getOperations().size() == 1); 1116 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1117 headerBlock->getOperations().front()); 1118 if (!condBrOp) 1119 return failure(); 1120 rewriter.eraseBlock(headerBlock); 1121 1122 // Branch from merge block to continue block. 1123 auto *mergeBlock = op.getMergeBlock(); 1124 Operation *terminator = mergeBlock->getTerminator(); 1125 ValueRange terminatorOperands = terminator->getOperands(); 1126 rewriter.setInsertionPointToEnd(mergeBlock); 1127 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1128 1129 // Link current block to `true` and `false` blocks within the selection. 1130 Block *trueBlock = condBrOp.getTrueBlock(); 1131 Block *falseBlock = condBrOp.getFalseBlock(); 1132 rewriter.setInsertionPointToEnd(currentBlock); 1133 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock, 1134 condBrOp.trueTargetOperands(), falseBlock, 1135 condBrOp.falseTargetOperands()); 1136 1137 rewriter.inlineRegionBefore(op.body(), continueBlock); 1138 rewriter.replaceOp(op, continueBlock->getArguments()); 1139 return success(); 1140 } 1141 }; 1142 1143 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1144 /// puts a restriction on `Shift` and `Base` to have the same bit width, 1145 /// `Shift` is zero or sign extended to match this specification. Cases when 1146 /// `Shift` bit width > `Base` bit width are considered to be illegal. 1147 template <typename SPIRVOp, typename LLVMOp> 1148 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1149 public: 1150 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1151 1152 LogicalResult 1153 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 1154 ConversionPatternRewriter &rewriter) const override { 1155 1156 auto dstType = this->typeConverter.convertType(operation.getType()); 1157 if (!dstType) 1158 return failure(); 1159 1160 Type op1Type = operation.operand1().getType(); 1161 Type op2Type = operation.operand2().getType(); 1162 1163 if (op1Type == op2Type) { 1164 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, 1165 operands); 1166 return success(); 1167 } 1168 1169 Location loc = operation.getLoc(); 1170 Value extended; 1171 if (isUnsignedIntegerOrVector(op2Type)) { 1172 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType, 1173 operation.operand2()); 1174 } else { 1175 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType, 1176 operation.operand2()); 1177 } 1178 Value result = rewriter.template create<LLVMOp>( 1179 loc, dstType, operation.operand1(), extended); 1180 rewriter.replaceOp(operation, result); 1181 return success(); 1182 } 1183 }; 1184 1185 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> { 1186 public: 1187 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion; 1188 1189 LogicalResult 1190 matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands, 1191 ConversionPatternRewriter &rewriter) const override { 1192 auto dstType = typeConverter.convertType(tanOp.getType()); 1193 if (!dstType) 1194 return failure(); 1195 1196 Location loc = tanOp.getLoc(); 1197 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); 1198 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); 1199 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1200 return success(); 1201 } 1202 }; 1203 1204 /// Convert `spv.Tanh` to 1205 /// 1206 /// exp(2x) - 1 1207 /// ----------- 1208 /// exp(2x) + 1 1209 /// 1210 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> { 1211 public: 1212 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion; 1213 1214 LogicalResult 1215 matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands, 1216 ConversionPatternRewriter &rewriter) const override { 1217 auto srcType = tanhOp.getType(); 1218 auto dstType = typeConverter.convertType(srcType); 1219 if (!dstType) 1220 return failure(); 1221 1222 Location loc = tanhOp.getLoc(); 1223 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1224 Value multiplied = 1225 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); 1226 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1227 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1228 Value numerator = 1229 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1230 Value denominator = 1231 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1232 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1233 denominator); 1234 return success(); 1235 } 1236 }; 1237 1238 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1239 public: 1240 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1241 1242 LogicalResult 1243 matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands, 1244 ConversionPatternRewriter &rewriter) const override { 1245 auto srcType = varOp.getType(); 1246 // Initialization is supported for scalars and vectors only. 1247 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType(); 1248 auto init = varOp.initializer(); 1249 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>()) 1250 return failure(); 1251 1252 auto dstType = typeConverter.convertType(srcType); 1253 if (!dstType) 1254 return failure(); 1255 1256 Location loc = varOp.getLoc(); 1257 Value size = createI32ConstantOf(loc, rewriter, 1); 1258 if (!init) { 1259 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size); 1260 return success(); 1261 } 1262 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size); 1263 rewriter.create<LLVM::StoreOp>(loc, init, allocated); 1264 rewriter.replaceOp(varOp, allocated); 1265 return success(); 1266 } 1267 }; 1268 1269 //===----------------------------------------------------------------------===// 1270 // FuncOp conversion 1271 //===----------------------------------------------------------------------===// 1272 1273 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1274 public: 1275 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1276 1277 LogicalResult 1278 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands, 1279 ConversionPatternRewriter &rewriter) const override { 1280 1281 // Convert function signature. At the moment LLVMType converter is enough 1282 // for currently supported types. 1283 auto funcType = funcOp.getType(); 1284 TypeConverter::SignatureConversion signatureConverter( 1285 funcType.getNumInputs()); 1286 auto llvmType = typeConverter.convertFunctionSignature( 1287 funcOp.getType(), /*isVariadic=*/false, signatureConverter); 1288 if (!llvmType) 1289 return failure(); 1290 1291 // Create a new `LLVMFuncOp` 1292 Location loc = funcOp.getLoc(); 1293 StringRef name = funcOp.getName(); 1294 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1295 1296 // Convert SPIR-V Function Control to equivalent LLVM function attribute 1297 MLIRContext *context = funcOp.getContext(); 1298 switch (funcOp.function_control()) { 1299 #define DISPATCH(functionControl, llvmAttr) \ 1300 case functionControl: \ 1301 newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ 1302 break; 1303 1304 DISPATCH(spirv::FunctionControl::Inline, 1305 StringAttr::get("alwaysinline", context)); 1306 DISPATCH(spirv::FunctionControl::DontInline, 1307 StringAttr::get("noinline", context)); 1308 DISPATCH(spirv::FunctionControl::Pure, 1309 StringAttr::get("readonly", context)); 1310 DISPATCH(spirv::FunctionControl::Const, 1311 StringAttr::get("readnone", context)); 1312 1313 #undef DISPATCH 1314 1315 // Default: if `spirv::FunctionControl::None`, then no attributes are 1316 // needed. 1317 default: 1318 break; 1319 } 1320 1321 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1322 newFuncOp.end()); 1323 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 1324 &signatureConverter))) { 1325 return failure(); 1326 } 1327 rewriter.eraseOp(funcOp); 1328 return success(); 1329 } 1330 }; 1331 1332 //===----------------------------------------------------------------------===// 1333 // ModuleOp conversion 1334 //===----------------------------------------------------------------------===// 1335 1336 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1337 public: 1338 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1339 1340 LogicalResult 1341 matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands, 1342 ConversionPatternRewriter &rewriter) const override { 1343 1344 auto newModuleOp = 1345 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 1346 rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); 1347 1348 // Remove the terminator block that was automatically added by builder 1349 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1350 rewriter.eraseOp(spvModuleOp); 1351 return success(); 1352 } 1353 }; 1354 1355 class ModuleEndConversionPattern 1356 : public SPIRVToLLVMConversion<spirv::ModuleEndOp> { 1357 public: 1358 using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion; 1359 1360 LogicalResult 1361 matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands, 1362 ConversionPatternRewriter &rewriter) const override { 1363 1364 rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp); 1365 return success(); 1366 } 1367 }; 1368 1369 } // namespace 1370 1371 //===----------------------------------------------------------------------===// 1372 // Pattern population 1373 //===----------------------------------------------------------------------===// 1374 1375 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { 1376 typeConverter.addConversion([&](spirv::ArrayType type) { 1377 return convertArrayType(type, typeConverter); 1378 }); 1379 typeConverter.addConversion([&](spirv::PointerType type) { 1380 return convertPointerType(type, typeConverter); 1381 }); 1382 typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1383 return convertRuntimeArrayType(type, typeConverter); 1384 }); 1385 typeConverter.addConversion([&](spirv::StructType type) { 1386 return convertStructType(type, typeConverter); 1387 }); 1388 } 1389 1390 void mlir::populateSPIRVToLLVMConversionPatterns( 1391 MLIRContext *context, LLVMTypeConverter &typeConverter, 1392 OwningRewritePatternList &patterns) { 1393 patterns.insert< 1394 // Arithmetic ops 1395 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1396 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1397 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1398 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1399 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1400 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1401 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1402 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1403 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1404 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1405 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1406 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1407 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1408 1409 // Bitwise ops 1410 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1411 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1412 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1413 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1414 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1415 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1416 NotPattern<spirv::NotOp>, 1417 1418 // Cast ops 1419 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>, 1420 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1421 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1422 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1423 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1424 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1425 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1426 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1427 1428 // Comparison ops 1429 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1430 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1431 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1432 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1433 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1434 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1435 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1436 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1437 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1438 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1439 FComparePattern<spirv::FUnordGreaterThanEqualOp, 1440 LLVM::FCmpPredicate::uge>, 1441 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1442 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1443 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1444 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1445 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1446 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1447 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1448 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1449 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1450 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1451 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1452 1453 // Constant op 1454 ConstantScalarAndVectorPattern, 1455 1456 // Control Flow ops 1457 BranchConversionPattern, BranchConditionalConversionPattern, 1458 FunctionCallPattern, LoopPattern, SelectionPattern, 1459 ErasePattern<spirv::MergeOp>, 1460 1461 // Entry points and execution mode are handled separately. 1462 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1463 1464 // GLSL extended instruction set ops 1465 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, 1466 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, 1467 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, 1468 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, 1469 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>, 1470 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>, 1471 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>, 1472 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, 1473 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, 1474 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>, 1475 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>, 1476 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, 1477 InverseSqrtPattern, TanPattern, TanhPattern, 1478 1479 // Logical ops 1480 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1481 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1482 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1483 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1484 NotPattern<spirv::LogicalNotOp>, 1485 1486 // Memory ops 1487 AccessChainPattern, AddressOfPattern, GlobalVariablePattern, 1488 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>, 1489 VariablePattern, 1490 1491 // Miscellaneous ops 1492 CompositeExtractPattern, CompositeInsertPattern, 1493 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1494 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 1495 1496 // Shift ops 1497 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1498 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1499 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1500 1501 // Return ops 1502 ReturnPattern, ReturnValuePattern>(context, typeConverter); 1503 } 1504 1505 void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1506 MLIRContext *context, LLVMTypeConverter &typeConverter, 1507 OwningRewritePatternList &patterns) { 1508 patterns.insert<FuncConversionPattern>(context, typeConverter); 1509 } 1510 1511 void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1512 MLIRContext *context, LLVMTypeConverter &typeConverter, 1513 OwningRewritePatternList &patterns) { 1514 patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>( 1515 context, typeConverter); 1516 } 1517 1518 //===----------------------------------------------------------------------===// 1519 // Pre-conversion hooks 1520 //===----------------------------------------------------------------------===// 1521 1522 /// Hook for descriptor set and binding number encoding. 1523 static constexpr StringRef kBinding = "binding"; 1524 static constexpr StringRef kDescriptorSet = "descriptor_set"; 1525 void mlir::encodeBindAttribute(ModuleOp module) { 1526 auto spvModules = module.getOps<spirv::ModuleOp>(); 1527 for (auto spvModule : spvModules) { 1528 spvModule.walk([&](spirv::GlobalVariableOp op) { 1529 IntegerAttr descriptorSet = 1530 op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1531 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1532 // For every global variable in the module, get the ones with descriptor 1533 // set and binding numbers. 1534 if (descriptorSet && binding) { 1535 // Encode these numbers into the variable's symbolic name. If the 1536 // SPIR-V module has a name, add it at the beginning. 1537 auto moduleAndName = spvModule.getName().hasValue() 1538 ? spvModule.getName().getValue().str() + "_" + 1539 op.sym_name().str() 1540 : op.sym_name().str(); 1541 std::string name = 1542 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 1543 std::to_string(descriptorSet.getInt()), 1544 std::to_string(binding.getInt())); 1545 1546 // Replace all symbol uses and set the new symbol name. Finally, remove 1547 // descriptor set and binding attributes. 1548 if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule))) 1549 op.emitError("unable to replace all symbol uses for ") << name; 1550 SymbolTable::setSymbolName(op, name); 1551 op.removeAttr(kDescriptorSet); 1552 op.removeAttr(kBinding); 1553 } 1554 }); 1555 } 1556 } 1557