1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Support/LogicalResult.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 #define DEBUG_TYPE "spirv-to-llvm-pattern" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Utility functions 34 //===----------------------------------------------------------------------===// 35 36 /// Returns true if the given type is a signed integer or vector type. 37 static bool isSignedIntegerOrVector(Type type) { 38 if (type.isSignedInteger()) 39 return true; 40 if (auto vecType = type.dyn_cast<VectorType>()) 41 return vecType.getElementType().isSignedInteger(); 42 return false; 43 } 44 45 /// Returns true if the given type is an unsigned integer or vector type 46 static bool isUnsignedIntegerOrVector(Type type) { 47 if (type.isUnsignedInteger()) 48 return true; 49 if (auto vecType = type.dyn_cast<VectorType>()) 50 return vecType.getElementType().isUnsignedInteger(); 51 return false; 52 } 53 54 /// Returns the bit width of integer, float or vector of float or integer values 55 static unsigned getBitWidth(Type type) { 56 assert((type.isIntOrFloat() || type.isa<VectorType>()) && 57 "bitwidth is not supported for this type"); 58 if (type.isIntOrFloat()) 59 return type.getIntOrFloatBitWidth(); 60 auto vecType = type.dyn_cast<VectorType>(); 61 auto elementType = vecType.getElementType(); 62 assert(elementType.isIntOrFloat() && 63 "only integers and floats have a bitwidth"); 64 return elementType.getIntOrFloatBitWidth(); 65 } 66 67 /// Returns the bit width of LLVMType integer or vector. 68 static unsigned getLLVMTypeBitWidth(Type type) { 69 return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) 70 : type) 71 .cast<IntegerType>() 72 .getWidth(); 73 } 74 75 /// Creates `IntegerAttribute` with all bits set for given type 76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 77 if (auto vecType = type.dyn_cast<VectorType>()) { 78 auto integerType = vecType.getElementType().cast<IntegerType>(); 79 return builder.getIntegerAttr(integerType, -1); 80 } 81 auto integerType = type.cast<IntegerType>(); 82 return builder.getIntegerAttr(integerType, -1); 83 } 84 85 /// Creates `llvm.mlir.constant` with all bits set for the given type. 86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 87 PatternRewriter &rewriter) { 88 if (srcType.isa<VectorType>()) { 89 return rewriter.create<LLVM::ConstantOp>( 90 loc, dstType, 91 SplatElementsAttr::get(srcType.cast<ShapedType>(), 92 minusOneIntegerAttribute(srcType, rewriter))); 93 } 94 return rewriter.create<LLVM::ConstantOp>( 95 loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 96 } 97 98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 99 static Value createFPConstant(Location loc, Type srcType, Type dstType, 100 PatternRewriter &rewriter, double value) { 101 if (auto vecType = srcType.dyn_cast<VectorType>()) { 102 auto floatType = vecType.getElementType().cast<FloatType>(); 103 return rewriter.create<LLVM::ConstantOp>( 104 loc, dstType, 105 SplatElementsAttr::get(vecType, 106 rewriter.getFloatAttr(floatType, value))); 107 } 108 auto floatType = srcType.cast<FloatType>(); 109 return rewriter.create<LLVM::ConstantOp>( 110 loc, dstType, rewriter.getFloatAttr(floatType, value)); 111 } 112 113 /// Utility function for bitfield ops: 114 /// - `BitFieldInsert` 115 /// - `BitFieldSExtract` 116 /// - `BitFieldUExtract` 117 /// Truncates or extends the value. If the bitwidth of the value is the same as 118 /// `llvmType` bitwidth, the value remains unchanged. 119 static Value optionallyTruncateOrExtend(Location loc, Value value, 120 Type llvmType, 121 PatternRewriter &rewriter) { 122 auto srcType = value.getType(); 123 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 124 unsigned valueBitWidth = LLVM::isCompatibleType(srcType) 125 ? getLLVMTypeBitWidth(srcType) 126 : getBitWidth(srcType); 127 128 if (valueBitWidth < targetBitWidth) 129 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 130 // If the bit widths of `Count` and `Offset` are greater than the bit width 131 // of the target type, they are truncated. Truncation is safe since `Count` 132 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 133 // both values can be expressed in 8 bits. 134 if (valueBitWidth > targetBitWidth) 135 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 136 return value; 137 } 138 139 /// Broadcasts the value to vector with `numElements` number of elements. 140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 141 LLVMTypeConverter &typeConverter, 142 ConversionPatternRewriter &rewriter) { 143 auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 144 auto llvmVectorType = typeConverter.convertType(vectorType); 145 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 146 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 147 for (unsigned i = 0; i < numElements; ++i) { 148 auto index = rewriter.create<LLVM::ConstantOp>( 149 loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 150 broadcasted = rewriter.create<LLVM::InsertElementOp>( 151 loc, llvmVectorType, broadcasted, toBroadcast, index); 152 } 153 return broadcasted; 154 } 155 156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 157 static Value optionallyBroadcast(Location loc, Value value, Type srcType, 158 LLVMTypeConverter &typeConverter, 159 ConversionPatternRewriter &rewriter) { 160 if (auto vectorType = srcType.dyn_cast<VectorType>()) { 161 unsigned numElements = vectorType.getNumElements(); 162 return broadcast(loc, value, numElements, typeConverter, rewriter); 163 } 164 return value; 165 } 166 167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 168 /// `BitFieldUExtract`. 169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 170 /// a vector type, construct a vector that has: 171 /// - same number of elements as `Base` 172 /// - each element has the type that is the same as the type of `Offset` or 173 /// `Count` 174 /// - each element has the same value as `Offset` or `Count` 175 /// Then cast `Offset` and `Count` if their bit width is different 176 /// from `Base` bit width. 177 static Value processCountOrOffset(Location loc, Value value, Type srcType, 178 Type dstType, LLVMTypeConverter &converter, 179 ConversionPatternRewriter &rewriter) { 180 Value broadcasted = 181 optionallyBroadcast(loc, value, srcType, converter, rewriter); 182 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 183 } 184 185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 186 /// offset to LLVM struct. Otherwise, the conversion is not supported. 187 static Optional<Type> 188 convertStructTypeWithOffset(spirv::StructType type, 189 LLVMTypeConverter &converter) { 190 if (type != VulkanLayoutUtils::decorateType(type)) 191 return llvm::None; 192 193 auto elementsVector = llvm::to_vector<8>( 194 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 195 return converter.convertType(elementType); 196 })); 197 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 198 /*isPacked=*/false); 199 } 200 201 /// Converts SPIR-V struct with no offset to packed LLVM struct. 202 static Type convertStructTypePacked(spirv::StructType type, 203 LLVMTypeConverter &converter) { 204 auto elementsVector = llvm::to_vector<8>( 205 llvm::map_range(type.getElementTypes(), [&](Type elementType) { 206 return converter.convertType(elementType); 207 })); 208 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 209 /*isPacked=*/true); 210 } 211 212 /// Creates LLVM dialect constant with the given value. 213 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 214 unsigned value) { 215 return rewriter.create<LLVM::ConstantOp>( 216 loc, IntegerType::get(rewriter.getContext(), 32), 217 rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 218 } 219 220 /// Utility for `spv.Load` and `spv.Store` conversion. 221 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, 222 ConversionPatternRewriter &rewriter, 223 LLVMTypeConverter &typeConverter, 224 unsigned alignment, bool isVolatile, 225 bool isNonTemporal) { 226 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 227 auto dstType = typeConverter.convertType(loadOp.getType()); 228 if (!dstType) 229 return failure(); 230 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 231 loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, 232 isVolatile, isNonTemporal); 233 return success(); 234 } 235 auto storeOp = cast<spirv::StoreOp>(op); 236 spirv::StoreOpAdaptor adaptor(operands); 237 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(), 238 adaptor.ptr(), alignment, 239 isVolatile, isNonTemporal); 240 return success(); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Type conversion 245 //===----------------------------------------------------------------------===// 246 247 /// Converts SPIR-V array type to LLVM array. Natural stride (according to 248 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 249 /// when converting ops that manipulate array types. 250 static Optional<Type> convertArrayType(spirv::ArrayType type, 251 TypeConverter &converter) { 252 unsigned stride = type.getArrayStride(); 253 Type elementType = type.getElementType(); 254 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes(); 255 if (stride != 0 && 256 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) 257 return llvm::None; 258 259 auto llvmElementType = converter.convertType(elementType); 260 unsigned numElements = type.getNumElements(); 261 return LLVM::LLVMArrayType::get(llvmElementType, numElements); 262 } 263 264 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 265 /// modelled at the moment. 266 static Type convertPointerType(spirv::PointerType type, 267 TypeConverter &converter) { 268 auto pointeeType = converter.convertType(type.getPointeeType()); 269 return LLVM::LLVMPointerType::get(pointeeType); 270 } 271 272 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 273 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 274 /// no modelling of array stride at the moment. 275 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 276 TypeConverter &converter) { 277 if (type.getArrayStride() != 0) 278 return llvm::None; 279 auto elementType = converter.convertType(type.getElementType()); 280 return LLVM::LLVMArrayType::get(elementType, 0); 281 } 282 283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 284 /// member decorations. Also, only natural offset is supported. 285 static Optional<Type> convertStructType(spirv::StructType type, 286 LLVMTypeConverter &converter) { 287 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 288 type.getMemberDecorations(memberDecorations); 289 if (!memberDecorations.empty()) 290 return llvm::None; 291 if (type.hasOffset()) 292 return convertStructTypeWithOffset(type, converter); 293 return convertStructTypePacked(type, converter); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // Operation conversion 298 //===----------------------------------------------------------------------===// 299 300 namespace { 301 302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 303 public: 304 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 305 306 LogicalResult 307 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, 308 ConversionPatternRewriter &rewriter) const override { 309 auto dstType = typeConverter.convertType(op.component_ptr().getType()); 310 if (!dstType) 311 return failure(); 312 // To use GEP we need to add a first 0 index to go through the pointer. 313 auto indices = llvm::to_vector<4>(adaptor.indices()); 314 Type indexType = op.indices().front().getType(); 315 auto llvmIndexType = typeConverter.convertType(indexType); 316 if (!llvmIndexType) 317 return failure(); 318 Value zero = rewriter.create<LLVM::ConstantOp>( 319 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 320 indices.insert(indices.begin(), zero); 321 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(), 322 indices); 323 return success(); 324 } 325 }; 326 327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 328 public: 329 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 330 331 LogicalResult 332 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, 333 ConversionPatternRewriter &rewriter) const override { 334 auto dstType = typeConverter.convertType(op.pointer().getType()); 335 if (!dstType) 336 return failure(); 337 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable()); 338 return success(); 339 } 340 }; 341 342 class BitFieldInsertPattern 343 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 344 public: 345 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 346 347 LogicalResult 348 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, 349 ConversionPatternRewriter &rewriter) const override { 350 auto srcType = op.getType(); 351 auto dstType = typeConverter.convertType(srcType); 352 if (!dstType) 353 return failure(); 354 Location loc = op.getLoc(); 355 356 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 357 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 358 typeConverter, rewriter); 359 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 360 typeConverter, rewriter); 361 362 // Create a mask with bits set outside [Offset, Offset + Count - 1]. 363 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 364 Value maskShiftedByCount = 365 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 366 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 367 maskShiftedByCount, minusOne); 368 Value maskShiftedByCountAndOffset = 369 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 370 Value mask = rewriter.create<LLVM::XOrOp>( 371 loc, dstType, maskShiftedByCountAndOffset, minusOne); 372 373 // Extract unchanged bits from the `Base` that are outside of 374 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 375 Value baseAndMask = 376 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask); 377 Value insertShiftedByOffset = 378 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset); 379 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 380 insertShiftedByOffset); 381 return success(); 382 } 383 }; 384 385 /// Converts SPIR-V ConstantOp with scalar or vector type. 386 class ConstantScalarAndVectorPattern 387 : public SPIRVToLLVMConversion<spirv::ConstantOp> { 388 public: 389 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 390 391 LogicalResult 392 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const override { 394 auto srcType = constOp.getType(); 395 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat()) 396 return failure(); 397 398 auto dstType = typeConverter.convertType(srcType); 399 if (!dstType) 400 return failure(); 401 402 // SPIR-V constant can be a signed/unsigned integer, which has to be 403 // casted to signless integer when converting to LLVM dialect. Removing the 404 // sign bit may have unexpected behaviour. However, it is better to handle 405 // it case-by-case, given that the purpose of the conversion is not to 406 // cover all possible corner cases. 407 if (isSignedIntegerOrVector(srcType) || 408 isUnsignedIntegerOrVector(srcType)) { 409 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 410 411 if (srcType.isa<VectorType>()) { 412 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); 413 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 414 constOp, dstType, 415 dstElementsAttr.mapValues( 416 signlessType, [&](const APInt &value) { return value; })); 417 return success(); 418 } 419 auto srcAttr = constOp.value().cast<IntegerAttr>(); 420 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 421 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 422 return success(); 423 } 424 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 425 constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); 426 return success(); 427 } 428 }; 429 430 class BitFieldSExtractPattern 431 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 432 public: 433 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 434 435 LogicalResult 436 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, 437 ConversionPatternRewriter &rewriter) const override { 438 auto srcType = op.getType(); 439 auto dstType = typeConverter.convertType(srcType); 440 if (!dstType) 441 return failure(); 442 Location loc = op.getLoc(); 443 444 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 445 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 446 typeConverter, rewriter); 447 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 448 typeConverter, rewriter); 449 450 // Create a constant that holds the size of the `Base`. 451 IntegerType integerType; 452 if (auto vecType = srcType.dyn_cast<VectorType>()) 453 integerType = vecType.getElementType().cast<IntegerType>(); 454 else 455 integerType = srcType.cast<IntegerType>(); 456 457 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 458 Value size = 459 srcType.isa<VectorType>() 460 ? rewriter.create<LLVM::ConstantOp>( 461 loc, dstType, 462 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize)) 463 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 464 465 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 466 // at Offset + Count - 1 is the most significant bit now. 467 Value countPlusOffset = 468 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 469 Value amountToShiftLeft = 470 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 471 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 472 loc, dstType, op.base(), amountToShiftLeft); 473 474 // Shift the result right, filling the bits with the sign bit. 475 Value amountToShiftRight = 476 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 477 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 478 amountToShiftRight); 479 return success(); 480 } 481 }; 482 483 class BitFieldUExtractPattern 484 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 485 public: 486 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 487 488 LogicalResult 489 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, 490 ConversionPatternRewriter &rewriter) const override { 491 auto srcType = op.getType(); 492 auto dstType = typeConverter.convertType(srcType); 493 if (!dstType) 494 return failure(); 495 Location loc = op.getLoc(); 496 497 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 498 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, 499 typeConverter, rewriter); 500 Value count = processCountOrOffset(loc, op.count(), srcType, dstType, 501 typeConverter, rewriter); 502 503 // Create a mask with bits set at [0, Count - 1]. 504 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 505 Value maskShiftedByCount = 506 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 507 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 508 minusOne); 509 510 // Shift `Base` by `Offset` and apply the mask on it. 511 Value shiftedBase = 512 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset); 513 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 514 return success(); 515 } 516 }; 517 518 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 519 public: 520 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 521 522 LogicalResult 523 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, 524 ConversionPatternRewriter &rewriter) const override { 525 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), 526 branchOp.getTarget()); 527 return success(); 528 } 529 }; 530 531 class BranchConditionalConversionPattern 532 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 533 public: 534 using SPIRVToLLVMConversion< 535 spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 536 537 LogicalResult 538 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, 539 ConversionPatternRewriter &rewriter) const override { 540 // If branch weights exist, map them to 32-bit integer vector. 541 ElementsAttr branchWeights = nullptr; 542 if (auto weights = op.branch_weights()) { 543 VectorType weightType = VectorType::get(2, rewriter.getI32Type()); 544 branchWeights = 545 DenseElementsAttr::get(weightType, weights.getValue().getValue()); 546 } 547 548 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 549 op, op.condition(), op.getTrueBlockArguments(), 550 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 551 op.getFalseBlock()); 552 return success(); 553 } 554 }; 555 556 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type 557 /// is an aggregate type (struct or array). Otherwise, converts to 558 /// `llvm.extractelement` that operates on vectors. 559 class CompositeExtractPattern 560 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 561 public: 562 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 563 564 LogicalResult 565 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, 566 ConversionPatternRewriter &rewriter) const override { 567 auto dstType = this->typeConverter.convertType(op.getType()); 568 if (!dstType) 569 return failure(); 570 571 Type containerType = op.composite().getType(); 572 if (containerType.isa<VectorType>()) { 573 Location loc = op.getLoc(); 574 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 575 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 576 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 577 op, dstType, adaptor.composite(), index); 578 return success(); 579 } 580 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 581 op, dstType, adaptor.composite(), op.indices()); 582 return success(); 583 } 584 }; 585 586 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type 587 /// is an aggregate type (struct or array). Otherwise, converts to 588 /// `llvm.insertelement` that operates on vectors. 589 class CompositeInsertPattern 590 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 591 public: 592 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 593 594 LogicalResult 595 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, 596 ConversionPatternRewriter &rewriter) const override { 597 auto dstType = this->typeConverter.convertType(op.getType()); 598 if (!dstType) 599 return failure(); 600 601 Type containerType = op.composite().getType(); 602 if (containerType.isa<VectorType>()) { 603 Location loc = op.getLoc(); 604 IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); 605 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 606 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 607 op, dstType, adaptor.composite(), adaptor.object(), index); 608 return success(); 609 } 610 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 611 op, dstType, adaptor.composite(), adaptor.object(), op.indices()); 612 return success(); 613 } 614 }; 615 616 /// Converts SPIR-V operations that have straightforward LLVM equivalent 617 /// into LLVM dialect operations. 618 template <typename SPIRVOp, typename LLVMOp> 619 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 620 public: 621 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 622 623 LogicalResult 624 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 625 ConversionPatternRewriter &rewriter) const override { 626 auto dstType = this->typeConverter.convertType(operation.getType()); 627 if (!dstType) 628 return failure(); 629 rewriter.template replaceOpWithNewOp<LLVMOp>( 630 operation, dstType, adaptor.getOperands(), operation->getAttrs()); 631 return success(); 632 } 633 }; 634 635 /// Converts `spv.ExecutionMode` into a global struct constant that holds 636 /// execution mode information. 637 class ExecutionModePattern 638 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 639 public: 640 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 641 642 LogicalResult 643 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, 644 ConversionPatternRewriter &rewriter) const override { 645 // First, create the global struct's name that would be associated with 646 // this entry point's execution mode. We set it to be: 647 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} 648 ModuleOp module = op->getParentOfType<ModuleOp>(); 649 IntegerAttr executionModeAttr = op.execution_modeAttr(); 650 std::string moduleName; 651 if (module.getName().hasValue()) 652 moduleName = "_" + module.getName().getValue().str(); 653 else 654 moduleName = ""; 655 std::string executionModeInfoName = 656 llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName, 657 op.fn().str(), executionModeAttr.getValue()); 658 659 MLIRContext *context = rewriter.getContext(); 660 OpBuilder::InsertionGuard guard(rewriter); 661 rewriter.setInsertionPointToStart(module.getBody()); 662 663 // Create a struct type, corresponding to the C struct below. 664 // struct { 665 // int32_t executionMode; 666 // int32_t values[]; // optional values 667 // }; 668 auto llvmI32Type = IntegerType::get(context, 32); 669 SmallVector<Type, 2> fields; 670 fields.push_back(llvmI32Type); 671 ArrayAttr values = op.values(); 672 if (!values.empty()) { 673 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 674 fields.push_back(arrayType); 675 } 676 auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 677 678 // Create `llvm.mlir.global` with initializer region containing one block. 679 auto global = rewriter.create<LLVM::GlobalOp>( 680 UnknownLoc::get(context), structType, /*isConstant=*/true, 681 LLVM::Linkage::External, executionModeInfoName, Attribute(), 682 /*alignment=*/0); 683 Location loc = global.getLoc(); 684 Region ®ion = global.getInitializerRegion(); 685 Block *block = rewriter.createBlock(®ion); 686 687 // Initialize the struct and set the execution mode value. 688 rewriter.setInsertionPoint(block, block->begin()); 689 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 690 Value executionMode = 691 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr); 692 structValue = rewriter.create<LLVM::InsertValueOp>( 693 loc, structType, structValue, executionMode, 694 ArrayAttr::get(context, 695 {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)})); 696 697 // Insert extra operands if they exist into execution mode info struct. 698 for (unsigned i = 0, e = values.size(); i < e; ++i) { 699 auto attr = values.getValue()[i]; 700 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 701 structValue = rewriter.create<LLVM::InsertValueOp>( 702 loc, structType, structValue, entry, 703 ArrayAttr::get(context, 704 {rewriter.getIntegerAttr(rewriter.getI32Type(), 1), 705 rewriter.getIntegerAttr(rewriter.getI32Type(), i)})); 706 } 707 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 708 rewriter.eraseOp(op); 709 return success(); 710 } 711 }; 712 713 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global 714 /// returns a pointer, whereas in LLVM dialect the global holds an actual value. 715 /// This difference is handled by `spv.mlir.addressof` and 716 /// `llvm.mlir.addressof`ops that both return a pointer. 717 class GlobalVariablePattern 718 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 719 public: 720 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion; 721 722 LogicalResult 723 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, 724 ConversionPatternRewriter &rewriter) const override { 725 // Currently, there is no support of initialization with a constant value in 726 // SPIR-V dialect. Specialization constants are not considered as well. 727 if (op.initializer()) 728 return failure(); 729 730 auto srcType = op.type().cast<spirv::PointerType>(); 731 auto dstType = typeConverter.convertType(srcType.getPointeeType()); 732 if (!dstType) 733 return failure(); 734 735 // Limit conversion to the current invocation only or `StorageBuffer` 736 // required by SPIR-V runner. 737 // This is okay because multiple invocations are not supported yet. 738 auto storageClass = srcType.getStorageClass(); 739 switch (storageClass) { 740 case spirv::StorageClass::Input: 741 case spirv::StorageClass::Private: 742 case spirv::StorageClass::Output: 743 case spirv::StorageClass::StorageBuffer: 744 case spirv::StorageClass::UniformConstant: 745 break; 746 default: 747 return failure(); 748 } 749 750 // LLVM dialect spec: "If the global value is a constant, storing into it is 751 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' 752 // storage class that is read-only. 753 bool isConstant = (storageClass == spirv::StorageClass::Input) || 754 (storageClass == spirv::StorageClass::UniformConstant); 755 // SPIR-V spec: "By default, functions and global variables are private to a 756 // module and cannot be accessed by other modules. However, a module may be 757 // written to export or import functions and global (module scope) 758 // variables.". Therefore, map 'Private' storage class to private linkage, 759 // 'Input' and 'Output' to external linkage. 760 auto linkage = storageClass == spirv::StorageClass::Private 761 ? LLVM::Linkage::Private 762 : LLVM::Linkage::External; 763 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 764 op, dstType, isConstant, linkage, op.sym_name(), Attribute(), 765 /*alignment=*/0); 766 767 // Attach location attribute if applicable 768 if (op.locationAttr()) 769 newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr()); 770 771 return success(); 772 } 773 }; 774 775 /// Converts SPIR-V cast ops that do not have straightforward LLVM 776 /// equivalent in LLVM dialect. 777 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 778 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 779 public: 780 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 781 782 LogicalResult 783 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 784 ConversionPatternRewriter &rewriter) const override { 785 786 Type fromType = operation.operand().getType(); 787 Type toType = operation.getType(); 788 789 auto dstType = this->typeConverter.convertType(toType); 790 if (!dstType) 791 return failure(); 792 793 if (getBitWidth(fromType) < getBitWidth(toType)) { 794 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType, 795 adaptor.getOperands()); 796 return success(); 797 } 798 if (getBitWidth(fromType) > getBitWidth(toType)) { 799 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType, 800 adaptor.getOperands()); 801 return success(); 802 } 803 return failure(); 804 } 805 }; 806 807 class FunctionCallPattern 808 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 809 public: 810 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 811 812 LogicalResult 813 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, 814 ConversionPatternRewriter &rewriter) const override { 815 if (callOp.getNumResults() == 0) { 816 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 817 callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs()); 818 return success(); 819 } 820 821 // Function returns a single result. 822 auto dstType = typeConverter.convertType(callOp.getType(0)); 823 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 824 callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); 825 return success(); 826 } 827 }; 828 829 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 830 template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 831 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 832 public: 833 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 834 835 LogicalResult 836 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 837 ConversionPatternRewriter &rewriter) const override { 838 839 auto dstType = this->typeConverter.convertType(operation.getType()); 840 if (!dstType) 841 return failure(); 842 843 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 844 operation, dstType, predicate, operation.operand1(), 845 operation.operand2()); 846 return success(); 847 } 848 }; 849 850 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 851 template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 852 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 853 public: 854 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 855 856 LogicalResult 857 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 858 ConversionPatternRewriter &rewriter) const override { 859 860 auto dstType = this->typeConverter.convertType(operation.getType()); 861 if (!dstType) 862 return failure(); 863 864 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 865 operation, dstType, predicate, operation.operand1(), 866 operation.operand2()); 867 return success(); 868 } 869 }; 870 871 class InverseSqrtPattern 872 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> { 873 public: 874 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion; 875 876 LogicalResult 877 matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor, 878 ConversionPatternRewriter &rewriter) const override { 879 auto srcType = op.getType(); 880 auto dstType = typeConverter.convertType(srcType); 881 if (!dstType) 882 return failure(); 883 884 Location loc = op.getLoc(); 885 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 886 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); 887 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 888 return success(); 889 } 890 }; 891 892 /// Converts `spv.Load` and `spv.Store` to LLVM dialect. 893 template <typename SPIRVOp> 894 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { 895 public: 896 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 897 898 LogicalResult 899 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 900 ConversionPatternRewriter &rewriter) const override { 901 if (!op.memory_access().hasValue()) { 902 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 903 this->typeConverter, /*alignment=*/0, 904 /*isVolatile=*/false, 905 /*isNonTemporal=*/false); 906 } 907 auto memoryAccess = op.memory_access().getValue(); 908 switch (memoryAccess) { 909 case spirv::MemoryAccess::Aligned: 910 case spirv::MemoryAccess::None: 911 case spirv::MemoryAccess::Nontemporal: 912 case spirv::MemoryAccess::Volatile: { 913 unsigned alignment = 914 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; 915 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 916 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 917 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 918 this->typeConverter, alignment, isVolatile, 919 isNonTemporal); 920 } 921 default: 922 // There is no support of other memory access attributes. 923 return failure(); 924 } 925 } 926 }; 927 928 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. 929 template <typename SPIRVOp> 930 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 931 public: 932 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 933 934 LogicalResult 935 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, 936 ConversionPatternRewriter &rewriter) const override { 937 auto srcType = notOp.getType(); 938 auto dstType = this->typeConverter.convertType(srcType); 939 if (!dstType) 940 return failure(); 941 942 Location loc = notOp.getLoc(); 943 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 944 auto mask = srcType.template isa<VectorType>() 945 ? rewriter.create<LLVM::ConstantOp>( 946 loc, dstType, 947 SplatElementsAttr::get( 948 srcType.template cast<VectorType>(), minusOne)) 949 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 950 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 951 notOp.operand(), mask); 952 return success(); 953 } 954 }; 955 956 /// A template pattern that erases the given `SPIRVOp`. 957 template <typename SPIRVOp> 958 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 959 public: 960 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 961 962 LogicalResult 963 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 964 ConversionPatternRewriter &rewriter) const override { 965 rewriter.eraseOp(op); 966 return success(); 967 } 968 }; 969 970 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 971 public: 972 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 973 974 LogicalResult 975 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, 976 ConversionPatternRewriter &rewriter) const override { 977 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 978 ArrayRef<Value>()); 979 return success(); 980 } 981 }; 982 983 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 984 public: 985 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 986 987 LogicalResult 988 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, 989 ConversionPatternRewriter &rewriter) const override { 990 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 991 adaptor.getOperands()); 992 return success(); 993 } 994 }; 995 996 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should 997 /// be reachable for conversion to succeed. The structure of the loop in LLVM 998 /// dialect will be the following: 999 /// 1000 /// +------------------------------------+ 1001 /// | <code before spv.mlir.loop> | 1002 /// | llvm.br ^header | 1003 /// +------------------------------------+ 1004 /// | 1005 /// +----------------+ | 1006 /// | | | 1007 /// | V V 1008 /// | +------------------------------------+ 1009 /// | | ^header: | 1010 /// | | <header code> | 1011 /// | | llvm.cond_br %cond, ^body, ^exit | 1012 /// | +------------------------------------+ 1013 /// | | 1014 /// | |----------------------+ 1015 /// | | | 1016 /// | V | 1017 /// | +------------------------------------+ | 1018 /// | | ^body: | | 1019 /// | | <body code> | | 1020 /// | | llvm.br ^continue | | 1021 /// | +------------------------------------+ | 1022 /// | | | 1023 /// | V | 1024 /// | +------------------------------------+ | 1025 /// | | ^continue: | | 1026 /// | | <continue code> | | 1027 /// | | llvm.br ^header | | 1028 /// | +------------------------------------+ | 1029 /// | | | 1030 /// +---------------+ +----------------------+ 1031 /// | 1032 /// V 1033 /// +------------------------------------+ 1034 /// | ^exit: | 1035 /// | llvm.br ^remaining | 1036 /// +------------------------------------+ 1037 /// | 1038 /// V 1039 /// +------------------------------------+ 1040 /// | ^remaining: | 1041 /// | <code after spv.mlir.loop> | 1042 /// +------------------------------------+ 1043 /// 1044 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1045 public: 1046 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1047 1048 LogicalResult 1049 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, 1050 ConversionPatternRewriter &rewriter) const override { 1051 // There is no support of loop control at the moment. 1052 if (loopOp.loop_control() != spirv::LoopControl::None) 1053 return failure(); 1054 1055 Location loc = loopOp.getLoc(); 1056 1057 // Split the current block after `spv.mlir.loop`. The remaining ops will be 1058 // used in `endBlock`. 1059 Block *currentBlock = rewriter.getBlock(); 1060 auto position = Block::iterator(loopOp); 1061 Block *endBlock = rewriter.splitBlock(currentBlock, position); 1062 1063 // Remove entry block and create a branch in the current block going to the 1064 // header block. 1065 Block *entryBlock = loopOp.getEntryBlock(); 1066 assert(entryBlock->getOperations().size() == 1); 1067 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1068 if (!brOp) 1069 return failure(); 1070 Block *headerBlock = loopOp.getHeaderBlock(); 1071 rewriter.setInsertionPointToEnd(currentBlock); 1072 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1073 rewriter.eraseBlock(entryBlock); 1074 1075 // Branch from merge block to end block. 1076 Block *mergeBlock = loopOp.getMergeBlock(); 1077 Operation *terminator = mergeBlock->getTerminator(); 1078 ValueRange terminatorOperands = terminator->getOperands(); 1079 rewriter.setInsertionPointToEnd(mergeBlock); 1080 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1081 1082 rewriter.inlineRegionBefore(loopOp.body(), endBlock); 1083 rewriter.replaceOp(loopOp, endBlock->getArguments()); 1084 return success(); 1085 } 1086 }; 1087 1088 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header 1089 /// block. All blocks within selection should be reachable for conversion to 1090 /// succeed. 1091 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1092 public: 1093 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1094 1095 LogicalResult 1096 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, 1097 ConversionPatternRewriter &rewriter) const override { 1098 // There is no support for `Flatten` or `DontFlatten` selection control at 1099 // the moment. This are just compiler hints and can be performed during the 1100 // optimization passes. 1101 if (op.selection_control() != spirv::SelectionControl::None) 1102 return failure(); 1103 1104 // `spv.mlir.selection` should have at least two blocks: one selection 1105 // header block and one merge block. If no blocks are present, or control 1106 // flow branches straight to merge block (two blocks are present), the op is 1107 // redundant and it is erased. 1108 if (op.body().getBlocks().size() <= 2) { 1109 rewriter.eraseOp(op); 1110 return success(); 1111 } 1112 1113 Location loc = op.getLoc(); 1114 1115 // Split the current block after `spv.mlir.selection`. The remaining ops 1116 // will be used in `continueBlock`. 1117 auto *currentBlock = rewriter.getInsertionBlock(); 1118 rewriter.setInsertionPointAfter(op); 1119 auto position = rewriter.getInsertionPoint(); 1120 auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1121 1122 // Extract conditional branch information from the header block. By SPIR-V 1123 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` 1124 // op. Note that `spv.Switch op` is not supported at the moment in the 1125 // SPIR-V dialect. Remove this block when finished. 1126 auto *headerBlock = op.getHeaderBlock(); 1127 assert(headerBlock->getOperations().size() == 1); 1128 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1129 headerBlock->getOperations().front()); 1130 if (!condBrOp) 1131 return failure(); 1132 rewriter.eraseBlock(headerBlock); 1133 1134 // Branch from merge block to continue block. 1135 auto *mergeBlock = op.getMergeBlock(); 1136 Operation *terminator = mergeBlock->getTerminator(); 1137 ValueRange terminatorOperands = terminator->getOperands(); 1138 rewriter.setInsertionPointToEnd(mergeBlock); 1139 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1140 1141 // Link current block to `true` and `false` blocks within the selection. 1142 Block *trueBlock = condBrOp.getTrueBlock(); 1143 Block *falseBlock = condBrOp.getFalseBlock(); 1144 rewriter.setInsertionPointToEnd(currentBlock); 1145 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock, 1146 condBrOp.trueTargetOperands(), falseBlock, 1147 condBrOp.falseTargetOperands()); 1148 1149 rewriter.inlineRegionBefore(op.body(), continueBlock); 1150 rewriter.replaceOp(op, continueBlock->getArguments()); 1151 return success(); 1152 } 1153 }; 1154 1155 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1156 /// puts a restriction on `Shift` and `Base` to have the same bit width, 1157 /// `Shift` is zero or sign extended to match this specification. Cases when 1158 /// `Shift` bit width > `Base` bit width are considered to be illegal. 1159 template <typename SPIRVOp, typename LLVMOp> 1160 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1161 public: 1162 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1163 1164 LogicalResult 1165 matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, 1166 ConversionPatternRewriter &rewriter) const override { 1167 1168 auto dstType = this->typeConverter.convertType(operation.getType()); 1169 if (!dstType) 1170 return failure(); 1171 1172 Type op1Type = operation.operand1().getType(); 1173 Type op2Type = operation.operand2().getType(); 1174 1175 if (op1Type == op2Type) { 1176 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, 1177 adaptor.getOperands()); 1178 return success(); 1179 } 1180 1181 Location loc = operation.getLoc(); 1182 Value extended; 1183 if (isUnsignedIntegerOrVector(op2Type)) { 1184 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType, 1185 adaptor.operand2()); 1186 } else { 1187 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType, 1188 adaptor.operand2()); 1189 } 1190 Value result = rewriter.template create<LLVMOp>( 1191 loc, dstType, adaptor.operand1(), extended); 1192 rewriter.replaceOp(operation, result); 1193 return success(); 1194 } 1195 }; 1196 1197 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> { 1198 public: 1199 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion; 1200 1201 LogicalResult 1202 matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor, 1203 ConversionPatternRewriter &rewriter) const override { 1204 auto dstType = typeConverter.convertType(tanOp.getType()); 1205 if (!dstType) 1206 return failure(); 1207 1208 Location loc = tanOp.getLoc(); 1209 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); 1210 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); 1211 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1212 return success(); 1213 } 1214 }; 1215 1216 /// Convert `spv.Tanh` to 1217 /// 1218 /// exp(2x) - 1 1219 /// ----------- 1220 /// exp(2x) + 1 1221 /// 1222 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> { 1223 public: 1224 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion; 1225 1226 LogicalResult 1227 matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor, 1228 ConversionPatternRewriter &rewriter) const override { 1229 auto srcType = tanhOp.getType(); 1230 auto dstType = typeConverter.convertType(srcType); 1231 if (!dstType) 1232 return failure(); 1233 1234 Location loc = tanhOp.getLoc(); 1235 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1236 Value multiplied = 1237 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); 1238 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1239 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1240 Value numerator = 1241 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1242 Value denominator = 1243 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1244 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1245 denominator); 1246 return success(); 1247 } 1248 }; 1249 1250 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1251 public: 1252 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1253 1254 LogicalResult 1255 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, 1256 ConversionPatternRewriter &rewriter) const override { 1257 auto srcType = varOp.getType(); 1258 // Initialization is supported for scalars and vectors only. 1259 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType(); 1260 auto init = varOp.initializer(); 1261 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>()) 1262 return failure(); 1263 1264 auto dstType = typeConverter.convertType(srcType); 1265 if (!dstType) 1266 return failure(); 1267 1268 Location loc = varOp.getLoc(); 1269 Value size = createI32ConstantOf(loc, rewriter, 1); 1270 if (!init) { 1271 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size); 1272 return success(); 1273 } 1274 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size); 1275 rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated); 1276 rewriter.replaceOp(varOp, allocated); 1277 return success(); 1278 } 1279 }; 1280 1281 //===----------------------------------------------------------------------===// 1282 // FuncOp conversion 1283 //===----------------------------------------------------------------------===// 1284 1285 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1286 public: 1287 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1288 1289 LogicalResult 1290 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 1291 ConversionPatternRewriter &rewriter) const override { 1292 1293 // Convert function signature. At the moment LLVMType converter is enough 1294 // for currently supported types. 1295 auto funcType = funcOp.getType(); 1296 TypeConverter::SignatureConversion signatureConverter( 1297 funcType.getNumInputs()); 1298 auto llvmType = typeConverter.convertFunctionSignature( 1299 funcOp.getType(), /*isVariadic=*/false, signatureConverter); 1300 if (!llvmType) 1301 return failure(); 1302 1303 // Create a new `LLVMFuncOp` 1304 Location loc = funcOp.getLoc(); 1305 StringRef name = funcOp.getName(); 1306 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1307 1308 // Convert SPIR-V Function Control to equivalent LLVM function attribute 1309 MLIRContext *context = funcOp.getContext(); 1310 switch (funcOp.function_control()) { 1311 #define DISPATCH(functionControl, llvmAttr) \ 1312 case functionControl: \ 1313 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ 1314 break; 1315 1316 DISPATCH(spirv::FunctionControl::Inline, 1317 StringAttr::get(context, "alwaysinline")); 1318 DISPATCH(spirv::FunctionControl::DontInline, 1319 StringAttr::get(context, "noinline")); 1320 DISPATCH(spirv::FunctionControl::Pure, 1321 StringAttr::get(context, "readonly")); 1322 DISPATCH(spirv::FunctionControl::Const, 1323 StringAttr::get(context, "readnone")); 1324 1325 #undef DISPATCH 1326 1327 // Default: if `spirv::FunctionControl::None`, then no attributes are 1328 // needed. 1329 default: 1330 break; 1331 } 1332 1333 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1334 newFuncOp.end()); 1335 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, 1336 &signatureConverter))) { 1337 return failure(); 1338 } 1339 rewriter.eraseOp(funcOp); 1340 return success(); 1341 } 1342 }; 1343 1344 //===----------------------------------------------------------------------===// 1345 // ModuleOp conversion 1346 //===----------------------------------------------------------------------===// 1347 1348 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1349 public: 1350 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1351 1352 LogicalResult 1353 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, 1354 ConversionPatternRewriter &rewriter) const override { 1355 1356 auto newModuleOp = 1357 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 1358 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); 1359 1360 // Remove the terminator block that was automatically added by builder 1361 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1362 rewriter.eraseOp(spvModuleOp); 1363 return success(); 1364 } 1365 }; 1366 1367 //===----------------------------------------------------------------------===// 1368 // VectorShuffleOp conversion 1369 //===----------------------------------------------------------------------===// 1370 1371 class VectorShufflePattern 1372 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { 1373 public: 1374 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; 1375 LogicalResult 1376 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, 1377 ConversionPatternRewriter &rewriter) const override { 1378 Location loc = op.getLoc(); 1379 auto components = adaptor.components(); 1380 auto vector1 = adaptor.vector1(); 1381 auto vector2 = adaptor.vector2(); 1382 int vector1Size = vector1.getType().cast<VectorType>().getNumElements(); 1383 int vector2Size = vector2.getType().cast<VectorType>().getNumElements(); 1384 if (vector1Size == vector2Size) { 1385 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2, 1386 components); 1387 return success(); 1388 } 1389 1390 auto dstType = typeConverter.convertType(op.getType()); 1391 auto scalarType = dstType.cast<VectorType>().getElementType(); 1392 auto componentsArray = components.getValue(); 1393 auto context = rewriter.getContext(); 1394 auto llvmI32Type = IntegerType::get(context, 32); 1395 Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); 1396 for (unsigned i = 0; i < componentsArray.size(); i++) { 1397 if (componentsArray[i].isa<IntegerAttr>()) 1398 op.emitError("unable to support non-constant component"); 1399 1400 int indexVal = componentsArray[i].cast<IntegerAttr>().getInt(); 1401 if (indexVal == -1) 1402 continue; 1403 1404 int offsetVal = 0; 1405 Value baseVector = vector1; 1406 if (indexVal >= vector1Size) { 1407 offsetVal = vector1Size; 1408 baseVector = vector2; 1409 } 1410 1411 Value dstIndex = rewriter.create<LLVM::ConstantOp>( 1412 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); 1413 Value index = rewriter.create<LLVM::ConstantOp>( 1414 loc, llvmI32Type, 1415 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); 1416 1417 auto extractOp = rewriter.create<LLVM::ExtractElementOp>( 1418 loc, scalarType, baseVector, index); 1419 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, 1420 extractOp, dstIndex); 1421 } 1422 rewriter.replaceOp(op, targetOp); 1423 return success(); 1424 } 1425 }; 1426 } // namespace 1427 1428 //===----------------------------------------------------------------------===// 1429 // Pattern population 1430 //===----------------------------------------------------------------------===// 1431 1432 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { 1433 typeConverter.addConversion([&](spirv::ArrayType type) { 1434 return convertArrayType(type, typeConverter); 1435 }); 1436 typeConverter.addConversion([&](spirv::PointerType type) { 1437 return convertPointerType(type, typeConverter); 1438 }); 1439 typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1440 return convertRuntimeArrayType(type, typeConverter); 1441 }); 1442 typeConverter.addConversion([&](spirv::StructType type) { 1443 return convertStructType(type, typeConverter); 1444 }); 1445 } 1446 1447 void mlir::populateSPIRVToLLVMConversionPatterns( 1448 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1449 patterns.add< 1450 // Arithmetic ops 1451 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1452 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1453 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1454 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1455 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1456 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1457 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1458 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1459 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1460 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1461 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1462 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1463 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1464 1465 // Bitwise ops 1466 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1467 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1468 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1469 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1470 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1471 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1472 NotPattern<spirv::NotOp>, 1473 1474 // Cast ops 1475 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>, 1476 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1477 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1478 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1479 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1480 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1481 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1482 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1483 1484 // Comparison ops 1485 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1486 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1487 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1488 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1489 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1490 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1491 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1492 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1493 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1494 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1495 FComparePattern<spirv::FUnordGreaterThanEqualOp, 1496 LLVM::FCmpPredicate::uge>, 1497 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1498 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1499 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1500 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1501 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1502 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1503 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1504 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1505 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1506 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1507 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1508 1509 // Constant op 1510 ConstantScalarAndVectorPattern, 1511 1512 // Control Flow ops 1513 BranchConversionPattern, BranchConditionalConversionPattern, 1514 FunctionCallPattern, LoopPattern, SelectionPattern, 1515 ErasePattern<spirv::MergeOp>, 1516 1517 // Entry points and execution mode are handled separately. 1518 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1519 1520 // GLSL extended instruction set ops 1521 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, 1522 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, 1523 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, 1524 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, 1525 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>, 1526 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>, 1527 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>, 1528 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, 1529 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, 1530 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>, 1531 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>, 1532 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, 1533 InverseSqrtPattern, TanPattern, TanhPattern, 1534 1535 // Logical ops 1536 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1537 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1538 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1539 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1540 NotPattern<spirv::LogicalNotOp>, 1541 1542 // Memory ops 1543 AccessChainPattern, AddressOfPattern, GlobalVariablePattern, 1544 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>, 1545 VariablePattern, 1546 1547 // Miscellaneous ops 1548 CompositeExtractPattern, CompositeInsertPattern, 1549 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1550 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 1551 VectorShufflePattern, 1552 1553 // Shift ops 1554 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1555 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1556 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1557 1558 // Return ops 1559 ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter); 1560 } 1561 1562 void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1563 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1564 patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter); 1565 } 1566 1567 void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1568 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1569 patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter); 1570 } 1571 1572 //===----------------------------------------------------------------------===// 1573 // Pre-conversion hooks 1574 //===----------------------------------------------------------------------===// 1575 1576 /// Hook for descriptor set and binding number encoding. 1577 static constexpr StringRef kBinding = "binding"; 1578 static constexpr StringRef kDescriptorSet = "descriptor_set"; 1579 void mlir::encodeBindAttribute(ModuleOp module) { 1580 auto spvModules = module.getOps<spirv::ModuleOp>(); 1581 for (auto spvModule : spvModules) { 1582 spvModule.walk([&](spirv::GlobalVariableOp op) { 1583 IntegerAttr descriptorSet = 1584 op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1585 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1586 // For every global variable in the module, get the ones with descriptor 1587 // set and binding numbers. 1588 if (descriptorSet && binding) { 1589 // Encode these numbers into the variable's symbolic name. If the 1590 // SPIR-V module has a name, add it at the beginning. 1591 auto moduleAndName = spvModule.getName().hasValue() 1592 ? spvModule.getName().getValue().str() + "_" + 1593 op.sym_name().str() 1594 : op.sym_name().str(); 1595 std::string name = 1596 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 1597 std::to_string(descriptorSet.getInt()), 1598 std::to_string(binding.getInt())); 1599 auto nameAttr = StringAttr::get(op->getContext(), name); 1600 1601 // Replace all symbol uses and set the new symbol name. Finally, remove 1602 // descriptor set and binding attributes. 1603 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) 1604 op.emitError("unable to replace all symbol uses for ") << name; 1605 SymbolTable::setSymbolName(op, nameAttr); 1606 op->removeAttr(kDescriptorSet); 1607 op->removeAttr(kBinding); 1608 } 1609 }); 1610 } 1611 } 1612