1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Support/LogicalResult.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 #define DEBUG_TYPE "spirv-to-llvm-pattern" 28 29 using namespace mlir; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// Returns true if the given type is a signed integer or vector type. 36 static bool isSignedIntegerOrVector(Type type) { 37 if (type.isSignedInteger()) 38 return true; 39 if (auto vecType = type.dyn_cast<VectorType>()) 40 return vecType.getElementType().isSignedInteger(); 41 return false; 42 } 43 44 /// Returns true if the given type is an unsigned integer or vector type 45 static bool isUnsignedIntegerOrVector(Type type) { 46 if (type.isUnsignedInteger()) 47 return true; 48 if (auto vecType = type.dyn_cast<VectorType>()) 49 return vecType.getElementType().isUnsignedInteger(); 50 return false; 51 } 52 53 /// Returns the bit width of integer, float or vector of float or integer values 54 static unsigned getBitWidth(Type type) { 55 assert((type.isIntOrFloat() || type.isa<VectorType>()) && 56 "bitwidth is not supported for this type"); 57 if (type.isIntOrFloat()) 58 return type.getIntOrFloatBitWidth(); 59 auto vecType = type.dyn_cast<VectorType>(); 60 auto elementType = vecType.getElementType(); 61 assert(elementType.isIntOrFloat() && 62 "only integers and floats have a bitwidth"); 63 return elementType.getIntOrFloatBitWidth(); 64 } 65 66 /// Returns the bit width of LLVMType integer or vector. 67 static unsigned getLLVMTypeBitWidth(Type type) { 68 return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) 69 : type) 70 .cast<IntegerType>() 71 .getWidth(); 72 } 73 74 /// Creates `IntegerAttribute` with all bits set for given type 75 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 76 if (auto vecType = type.dyn_cast<VectorType>()) { 77 auto integerType = vecType.getElementType().cast<IntegerType>(); 78 return builder.getIntegerAttr(integerType, -1); 79 } 80 auto integerType = type.cast<IntegerType>(); 81 return builder.getIntegerAttr(integerType, -1); 82 } 83 84 /// Creates `llvm.mlir.constant` with all bits set for the given type. 85 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 86 PatternRewriter &rewriter) { 87 if (srcType.isa<VectorType>()) { 88 return rewriter.create<LLVM::ConstantOp>( 89 loc, dstType, 90 SplatElementsAttr::get(srcType.cast<ShapedType>(), 91 minusOneIntegerAttribute(srcType, rewriter))); 92 } 93 return rewriter.create<LLVM::ConstantOp>( 94 loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 95 } 96 97 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 98 static Value createFPConstant(Location loc, Type srcType, Type dstType, 99 PatternRewriter &rewriter, double value) { 100 if (auto vecType = srcType.dyn_cast<VectorType>()) { 101 auto floatType = vecType.getElementType().cast<FloatType>(); 102 return rewriter.create<LLVM::ConstantOp>( 103 loc, dstType, 104 SplatElementsAttr::get(vecType, 105 rewriter.getFloatAttr(floatType, value))); 106 } 107 auto floatType = srcType.cast<FloatType>(); 108 return rewriter.create<LLVM::ConstantOp>( 109 loc, dstType, rewriter.getFloatAttr(floatType, value)); 110 } 111 112 /// Utility function for bitfield ops: 113 /// - `BitFieldInsert` 114 /// - `BitFieldSExtract` 115 /// - `BitFieldUExtract` 116 /// Truncates or extends the value. If the bitwidth of the value is the same as 117 /// `llvmType` bitwidth, the value remains unchanged. 118 static Value optionallyTruncateOrExtend(Location loc, Value value, 119 Type llvmType, 120 PatternRewriter &rewriter) { 121 auto srcType = value.getType(); 122 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 123 unsigned valueBitWidth = LLVM::isCompatibleType(srcType) 124 ? getLLVMTypeBitWidth(srcType) 125 : getBitWidth(srcType); 126 127 if (valueBitWidth < targetBitWidth) 128 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 129 // If the bit widths of `Count` and `Offset` are greater than the bit width 130 // of the target type, they are truncated. Truncation is safe since `Count` 131 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 132 // both values can be expressed in 8 bits. 133 if (valueBitWidth > targetBitWidth) 134 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 135 return value; 136 } 137 138 /// Broadcasts the value to vector with `numElements` number of elements. 139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 140 LLVMTypeConverter &typeConverter, 141 ConversionPatternRewriter &rewriter) { 142 auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 143 auto llvmVectorType = typeConverter.convertType(vectorType); 144 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 145 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 146 for (unsigned i = 0; i < numElements; ++i) { 147 auto index = rewriter.create<LLVM::ConstantOp>( 148 loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 149 broadcasted = rewriter.create<LLVM::InsertElementOp>( 150 loc, llvmVectorType, broadcasted, toBroadcast, index); 151 } 152 return broadcasted; 153 } 154 155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 156 static Value optionallyBroadcast(Location loc, Value value, Type srcType, 157 LLVMTypeConverter &typeConverter, 158 ConversionPatternRewriter &rewriter) { 159 if (auto vectorType = srcType.dyn_cast<VectorType>()) { 160 unsigned numElements = vectorType.getNumElements(); 161 return broadcast(loc, value, numElements, typeConverter, rewriter); 162 } 163 return value; 164 } 165 166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 167 /// `BitFieldUExtract`. 168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 169 /// a vector type, construct a vector that has: 170 /// - same number of elements as `Base` 171 /// - each element has the type that is the same as the type of `Offset` or 172 /// `Count` 173 /// - each element has the same value as `Offset` or `Count` 174 /// Then cast `Offset` and `Count` if their bit width is different 175 /// from `Base` bit width. 176 static Value processCountOrOffset(Location loc, Value value, Type srcType, 177 Type dstType, LLVMTypeConverter &converter, 178 ConversionPatternRewriter &rewriter) { 179 Value broadcasted = 180 optionallyBroadcast(loc, value, srcType, converter, rewriter); 181 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 182 } 183 184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 185 /// offset to LLVM struct. Otherwise, the conversion is not supported. 186 static Optional<Type> 187 convertStructTypeWithOffset(spirv::StructType type, 188 LLVMTypeConverter &converter) { 189 if (type != VulkanLayoutUtils::decorateType(type)) 190 return llvm::None; 191 192 auto elementsVector = llvm::to_vector<8>( 193 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 194 return converter.convertType(elementType); 195 })); 196 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 197 /*isPacked=*/false); 198 } 199 200 /// Converts SPIR-V struct with no offset to packed LLVM struct. 201 static Type convertStructTypePacked(spirv::StructType type, 202 LLVMTypeConverter &converter) { 203 auto elementsVector = llvm::to_vector<8>( 204 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 205 return converter.convertType(elementType); 206 })); 207 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 208 /*isPacked=*/true); 209 } 210 211 /// Creates LLVM dialect constant with the given value. 212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 213 unsigned value) { 214 return rewriter.create<LLVM::ConstantOp>( 215 loc, IntegerType::get(rewriter.getContext(), 32), 216 rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 217 } 218 219 /// Utility for `spv.Load` and `spv.Store` conversion. 220 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, 221 ConversionPatternRewriter &rewriter, 222 LLVMTypeConverter &typeConverter, 223 unsigned alignment, bool isVolatile, 224 bool isNonTemporal) { 225 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 226 auto dstType = typeConverter.convertType(loadOp.getType()); 227 if (!dstType) 228 return failure(); 229 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 230 loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, 231 isVolatile, isNonTemporal); 232 return success(); 233 } 234 auto storeOp = cast<spirv::StoreOp>(op); 235 spirv::StoreOpAdaptor adaptor(operands); 236 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(), 237 adaptor.ptr(), alignment, 238 isVolatile, isNonTemporal); 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // Type conversion 244 //===----------------------------------------------------------------------===// 245 246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to 247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 248 /// when converting ops that manipulate array types. 249 static Optional<Type> convertArrayType(spirv::ArrayType type, 250 TypeConverter &converter) { 251 unsigned stride = type.getArrayStride(); 252 Type elementType = type.getElementType(); 253 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes(); 254 if (stride != 0 && 255 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) 256 return llvm::None; 257 258 auto llvmElementType = converter.convertType(elementType); 259 unsigned numElements = type.getNumElements(); 260 return LLVM::LLVMArrayType::get(llvmElementType, numElements); 261 } 262 263 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 264 /// modelled at the moment. 265 static Type convertPointerType(spirv::PointerType type, 266 TypeConverter &converter) { 267 auto pointeeType = converter.convertType(type.getPointeeType()); 268 return LLVM::LLVMPointerType::get(pointeeType); 269 } 270 271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 273 /// no modelling of array stride at the moment. 274 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 275 TypeConverter &converter) { 276 if (type.getArrayStride() != 0) 277 return llvm::None; 278 auto elementType = converter.convertType(type.getElementType()); 279 return LLVM::LLVMArrayType::get(elementType, 0); 280 } 281 282 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 283 /// member decorations. Also, only natural offset is supported. 284 static Optional<Type> convertStructType(spirv::StructType type, 285 LLVMTypeConverter &converter) { 286 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 287 type.getMemberDecorations(memberDecorations); 288 if (!memberDecorations.empty()) 289 return llvm::None; 290 if (type.hasOffset()) 291 return convertStructTypeWithOffset(type, converter); 292 return convertStructTypePacked(type, converter); 293 } 294 295 //===----------------------------------------------------------------------===// 296 // Operation conversion 297 //===----------------------------------------------------------------------===// 298 299 namespace { 300 301 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 302 public: 303 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 304 305 LogicalResult 306 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, 307 ConversionPatternRewriter &rewriter) const override { 308 auto dstType = typeConverter.convertType(op.component_ptr().getType()); 309 if (!dstType) 310 return failure(); 311 // To use GEP we need to add a first 0 index to go through the pointer. 312 auto indices = llvm::to_vector<4>(adaptor.indices()); 313 Type indexType = op.indices().front().getType(); 314 auto llvmIndexType = typeConverter.convertType(indexType); 315 if (!llvmIndexType) 316 return failure(); 317 Value zero = rewriter.create<LLVM::ConstantOp>( 318 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 319 indices.insert(indices.begin(), zero); 320 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(), 321 indices); 322 return success(); 323 } 324 }; 325 326 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 327 public: 328 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 329 330 LogicalResult 331 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, 332 ConversionPatternRewriter &rewriter) const override { 333 auto dstType = typeConverter.convertType(op.pointer().getType()); 334 if (!dstType) 335 return failure(); 336 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable()); 337 return success(); 338 } 339 }; 340 341 class BitFieldInsertPattern 342 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 343 public: 344 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 345 346 LogicalResult 347 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, 348 ConversionPatternRewriter &rewriter) const override { 349 auto srcType = op.getType(); 350 auto dstType = typeConverter.convertType(srcType); 351 if (!dstType) 352 return failure(); 353 Location loc = op.getLoc(); 354 355 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 356 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 357 typeConverter, rewriter); 358 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 359 typeConverter, rewriter); 360 361 // Create a mask with bits set outside [Offset, Offset + Count - 1]. 362 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 363 Value maskShiftedByCount = 364 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 365 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 366 maskShiftedByCount, minusOne); 367 Value maskShiftedByCountAndOffset = 368 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 369 Value mask = rewriter.create<LLVM::XOrOp>( 370 loc, dstType, maskShiftedByCountAndOffset, minusOne); 371 372 // Extract unchanged bits from the `Base` that are outside of 373 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 374 Value baseAndMask = 375 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask); 376 Value insertShiftedByOffset = 377 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset); 378 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 379 insertShiftedByOffset); 380 return success(); 381 } 382 }; 383 384 /// Converts SPIR-V ConstantOp with scalar or vector type. 385 class ConstantScalarAndVectorPattern 386 : public SPIRVToLLVMConversion<spirv::ConstantOp> { 387 public: 388 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 389 390 LogicalResult 391 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 auto srcType = constOp.getType(); 394 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat()) 395 return failure(); 396 397 auto dstType = typeConverter.convertType(srcType); 398 if (!dstType) 399 return failure(); 400 401 // SPIR-V constant can be a signed/unsigned integer, which has to be 402 // casted to signless integer when converting to LLVM dialect. Removing the 403 // sign bit may have unexpected behaviour. However, it is better to handle 404 // it case-by-case, given that the purpose of the conversion is not to 405 // cover all possible corner cases. 406 if (isSignedIntegerOrVector(srcType) || 407 isUnsignedIntegerOrVector(srcType)) { 408 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 409 410 if (srcType.isa<VectorType>()) { 411 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); 412 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 413 constOp, dstType, 414 dstElementsAttr.mapValues( 415 signlessType, [&](const APInt &value) { return value; })); 416 return success(); 417 } 418 auto srcAttr = constOp.value().cast<IntegerAttr>(); 419 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 420 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 421 return success(); 422 } 423 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 424 constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); 425 return success(); 426 } 427 }; 428 429 class BitFieldSExtractPattern 430 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 431 public: 432 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 433 434 LogicalResult 435 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, 436 ConversionPatternRewriter &rewriter) const override { 437 auto srcType = op.getType(); 438 auto dstType = typeConverter.convertType(srcType); 439 if (!dstType) 440 return failure(); 441 Location loc = op.getLoc(); 442 443 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 444 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 445 typeConverter, rewriter); 446 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 447 typeConverter, rewriter); 448 449 // Create a constant that holds the size of the `Base`. 450 IntegerType integerType; 451 if (auto vecType = srcType.dyn_cast<VectorType>()) 452 integerType = vecType.getElementType().cast<IntegerType>(); 453 else 454 integerType = srcType.cast<IntegerType>(); 455 456 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 457 Value size = 458 srcType.isa<VectorType>() 459 ? rewriter.create<LLVM::ConstantOp>( 460 loc, dstType, 461 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize)) 462 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 463 464 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 465 // at Offset + Count - 1 is the most significant bit now. 466 Value countPlusOffset = 467 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 468 Value amountToShiftLeft = 469 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 470 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 471 loc, dstType, op.base(), amountToShiftLeft); 472 473 // Shift the result right, filling the bits with the sign bit. 474 Value amountToShiftRight = 475 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 476 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 477 amountToShiftRight); 478 return success(); 479 } 480 }; 481 482 class BitFieldUExtractPattern 483 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 484 public: 485 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 486 487 LogicalResult 488 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, 489 ConversionPatternRewriter &rewriter) const override { 490 auto srcType = op.getType(); 491 auto dstType = typeConverter.convertType(srcType); 492 if (!dstType) 493 return failure(); 494 Location loc = op.getLoc(); 495 496 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 497 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 498 typeConverter, rewriter); 499 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 500 typeConverter, rewriter); 501 502 // Create a mask with bits set at [0, Count - 1]. 503 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 504 Value maskShiftedByCount = 505 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 506 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 507 minusOne); 508 509 // Shift `Base` by `Offset` and apply the mask on it. 510 Value shiftedBase = 511 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset); 512 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 513 return success(); 514 } 515 }; 516 517 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 518 public: 519 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 520 521 LogicalResult 522 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, 523 ConversionPatternRewriter &rewriter) const override { 524 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), 525 branchOp.getTarget()); 526 return success(); 527 } 528 }; 529 530 class BranchConditionalConversionPattern 531 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 532 public: 533 using SPIRVToLLVMConversion< 534 spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 535 536 LogicalResult 537 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, 538 ConversionPatternRewriter &rewriter) const override { 539 // If branch weights exist, map them to 32-bit integer vector. 540 ElementsAttr branchWeights = nullptr; 541 if (auto weights = op.branch_weights()) { 542 VectorType weightType = VectorType::get(2, rewriter.getI32Type()); 543 branchWeights = 544 DenseElementsAttr::get(weightType, weights.getValue().getValue()); 545 } 546 547 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 548 op, op.condition(), op.getTrueBlockArguments(), 549 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 550 op.getFalseBlock()); 551 return success(); 552 } 553 }; 554 555 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type 556 /// is an aggregate type (struct or array). Otherwise, converts to 557 /// `llvm.extractelement` that operates on vectors. 558 class CompositeExtractPattern 559 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 560 public: 561 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 562 563 LogicalResult 564 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, 565 ConversionPatternRewriter &rewriter) const override { 566 auto dstType = this->typeConverter.convertType(op.getType()); 567 if (!dstType) 568 return failure(); 569 570 Type containerType = op.composite().getType(); 571 if (containerType.isa<VectorType>()) { 572 Location loc = op.getLoc(); 573 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 574 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 575 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 576 op, dstType, adaptor.composite(), index); 577 return success(); 578 } 579 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 580 op, dstType, adaptor.composite(), op.indices()); 581 return success(); 582 } 583 }; 584 585 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type 586 /// is an aggregate type (struct or array). Otherwise, converts to 587 /// `llvm.insertelement` that operates on vectors. 588 class CompositeInsertPattern 589 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 590 public: 591 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 592 593 LogicalResult 594 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, 595 ConversionPatternRewriter &rewriter) const override { 596 auto dstType = this->typeConverter.convertType(op.getType()); 597 if (!dstType) 598 return failure(); 599 600 Type containerType = op.composite().getType(); 601 if (containerType.isa<VectorType>()) { 602 Location loc = op.getLoc(); 603 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 604 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 605 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 606 op, dstType, adaptor.composite(), adaptor.object(), index); 607 return success(); 608 } 609 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 610 op, dstType, adaptor.composite(), adaptor.object(), op.indices()); 611 return success(); 612 } 613 }; 614 615 /// Converts SPIR-V operations that have straightforward LLVM equivalent 616 /// into LLVM dialect operations. 617 template <typename SPIRVOp, typename LLVMOp> 618 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 619 public: 620 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 621 622 LogicalResult 623 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 624 ConversionPatternRewriter &rewriter) const override { 625 auto dstType = this->typeConverter.convertType(operation.getType()); 626 if (!dstType) 627 return failure(); 628 rewriter.template replaceOpWithNewOp<LLVMOp>( 629 operation, dstType, adaptor.getOperands(), operation->getAttrs()); 630 return success(); 631 } 632 }; 633 634 /// Converts `spv.ExecutionMode` into a global struct constant that holds 635 /// execution mode information. 636 class ExecutionModePattern 637 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 638 public: 639 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 640 641 LogicalResult 642 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, 643 ConversionPatternRewriter &rewriter) const override { 644 // First, create the global struct's name that would be associated with 645 // this entry point's execution mode. We set it to be: 646 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} 647 ModuleOp module = op->getParentOfType<ModuleOp>(); 648 IntegerAttr executionModeAttr = op.execution_modeAttr(); 649 std::string moduleName; 650 if (module.getName().hasValue()) 651 moduleName = "_" + module.getName().getValue().str(); 652 else 653 moduleName = ""; 654 std::string executionModeInfoName = 655 llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName, 656 op.fn().str(), executionModeAttr.getValue()); 657 658 MLIRContext *context = rewriter.getContext(); 659 OpBuilder::InsertionGuard guard(rewriter); 660 rewriter.setInsertionPointToStart(module.getBody()); 661 662 // Create a struct type, corresponding to the C struct below. 663 // struct { 664 // int32_t executionMode; 665 // int32_t values[]; // optional values 666 // }; 667 auto llvmI32Type = IntegerType::get(context, 32); 668 SmallVector<Type, 2> fields; 669 fields.push_back(llvmI32Type); 670 ArrayAttr values = op.values(); 671 if (!values.empty()) { 672 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 673 fields.push_back(arrayType); 674 } 675 auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 676 677 // Create `llvm.mlir.global` with initializer region containing one block. 678 auto global = rewriter.create<LLVM::GlobalOp>( 679 UnknownLoc::get(context), structType, /*isConstant=*/true, 680 LLVM::Linkage::External, executionModeInfoName, Attribute(), 681 /*alignment=*/0); 682 Location loc = global.getLoc(); 683 Region ®ion = global.getInitializerRegion(); 684 Block *block = rewriter.createBlock(®ion); 685 686 // Initialize the struct and set the execution mode value. 687 rewriter.setInsertionPoint(block, block->begin()); 688 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 689 Value executionMode = 690 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr); 691 structValue = rewriter.create<LLVM::InsertValueOp>( 692 loc, structType, structValue, executionMode, 693 ArrayAttr::get(context, 694 {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)})); 695 696 // Insert extra operands if they exist into execution mode info struct. 697 for (unsigned i = 0, e = values.size(); i < e; ++i) { 698 auto attr = values.getValue()[i]; 699 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 700 structValue = rewriter.create<LLVM::InsertValueOp>( 701 loc, structType, structValue, entry, 702 ArrayAttr::get(context, 703 {rewriter.getIntegerAttr(rewriter.getI32Type(), 1), 704 rewriter.getIntegerAttr(rewriter.getI32Type(), i)})); 705 } 706 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 707 rewriter.eraseOp(op); 708 return success(); 709 } 710 }; 711 712 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global 713 /// returns a pointer, whereas in LLVM dialect the global holds an actual value. 714 /// This difference is handled by `spv.mlir.addressof` and 715 /// `llvm.mlir.addressof`ops that both return a pointer. 716 class GlobalVariablePattern 717 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 718 public: 719 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion; 720 721 LogicalResult 722 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, 723 ConversionPatternRewriter &rewriter) const override { 724 // Currently, there is no support of initialization with a constant value in 725 // SPIR-V dialect. Specialization constants are not considered as well. 726 if (op.initializer()) 727 return failure(); 728 729 auto srcType = op.type().cast<spirv::PointerType>(); 730 auto dstType = typeConverter.convertType(srcType.getPointeeType()); 731 if (!dstType) 732 return failure(); 733 734 // Limit conversion to the current invocation only or `StorageBuffer` 735 // required by SPIR-V runner. 736 // This is okay because multiple invocations are not supported yet. 737 auto storageClass = srcType.getStorageClass(); 738 switch (storageClass) { 739 case spirv::StorageClass::Input: 740 case spirv::StorageClass::Private: 741 case spirv::StorageClass::Output: 742 case spirv::StorageClass::StorageBuffer: 743 case spirv::StorageClass::UniformConstant: 744 break; 745 default: 746 return failure(); 747 } 748 749 // LLVM dialect spec: "If the global value is a constant, storing into it is 750 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' 751 // storage class that is read-only. 752 bool isConstant = (storageClass == spirv::StorageClass::Input) || 753 (storageClass == spirv::StorageClass::UniformConstant); 754 // SPIR-V spec: "By default, functions and global variables are private to a 755 // module and cannot be accessed by other modules. However, a module may be 756 // written to export or import functions and global (module scope) 757 // variables.". Therefore, map 'Private' storage class to private linkage, 758 // 'Input' and 'Output' to external linkage. 759 auto linkage = storageClass == spirv::StorageClass::Private 760 ? LLVM::Linkage::Private 761 : LLVM::Linkage::External; 762 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 763 op, dstType, isConstant, linkage, op.sym_name(), Attribute(), 764 /*alignment=*/0); 765 766 // Attach location attribute if applicable 767 if (op.locationAttr()) 768 newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr()); 769 770 return success(); 771 } 772 }; 773 774 /// Converts SPIR-V cast ops that do not have straightforward LLVM 775 /// equivalent in LLVM dialect. 776 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 777 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 778 public: 779 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 780 781 LogicalResult 782 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 783 ConversionPatternRewriter &rewriter) const override { 784 785 Type fromType = operation.operand().getType(); 786 Type toType = operation.getType(); 787 788 auto dstType = this->typeConverter.convertType(toType); 789 if (!dstType) 790 return failure(); 791 792 if (getBitWidth(fromType) < getBitWidth(toType)) { 793 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType, 794 adaptor.getOperands()); 795 return success(); 796 } 797 if (getBitWidth(fromType) > getBitWidth(toType)) { 798 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType, 799 adaptor.getOperands()); 800 return success(); 801 } 802 return failure(); 803 } 804 }; 805 806 class FunctionCallPattern 807 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 808 public: 809 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 810 811 LogicalResult 812 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, 813 ConversionPatternRewriter &rewriter) const override { 814 if (callOp.getNumResults() == 0) { 815 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 816 callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs()); 817 return success(); 818 } 819 820 // Function returns a single result. 821 auto dstType = typeConverter.convertType(callOp.getType(0)); 822 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 823 callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); 824 return success(); 825 } 826 }; 827 828 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 829 template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 830 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 831 public: 832 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 833 834 LogicalResult 835 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 838 auto dstType = this->typeConverter.convertType(operation.getType()); 839 if (!dstType) 840 return failure(); 841 842 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 843 operation, dstType, predicate, operation.operand1(), 844 operation.operand2()); 845 return success(); 846 } 847 }; 848 849 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 850 template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 851 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 852 public: 853 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 854 855 LogicalResult 856 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 857 ConversionPatternRewriter &rewriter) const override { 858 859 auto dstType = this->typeConverter.convertType(operation.getType()); 860 if (!dstType) 861 return failure(); 862 863 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 864 operation, dstType, predicate, operation.operand1(), 865 operation.operand2()); 866 return success(); 867 } 868 }; 869 870 class InverseSqrtPattern 871 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> { 872 public: 873 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion; 874 875 LogicalResult 876 matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor, 877 ConversionPatternRewriter &rewriter) const override { 878 auto srcType = op.getType(); 879 auto dstType = typeConverter.convertType(srcType); 880 if (!dstType) 881 return failure(); 882 883 Location loc = op.getLoc(); 884 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 885 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); 886 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 887 return success(); 888 } 889 }; 890 891 /// Converts `spv.Load` and `spv.Store` to LLVM dialect. 892 template <typename SPIRVOp> 893 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { 894 public: 895 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 896 897 LogicalResult 898 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 899 ConversionPatternRewriter &rewriter) const override { 900 if (!op.memory_access().hasValue()) { 901 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 902 this->typeConverter, /*alignment=*/0, 903 /*isVolatile=*/false, 904 /*isNonTemporal=*/false); 905 } 906 auto memoryAccess = op.memory_access().getValue(); 907 switch (memoryAccess) { 908 case spirv::MemoryAccess::Aligned: 909 case spirv::MemoryAccess::None: 910 case spirv::MemoryAccess::Nontemporal: 911 case spirv::MemoryAccess::Volatile: { 912 unsigned alignment = 913 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; 914 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 915 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 916 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 917 this->typeConverter, alignment, isVolatile, 918 isNonTemporal); 919 } 920 default: 921 // There is no support of other memory access attributes. 922 return failure(); 923 } 924 } 925 }; 926 927 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. 928 template <typename SPIRVOp> 929 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 930 public: 931 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 932 933 LogicalResult 934 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, 935 ConversionPatternRewriter &rewriter) const override { 936 auto srcType = notOp.getType(); 937 auto dstType = this->typeConverter.convertType(srcType); 938 if (!dstType) 939 return failure(); 940 941 Location loc = notOp.getLoc(); 942 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 943 auto mask = srcType.template isa<VectorType>() 944 ? rewriter.create<LLVM::ConstantOp>( 945 loc, dstType, 946 SplatElementsAttr::get( 947 srcType.template cast<VectorType>(), minusOne)) 948 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 949 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 950 notOp.operand(), mask); 951 return success(); 952 } 953 }; 954 955 /// A template pattern that erases the given `SPIRVOp`. 956 template <typename SPIRVOp> 957 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 958 public: 959 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 960 961 LogicalResult 962 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 963 ConversionPatternRewriter &rewriter) const override { 964 rewriter.eraseOp(op); 965 return success(); 966 } 967 }; 968 969 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 970 public: 971 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 972 973 LogicalResult 974 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, 975 ConversionPatternRewriter &rewriter) const override { 976 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 977 ArrayRef<Value>()); 978 return success(); 979 } 980 }; 981 982 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 983 public: 984 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 985 986 LogicalResult 987 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, 988 ConversionPatternRewriter &rewriter) const override { 989 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 990 adaptor.getOperands()); 991 return success(); 992 } 993 }; 994 995 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should 996 /// be reachable for conversion to succeed. The structure of the loop in LLVM 997 /// dialect will be the following: 998 /// 999 /// +------------------------------------+ 1000 /// | <code before spv.mlir.loop> | 1001 /// | llvm.br ^header | 1002 /// +------------------------------------+ 1003 /// | 1004 /// +----------------+ | 1005 /// | | | 1006 /// | V V 1007 /// | +------------------------------------+ 1008 /// | | ^header: | 1009 /// | | <header code> | 1010 /// | | llvm.cond_br %cond, ^body, ^exit | 1011 /// | +------------------------------------+ 1012 /// | | 1013 /// | |----------------------+ 1014 /// | | | 1015 /// | V | 1016 /// | +------------------------------------+ | 1017 /// | | ^body: | | 1018 /// | | <body code> | | 1019 /// | | llvm.br ^continue | | 1020 /// | +------------------------------------+ | 1021 /// | | | 1022 /// | V | 1023 /// | +------------------------------------+ | 1024 /// | | ^continue: | | 1025 /// | | <continue code> | | 1026 /// | | llvm.br ^header | | 1027 /// | +------------------------------------+ | 1028 /// | | | 1029 /// +---------------+ +----------------------+ 1030 /// | 1031 /// V 1032 /// +------------------------------------+ 1033 /// | ^exit: | 1034 /// | llvm.br ^remaining | 1035 /// +------------------------------------+ 1036 /// | 1037 /// V 1038 /// +------------------------------------+ 1039 /// | ^remaining: | 1040 /// | <code after spv.mlir.loop> | 1041 /// +------------------------------------+ 1042 /// 1043 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1044 public: 1045 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1046 1047 LogicalResult 1048 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, 1049 ConversionPatternRewriter &rewriter) const override { 1050 // There is no support of loop control at the moment. 1051 if (loopOp.loop_control() != spirv::LoopControl::None) 1052 return failure(); 1053 1054 Location loc = loopOp.getLoc(); 1055 1056 // Split the current block after `spv.mlir.loop`. The remaining ops will be 1057 // used in `endBlock`. 1058 Block *currentBlock = rewriter.getBlock(); 1059 auto position = Block::iterator(loopOp); 1060 Block *endBlock = rewriter.splitBlock(currentBlock, position); 1061 1062 // Remove entry block and create a branch in the current block going to the 1063 // header block. 1064 Block *entryBlock = loopOp.getEntryBlock(); 1065 assert(entryBlock->getOperations().size() == 1); 1066 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1067 if (!brOp) 1068 return failure(); 1069 Block *headerBlock = loopOp.getHeaderBlock(); 1070 rewriter.setInsertionPointToEnd(currentBlock); 1071 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1072 rewriter.eraseBlock(entryBlock); 1073 1074 // Branch from merge block to end block. 1075 Block *mergeBlock = loopOp.getMergeBlock(); 1076 Operation *terminator = mergeBlock->getTerminator(); 1077 ValueRange terminatorOperands = terminator->getOperands(); 1078 rewriter.setInsertionPointToEnd(mergeBlock); 1079 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1080 1081 rewriter.inlineRegionBefore(loopOp.body(), endBlock); 1082 rewriter.replaceOp(loopOp, endBlock->getArguments()); 1083 return success(); 1084 } 1085 }; 1086 1087 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header 1088 /// block. All blocks within selection should be reachable for conversion to 1089 /// succeed. 1090 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1091 public: 1092 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1093 1094 LogicalResult 1095 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, 1096 ConversionPatternRewriter &rewriter) const override { 1097 // There is no support for `Flatten` or `DontFlatten` selection control at 1098 // the moment. This are just compiler hints and can be performed during the 1099 // optimization passes. 1100 if (op.selection_control() != spirv::SelectionControl::None) 1101 return failure(); 1102 1103 // `spv.mlir.selection` should have at least two blocks: one selection 1104 // header block and one merge block. If no blocks are present, or control 1105 // flow branches straight to merge block (two blocks are present), the op is 1106 // redundant and it is erased. 1107 if (op.body().getBlocks().size() <= 2) { 1108 rewriter.eraseOp(op); 1109 return success(); 1110 } 1111 1112 Location loc = op.getLoc(); 1113 1114 // Split the current block after `spv.mlir.selection`. The remaining ops 1115 // will be used in `continueBlock`. 1116 auto *currentBlock = rewriter.getInsertionBlock(); 1117 rewriter.setInsertionPointAfter(op); 1118 auto position = rewriter.getInsertionPoint(); 1119 auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1120 1121 // Extract conditional branch information from the header block. By SPIR-V 1122 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` 1123 // op. Note that `spv.Switch op` is not supported at the moment in the 1124 // SPIR-V dialect. Remove this block when finished. 1125 auto *headerBlock = op.getHeaderBlock(); 1126 assert(headerBlock->getOperations().size() == 1); 1127 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1128 headerBlock->getOperations().front()); 1129 if (!condBrOp) 1130 return failure(); 1131 rewriter.eraseBlock(headerBlock); 1132 1133 // Branch from merge block to continue block. 1134 auto *mergeBlock = op.getMergeBlock(); 1135 Operation *terminator = mergeBlock->getTerminator(); 1136 ValueRange terminatorOperands = terminator->getOperands(); 1137 rewriter.setInsertionPointToEnd(mergeBlock); 1138 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1139 1140 // Link current block to `true` and `false` blocks within the selection. 1141 Block *trueBlock = condBrOp.getTrueBlock(); 1142 Block *falseBlock = condBrOp.getFalseBlock(); 1143 rewriter.setInsertionPointToEnd(currentBlock); 1144 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock, 1145 condBrOp.trueTargetOperands(), falseBlock, 1146 condBrOp.falseTargetOperands()); 1147 1148 rewriter.inlineRegionBefore(op.body(), continueBlock); 1149 rewriter.replaceOp(op, continueBlock->getArguments()); 1150 return success(); 1151 } 1152 }; 1153 1154 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1155 /// puts a restriction on `Shift` and `Base` to have the same bit width, 1156 /// `Shift` is zero or sign extended to match this specification. Cases when 1157 /// `Shift` bit width > `Base` bit width are considered to be illegal. 1158 template <typename SPIRVOp, typename LLVMOp> 1159 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1160 public: 1161 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1162 1163 LogicalResult 1164 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 1165 ConversionPatternRewriter &rewriter) const override { 1166 1167 auto dstType = this->typeConverter.convertType(operation.getType()); 1168 if (!dstType) 1169 return failure(); 1170 1171 Type op1Type = operation.operand1().getType(); 1172 Type op2Type = operation.operand2().getType(); 1173 1174 if (op1Type == op2Type) { 1175 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, 1176 adaptor.getOperands()); 1177 return success(); 1178 } 1179 1180 Location loc = operation.getLoc(); 1181 Value extended; 1182 if (isUnsignedIntegerOrVector(op2Type)) { 1183 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType, 1184 adaptor.operand2()); 1185 } else { 1186 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType, 1187 adaptor.operand2()); 1188 } 1189 Value result = rewriter.template create<LLVMOp>( 1190 loc, dstType, adaptor.operand1(), extended); 1191 rewriter.replaceOp(operation, result); 1192 return success(); 1193 } 1194 }; 1195 1196 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> { 1197 public: 1198 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion; 1199 1200 LogicalResult 1201 matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor, 1202 ConversionPatternRewriter &rewriter) const override { 1203 auto dstType = typeConverter.convertType(tanOp.getType()); 1204 if (!dstType) 1205 return failure(); 1206 1207 Location loc = tanOp.getLoc(); 1208 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); 1209 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); 1210 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1211 return success(); 1212 } 1213 }; 1214 1215 /// Convert `spv.Tanh` to 1216 /// 1217 /// exp(2x) - 1 1218 /// ----------- 1219 /// exp(2x) + 1 1220 /// 1221 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> { 1222 public: 1223 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion; 1224 1225 LogicalResult 1226 matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor, 1227 ConversionPatternRewriter &rewriter) const override { 1228 auto srcType = tanhOp.getType(); 1229 auto dstType = typeConverter.convertType(srcType); 1230 if (!dstType) 1231 return failure(); 1232 1233 Location loc = tanhOp.getLoc(); 1234 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1235 Value multiplied = 1236 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); 1237 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1238 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1239 Value numerator = 1240 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1241 Value denominator = 1242 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1243 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1244 denominator); 1245 return success(); 1246 } 1247 }; 1248 1249 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1250 public: 1251 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1252 1253 LogicalResult 1254 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, 1255 ConversionPatternRewriter &rewriter) const override { 1256 auto srcType = varOp.getType(); 1257 // Initialization is supported for scalars and vectors only. 1258 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType(); 1259 auto init = varOp.initializer(); 1260 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>()) 1261 return failure(); 1262 1263 auto dstType = typeConverter.convertType(srcType); 1264 if (!dstType) 1265 return failure(); 1266 1267 Location loc = varOp.getLoc(); 1268 Value size = createI32ConstantOf(loc, rewriter, 1); 1269 if (!init) { 1270 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size); 1271 return success(); 1272 } 1273 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size); 1274 rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated); 1275 rewriter.replaceOp(varOp, allocated); 1276 return success(); 1277 } 1278 }; 1279 1280 //===----------------------------------------------------------------------===// 1281 // FuncOp conversion 1282 //===----------------------------------------------------------------------===// 1283 1284 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1285 public: 1286 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1287 1288 LogicalResult 1289 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 1290 ConversionPatternRewriter &rewriter) const override { 1291 1292 // Convert function signature. At the moment LLVMType converter is enough 1293 // for currently supported types. 1294 auto funcType = funcOp.getFunctionType(); 1295 TypeConverter::SignatureConversion signatureConverter( 1296 funcType.getNumInputs()); 1297 auto llvmType = typeConverter.convertFunctionSignature( 1298 funcType, /*isVariadic=*/false, signatureConverter); 1299 if (!llvmType) 1300 return failure(); 1301 1302 // Create a new `LLVMFuncOp` 1303 Location loc = funcOp.getLoc(); 1304 StringRef name = funcOp.getName(); 1305 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1306 1307 // Convert SPIR-V Function Control to equivalent LLVM function attribute 1308 MLIRContext *context = funcOp.getContext(); 1309 switch (funcOp.function_control()) { 1310 #define DISPATCH(functionControl, llvmAttr) \ 1311 case functionControl: \ 1312 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ 1313 break; 1314 1315 DISPATCH(spirv::FunctionControl::Inline, 1316 StringAttr::get(context, "alwaysinline")); 1317 DISPATCH(spirv::FunctionControl::DontInline, 1318 StringAttr::get(context, "noinline")); 1319 DISPATCH(spirv::FunctionControl::Pure, 1320 StringAttr::get(context, "readonly")); 1321 DISPATCH(spirv::FunctionControl::Const, 1322 StringAttr::get(context, "readnone")); 1323 1324 #undef DISPATCH 1325 1326 // Default: if `spirv::FunctionControl::None`, then no attributes are 1327 // needed. 1328 default: 1329 break; 1330 } 1331 1332 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1333 newFuncOp.end()); 1334 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 1335 &signatureConverter))) { 1336 return failure(); 1337 } 1338 rewriter.eraseOp(funcOp); 1339 return success(); 1340 } 1341 }; 1342 1343 //===----------------------------------------------------------------------===// 1344 // ModuleOp conversion 1345 //===----------------------------------------------------------------------===// 1346 1347 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1348 public: 1349 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1350 1351 LogicalResult 1352 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, 1353 ConversionPatternRewriter &rewriter) const override { 1354 1355 auto newModuleOp = 1356 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 1357 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); 1358 1359 // Remove the terminator block that was automatically added by builder 1360 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1361 rewriter.eraseOp(spvModuleOp); 1362 return success(); 1363 } 1364 }; 1365 1366 //===----------------------------------------------------------------------===// 1367 // VectorShuffleOp conversion 1368 //===----------------------------------------------------------------------===// 1369 1370 class VectorShufflePattern 1371 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { 1372 public: 1373 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; 1374 LogicalResult 1375 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, 1376 ConversionPatternRewriter &rewriter) const override { 1377 Location loc = op.getLoc(); 1378 auto components = adaptor.components(); 1379 auto vector1 = adaptor.vector1(); 1380 auto vector2 = adaptor.vector2(); 1381 int vector1Size = vector1.getType().cast<VectorType>().getNumElements(); 1382 int vector2Size = vector2.getType().cast<VectorType>().getNumElements(); 1383 if (vector1Size == vector2Size) { 1384 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2, 1385 components); 1386 return success(); 1387 } 1388 1389 auto dstType = typeConverter.convertType(op.getType()); 1390 auto scalarType = dstType.cast<VectorType>().getElementType(); 1391 auto componentsArray = components.getValue(); 1392 auto *context = rewriter.getContext(); 1393 auto llvmI32Type = IntegerType::get(context, 32); 1394 Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); 1395 for (unsigned i = 0; i < componentsArray.size(); i++) { 1396 if (componentsArray[i].isa<IntegerAttr>()) 1397 op.emitError("unable to support non-constant component"); 1398 1399 int indexVal = componentsArray[i].cast<IntegerAttr>().getInt(); 1400 if (indexVal == -1) 1401 continue; 1402 1403 int offsetVal = 0; 1404 Value baseVector = vector1; 1405 if (indexVal >= vector1Size) { 1406 offsetVal = vector1Size; 1407 baseVector = vector2; 1408 } 1409 1410 Value dstIndex = rewriter.create<LLVM::ConstantOp>( 1411 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); 1412 Value index = rewriter.create<LLVM::ConstantOp>( 1413 loc, llvmI32Type, 1414 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); 1415 1416 auto extractOp = rewriter.create<LLVM::ExtractElementOp>( 1417 loc, scalarType, baseVector, index); 1418 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, 1419 extractOp, dstIndex); 1420 } 1421 rewriter.replaceOp(op, targetOp); 1422 return success(); 1423 } 1424 }; 1425 } // namespace 1426 1427 //===----------------------------------------------------------------------===// 1428 // Pattern population 1429 //===----------------------------------------------------------------------===// 1430 1431 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { 1432 typeConverter.addConversion([&](spirv::ArrayType type) { 1433 return convertArrayType(type, typeConverter); 1434 }); 1435 typeConverter.addConversion([&](spirv::PointerType type) { 1436 return convertPointerType(type, typeConverter); 1437 }); 1438 typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1439 return convertRuntimeArrayType(type, typeConverter); 1440 }); 1441 typeConverter.addConversion([&](spirv::StructType type) { 1442 return convertStructType(type, typeConverter); 1443 }); 1444 } 1445 1446 void mlir::populateSPIRVToLLVMConversionPatterns( 1447 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1448 patterns.add< 1449 // Arithmetic ops 1450 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1451 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1452 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1453 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1454 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1455 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1456 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1457 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1458 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1459 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1460 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1461 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1462 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1463 1464 // Bitwise ops 1465 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1466 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1467 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1468 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1469 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1470 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1471 NotPattern<spirv::NotOp>, 1472 1473 // Cast ops 1474 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>, 1475 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1476 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1477 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1478 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1479 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1480 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1481 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1482 1483 // Comparison ops 1484 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1485 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1486 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1487 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1488 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1489 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1490 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1491 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1492 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1493 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1494 FComparePattern<spirv::FUnordGreaterThanEqualOp, 1495 LLVM::FCmpPredicate::uge>, 1496 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1497 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1498 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1499 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1500 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1501 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1502 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1503 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1504 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1505 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1506 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1507 1508 // Constant op 1509 ConstantScalarAndVectorPattern, 1510 1511 // Control Flow ops 1512 BranchConversionPattern, BranchConditionalConversionPattern, 1513 FunctionCallPattern, LoopPattern, SelectionPattern, 1514 ErasePattern<spirv::MergeOp>, 1515 1516 // Entry points and execution mode are handled separately. 1517 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1518 1519 // GLSL extended instruction set ops 1520 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, 1521 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, 1522 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, 1523 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, 1524 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>, 1525 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>, 1526 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>, 1527 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, 1528 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, 1529 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>, 1530 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>, 1531 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, 1532 InverseSqrtPattern, TanPattern, TanhPattern, 1533 1534 // Logical ops 1535 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1536 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1537 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1538 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1539 NotPattern<spirv::LogicalNotOp>, 1540 1541 // Memory ops 1542 AccessChainPattern, AddressOfPattern, GlobalVariablePattern, 1543 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>, 1544 VariablePattern, 1545 1546 // Miscellaneous ops 1547 CompositeExtractPattern, CompositeInsertPattern, 1548 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1549 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 1550 VectorShufflePattern, 1551 1552 // Shift ops 1553 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1554 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1555 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1556 1557 // Return ops 1558 ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter); 1559 } 1560 1561 void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1562 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1563 patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter); 1564 } 1565 1566 void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1567 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1568 patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter); 1569 } 1570 1571 //===----------------------------------------------------------------------===// 1572 // Pre-conversion hooks 1573 //===----------------------------------------------------------------------===// 1574 1575 /// Hook for descriptor set and binding number encoding. 1576 static constexpr StringRef kBinding = "binding"; 1577 static constexpr StringRef kDescriptorSet = "descriptor_set"; 1578 void mlir::encodeBindAttribute(ModuleOp module) { 1579 auto spvModules = module.getOps<spirv::ModuleOp>(); 1580 for (auto spvModule : spvModules) { 1581 spvModule.walk([&](spirv::GlobalVariableOp op) { 1582 IntegerAttr descriptorSet = 1583 op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1584 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1585 // For every global variable in the module, get the ones with descriptor 1586 // set and binding numbers. 1587 if (descriptorSet && binding) { 1588 // Encode these numbers into the variable's symbolic name. If the 1589 // SPIR-V module has a name, add it at the beginning. 1590 auto moduleAndName = spvModule.getName().hasValue() 1591 ? spvModule.getName().getValue().str() + "_" + 1592 op.sym_name().str() 1593 : op.sym_name().str(); 1594 std::string name = 1595 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 1596 std::to_string(descriptorSet.getInt()), 1597 std::to_string(binding.getInt())); 1598 auto nameAttr = StringAttr::get(op->getContext(), name); 1599 1600 // Replace all symbol uses and set the new symbol name. Finally, remove 1601 // descriptor set and binding attributes. 1602 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) 1603 op.emitError("unable to replace all symbol uses for ") << name; 1604 SymbolTable::setSymbolName(op, nameAttr); 1605 op->removeAttr(kDescriptorSet); 1606 op->removeAttr(kBinding); 1607 } 1608 }); 1609 } 1610 } 1611