1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38   if (type.isSignedInteger())
39     return true;
40   if (auto vecType = type.dyn_cast<VectorType>())
41     return vecType.getElementType().isSignedInteger();
42   return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47   if (type.isUnsignedInteger())
48     return true;
49   if (auto vecType = type.dyn_cast<VectorType>())
50     return vecType.getElementType().isUnsignedInteger();
51   return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57          "bitwidth is not supported for this type");
58   if (type.isIntOrFloat())
59     return type.getIntOrFloatBitWidth();
60   auto vecType = type.dyn_cast<VectorType>();
61   auto elementType = vecType.getElementType();
62   assert(elementType.isIntOrFloat() &&
63          "only integers and floats have a bitwidth");
64   return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
69   auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
70   return (vectorType ? vectorType.getElementType() : type)
71       .cast<LLVM::LLVMIntegerType>()
72       .getBitWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77   if (auto vecType = type.dyn_cast<VectorType>()) {
78     auto integerType = vecType.getElementType().cast<IntegerType>();
79     return builder.getIntegerAttr(integerType, -1);
80   }
81   auto integerType = type.cast<IntegerType>();
82   return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87                                       PatternRewriter &rewriter) {
88   if (srcType.isa<VectorType>()) {
89     return rewriter.create<LLVM::ConstantOp>(
90         loc, dstType,
91         SplatElementsAttr::get(srcType.cast<ShapedType>(),
92                                minusOneIntegerAttribute(srcType, rewriter)));
93   }
94   return rewriter.create<LLVM::ConstantOp>(
95       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100                               PatternRewriter &rewriter, double value) {
101   if (auto vecType = srcType.dyn_cast<VectorType>()) {
102     auto floatType = vecType.getElementType().cast<FloatType>();
103     return rewriter.create<LLVM::ConstantOp>(
104         loc, dstType,
105         SplatElementsAttr::get(vecType,
106                                rewriter.getFloatAttr(floatType, value)));
107   }
108   auto floatType = srcType.cast<FloatType>();
109   return rewriter.create<LLVM::ConstantOp>(
110       loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 ///   - `BitFieldInsert`
115 ///   - `BitFieldSExtract`
116 ///   - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `dstType` bitwidth, the value remains unchanged.
119 static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
120                                         PatternRewriter &rewriter) {
121   auto srcType = value.getType();
122   auto llvmType = dstType.cast<LLVM::LLVMType>();
123   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124   unsigned valueBitWidth =
125       srcType.isa<LLVM::LLVMType>()
126           ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
127           : getBitWidth(srcType);
128 
129   if (valueBitWidth < targetBitWidth)
130     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
131   // If the bit widths of `Count` and `Offset` are greater than the bit width
132   // of the target type, they are truncated. Truncation is safe since `Count`
133   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
134   // both values can be expressed in 8 bits.
135   if (valueBitWidth > targetBitWidth)
136     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
137   return value;
138 }
139 
140 /// Broadcasts the value to vector with `numElements` number of elements.
141 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
142                        LLVMTypeConverter &typeConverter,
143                        ConversionPatternRewriter &rewriter) {
144   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
145   auto llvmVectorType = typeConverter.convertType(vectorType);
146   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
147   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
148   for (unsigned i = 0; i < numElements; ++i) {
149     auto index = rewriter.create<LLVM::ConstantOp>(
150         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
151     broadcasted = rewriter.create<LLVM::InsertElementOp>(
152         loc, llvmVectorType, broadcasted, toBroadcast, index);
153   }
154   return broadcasted;
155 }
156 
157 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
158 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
159                                  LLVMTypeConverter &typeConverter,
160                                  ConversionPatternRewriter &rewriter) {
161   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
162     unsigned numElements = vectorType.getNumElements();
163     return broadcast(loc, value, numElements, typeConverter, rewriter);
164   }
165   return value;
166 }
167 
168 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
169 /// `BitFieldUExtract`.
170 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
171 /// a vector type, construct a vector that has:
172 ///  - same number of elements as `Base`
173 ///  - each element has the type that is the same as the type of `Offset` or
174 ///    `Count`
175 ///  - each element has the same value as `Offset` or `Count`
176 /// Then cast `Offset` and `Count` if their bit width is different
177 /// from `Base` bit width.
178 static Value processCountOrOffset(Location loc, Value value, Type srcType,
179                                   Type dstType, LLVMTypeConverter &converter,
180                                   ConversionPatternRewriter &rewriter) {
181   Value broadcasted =
182       optionallyBroadcast(loc, value, srcType, converter, rewriter);
183   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
184 }
185 
186 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
187 /// offset to LLVM struct. Otherwise, the conversion is not supported.
188 static Optional<Type>
189 convertStructTypeWithOffset(spirv::StructType type,
190                             LLVMTypeConverter &converter) {
191   if (type != VulkanLayoutUtils::decorateType(type))
192     return llvm::None;
193 
194   auto elementsVector = llvm::to_vector<8>(
195       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
196         return converter.convertType(elementType).cast<LLVM::LLVMType>();
197       }));
198   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
199                                           /*isPacked=*/false);
200 }
201 
202 /// Converts SPIR-V struct with no offset to packed LLVM struct.
203 static Type convertStructTypePacked(spirv::StructType type,
204                                     LLVMTypeConverter &converter) {
205   auto elementsVector = llvm::to_vector<8>(
206       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
207         return converter.convertType(elementType).cast<LLVM::LLVMType>();
208       }));
209   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
210                                           /*isPacked=*/true);
211 }
212 
213 /// Creates LLVM dialect constant with the given value.
214 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
215                                  unsigned value) {
216   return rewriter.create<LLVM::ConstantOp>(
217       loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 32),
218       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
219 }
220 
221 /// Utility for `spv.Load` and `spv.Store` conversion.
222 static LogicalResult replaceWithLoadOrStore(Operation *op,
223                                             ConversionPatternRewriter &rewriter,
224                                             LLVMTypeConverter &typeConverter,
225                                             unsigned alignment, bool isVolatile,
226                                             bool isNonTemporal) {
227   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
228     auto dstType = typeConverter.convertType(loadOp.getType());
229     if (!dstType)
230       return failure();
231     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
232         loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
233     return success();
234   }
235   auto storeOp = cast<spirv::StoreOp>(op);
236   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
237                                              storeOp.ptr(), alignment,
238                                              isVolatile, isNonTemporal);
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Type conversion
244 //===----------------------------------------------------------------------===//
245 
246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
248 /// when converting ops that manipulate array types.
249 static Optional<Type> convertArrayType(spirv::ArrayType type,
250                                        TypeConverter &converter) {
251   unsigned stride = type.getArrayStride();
252   Type elementType = type.getElementType();
253   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
254   if (stride != 0 &&
255       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
256     return llvm::None;
257 
258   auto llvmElementType =
259       converter.convertType(elementType).cast<LLVM::LLVMType>();
260   unsigned numElements = type.getNumElements();
261   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
262 }
263 
264 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
265 /// modelled at the moment.
266 static Type convertPointerType(spirv::PointerType type,
267                                TypeConverter &converter) {
268   auto pointeeType =
269       converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
270   return LLVM::LLVMPointerType::get(pointeeType);
271 }
272 
273 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
274 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
275 /// no modelling of array stride at the moment.
276 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
277                                               TypeConverter &converter) {
278   if (type.getArrayStride() != 0)
279     return llvm::None;
280   auto elementType =
281       converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
282   return LLVM::LLVMArrayType::get(elementType, 0);
283 }
284 
285 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
286 /// member decorations. Also, only natural offset is supported.
287 static Optional<Type> convertStructType(spirv::StructType type,
288                                         LLVMTypeConverter &converter) {
289   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
290   type.getMemberDecorations(memberDecorations);
291   if (!memberDecorations.empty())
292     return llvm::None;
293   if (type.hasOffset())
294     return convertStructTypeWithOffset(type, converter);
295   return convertStructTypePacked(type, converter);
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Operation conversion
300 //===----------------------------------------------------------------------===//
301 
302 namespace {
303 
304 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
305 public:
306   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
307 
308   LogicalResult
309   matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
310                   ConversionPatternRewriter &rewriter) const override {
311     auto dstType = typeConverter.convertType(op.component_ptr().getType());
312     if (!dstType)
313       return failure();
314     // To use GEP we need to add a first 0 index to go through the pointer.
315     auto indices = llvm::to_vector<4>(op.indices());
316     Type indexType = op.indices().front().getType();
317     auto llvmIndexType = typeConverter.convertType(indexType);
318     if (!llvmIndexType)
319       return failure();
320     Value zero = rewriter.create<LLVM::ConstantOp>(
321         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
322     indices.insert(indices.begin(), zero);
323     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
324                                              indices);
325     return success();
326   }
327 };
328 
329 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
330 public:
331   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
332 
333   LogicalResult
334   matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
335                   ConversionPatternRewriter &rewriter) const override {
336     auto dstType = typeConverter.convertType(op.pointer().getType());
337     if (!dstType)
338       return failure();
339     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
340         op, dstType.cast<LLVM::LLVMType>(), op.variable());
341     return success();
342   }
343 };
344 
345 class BitFieldInsertPattern
346     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
347 public:
348   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
349 
350   LogicalResult
351   matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
352                   ConversionPatternRewriter &rewriter) const override {
353     auto srcType = op.getType();
354     auto dstType = typeConverter.convertType(srcType);
355     if (!dstType)
356       return failure();
357     Location loc = op.getLoc();
358 
359     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
360     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
361                                         typeConverter, rewriter);
362     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
363                                        typeConverter, rewriter);
364 
365     // Create a mask with bits set outside [Offset, Offset + Count - 1].
366     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
367     Value maskShiftedByCount =
368         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
369     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
370                                                  maskShiftedByCount, minusOne);
371     Value maskShiftedByCountAndOffset =
372         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
373     Value mask = rewriter.create<LLVM::XOrOp>(
374         loc, dstType, maskShiftedByCountAndOffset, minusOne);
375 
376     // Extract unchanged bits from the `Base`  that are outside of
377     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
378     Value baseAndMask =
379         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
380     Value insertShiftedByOffset =
381         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
382     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
383                                             insertShiftedByOffset);
384     return success();
385   }
386 };
387 
388 /// Converts SPIR-V ConstantOp with scalar or vector type.
389 class ConstantScalarAndVectorPattern
390     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
391 public:
392   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
393 
394   LogicalResult
395   matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
396                   ConversionPatternRewriter &rewriter) const override {
397     auto srcType = constOp.getType();
398     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
399       return failure();
400 
401     auto dstType = typeConverter.convertType(srcType);
402     if (!dstType)
403       return failure();
404 
405     // SPIR-V constant can be a signed/unsigned integer, which has to be
406     // casted to signless integer when converting to LLVM dialect. Removing the
407     // sign bit may have unexpected behaviour. However, it is better to handle
408     // it case-by-case, given that the purpose of the conversion is not to
409     // cover all possible corner cases.
410     if (isSignedIntegerOrVector(srcType) ||
411         isUnsignedIntegerOrVector(srcType)) {
412       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
413 
414       if (srcType.isa<VectorType>()) {
415         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
416         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
417             constOp, dstType,
418             dstElementsAttr.mapValues(
419                 signlessType, [&](const APInt &value) { return value; }));
420         return success();
421       }
422       auto srcAttr = constOp.value().cast<IntegerAttr>();
423       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
424       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
425       return success();
426     }
427     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
428                                                   constOp.getAttrs());
429     return success();
430   }
431 };
432 
433 class BitFieldSExtractPattern
434     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
435 public:
436   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
437 
438   LogicalResult
439   matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
440                   ConversionPatternRewriter &rewriter) const override {
441     auto srcType = op.getType();
442     auto dstType = typeConverter.convertType(srcType);
443     if (!dstType)
444       return failure();
445     Location loc = op.getLoc();
446 
447     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
448     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
449                                         typeConverter, rewriter);
450     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
451                                        typeConverter, rewriter);
452 
453     // Create a constant that holds the size of the `Base`.
454     IntegerType integerType;
455     if (auto vecType = srcType.dyn_cast<VectorType>())
456       integerType = vecType.getElementType().cast<IntegerType>();
457     else
458       integerType = srcType.cast<IntegerType>();
459 
460     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
461     Value size =
462         srcType.isa<VectorType>()
463             ? rewriter.create<LLVM::ConstantOp>(
464                   loc, dstType,
465                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
466             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
467 
468     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
469     // at Offset + Count - 1 is the most significant bit now.
470     Value countPlusOffset =
471         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
472     Value amountToShiftLeft =
473         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
474     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
475         loc, dstType, op.base(), amountToShiftLeft);
476 
477     // Shift the result right, filling the bits with the sign bit.
478     Value amountToShiftRight =
479         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
480     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
481                                               amountToShiftRight);
482     return success();
483   }
484 };
485 
486 class BitFieldUExtractPattern
487     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
488 public:
489   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
490 
491   LogicalResult
492   matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
493                   ConversionPatternRewriter &rewriter) const override {
494     auto srcType = op.getType();
495     auto dstType = typeConverter.convertType(srcType);
496     if (!dstType)
497       return failure();
498     Location loc = op.getLoc();
499 
500     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
501     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
502                                         typeConverter, rewriter);
503     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
504                                        typeConverter, rewriter);
505 
506     // Create a mask with bits set at [0, Count - 1].
507     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
508     Value maskShiftedByCount =
509         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
510     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
511                                               minusOne);
512 
513     // Shift `Base` by `Offset` and apply the mask on it.
514     Value shiftedBase =
515         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
516     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
517     return success();
518   }
519 };
520 
521 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
522 public:
523   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
524 
525   LogicalResult
526   matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
527                   ConversionPatternRewriter &rewriter) const override {
528     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
529                                             branchOp.getTarget());
530     return success();
531   }
532 };
533 
534 class BranchConditionalConversionPattern
535     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
536 public:
537   using SPIRVToLLVMConversion<
538       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
539 
540   LogicalResult
541   matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
542                   ConversionPatternRewriter &rewriter) const override {
543     // If branch weights exist, map them to 32-bit integer vector.
544     ElementsAttr branchWeights = nullptr;
545     if (auto weights = op.branch_weights()) {
546       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
547       branchWeights =
548           DenseElementsAttr::get(weightType, weights.getValue().getValue());
549     }
550 
551     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
552         op, op.condition(), op.getTrueBlockArguments(),
553         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
554         op.getFalseBlock());
555     return success();
556   }
557 };
558 
559 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
560 /// is an aggregate type (struct or array). Otherwise, converts to
561 /// `llvm.extractelement` that operates on vectors.
562 class CompositeExtractPattern
563     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
564 public:
565   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
566 
567   LogicalResult
568   matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
569                   ConversionPatternRewriter &rewriter) const override {
570     auto dstType = this->typeConverter.convertType(op.getType());
571     if (!dstType)
572       return failure();
573 
574     Type containerType = op.composite().getType();
575     if (containerType.isa<VectorType>()) {
576       Location loc = op.getLoc();
577       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
578       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
579       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
580           op, dstType, op.composite(), index);
581       return success();
582     }
583     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
584         op, dstType, op.composite(), op.indices());
585     return success();
586   }
587 };
588 
589 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
590 /// is an aggregate type (struct or array). Otherwise, converts to
591 /// `llvm.insertelement` that operates on vectors.
592 class CompositeInsertPattern
593     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
594 public:
595   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
596 
597   LogicalResult
598   matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
599                   ConversionPatternRewriter &rewriter) const override {
600     auto dstType = this->typeConverter.convertType(op.getType());
601     if (!dstType)
602       return failure();
603 
604     Type containerType = op.composite().getType();
605     if (containerType.isa<VectorType>()) {
606       Location loc = op.getLoc();
607       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
608       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
609       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
610           op, dstType, op.composite(), op.object(), index);
611       return success();
612     }
613     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
614         op, dstType, op.composite(), op.object(), op.indices());
615     return success();
616   }
617 };
618 
619 /// Converts SPIR-V operations that have straightforward LLVM equivalent
620 /// into LLVM dialect operations.
621 template <typename SPIRVOp, typename LLVMOp>
622 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
623 public:
624   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
625 
626   LogicalResult
627   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
628                   ConversionPatternRewriter &rewriter) const override {
629     auto dstType = this->typeConverter.convertType(operation.getType());
630     if (!dstType)
631       return failure();
632     rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
633                                                  operation.getAttrs());
634     return success();
635   }
636 };
637 
638 /// Converts `spv.ExecutionMode` into a global struct constant that holds
639 /// execution mode information.
640 class ExecutionModePattern
641     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
642 public:
643   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
644 
645   LogicalResult
646   matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
647                   ConversionPatternRewriter &rewriter) const override {
648     // First, create the global struct's name that would be associated with
649     // this entry point's execution mode. We set it to be:
650     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info
651     ModuleOp module = op->getParentOfType<ModuleOp>();
652     std::string moduleName;
653     if (module.getName().hasValue())
654       moduleName = "_" + module.getName().getValue().str();
655     else
656       moduleName = "";
657     std::string executionModeInfoName = llvm::formatv(
658         "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str());
659 
660     MLIRContext *context = rewriter.getContext();
661     OpBuilder::InsertionGuard guard(rewriter);
662     rewriter.setInsertionPointToStart(module.getBody());
663 
664     // Create a struct type, corresponding to the C struct below.
665     // struct {
666     //   int32_t executionMode;
667     //   int32_t values[];          // optional values
668     // };
669     auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32);
670     SmallVector<LLVM::LLVMType, 2> fields;
671     fields.push_back(llvmI32Type);
672     ArrayAttr values = op.values();
673     if (!values.empty()) {
674       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
675       fields.push_back(arrayType);
676     }
677     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
678 
679     // Create `llvm.mlir.global` with initializer region containing one block.
680     auto global = rewriter.create<LLVM::GlobalOp>(
681         UnknownLoc::get(context), structType, /*isConstant=*/true,
682         LLVM::Linkage::External, executionModeInfoName, Attribute());
683     Location loc = global.getLoc();
684     Region &region = global.getInitializerRegion();
685     Block *block = rewriter.createBlock(&region);
686 
687     // Initialize the struct and set the execution mode value.
688     rewriter.setInsertionPoint(block, block->begin());
689     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
690     IntegerAttr executionModeAttr = op.execution_modeAttr();
691     Value executionMode =
692         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
693     structValue = rewriter.create<LLVM::InsertValueOp>(
694         loc, structType, structValue, executionMode,
695         ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
696                        context));
697 
698     // Insert extra operands if they exist into execution mode info struct.
699     for (unsigned i = 0, e = values.size(); i < e; ++i) {
700       auto attr = values.getValue()[i];
701       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
702       structValue = rewriter.create<LLVM::InsertValueOp>(
703           loc, structType, structValue, entry,
704           ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
705                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
706                          context));
707     }
708     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
709     rewriter.eraseOp(op);
710     return success();
711   }
712 };
713 
714 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
715 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
716 /// This difference is handled by `spv.mlir.addressof` and
717 /// `llvm.mlir.addressof`ops that both return a pointer.
718 class GlobalVariablePattern
719     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
720 public:
721   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
722 
723   LogicalResult
724   matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
725                   ConversionPatternRewriter &rewriter) const override {
726     // Currently, there is no support of initialization with a constant value in
727     // SPIR-V dialect. Specialization constants are not considered as well.
728     if (op.initializer())
729       return failure();
730 
731     auto srcType = op.type().cast<spirv::PointerType>();
732     auto dstType = typeConverter.convertType(srcType.getPointeeType());
733     if (!dstType)
734       return failure();
735 
736     // Limit conversion to the current invocation only or `StorageBuffer`
737     // required by SPIR-V runner.
738     // This is okay because multiple invocations are not supported yet.
739     auto storageClass = srcType.getStorageClass();
740     if (storageClass != spirv::StorageClass::Input &&
741         storageClass != spirv::StorageClass::Private &&
742         storageClass != spirv::StorageClass::Output &&
743         storageClass != spirv::StorageClass::StorageBuffer) {
744       return failure();
745     }
746 
747     // LLVM dialect spec: "If the global value is a constant, storing into it is
748     // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
749     // read-only.
750     bool isConstant = storageClass == spirv::StorageClass::Input;
751     // SPIR-V spec: "By default, functions and global variables are private to a
752     // module and cannot be accessed by other modules. However, a module may be
753     // written to export or import functions and global (module scope)
754     // variables.". Therefore, map 'Private' storage class to private linkage,
755     // 'Input' and 'Output' to external linkage.
756     auto linkage = storageClass == spirv::StorageClass::Private
757                        ? LLVM::Linkage::Private
758                        : LLVM::Linkage::External;
759     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
760         op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
761         Attribute());
762     return success();
763   }
764 };
765 
766 /// Converts SPIR-V cast ops that do not have straightforward LLVM
767 /// equivalent in LLVM dialect.
768 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
769 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
770 public:
771   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
772 
773   LogicalResult
774   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
775                   ConversionPatternRewriter &rewriter) const override {
776 
777     Type fromType = operation.operand().getType();
778     Type toType = operation.getType();
779 
780     auto dstType = this->typeConverter.convertType(toType);
781     if (!dstType)
782       return failure();
783 
784     if (getBitWidth(fromType) < getBitWidth(toType)) {
785       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
786                                                       operands);
787       return success();
788     }
789     if (getBitWidth(fromType) > getBitWidth(toType)) {
790       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
791                                                         operands);
792       return success();
793     }
794     return failure();
795   }
796 };
797 
798 class FunctionCallPattern
799     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
800 public:
801   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
802 
803   LogicalResult
804   matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
805                   ConversionPatternRewriter &rewriter) const override {
806     if (callOp.getNumResults() == 0) {
807       rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
808                                                 callOp.getAttrs());
809       return success();
810     }
811 
812     // Function returns a single result.
813     auto dstType = typeConverter.convertType(callOp.getType(0));
814     rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
815                                               callOp.getAttrs());
816     return success();
817   }
818 };
819 
820 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
821 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
822 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
823 public:
824   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
825 
826   LogicalResult
827   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
828                   ConversionPatternRewriter &rewriter) const override {
829 
830     auto dstType = this->typeConverter.convertType(operation.getType());
831     if (!dstType)
832       return failure();
833 
834     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
835         operation, dstType,
836         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
837         operation.operand1(), operation.operand2());
838     return success();
839   }
840 };
841 
842 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
843 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
844 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
845 public:
846   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
847 
848   LogicalResult
849   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
850                   ConversionPatternRewriter &rewriter) const override {
851 
852     auto dstType = this->typeConverter.convertType(operation.getType());
853     if (!dstType)
854       return failure();
855 
856     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
857         operation, dstType,
858         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
859         operation.operand1(), operation.operand2());
860     return success();
861   }
862 };
863 
864 class InverseSqrtPattern
865     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
866 public:
867   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
868 
869   LogicalResult
870   matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
871                   ConversionPatternRewriter &rewriter) const override {
872     auto srcType = op.getType();
873     auto dstType = typeConverter.convertType(srcType);
874     if (!dstType)
875       return failure();
876 
877     Location loc = op.getLoc();
878     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
879     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
880     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
881     return success();
882   }
883 };
884 
885 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
886 template <typename SPIRVop>
887 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
888 public:
889   using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
890 
891   LogicalResult
892   matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
893                   ConversionPatternRewriter &rewriter) const override {
894 
895     if (!op.memory_access().hasValue()) {
896       replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0,
897                              /*isVolatile=*/false, /*isNonTemporal=*/false);
898       return success();
899     }
900     auto memoryAccess = op.memory_access().getValue();
901     switch (memoryAccess) {
902     case spirv::MemoryAccess::Aligned:
903     case spirv::MemoryAccess::None:
904     case spirv::MemoryAccess::Nontemporal:
905     case spirv::MemoryAccess::Volatile: {
906       unsigned alignment =
907           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
908       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
909       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
910       replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment,
911                              isVolatile, isNonTemporal);
912       return success();
913     }
914     default:
915       // There is no support of other memory access attributes.
916       return failure();
917     }
918   }
919 };
920 
921 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
922 template <typename SPIRVOp>
923 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
924 public:
925   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
926 
927   LogicalResult
928   matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
929                   ConversionPatternRewriter &rewriter) const override {
930 
931     auto srcType = notOp.getType();
932     auto dstType = this->typeConverter.convertType(srcType);
933     if (!dstType)
934       return failure();
935 
936     Location loc = notOp.getLoc();
937     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
938     auto mask = srcType.template isa<VectorType>()
939                     ? rewriter.create<LLVM::ConstantOp>(
940                           loc, dstType,
941                           SplatElementsAttr::get(
942                               srcType.template cast<VectorType>(), minusOne))
943                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
944     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
945                                                       notOp.operand(), mask);
946     return success();
947   }
948 };
949 
950 /// A template pattern that erases the given `SPIRVOp`.
951 template <typename SPIRVOp>
952 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
953 public:
954   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
955 
956   LogicalResult
957   matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
958                   ConversionPatternRewriter &rewriter) const override {
959     rewriter.eraseOp(op);
960     return success();
961   }
962 };
963 
964 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
965 public:
966   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
967 
968   LogicalResult
969   matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
970                   ConversionPatternRewriter &rewriter) const override {
971     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
972                                                 ArrayRef<Value>());
973     return success();
974   }
975 };
976 
977 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
978 public:
979   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
980 
981   LogicalResult
982   matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
983                   ConversionPatternRewriter &rewriter) const override {
984     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
985                                                 operands);
986     return success();
987   }
988 };
989 
990 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
991 /// reachable for conversion to succeed.
992 /// The structure of the loop in LLVM dialect will be the following:
993 ///
994 ///      +------------------------------------+
995 ///      | <code before spv.loop>             |
996 ///      | llvm.br ^header                    |
997 ///      +------------------------------------+
998 ///                           |
999 ///   +----------------+      |
1000 ///   |                |      |
1001 ///   |                V      V
1002 ///   |  +------------------------------------+
1003 ///   |  | ^header:                           |
1004 ///   |  |   <header code>                    |
1005 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1006 ///   |  +------------------------------------+
1007 ///   |                    |
1008 ///   |                    |----------------------+
1009 ///   |                    |                      |
1010 ///   |                    V                      |
1011 ///   |  +------------------------------------+   |
1012 ///   |  | ^body:                             |   |
1013 ///   |  |   <body code>                      |   |
1014 ///   |  |   llvm.br ^continue                |   |
1015 ///   |  +------------------------------------+   |
1016 ///   |                    |                      |
1017 ///   |                    V                      |
1018 ///   |  +------------------------------------+   |
1019 ///   |  | ^continue:                         |   |
1020 ///   |  |   <continue code>                  |   |
1021 ///   |  |   llvm.br ^header                  |   |
1022 ///   |  +------------------------------------+   |
1023 ///   |               |                           |
1024 ///   +---------------+    +----------------------+
1025 ///                        |
1026 ///                        V
1027 ///      +------------------------------------+
1028 ///      | ^exit:                             |
1029 ///      |   llvm.br ^remaining               |
1030 ///      +------------------------------------+
1031 ///                        |
1032 ///                        V
1033 ///      +------------------------------------+
1034 ///      | ^remaining:                        |
1035 ///      |   <code after spv.loop>            |
1036 ///      +------------------------------------+
1037 ///
1038 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1039 public:
1040   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1041 
1042   LogicalResult
1043   matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
1044                   ConversionPatternRewriter &rewriter) const override {
1045     // There is no support of loop control at the moment.
1046     if (loopOp.loop_control() != spirv::LoopControl::None)
1047       return failure();
1048 
1049     Location loc = loopOp.getLoc();
1050 
1051     // Split the current block after `spv.loop`. The remaining ops will be used
1052     // in `endBlock`.
1053     Block *currentBlock = rewriter.getBlock();
1054     auto position = Block::iterator(loopOp);
1055     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1056 
1057     // Remove entry block and create a branch in the current block going to the
1058     // header block.
1059     Block *entryBlock = loopOp.getEntryBlock();
1060     assert(entryBlock->getOperations().size() == 1);
1061     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1062     if (!brOp)
1063       return failure();
1064     Block *headerBlock = loopOp.getHeaderBlock();
1065     rewriter.setInsertionPointToEnd(currentBlock);
1066     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1067     rewriter.eraseBlock(entryBlock);
1068 
1069     // Branch from merge block to end block.
1070     Block *mergeBlock = loopOp.getMergeBlock();
1071     Operation *terminator = mergeBlock->getTerminator();
1072     ValueRange terminatorOperands = terminator->getOperands();
1073     rewriter.setInsertionPointToEnd(mergeBlock);
1074     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1075 
1076     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1077     rewriter.replaceOp(loopOp, endBlock->getArguments());
1078     return success();
1079   }
1080 };
1081 
1082 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1083 /// All blocks within selection should be reachable for conversion to succeed.
1084 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1085 public:
1086   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1087 
1088   LogicalResult
1089   matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1090                   ConversionPatternRewriter &rewriter) const override {
1091     // There is no support for `Flatten` or `DontFlatten` selection control at
1092     // the moment. This are just compiler hints and can be performed during the
1093     // optimization passes.
1094     if (op.selection_control() != spirv::SelectionControl::None)
1095       return failure();
1096 
1097     // `spv.selection` should have at least two blocks: one selection header
1098     // block and one merge block. If no blocks are present, or control flow
1099     // branches straight to merge block (two blocks are present), the op is
1100     // redundant and it is erased.
1101     if (op.body().getBlocks().size() <= 2) {
1102       rewriter.eraseOp(op);
1103       return success();
1104     }
1105 
1106     Location loc = op.getLoc();
1107 
1108     // Split the current block after `spv.selection`. The remaining ops will be
1109     // used in `continueBlock`.
1110     auto *currentBlock = rewriter.getInsertionBlock();
1111     rewriter.setInsertionPointAfter(op);
1112     auto position = rewriter.getInsertionPoint();
1113     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1114 
1115     // Extract conditional branch information from the header block. By SPIR-V
1116     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1117     // op. Note that `spv.Switch op` is not supported at the moment in the
1118     // SPIR-V dialect. Remove this block when finished.
1119     auto *headerBlock = op.getHeaderBlock();
1120     assert(headerBlock->getOperations().size() == 1);
1121     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1122         headerBlock->getOperations().front());
1123     if (!condBrOp)
1124       return failure();
1125     rewriter.eraseBlock(headerBlock);
1126 
1127     // Branch from merge block to continue block.
1128     auto *mergeBlock = op.getMergeBlock();
1129     Operation *terminator = mergeBlock->getTerminator();
1130     ValueRange terminatorOperands = terminator->getOperands();
1131     rewriter.setInsertionPointToEnd(mergeBlock);
1132     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1133 
1134     // Link current block to `true` and `false` blocks within the selection.
1135     Block *trueBlock = condBrOp.getTrueBlock();
1136     Block *falseBlock = condBrOp.getFalseBlock();
1137     rewriter.setInsertionPointToEnd(currentBlock);
1138     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1139                                     condBrOp.trueTargetOperands(), falseBlock,
1140                                     condBrOp.falseTargetOperands());
1141 
1142     rewriter.inlineRegionBefore(op.body(), continueBlock);
1143     rewriter.replaceOp(op, continueBlock->getArguments());
1144     return success();
1145   }
1146 };
1147 
1148 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1149 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1150 /// `Shift` is zero or sign extended to match this specification. Cases when
1151 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1152 template <typename SPIRVOp, typename LLVMOp>
1153 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1154 public:
1155   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1156 
1157   LogicalResult
1158   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1159                   ConversionPatternRewriter &rewriter) const override {
1160 
1161     auto dstType = this->typeConverter.convertType(operation.getType());
1162     if (!dstType)
1163       return failure();
1164 
1165     Type op1Type = operation.operand1().getType();
1166     Type op2Type = operation.operand2().getType();
1167 
1168     if (op1Type == op2Type) {
1169       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1170                                                    operands);
1171       return success();
1172     }
1173 
1174     Location loc = operation.getLoc();
1175     Value extended;
1176     if (isUnsignedIntegerOrVector(op2Type)) {
1177       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1178                                                         operation.operand2());
1179     } else {
1180       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1181                                                         operation.operand2());
1182     }
1183     Value result = rewriter.template create<LLVMOp>(
1184         loc, dstType, operation.operand1(), extended);
1185     rewriter.replaceOp(operation, result);
1186     return success();
1187   }
1188 };
1189 
1190 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1191 public:
1192   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1193 
1194   LogicalResult
1195   matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1196                   ConversionPatternRewriter &rewriter) const override {
1197     auto dstType = typeConverter.convertType(tanOp.getType());
1198     if (!dstType)
1199       return failure();
1200 
1201     Location loc = tanOp.getLoc();
1202     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1203     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1204     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1205     return success();
1206   }
1207 };
1208 
1209 /// Convert `spv.Tanh` to
1210 ///
1211 ///   exp(2x) - 1
1212 ///   -----------
1213 ///   exp(2x) + 1
1214 ///
1215 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1216 public:
1217   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1218 
1219   LogicalResult
1220   matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1221                   ConversionPatternRewriter &rewriter) const override {
1222     auto srcType = tanhOp.getType();
1223     auto dstType = typeConverter.convertType(srcType);
1224     if (!dstType)
1225       return failure();
1226 
1227     Location loc = tanhOp.getLoc();
1228     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1229     Value multiplied =
1230         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1231     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1232     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1233     Value numerator =
1234         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1235     Value denominator =
1236         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1237     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1238                                               denominator);
1239     return success();
1240   }
1241 };
1242 
1243 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1244 public:
1245   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1246 
1247   LogicalResult
1248   matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1249                   ConversionPatternRewriter &rewriter) const override {
1250     auto srcType = varOp.getType();
1251     // Initialization is supported for scalars and vectors only.
1252     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1253     auto init = varOp.initializer();
1254     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1255       return failure();
1256 
1257     auto dstType = typeConverter.convertType(srcType);
1258     if (!dstType)
1259       return failure();
1260 
1261     Location loc = varOp.getLoc();
1262     Value size = createI32ConstantOf(loc, rewriter, 1);
1263     if (!init) {
1264       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1265       return success();
1266     }
1267     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1268     rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1269     rewriter.replaceOp(varOp, allocated);
1270     return success();
1271   }
1272 };
1273 
1274 //===----------------------------------------------------------------------===//
1275 // FuncOp conversion
1276 //===----------------------------------------------------------------------===//
1277 
1278 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1279 public:
1280   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1281 
1282   LogicalResult
1283   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1284                   ConversionPatternRewriter &rewriter) const override {
1285 
1286     // Convert function signature. At the moment LLVMType converter is enough
1287     // for currently supported types.
1288     auto funcType = funcOp.getType();
1289     TypeConverter::SignatureConversion signatureConverter(
1290         funcType.getNumInputs());
1291     auto llvmType = typeConverter.convertFunctionSignature(
1292         funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1293     if (!llvmType)
1294       return failure();
1295 
1296     // Create a new `LLVMFuncOp`
1297     Location loc = funcOp.getLoc();
1298     StringRef name = funcOp.getName();
1299     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1300 
1301     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1302     MLIRContext *context = funcOp.getContext();
1303     switch (funcOp.function_control()) {
1304 #define DISPATCH(functionControl, llvmAttr)                                    \
1305   case functionControl:                                                        \
1306     newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));    \
1307     break;
1308 
1309       DISPATCH(spirv::FunctionControl::Inline,
1310                StringAttr::get("alwaysinline", context));
1311       DISPATCH(spirv::FunctionControl::DontInline,
1312                StringAttr::get("noinline", context));
1313       DISPATCH(spirv::FunctionControl::Pure,
1314                StringAttr::get("readonly", context));
1315       DISPATCH(spirv::FunctionControl::Const,
1316                StringAttr::get("readnone", context));
1317 
1318 #undef DISPATCH
1319 
1320     // Default: if `spirv::FunctionControl::None`, then no attributes are
1321     // needed.
1322     default:
1323       break;
1324     }
1325 
1326     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1327                                 newFuncOp.end());
1328     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1329                                            &signatureConverter))) {
1330       return failure();
1331     }
1332     rewriter.eraseOp(funcOp);
1333     return success();
1334   }
1335 };
1336 
1337 //===----------------------------------------------------------------------===//
1338 // ModuleOp conversion
1339 //===----------------------------------------------------------------------===//
1340 
1341 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1342 public:
1343   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1344 
1345   LogicalResult
1346   matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1347                   ConversionPatternRewriter &rewriter) const override {
1348 
1349     auto newModuleOp =
1350         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1351     rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1352 
1353     // Remove the terminator block that was automatically added by builder
1354     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1355     rewriter.eraseOp(spvModuleOp);
1356     return success();
1357   }
1358 };
1359 
1360 class ModuleEndConversionPattern
1361     : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1362 public:
1363   using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1364 
1365   LogicalResult
1366   matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1367                   ConversionPatternRewriter &rewriter) const override {
1368 
1369     rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1370     return success();
1371   }
1372 };
1373 
1374 } // namespace
1375 
1376 //===----------------------------------------------------------------------===//
1377 // Pattern population
1378 //===----------------------------------------------------------------------===//
1379 
1380 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1381   typeConverter.addConversion([&](spirv::ArrayType type) {
1382     return convertArrayType(type, typeConverter);
1383   });
1384   typeConverter.addConversion([&](spirv::PointerType type) {
1385     return convertPointerType(type, typeConverter);
1386   });
1387   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1388     return convertRuntimeArrayType(type, typeConverter);
1389   });
1390   typeConverter.addConversion([&](spirv::StructType type) {
1391     return convertStructType(type, typeConverter);
1392   });
1393 }
1394 
1395 void mlir::populateSPIRVToLLVMConversionPatterns(
1396     MLIRContext *context, LLVMTypeConverter &typeConverter,
1397     OwningRewritePatternList &patterns) {
1398   patterns.insert<
1399       // Arithmetic ops
1400       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1401       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1402       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1403       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1404       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1405       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1406       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1407       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1408       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1409       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1410       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1411       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1412       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1413 
1414       // Bitwise ops
1415       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1416       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1417       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1418       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1419       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1420       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1421       NotPattern<spirv::NotOp>,
1422 
1423       // Cast ops
1424       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1425       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1426       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1427       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1428       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1429       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1430       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1431       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1432 
1433       // Comparison ops
1434       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1435       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1436       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1437       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1438       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1439       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1440       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1441       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1442       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1443       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1444       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1445                       LLVM::FCmpPredicate::uge>,
1446       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1447       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1448       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1449       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1450       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1451       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1452       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1453       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1454       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1455       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1456       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1457 
1458       // Constant op
1459       ConstantScalarAndVectorPattern,
1460 
1461       // Control Flow ops
1462       BranchConversionPattern, BranchConditionalConversionPattern,
1463       FunctionCallPattern, LoopPattern, SelectionPattern,
1464       ErasePattern<spirv::MergeOp>,
1465 
1466       // Entry points and execution mode are handled separately.
1467       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1468 
1469       // GLSL extended instruction set ops
1470       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1471       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1472       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1473       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1474       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1475       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1476       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1477       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1478       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1479       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1480       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1481       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1482       InverseSqrtPattern, TanPattern, TanhPattern,
1483 
1484       // Logical ops
1485       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1486       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1487       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1488       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1489       NotPattern<spirv::LogicalNotOp>,
1490 
1491       // Memory ops
1492       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1493       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1494       VariablePattern,
1495 
1496       // Miscellaneous ops
1497       CompositeExtractPattern, CompositeInsertPattern,
1498       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1499       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1500 
1501       // Shift ops
1502       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1503       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1504       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1505 
1506       // Return ops
1507       ReturnPattern, ReturnValuePattern>(context, typeConverter);
1508 }
1509 
1510 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1511     MLIRContext *context, LLVMTypeConverter &typeConverter,
1512     OwningRewritePatternList &patterns) {
1513   patterns.insert<FuncConversionPattern>(context, typeConverter);
1514 }
1515 
1516 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1517     MLIRContext *context, LLVMTypeConverter &typeConverter,
1518     OwningRewritePatternList &patterns) {
1519   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1520       context, typeConverter);
1521 }
1522 
1523 //===----------------------------------------------------------------------===//
1524 // Pre-conversion hooks
1525 //===----------------------------------------------------------------------===//
1526 
1527 /// Hook for descriptor set and binding number encoding.
1528 static constexpr StringRef kBinding = "binding";
1529 static constexpr StringRef kDescriptorSet = "descriptor_set";
1530 void mlir::encodeBindAttribute(ModuleOp module) {
1531   auto spvModules = module.getOps<spirv::ModuleOp>();
1532   for (auto spvModule : spvModules) {
1533     spvModule.walk([&](spirv::GlobalVariableOp op) {
1534       IntegerAttr descriptorSet =
1535           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1536       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1537       // For every global variable in the module, get the ones with descriptor
1538       // set and binding numbers.
1539       if (descriptorSet && binding) {
1540         // Encode these numbers into the variable's symbolic name. If the
1541         // SPIR-V module has a name, add it at the beginning.
1542         auto moduleAndName = spvModule.getName().hasValue()
1543                                  ? spvModule.getName().getValue().str() + "_" +
1544                                        op.sym_name().str()
1545                                  : op.sym_name().str();
1546         std::string name =
1547             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1548                           std::to_string(descriptorSet.getInt()),
1549                           std::to_string(binding.getInt()));
1550 
1551         // Replace all symbol uses and set the new symbol name. Finally, remove
1552         // descriptor set and binding attributes.
1553         if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1554           op.emitError("unable to replace all symbol uses for ") << name;
1555         SymbolTable::setSymbolName(op, name);
1556         op.removeAttr(kDescriptorSet);
1557         op.removeAttr(kBinding);
1558       }
1559     });
1560   }
1561 }
1562