1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 #define DEBUG_TYPE "spirv-to-llvm-pattern" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Utility functions 34 //===----------------------------------------------------------------------===// 35 36 /// Returns true if the given type is a signed integer or vector type. 37 static bool isSignedIntegerOrVector(Type type) { 38 if (type.isSignedInteger()) 39 return true; 40 if (auto vecType = type.dyn_cast<VectorType>()) 41 return vecType.getElementType().isSignedInteger(); 42 return false; 43 } 44 45 /// Returns true if the given type is an unsigned integer or vector type 46 static bool isUnsignedIntegerOrVector(Type type) { 47 if (type.isUnsignedInteger()) 48 return true; 49 if (auto vecType = type.dyn_cast<VectorType>()) 50 return vecType.getElementType().isUnsignedInteger(); 51 return false; 52 } 53 54 /// Returns the bit width of integer, float or vector of float or integer values 55 static unsigned getBitWidth(Type type) { 56 assert((type.isIntOrFloat() || type.isa<VectorType>()) && 57 "bitwidth is not supported for this type"); 58 if (type.isIntOrFloat()) 59 return type.getIntOrFloatBitWidth(); 60 auto vecType = type.dyn_cast<VectorType>(); 61 auto elementType = vecType.getElementType(); 62 assert(elementType.isIntOrFloat() && 63 "only integers and floats have a bitwidth"); 64 return elementType.getIntOrFloatBitWidth(); 65 } 66 67 /// Returns the bit width of LLVMType integer or vector. 68 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { 69 auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>(); 70 return (vectorType ? vectorType.getElementType() : type) 71 .cast<LLVM::LLVMIntegerType>() 72 .getBitWidth(); 73 } 74 75 /// Creates `IntegerAttribute` with all bits set for given type 76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 77 if (auto vecType = type.dyn_cast<VectorType>()) { 78 auto integerType = vecType.getElementType().cast<IntegerType>(); 79 return builder.getIntegerAttr(integerType, -1); 80 } 81 auto integerType = type.cast<IntegerType>(); 82 return builder.getIntegerAttr(integerType, -1); 83 } 84 85 /// Creates `llvm.mlir.constant` with all bits set for the given type. 86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 87 PatternRewriter &rewriter) { 88 if (srcType.isa<VectorType>()) { 89 return rewriter.create<LLVM::ConstantOp>( 90 loc, dstType, 91 SplatElementsAttr::get(srcType.cast<ShapedType>(), 92 minusOneIntegerAttribute(srcType, rewriter))); 93 } 94 return rewriter.create<LLVM::ConstantOp>( 95 loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 96 } 97 98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 99 static Value createFPConstant(Location loc, Type srcType, Type dstType, 100 PatternRewriter &rewriter, double value) { 101 if (auto vecType = srcType.dyn_cast<VectorType>()) { 102 auto floatType = vecType.getElementType().cast<FloatType>(); 103 return rewriter.create<LLVM::ConstantOp>( 104 loc, dstType, 105 SplatElementsAttr::get(vecType, 106 rewriter.getFloatAttr(floatType, value))); 107 } 108 auto floatType = srcType.cast<FloatType>(); 109 return rewriter.create<LLVM::ConstantOp>( 110 loc, dstType, rewriter.getFloatAttr(floatType, value)); 111 } 112 113 /// Utility function for bitfield ops: 114 /// - `BitFieldInsert` 115 /// - `BitFieldSExtract` 116 /// - `BitFieldUExtract` 117 /// Truncates or extends the value. If the bitwidth of the value is the same as 118 /// `dstType` bitwidth, the value remains unchanged. 119 static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, 120 PatternRewriter &rewriter) { 121 auto srcType = value.getType(); 122 auto llvmType = dstType.cast<LLVM::LLVMType>(); 123 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 124 unsigned valueBitWidth = 125 srcType.isa<LLVM::LLVMType>() 126 ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>()) 127 : getBitWidth(srcType); 128 129 if (valueBitWidth < targetBitWidth) 130 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 131 // If the bit widths of `Count` and `Offset` are greater than the bit width 132 // of the target type, they are truncated. Truncation is safe since `Count` 133 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 134 // both values can be expressed in 8 bits. 135 if (valueBitWidth > targetBitWidth) 136 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 137 return value; 138 } 139 140 /// Broadcasts the value to vector with `numElements` number of elements. 141 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 142 LLVMTypeConverter &typeConverter, 143 ConversionPatternRewriter &rewriter) { 144 auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 145 auto llvmVectorType = typeConverter.convertType(vectorType); 146 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 147 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 148 for (unsigned i = 0; i < numElements; ++i) { 149 auto index = rewriter.create<LLVM::ConstantOp>( 150 loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 151 broadcasted = rewriter.create<LLVM::InsertElementOp>( 152 loc, llvmVectorType, broadcasted, toBroadcast, index); 153 } 154 return broadcasted; 155 } 156 157 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 158 static Value optionallyBroadcast(Location loc, Value value, Type srcType, 159 LLVMTypeConverter &typeConverter, 160 ConversionPatternRewriter &rewriter) { 161 if (auto vectorType = srcType.dyn_cast<VectorType>()) { 162 unsigned numElements = vectorType.getNumElements(); 163 return broadcast(loc, value, numElements, typeConverter, rewriter); 164 } 165 return value; 166 } 167 168 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 169 /// `BitFieldUExtract`. 170 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 171 /// a vector type, construct a vector that has: 172 /// - same number of elements as `Base` 173 /// - each element has the type that is the same as the type of `Offset` or 174 /// `Count` 175 /// - each element has the same value as `Offset` or `Count` 176 /// Then cast `Offset` and `Count` if their bit width is different 177 /// from `Base` bit width. 178 static Value processCountOrOffset(Location loc, Value value, Type srcType, 179 Type dstType, LLVMTypeConverter &converter, 180 ConversionPatternRewriter &rewriter) { 181 Value broadcasted = 182 optionallyBroadcast(loc, value, srcType, converter, rewriter); 183 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 184 } 185 186 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 187 /// offset to LLVM struct. Otherwise, the conversion is not supported. 188 static Optional<Type> 189 convertStructTypeWithOffset(spirv::StructType type, 190 LLVMTypeConverter &converter) { 191 if (type != VulkanLayoutUtils::decorateType(type)) 192 return llvm::None; 193 194 auto elementsVector = llvm::to_vector<8>( 195 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 196 return converter.convertType(elementType).cast<LLVM::LLVMType>(); 197 })); 198 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 199 /*isPacked=*/false); 200 } 201 202 /// Converts SPIR-V struct with no offset to packed LLVM struct. 203 static Type convertStructTypePacked(spirv::StructType type, 204 LLVMTypeConverter &converter) { 205 auto elementsVector = llvm::to_vector<8>( 206 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 207 return converter.convertType(elementType).cast<LLVM::LLVMType>(); 208 })); 209 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 210 /*isPacked=*/true); 211 } 212 213 /// Creates LLVM dialect constant with the given value. 214 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 215 unsigned value) { 216 return rewriter.create<LLVM::ConstantOp>( 217 loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 32), 218 rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 219 } 220 221 /// Utility for `spv.Load` and `spv.Store` conversion. 222 static LogicalResult replaceWithLoadOrStore(Operation *op, 223 ConversionPatternRewriter &rewriter, 224 LLVMTypeConverter &typeConverter, 225 unsigned alignment, bool isVolatile, 226 bool isNonTemporal) { 227 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 228 auto dstType = typeConverter.convertType(loadOp.getType()); 229 if (!dstType) 230 return failure(); 231 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 232 loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal); 233 return success(); 234 } 235 auto storeOp = cast<spirv::StoreOp>(op); 236 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(), 237 storeOp.ptr(), alignment, 238 isVolatile, isNonTemporal); 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // Type conversion 244 //===----------------------------------------------------------------------===// 245 246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to 247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 248 /// when converting ops that manipulate array types. 249 static Optional<Type> convertArrayType(spirv::ArrayType type, 250 TypeConverter &converter) { 251 unsigned stride = type.getArrayStride(); 252 Type elementType = type.getElementType(); 253 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes(); 254 if (stride != 0 && 255 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) 256 return llvm::None; 257 258 auto llvmElementType = 259 converter.convertType(elementType).cast<LLVM::LLVMType>(); 260 unsigned numElements = type.getNumElements(); 261 return LLVM::LLVMArrayType::get(llvmElementType, numElements); 262 } 263 264 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 265 /// modelled at the moment. 266 static Type convertPointerType(spirv::PointerType type, 267 TypeConverter &converter) { 268 auto pointeeType = 269 converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>(); 270 return LLVM::LLVMPointerType::get(pointeeType); 271 } 272 273 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 274 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 275 /// no modelling of array stride at the moment. 276 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 277 TypeConverter &converter) { 278 if (type.getArrayStride() != 0) 279 return llvm::None; 280 auto elementType = 281 converter.convertType(type.getElementType()).cast<LLVM::LLVMType>(); 282 return LLVM::LLVMArrayType::get(elementType, 0); 283 } 284 285 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 286 /// member decorations. Also, only natural offset is supported. 287 static Optional<Type> convertStructType(spirv::StructType type, 288 LLVMTypeConverter &converter) { 289 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 290 type.getMemberDecorations(memberDecorations); 291 if (!memberDecorations.empty()) 292 return llvm::None; 293 if (type.hasOffset()) 294 return convertStructTypeWithOffset(type, converter); 295 return convertStructTypePacked(type, converter); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Operation conversion 300 //===----------------------------------------------------------------------===// 301 302 namespace { 303 304 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 305 public: 306 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 307 308 LogicalResult 309 matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands, 310 ConversionPatternRewriter &rewriter) const override { 311 auto dstType = typeConverter.convertType(op.component_ptr().getType()); 312 if (!dstType) 313 return failure(); 314 // To use GEP we need to add a first 0 index to go through the pointer. 315 auto indices = llvm::to_vector<4>(op.indices()); 316 Type indexType = op.indices().front().getType(); 317 auto llvmIndexType = typeConverter.convertType(indexType); 318 if (!llvmIndexType) 319 return failure(); 320 Value zero = rewriter.create<LLVM::ConstantOp>( 321 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 322 indices.insert(indices.begin(), zero); 323 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(), 324 indices); 325 return success(); 326 } 327 }; 328 329 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 330 public: 331 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 332 333 LogicalResult 334 matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands, 335 ConversionPatternRewriter &rewriter) const override { 336 auto dstType = typeConverter.convertType(op.pointer().getType()); 337 if (!dstType) 338 return failure(); 339 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>( 340 op, dstType.cast<LLVM::LLVMType>(), op.variable()); 341 return success(); 342 } 343 }; 344 345 class BitFieldInsertPattern 346 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 347 public: 348 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 349 350 LogicalResult 351 matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands, 352 ConversionPatternRewriter &rewriter) const override { 353 auto srcType = op.getType(); 354 auto dstType = typeConverter.convertType(srcType); 355 if (!dstType) 356 return failure(); 357 Location loc = op.getLoc(); 358 359 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 360 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 361 typeConverter, rewriter); 362 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 363 typeConverter, rewriter); 364 365 // Create a mask with bits set outside [Offset, Offset + Count - 1]. 366 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 367 Value maskShiftedByCount = 368 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 369 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 370 maskShiftedByCount, minusOne); 371 Value maskShiftedByCountAndOffset = 372 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 373 Value mask = rewriter.create<LLVM::XOrOp>( 374 loc, dstType, maskShiftedByCountAndOffset, minusOne); 375 376 // Extract unchanged bits from the `Base` that are outside of 377 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 378 Value baseAndMask = 379 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask); 380 Value insertShiftedByOffset = 381 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset); 382 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 383 insertShiftedByOffset); 384 return success(); 385 } 386 }; 387 388 /// Converts SPIR-V ConstantOp with scalar or vector type. 389 class ConstantScalarAndVectorPattern 390 : public SPIRVToLLVMConversion<spirv::ConstantOp> { 391 public: 392 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 393 394 LogicalResult 395 matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands, 396 ConversionPatternRewriter &rewriter) const override { 397 auto srcType = constOp.getType(); 398 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat()) 399 return failure(); 400 401 auto dstType = typeConverter.convertType(srcType); 402 if (!dstType) 403 return failure(); 404 405 // SPIR-V constant can be a signed/unsigned integer, which has to be 406 // casted to signless integer when converting to LLVM dialect. Removing the 407 // sign bit may have unexpected behaviour. However, it is better to handle 408 // it case-by-case, given that the purpose of the conversion is not to 409 // cover all possible corner cases. 410 if (isSignedIntegerOrVector(srcType) || 411 isUnsignedIntegerOrVector(srcType)) { 412 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 413 414 if (srcType.isa<VectorType>()) { 415 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); 416 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 417 constOp, dstType, 418 dstElementsAttr.mapValues( 419 signlessType, [&](const APInt &value) { return value; })); 420 return success(); 421 } 422 auto srcAttr = constOp.value().cast<IntegerAttr>(); 423 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 424 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 425 return success(); 426 } 427 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands, 428 constOp.getAttrs()); 429 return success(); 430 } 431 }; 432 433 class BitFieldSExtractPattern 434 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 435 public: 436 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 437 438 LogicalResult 439 matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands, 440 ConversionPatternRewriter &rewriter) const override { 441 auto srcType = op.getType(); 442 auto dstType = typeConverter.convertType(srcType); 443 if (!dstType) 444 return failure(); 445 Location loc = op.getLoc(); 446 447 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 448 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 449 typeConverter, rewriter); 450 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 451 typeConverter, rewriter); 452 453 // Create a constant that holds the size of the `Base`. 454 IntegerType integerType; 455 if (auto vecType = srcType.dyn_cast<VectorType>()) 456 integerType = vecType.getElementType().cast<IntegerType>(); 457 else 458 integerType = srcType.cast<IntegerType>(); 459 460 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 461 Value size = 462 srcType.isa<VectorType>() 463 ? rewriter.create<LLVM::ConstantOp>( 464 loc, dstType, 465 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize)) 466 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 467 468 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 469 // at Offset + Count - 1 is the most significant bit now. 470 Value countPlusOffset = 471 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 472 Value amountToShiftLeft = 473 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 474 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 475 loc, dstType, op.base(), amountToShiftLeft); 476 477 // Shift the result right, filling the bits with the sign bit. 478 Value amountToShiftRight = 479 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 480 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 481 amountToShiftRight); 482 return success(); 483 } 484 }; 485 486 class BitFieldUExtractPattern 487 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 488 public: 489 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 490 491 LogicalResult 492 matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands, 493 ConversionPatternRewriter &rewriter) const override { 494 auto srcType = op.getType(); 495 auto dstType = typeConverter.convertType(srcType); 496 if (!dstType) 497 return failure(); 498 Location loc = op.getLoc(); 499 500 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 501 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 502 typeConverter, rewriter); 503 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 504 typeConverter, rewriter); 505 506 // Create a mask with bits set at [0, Count - 1]. 507 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 508 Value maskShiftedByCount = 509 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 510 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 511 minusOne); 512 513 // Shift `Base` by `Offset` and apply the mask on it. 514 Value shiftedBase = 515 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset); 516 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 517 return success(); 518 } 519 }; 520 521 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 522 public: 523 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 524 525 LogicalResult 526 matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands, 527 ConversionPatternRewriter &rewriter) const override { 528 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands, 529 branchOp.getTarget()); 530 return success(); 531 } 532 }; 533 534 class BranchConditionalConversionPattern 535 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 536 public: 537 using SPIRVToLLVMConversion< 538 spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 539 540 LogicalResult 541 matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands, 542 ConversionPatternRewriter &rewriter) const override { 543 // If branch weights exist, map them to 32-bit integer vector. 544 ElementsAttr branchWeights = nullptr; 545 if (auto weights = op.branch_weights()) { 546 VectorType weightType = VectorType::get(2, rewriter.getI32Type()); 547 branchWeights = 548 DenseElementsAttr::get(weightType, weights.getValue().getValue()); 549 } 550 551 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 552 op, op.condition(), op.getTrueBlockArguments(), 553 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 554 op.getFalseBlock()); 555 return success(); 556 } 557 }; 558 559 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type 560 /// is an aggregate type (struct or array). Otherwise, converts to 561 /// `llvm.extractelement` that operates on vectors. 562 class CompositeExtractPattern 563 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 564 public: 565 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 566 567 LogicalResult 568 matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands, 569 ConversionPatternRewriter &rewriter) const override { 570 auto dstType = this->typeConverter.convertType(op.getType()); 571 if (!dstType) 572 return failure(); 573 574 Type containerType = op.composite().getType(); 575 if (containerType.isa<VectorType>()) { 576 Location loc = op.getLoc(); 577 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 578 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 579 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 580 op, dstType, op.composite(), index); 581 return success(); 582 } 583 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 584 op, dstType, op.composite(), op.indices()); 585 return success(); 586 } 587 }; 588 589 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type 590 /// is an aggregate type (struct or array). Otherwise, converts to 591 /// `llvm.insertelement` that operates on vectors. 592 class CompositeInsertPattern 593 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 594 public: 595 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 596 597 LogicalResult 598 matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands, 599 ConversionPatternRewriter &rewriter) const override { 600 auto dstType = this->typeConverter.convertType(op.getType()); 601 if (!dstType) 602 return failure(); 603 604 Type containerType = op.composite().getType(); 605 if (containerType.isa<VectorType>()) { 606 Location loc = op.getLoc(); 607 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 608 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 609 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 610 op, dstType, op.composite(), op.object(), index); 611 return success(); 612 } 613 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 614 op, dstType, op.composite(), op.object(), op.indices()); 615 return success(); 616 } 617 }; 618 619 /// Converts SPIR-V operations that have straightforward LLVM equivalent 620 /// into LLVM dialect operations. 621 template <typename SPIRVOp, typename LLVMOp> 622 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 623 public: 624 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 625 626 LogicalResult 627 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 628 ConversionPatternRewriter &rewriter) const override { 629 auto dstType = this->typeConverter.convertType(operation.getType()); 630 if (!dstType) 631 return failure(); 632 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands, 633 operation.getAttrs()); 634 return success(); 635 } 636 }; 637 638 /// Converts `spv.ExecutionMode` into a global struct constant that holds 639 /// execution mode information. 640 class ExecutionModePattern 641 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 642 public: 643 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 644 645 LogicalResult 646 matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands, 647 ConversionPatternRewriter &rewriter) const override { 648 // First, create the global struct's name that would be associated with 649 // this entry point's execution mode. We set it to be: 650 // __spv__{SPIR-V module name}_{function name}_execution_mode_info 651 ModuleOp module = op->getParentOfType<ModuleOp>(); 652 std::string moduleName; 653 if (module.getName().hasValue()) 654 moduleName = "_" + module.getName().getValue().str(); 655 else 656 moduleName = ""; 657 std::string executionModeInfoName = llvm::formatv( 658 "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str()); 659 660 MLIRContext *context = rewriter.getContext(); 661 OpBuilder::InsertionGuard guard(rewriter); 662 rewriter.setInsertionPointToStart(module.getBody()); 663 664 // Create a struct type, corresponding to the C struct below. 665 // struct { 666 // int32_t executionMode; 667 // int32_t values[]; // optional values 668 // }; 669 auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32); 670 SmallVector<LLVM::LLVMType, 2> fields; 671 fields.push_back(llvmI32Type); 672 ArrayAttr values = op.values(); 673 if (!values.empty()) { 674 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 675 fields.push_back(arrayType); 676 } 677 auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 678 679 // Create `llvm.mlir.global` with initializer region containing one block. 680 auto global = rewriter.create<LLVM::GlobalOp>( 681 UnknownLoc::get(context), structType, /*isConstant=*/true, 682 LLVM::Linkage::External, executionModeInfoName, Attribute()); 683 Location loc = global.getLoc(); 684 Region ®ion = global.getInitializerRegion(); 685 Block *block = rewriter.createBlock(®ion); 686 687 // Initialize the struct and set the execution mode value. 688 rewriter.setInsertionPoint(block, block->begin()); 689 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 690 IntegerAttr executionModeAttr = op.execution_modeAttr(); 691 Value executionMode = 692 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr); 693 structValue = rewriter.create<LLVM::InsertValueOp>( 694 loc, structType, structValue, executionMode, 695 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, 696 context)); 697 698 // Insert extra operands if they exist into execution mode info struct. 699 for (unsigned i = 0, e = values.size(); i < e; ++i) { 700 auto attr = values.getValue()[i]; 701 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 702 structValue = rewriter.create<LLVM::InsertValueOp>( 703 loc, structType, structValue, entry, 704 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), 705 rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, 706 context)); 707 } 708 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 709 rewriter.eraseOp(op); 710 return success(); 711 } 712 }; 713 714 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global 715 /// returns a pointer, whereas in LLVM dialect the global holds an actual value. 716 /// This difference is handled by `spv.mlir.addressof` and 717 /// `llvm.mlir.addressof`ops that both return a pointer. 718 class GlobalVariablePattern 719 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 720 public: 721 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion; 722 723 LogicalResult 724 matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands, 725 ConversionPatternRewriter &rewriter) const override { 726 // Currently, there is no support of initialization with a constant value in 727 // SPIR-V dialect. Specialization constants are not considered as well. 728 if (op.initializer()) 729 return failure(); 730 731 auto srcType = op.type().cast<spirv::PointerType>(); 732 auto dstType = typeConverter.convertType(srcType.getPointeeType()); 733 if (!dstType) 734 return failure(); 735 736 // Limit conversion to the current invocation only or `StorageBuffer` 737 // required by SPIR-V runner. 738 // This is okay because multiple invocations are not supported yet. 739 auto storageClass = srcType.getStorageClass(); 740 if (storageClass != spirv::StorageClass::Input && 741 storageClass != spirv::StorageClass::Private && 742 storageClass != spirv::StorageClass::Output && 743 storageClass != spirv::StorageClass::StorageBuffer) { 744 return failure(); 745 } 746 747 // LLVM dialect spec: "If the global value is a constant, storing into it is 748 // not allowed.". This corresponds to SPIR-V 'Input' storage class that is 749 // read-only. 750 bool isConstant = storageClass == spirv::StorageClass::Input; 751 // SPIR-V spec: "By default, functions and global variables are private to a 752 // module and cannot be accessed by other modules. However, a module may be 753 // written to export or import functions and global (module scope) 754 // variables.". Therefore, map 'Private' storage class to private linkage, 755 // 'Input' and 'Output' to external linkage. 756 auto linkage = storageClass == spirv::StorageClass::Private 757 ? LLVM::Linkage::Private 758 : LLVM::Linkage::External; 759 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 760 op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(), 761 Attribute()); 762 return success(); 763 } 764 }; 765 766 /// Converts SPIR-V cast ops that do not have straightforward LLVM 767 /// equivalent in LLVM dialect. 768 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 769 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 770 public: 771 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 772 773 LogicalResult 774 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 775 ConversionPatternRewriter &rewriter) const override { 776 777 Type fromType = operation.operand().getType(); 778 Type toType = operation.getType(); 779 780 auto dstType = this->typeConverter.convertType(toType); 781 if (!dstType) 782 return failure(); 783 784 if (getBitWidth(fromType) < getBitWidth(toType)) { 785 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType, 786 operands); 787 return success(); 788 } 789 if (getBitWidth(fromType) > getBitWidth(toType)) { 790 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType, 791 operands); 792 return success(); 793 } 794 return failure(); 795 } 796 }; 797 798 class FunctionCallPattern 799 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 800 public: 801 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 802 803 LogicalResult 804 matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands, 805 ConversionPatternRewriter &rewriter) const override { 806 if (callOp.getNumResults() == 0) { 807 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands, 808 callOp.getAttrs()); 809 return success(); 810 } 811 812 // Function returns a single result. 813 auto dstType = typeConverter.convertType(callOp.getType(0)); 814 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands, 815 callOp.getAttrs()); 816 return success(); 817 } 818 }; 819 820 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 821 template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 822 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 823 public: 824 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 825 826 LogicalResult 827 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 828 ConversionPatternRewriter &rewriter) const override { 829 830 auto dstType = this->typeConverter.convertType(operation.getType()); 831 if (!dstType) 832 return failure(); 833 834 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 835 operation, dstType, 836 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), 837 operation.operand1(), operation.operand2()); 838 return success(); 839 } 840 }; 841 842 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 843 template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 844 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 845 public: 846 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 847 848 LogicalResult 849 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 850 ConversionPatternRewriter &rewriter) const override { 851 852 auto dstType = this->typeConverter.convertType(operation.getType()); 853 if (!dstType) 854 return failure(); 855 856 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 857 operation, dstType, 858 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), 859 operation.operand1(), operation.operand2()); 860 return success(); 861 } 862 }; 863 864 class InverseSqrtPattern 865 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> { 866 public: 867 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion; 868 869 LogicalResult 870 matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands, 871 ConversionPatternRewriter &rewriter) const override { 872 auto srcType = op.getType(); 873 auto dstType = typeConverter.convertType(srcType); 874 if (!dstType) 875 return failure(); 876 877 Location loc = op.getLoc(); 878 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 879 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); 880 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 881 return success(); 882 } 883 }; 884 885 /// Converts `spv.Load` and `spv.Store` to LLVM dialect. 886 template <typename SPIRVop> 887 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> { 888 public: 889 using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion; 890 891 LogicalResult 892 matchAndRewrite(SPIRVop op, ArrayRef<Value> operands, 893 ConversionPatternRewriter &rewriter) const override { 894 895 if (!op.memory_access().hasValue()) { 896 replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0, 897 /*isVolatile=*/false, /*isNonTemporal=*/false); 898 return success(); 899 } 900 auto memoryAccess = op.memory_access().getValue(); 901 switch (memoryAccess) { 902 case spirv::MemoryAccess::Aligned: 903 case spirv::MemoryAccess::None: 904 case spirv::MemoryAccess::Nontemporal: 905 case spirv::MemoryAccess::Volatile: { 906 unsigned alignment = 907 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; 908 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 909 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 910 replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment, 911 isVolatile, isNonTemporal); 912 return success(); 913 } 914 default: 915 // There is no support of other memory access attributes. 916 return failure(); 917 } 918 } 919 }; 920 921 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. 922 template <typename SPIRVOp> 923 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 924 public: 925 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 926 927 LogicalResult 928 matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands, 929 ConversionPatternRewriter &rewriter) const override { 930 931 auto srcType = notOp.getType(); 932 auto dstType = this->typeConverter.convertType(srcType); 933 if (!dstType) 934 return failure(); 935 936 Location loc = notOp.getLoc(); 937 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 938 auto mask = srcType.template isa<VectorType>() 939 ? rewriter.create<LLVM::ConstantOp>( 940 loc, dstType, 941 SplatElementsAttr::get( 942 srcType.template cast<VectorType>(), minusOne)) 943 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 944 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 945 notOp.operand(), mask); 946 return success(); 947 } 948 }; 949 950 /// A template pattern that erases the given `SPIRVOp`. 951 template <typename SPIRVOp> 952 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 953 public: 954 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 955 956 LogicalResult 957 matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands, 958 ConversionPatternRewriter &rewriter) const override { 959 rewriter.eraseOp(op); 960 return success(); 961 } 962 }; 963 964 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 965 public: 966 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 967 968 LogicalResult 969 matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands, 970 ConversionPatternRewriter &rewriter) const override { 971 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 972 ArrayRef<Value>()); 973 return success(); 974 } 975 }; 976 977 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 978 public: 979 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 980 981 LogicalResult 982 matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands, 983 ConversionPatternRewriter &rewriter) const override { 984 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 985 operands); 986 return success(); 987 } 988 }; 989 990 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be 991 /// reachable for conversion to succeed. 992 /// The structure of the loop in LLVM dialect will be the following: 993 /// 994 /// +------------------------------------+ 995 /// | <code before spv.loop> | 996 /// | llvm.br ^header | 997 /// +------------------------------------+ 998 /// | 999 /// +----------------+ | 1000 /// | | | 1001 /// | V V 1002 /// | +------------------------------------+ 1003 /// | | ^header: | 1004 /// | | <header code> | 1005 /// | | llvm.cond_br %cond, ^body, ^exit | 1006 /// | +------------------------------------+ 1007 /// | | 1008 /// | |----------------------+ 1009 /// | | | 1010 /// | V | 1011 /// | +------------------------------------+ | 1012 /// | | ^body: | | 1013 /// | | <body code> | | 1014 /// | | llvm.br ^continue | | 1015 /// | +------------------------------------+ | 1016 /// | | | 1017 /// | V | 1018 /// | +------------------------------------+ | 1019 /// | | ^continue: | | 1020 /// | | <continue code> | | 1021 /// | | llvm.br ^header | | 1022 /// | +------------------------------------+ | 1023 /// | | | 1024 /// +---------------+ +----------------------+ 1025 /// | 1026 /// V 1027 /// +------------------------------------+ 1028 /// | ^exit: | 1029 /// | llvm.br ^remaining | 1030 /// +------------------------------------+ 1031 /// | 1032 /// V 1033 /// +------------------------------------+ 1034 /// | ^remaining: | 1035 /// | <code after spv.loop> | 1036 /// +------------------------------------+ 1037 /// 1038 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1039 public: 1040 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1041 1042 LogicalResult 1043 matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands, 1044 ConversionPatternRewriter &rewriter) const override { 1045 // There is no support of loop control at the moment. 1046 if (loopOp.loop_control() != spirv::LoopControl::None) 1047 return failure(); 1048 1049 Location loc = loopOp.getLoc(); 1050 1051 // Split the current block after `spv.loop`. The remaining ops will be used 1052 // in `endBlock`. 1053 Block *currentBlock = rewriter.getBlock(); 1054 auto position = Block::iterator(loopOp); 1055 Block *endBlock = rewriter.splitBlock(currentBlock, position); 1056 1057 // Remove entry block and create a branch in the current block going to the 1058 // header block. 1059 Block *entryBlock = loopOp.getEntryBlock(); 1060 assert(entryBlock->getOperations().size() == 1); 1061 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1062 if (!brOp) 1063 return failure(); 1064 Block *headerBlock = loopOp.getHeaderBlock(); 1065 rewriter.setInsertionPointToEnd(currentBlock); 1066 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1067 rewriter.eraseBlock(entryBlock); 1068 1069 // Branch from merge block to end block. 1070 Block *mergeBlock = loopOp.getMergeBlock(); 1071 Operation *terminator = mergeBlock->getTerminator(); 1072 ValueRange terminatorOperands = terminator->getOperands(); 1073 rewriter.setInsertionPointToEnd(mergeBlock); 1074 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1075 1076 rewriter.inlineRegionBefore(loopOp.body(), endBlock); 1077 rewriter.replaceOp(loopOp, endBlock->getArguments()); 1078 return success(); 1079 } 1080 }; 1081 1082 /// Converts `spv.selection` with `spv.BranchConditional` in its header block. 1083 /// All blocks within selection should be reachable for conversion to succeed. 1084 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1085 public: 1086 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1087 1088 LogicalResult 1089 matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands, 1090 ConversionPatternRewriter &rewriter) const override { 1091 // There is no support for `Flatten` or `DontFlatten` selection control at 1092 // the moment. This are just compiler hints and can be performed during the 1093 // optimization passes. 1094 if (op.selection_control() != spirv::SelectionControl::None) 1095 return failure(); 1096 1097 // `spv.selection` should have at least two blocks: one selection header 1098 // block and one merge block. If no blocks are present, or control flow 1099 // branches straight to merge block (two blocks are present), the op is 1100 // redundant and it is erased. 1101 if (op.body().getBlocks().size() <= 2) { 1102 rewriter.eraseOp(op); 1103 return success(); 1104 } 1105 1106 Location loc = op.getLoc(); 1107 1108 // Split the current block after `spv.selection`. The remaining ops will be 1109 // used in `continueBlock`. 1110 auto *currentBlock = rewriter.getInsertionBlock(); 1111 rewriter.setInsertionPointAfter(op); 1112 auto position = rewriter.getInsertionPoint(); 1113 auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1114 1115 // Extract conditional branch information from the header block. By SPIR-V 1116 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` 1117 // op. Note that `spv.Switch op` is not supported at the moment in the 1118 // SPIR-V dialect. Remove this block when finished. 1119 auto *headerBlock = op.getHeaderBlock(); 1120 assert(headerBlock->getOperations().size() == 1); 1121 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1122 headerBlock->getOperations().front()); 1123 if (!condBrOp) 1124 return failure(); 1125 rewriter.eraseBlock(headerBlock); 1126 1127 // Branch from merge block to continue block. 1128 auto *mergeBlock = op.getMergeBlock(); 1129 Operation *terminator = mergeBlock->getTerminator(); 1130 ValueRange terminatorOperands = terminator->getOperands(); 1131 rewriter.setInsertionPointToEnd(mergeBlock); 1132 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1133 1134 // Link current block to `true` and `false` blocks within the selection. 1135 Block *trueBlock = condBrOp.getTrueBlock(); 1136 Block *falseBlock = condBrOp.getFalseBlock(); 1137 rewriter.setInsertionPointToEnd(currentBlock); 1138 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock, 1139 condBrOp.trueTargetOperands(), falseBlock, 1140 condBrOp.falseTargetOperands()); 1141 1142 rewriter.inlineRegionBefore(op.body(), continueBlock); 1143 rewriter.replaceOp(op, continueBlock->getArguments()); 1144 return success(); 1145 } 1146 }; 1147 1148 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1149 /// puts a restriction on `Shift` and `Base` to have the same bit width, 1150 /// `Shift` is zero or sign extended to match this specification. Cases when 1151 /// `Shift` bit width > `Base` bit width are considered to be illegal. 1152 template <typename SPIRVOp, typename LLVMOp> 1153 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1154 public: 1155 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1156 1157 LogicalResult 1158 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, 1159 ConversionPatternRewriter &rewriter) const override { 1160 1161 auto dstType = this->typeConverter.convertType(operation.getType()); 1162 if (!dstType) 1163 return failure(); 1164 1165 Type op1Type = operation.operand1().getType(); 1166 Type op2Type = operation.operand2().getType(); 1167 1168 if (op1Type == op2Type) { 1169 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, 1170 operands); 1171 return success(); 1172 } 1173 1174 Location loc = operation.getLoc(); 1175 Value extended; 1176 if (isUnsignedIntegerOrVector(op2Type)) { 1177 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType, 1178 operation.operand2()); 1179 } else { 1180 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType, 1181 operation.operand2()); 1182 } 1183 Value result = rewriter.template create<LLVMOp>( 1184 loc, dstType, operation.operand1(), extended); 1185 rewriter.replaceOp(operation, result); 1186 return success(); 1187 } 1188 }; 1189 1190 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> { 1191 public: 1192 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion; 1193 1194 LogicalResult 1195 matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands, 1196 ConversionPatternRewriter &rewriter) const override { 1197 auto dstType = typeConverter.convertType(tanOp.getType()); 1198 if (!dstType) 1199 return failure(); 1200 1201 Location loc = tanOp.getLoc(); 1202 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); 1203 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); 1204 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1205 return success(); 1206 } 1207 }; 1208 1209 /// Convert `spv.Tanh` to 1210 /// 1211 /// exp(2x) - 1 1212 /// ----------- 1213 /// exp(2x) + 1 1214 /// 1215 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> { 1216 public: 1217 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion; 1218 1219 LogicalResult 1220 matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands, 1221 ConversionPatternRewriter &rewriter) const override { 1222 auto srcType = tanhOp.getType(); 1223 auto dstType = typeConverter.convertType(srcType); 1224 if (!dstType) 1225 return failure(); 1226 1227 Location loc = tanhOp.getLoc(); 1228 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1229 Value multiplied = 1230 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); 1231 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1232 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1233 Value numerator = 1234 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1235 Value denominator = 1236 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1237 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1238 denominator); 1239 return success(); 1240 } 1241 }; 1242 1243 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1244 public: 1245 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1246 1247 LogicalResult 1248 matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands, 1249 ConversionPatternRewriter &rewriter) const override { 1250 auto srcType = varOp.getType(); 1251 // Initialization is supported for scalars and vectors only. 1252 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType(); 1253 auto init = varOp.initializer(); 1254 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>()) 1255 return failure(); 1256 1257 auto dstType = typeConverter.convertType(srcType); 1258 if (!dstType) 1259 return failure(); 1260 1261 Location loc = varOp.getLoc(); 1262 Value size = createI32ConstantOf(loc, rewriter, 1); 1263 if (!init) { 1264 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size); 1265 return success(); 1266 } 1267 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size); 1268 rewriter.create<LLVM::StoreOp>(loc, init, allocated); 1269 rewriter.replaceOp(varOp, allocated); 1270 return success(); 1271 } 1272 }; 1273 1274 //===----------------------------------------------------------------------===// 1275 // FuncOp conversion 1276 //===----------------------------------------------------------------------===// 1277 1278 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1279 public: 1280 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1281 1282 LogicalResult 1283 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands, 1284 ConversionPatternRewriter &rewriter) const override { 1285 1286 // Convert function signature. At the moment LLVMType converter is enough 1287 // for currently supported types. 1288 auto funcType = funcOp.getType(); 1289 TypeConverter::SignatureConversion signatureConverter( 1290 funcType.getNumInputs()); 1291 auto llvmType = typeConverter.convertFunctionSignature( 1292 funcOp.getType(), /*isVariadic=*/false, signatureConverter); 1293 if (!llvmType) 1294 return failure(); 1295 1296 // Create a new `LLVMFuncOp` 1297 Location loc = funcOp.getLoc(); 1298 StringRef name = funcOp.getName(); 1299 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1300 1301 // Convert SPIR-V Function Control to equivalent LLVM function attribute 1302 MLIRContext *context = funcOp.getContext(); 1303 switch (funcOp.function_control()) { 1304 #define DISPATCH(functionControl, llvmAttr) \ 1305 case functionControl: \ 1306 newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ 1307 break; 1308 1309 DISPATCH(spirv::FunctionControl::Inline, 1310 StringAttr::get("alwaysinline", context)); 1311 DISPATCH(spirv::FunctionControl::DontInline, 1312 StringAttr::get("noinline", context)); 1313 DISPATCH(spirv::FunctionControl::Pure, 1314 StringAttr::get("readonly", context)); 1315 DISPATCH(spirv::FunctionControl::Const, 1316 StringAttr::get("readnone", context)); 1317 1318 #undef DISPATCH 1319 1320 // Default: if `spirv::FunctionControl::None`, then no attributes are 1321 // needed. 1322 default: 1323 break; 1324 } 1325 1326 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1327 newFuncOp.end()); 1328 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 1329 &signatureConverter))) { 1330 return failure(); 1331 } 1332 rewriter.eraseOp(funcOp); 1333 return success(); 1334 } 1335 }; 1336 1337 //===----------------------------------------------------------------------===// 1338 // ModuleOp conversion 1339 //===----------------------------------------------------------------------===// 1340 1341 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1342 public: 1343 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1344 1345 LogicalResult 1346 matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands, 1347 ConversionPatternRewriter &rewriter) const override { 1348 1349 auto newModuleOp = 1350 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 1351 rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); 1352 1353 // Remove the terminator block that was automatically added by builder 1354 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1355 rewriter.eraseOp(spvModuleOp); 1356 return success(); 1357 } 1358 }; 1359 1360 class ModuleEndConversionPattern 1361 : public SPIRVToLLVMConversion<spirv::ModuleEndOp> { 1362 public: 1363 using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion; 1364 1365 LogicalResult 1366 matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands, 1367 ConversionPatternRewriter &rewriter) const override { 1368 1369 rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp); 1370 return success(); 1371 } 1372 }; 1373 1374 } // namespace 1375 1376 //===----------------------------------------------------------------------===// 1377 // Pattern population 1378 //===----------------------------------------------------------------------===// 1379 1380 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { 1381 typeConverter.addConversion([&](spirv::ArrayType type) { 1382 return convertArrayType(type, typeConverter); 1383 }); 1384 typeConverter.addConversion([&](spirv::PointerType type) { 1385 return convertPointerType(type, typeConverter); 1386 }); 1387 typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1388 return convertRuntimeArrayType(type, typeConverter); 1389 }); 1390 typeConverter.addConversion([&](spirv::StructType type) { 1391 return convertStructType(type, typeConverter); 1392 }); 1393 } 1394 1395 void mlir::populateSPIRVToLLVMConversionPatterns( 1396 MLIRContext *context, LLVMTypeConverter &typeConverter, 1397 OwningRewritePatternList &patterns) { 1398 patterns.insert< 1399 // Arithmetic ops 1400 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1401 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1402 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1403 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1404 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1405 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1406 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1407 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1408 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1409 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1410 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1411 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1412 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1413 1414 // Bitwise ops 1415 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1416 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1417 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1418 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1419 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1420 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1421 NotPattern<spirv::NotOp>, 1422 1423 // Cast ops 1424 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>, 1425 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1426 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1427 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1428 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1429 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1430 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1431 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1432 1433 // Comparison ops 1434 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1435 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1436 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1437 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1438 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1439 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1440 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1441 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1442 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1443 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1444 FComparePattern<spirv::FUnordGreaterThanEqualOp, 1445 LLVM::FCmpPredicate::uge>, 1446 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1447 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1448 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1449 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1450 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1451 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1452 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1453 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1454 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1455 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1456 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1457 1458 // Constant op 1459 ConstantScalarAndVectorPattern, 1460 1461 // Control Flow ops 1462 BranchConversionPattern, BranchConditionalConversionPattern, 1463 FunctionCallPattern, LoopPattern, SelectionPattern, 1464 ErasePattern<spirv::MergeOp>, 1465 1466 // Entry points and execution mode are handled separately. 1467 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1468 1469 // GLSL extended instruction set ops 1470 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, 1471 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, 1472 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, 1473 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, 1474 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>, 1475 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>, 1476 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>, 1477 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, 1478 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, 1479 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>, 1480 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>, 1481 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, 1482 InverseSqrtPattern, TanPattern, TanhPattern, 1483 1484 // Logical ops 1485 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1486 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1487 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1488 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1489 NotPattern<spirv::LogicalNotOp>, 1490 1491 // Memory ops 1492 AccessChainPattern, AddressOfPattern, GlobalVariablePattern, 1493 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>, 1494 VariablePattern, 1495 1496 // Miscellaneous ops 1497 CompositeExtractPattern, CompositeInsertPattern, 1498 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1499 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 1500 1501 // Shift ops 1502 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1503 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1504 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1505 1506 // Return ops 1507 ReturnPattern, ReturnValuePattern>(context, typeConverter); 1508 } 1509 1510 void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1511 MLIRContext *context, LLVMTypeConverter &typeConverter, 1512 OwningRewritePatternList &patterns) { 1513 patterns.insert<FuncConversionPattern>(context, typeConverter); 1514 } 1515 1516 void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1517 MLIRContext *context, LLVMTypeConverter &typeConverter, 1518 OwningRewritePatternList &patterns) { 1519 patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>( 1520 context, typeConverter); 1521 } 1522 1523 //===----------------------------------------------------------------------===// 1524 // Pre-conversion hooks 1525 //===----------------------------------------------------------------------===// 1526 1527 /// Hook for descriptor set and binding number encoding. 1528 static constexpr StringRef kBinding = "binding"; 1529 static constexpr StringRef kDescriptorSet = "descriptor_set"; 1530 void mlir::encodeBindAttribute(ModuleOp module) { 1531 auto spvModules = module.getOps<spirv::ModuleOp>(); 1532 for (auto spvModule : spvModules) { 1533 spvModule.walk([&](spirv::GlobalVariableOp op) { 1534 IntegerAttr descriptorSet = 1535 op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1536 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1537 // For every global variable in the module, get the ones with descriptor 1538 // set and binding numbers. 1539 if (descriptorSet && binding) { 1540 // Encode these numbers into the variable's symbolic name. If the 1541 // SPIR-V module has a name, add it at the beginning. 1542 auto moduleAndName = spvModule.getName().hasValue() 1543 ? spvModule.getName().getValue().str() + "_" + 1544 op.sym_name().str() 1545 : op.sym_name().str(); 1546 std::string name = 1547 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 1548 std::to_string(descriptorSet.getInt()), 1549 std::to_string(binding.getInt())); 1550 1551 // Replace all symbol uses and set the new symbol name. Finally, remove 1552 // descriptor set and binding attributes. 1553 if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule))) 1554 op.emitError("unable to replace all symbol uses for ") << name; 1555 SymbolTable::setSymbolName(op, name); 1556 op.removeAttr(kDescriptorSet); 1557 op.removeAttr(kBinding); 1558 } 1559 }); 1560 } 1561 } 1562