1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38   if (type.isSignedInteger())
39     return true;
40   if (auto vecType = type.dyn_cast<VectorType>())
41     return vecType.getElementType().isSignedInteger();
42   return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47   if (type.isUnsignedInteger())
48     return true;
49   if (auto vecType = type.dyn_cast<VectorType>())
50     return vecType.getElementType().isUnsignedInteger();
51   return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57          "bitwidth is not supported for this type");
58   if (type.isIntOrFloat())
59     return type.getIntOrFloatBitWidth();
60   auto vecType = type.dyn_cast<VectorType>();
61   auto elementType = vecType.getElementType();
62   assert(elementType.isIntOrFloat() &&
63          "only integers and floats have a bitwidth");
64   return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(Type type) {
69   return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
70                                              : type)
71       .cast<IntegerType>()
72       .getWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77   if (auto vecType = type.dyn_cast<VectorType>()) {
78     auto integerType = vecType.getElementType().cast<IntegerType>();
79     return builder.getIntegerAttr(integerType, -1);
80   }
81   auto integerType = type.cast<IntegerType>();
82   return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87                                       PatternRewriter &rewriter) {
88   if (srcType.isa<VectorType>()) {
89     return rewriter.create<LLVM::ConstantOp>(
90         loc, dstType,
91         SplatElementsAttr::get(srcType.cast<ShapedType>(),
92                                minusOneIntegerAttribute(srcType, rewriter)));
93   }
94   return rewriter.create<LLVM::ConstantOp>(
95       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100                               PatternRewriter &rewriter, double value) {
101   if (auto vecType = srcType.dyn_cast<VectorType>()) {
102     auto floatType = vecType.getElementType().cast<FloatType>();
103     return rewriter.create<LLVM::ConstantOp>(
104         loc, dstType,
105         SplatElementsAttr::get(vecType,
106                                rewriter.getFloatAttr(floatType, value)));
107   }
108   auto floatType = srcType.cast<FloatType>();
109   return rewriter.create<LLVM::ConstantOp>(
110       loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 ///   - `BitFieldInsert`
115 ///   - `BitFieldSExtract`
116 ///   - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
119 static Value optionallyTruncateOrExtend(Location loc, Value value,
120                                         Type llvmType,
121                                         PatternRewriter &rewriter) {
122   auto srcType = value.getType();
123   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125                                ? getLLVMTypeBitWidth(srcType)
126                                : getBitWidth(srcType);
127 
128   if (valueBitWidth < targetBitWidth)
129     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130   // If the bit widths of `Count` and `Offset` are greater than the bit width
131   // of the target type, they are truncated. Truncation is safe since `Count`
132   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133   // both values can be expressed in 8 bits.
134   if (valueBitWidth > targetBitWidth)
135     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136   return value;
137 }
138 
139 /// Broadcasts the value to vector with `numElements` number of elements.
140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141                        LLVMTypeConverter &typeConverter,
142                        ConversionPatternRewriter &rewriter) {
143   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144   auto llvmVectorType = typeConverter.convertType(vectorType);
145   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147   for (unsigned i = 0; i < numElements; ++i) {
148     auto index = rewriter.create<LLVM::ConstantOp>(
149         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150     broadcasted = rewriter.create<LLVM::InsertElementOp>(
151         loc, llvmVectorType, broadcasted, toBroadcast, index);
152   }
153   return broadcasted;
154 }
155 
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
157 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
158                                  LLVMTypeConverter &typeConverter,
159                                  ConversionPatternRewriter &rewriter) {
160   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161     unsigned numElements = vectorType.getNumElements();
162     return broadcast(loc, value, numElements, typeConverter, rewriter);
163   }
164   return value;
165 }
166 
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 ///  - same number of elements as `Base`
172 ///  - each element has the type that is the same as the type of `Offset` or
173 ///    `Count`
174 ///  - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
177 static Value processCountOrOffset(Location loc, Value value, Type srcType,
178                                   Type dstType, LLVMTypeConverter &converter,
179                                   ConversionPatternRewriter &rewriter) {
180   Value broadcasted =
181       optionallyBroadcast(loc, value, srcType, converter, rewriter);
182   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184 
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static Optional<Type>
188 convertStructTypeWithOffset(spirv::StructType type,
189                             LLVMTypeConverter &converter) {
190   if (type != VulkanLayoutUtils::decorateType(type))
191     return llvm::None;
192 
193   auto elementsVector = llvm::to_vector<8>(
194       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195         return converter.convertType(elementType);
196       }));
197   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198                                           /*isPacked=*/false);
199 }
200 
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
202 static Type convertStructTypePacked(spirv::StructType type,
203                                     LLVMTypeConverter &converter) {
204   auto elementsVector = llvm::to_vector<8>(
205       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206         return converter.convertType(elementType);
207       }));
208   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209                                           /*isPacked=*/true);
210 }
211 
212 /// Creates LLVM dialect constant with the given value.
213 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
214                                  unsigned value) {
215   return rewriter.create<LLVM::ConstantOp>(
216       loc, IntegerType::get(rewriter.getContext(), 32),
217       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219 
220 /// Utility for `spv.Load` and `spv.Store` conversion.
221 static LogicalResult replaceWithLoadOrStore(Operation *op,
222                                             ConversionPatternRewriter &rewriter,
223                                             LLVMTypeConverter &typeConverter,
224                                             unsigned alignment, bool isVolatile,
225                                             bool isNonTemporal) {
226   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227     auto dstType = typeConverter.convertType(loadOp.getType());
228     if (!dstType)
229       return failure();
230     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231         loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
232     return success();
233   }
234   auto storeOp = cast<spirv::StoreOp>(op);
235   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
236                                              storeOp.ptr(), alignment,
237                                              isVolatile, isNonTemporal);
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Type conversion
243 //===----------------------------------------------------------------------===//
244 
245 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
246 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
247 /// when converting ops that manipulate array types.
248 static Optional<Type> convertArrayType(spirv::ArrayType type,
249                                        TypeConverter &converter) {
250   unsigned stride = type.getArrayStride();
251   Type elementType = type.getElementType();
252   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
253   if (stride != 0 &&
254       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
255     return llvm::None;
256 
257   auto llvmElementType = converter.convertType(elementType);
258   unsigned numElements = type.getNumElements();
259   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
260 }
261 
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
264 static Type convertPointerType(spirv::PointerType type,
265                                TypeConverter &converter) {
266   auto pointeeType = converter.convertType(type.getPointeeType());
267   return LLVM::LLVMPointerType::get(pointeeType);
268 }
269 
270 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
271 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
272 /// no modelling of array stride at the moment.
273 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
274                                               TypeConverter &converter) {
275   if (type.getArrayStride() != 0)
276     return llvm::None;
277   auto elementType = converter.convertType(type.getElementType());
278   return LLVM::LLVMArrayType::get(elementType, 0);
279 }
280 
281 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
282 /// member decorations. Also, only natural offset is supported.
283 static Optional<Type> convertStructType(spirv::StructType type,
284                                         LLVMTypeConverter &converter) {
285   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
286   type.getMemberDecorations(memberDecorations);
287   if (!memberDecorations.empty())
288     return llvm::None;
289   if (type.hasOffset())
290     return convertStructTypeWithOffset(type, converter);
291   return convertStructTypePacked(type, converter);
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Operation conversion
296 //===----------------------------------------------------------------------===//
297 
298 namespace {
299 
300 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
301 public:
302   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
303 
304   LogicalResult
305   matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto dstType = typeConverter.convertType(op.component_ptr().getType());
308     if (!dstType)
309       return failure();
310     // To use GEP we need to add a first 0 index to go through the pointer.
311     auto indices = llvm::to_vector<4>(op.indices());
312     Type indexType = op.indices().front().getType();
313     auto llvmIndexType = typeConverter.convertType(indexType);
314     if (!llvmIndexType)
315       return failure();
316     Value zero = rewriter.create<LLVM::ConstantOp>(
317         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
318     indices.insert(indices.begin(), zero);
319     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
320                                              indices);
321     return success();
322   }
323 };
324 
325 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
326 public:
327   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
328 
329   LogicalResult
330   matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto dstType = typeConverter.convertType(op.pointer().getType());
333     if (!dstType)
334       return failure();
335     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
336     return success();
337   }
338 };
339 
340 class BitFieldInsertPattern
341     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
342 public:
343   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
344 
345   LogicalResult
346   matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
347                   ConversionPatternRewriter &rewriter) const override {
348     auto srcType = op.getType();
349     auto dstType = typeConverter.convertType(srcType);
350     if (!dstType)
351       return failure();
352     Location loc = op.getLoc();
353 
354     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
355     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
356                                         typeConverter, rewriter);
357     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
358                                        typeConverter, rewriter);
359 
360     // Create a mask with bits set outside [Offset, Offset + Count - 1].
361     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
362     Value maskShiftedByCount =
363         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
364     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
365                                                  maskShiftedByCount, minusOne);
366     Value maskShiftedByCountAndOffset =
367         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
368     Value mask = rewriter.create<LLVM::XOrOp>(
369         loc, dstType, maskShiftedByCountAndOffset, minusOne);
370 
371     // Extract unchanged bits from the `Base`  that are outside of
372     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
373     Value baseAndMask =
374         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
375     Value insertShiftedByOffset =
376         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
377     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
378                                             insertShiftedByOffset);
379     return success();
380   }
381 };
382 
383 /// Converts SPIR-V ConstantOp with scalar or vector type.
384 class ConstantScalarAndVectorPattern
385     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
386 public:
387   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
388 
389   LogicalResult
390   matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const override {
392     auto srcType = constOp.getType();
393     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
394       return failure();
395 
396     auto dstType = typeConverter.convertType(srcType);
397     if (!dstType)
398       return failure();
399 
400     // SPIR-V constant can be a signed/unsigned integer, which has to be
401     // casted to signless integer when converting to LLVM dialect. Removing the
402     // sign bit may have unexpected behaviour. However, it is better to handle
403     // it case-by-case, given that the purpose of the conversion is not to
404     // cover all possible corner cases.
405     if (isSignedIntegerOrVector(srcType) ||
406         isUnsignedIntegerOrVector(srcType)) {
407       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
408 
409       if (srcType.isa<VectorType>()) {
410         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
411         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
412             constOp, dstType,
413             dstElementsAttr.mapValues(
414                 signlessType, [&](const APInt &value) { return value; }));
415         return success();
416       }
417       auto srcAttr = constOp.value().cast<IntegerAttr>();
418       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
419       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
420       return success();
421     }
422     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
423                                                   constOp.getAttrs());
424     return success();
425   }
426 };
427 
428 class BitFieldSExtractPattern
429     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
430 public:
431   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
432 
433   LogicalResult
434   matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
435                   ConversionPatternRewriter &rewriter) const override {
436     auto srcType = op.getType();
437     auto dstType = typeConverter.convertType(srcType);
438     if (!dstType)
439       return failure();
440     Location loc = op.getLoc();
441 
442     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
443     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
444                                         typeConverter, rewriter);
445     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
446                                        typeConverter, rewriter);
447 
448     // Create a constant that holds the size of the `Base`.
449     IntegerType integerType;
450     if (auto vecType = srcType.dyn_cast<VectorType>())
451       integerType = vecType.getElementType().cast<IntegerType>();
452     else
453       integerType = srcType.cast<IntegerType>();
454 
455     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
456     Value size =
457         srcType.isa<VectorType>()
458             ? rewriter.create<LLVM::ConstantOp>(
459                   loc, dstType,
460                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
461             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
462 
463     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
464     // at Offset + Count - 1 is the most significant bit now.
465     Value countPlusOffset =
466         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
467     Value amountToShiftLeft =
468         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
469     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
470         loc, dstType, op.base(), amountToShiftLeft);
471 
472     // Shift the result right, filling the bits with the sign bit.
473     Value amountToShiftRight =
474         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
475     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
476                                               amountToShiftRight);
477     return success();
478   }
479 };
480 
481 class BitFieldUExtractPattern
482     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
483 public:
484   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
485 
486   LogicalResult
487   matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto srcType = op.getType();
490     auto dstType = typeConverter.convertType(srcType);
491     if (!dstType)
492       return failure();
493     Location loc = op.getLoc();
494 
495     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
496     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
497                                         typeConverter, rewriter);
498     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
499                                        typeConverter, rewriter);
500 
501     // Create a mask with bits set at [0, Count - 1].
502     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
503     Value maskShiftedByCount =
504         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
505     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
506                                               minusOne);
507 
508     // Shift `Base` by `Offset` and apply the mask on it.
509     Value shiftedBase =
510         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
511     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
512     return success();
513   }
514 };
515 
516 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
517 public:
518   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
519 
520   LogicalResult
521   matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
522                   ConversionPatternRewriter &rewriter) const override {
523     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
524                                             branchOp.getTarget());
525     return success();
526   }
527 };
528 
529 class BranchConditionalConversionPattern
530     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
531 public:
532   using SPIRVToLLVMConversion<
533       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
534 
535   LogicalResult
536   matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
537                   ConversionPatternRewriter &rewriter) const override {
538     // If branch weights exist, map them to 32-bit integer vector.
539     ElementsAttr branchWeights = nullptr;
540     if (auto weights = op.branch_weights()) {
541       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
542       branchWeights =
543           DenseElementsAttr::get(weightType, weights.getValue().getValue());
544     }
545 
546     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
547         op, op.condition(), op.getTrueBlockArguments(),
548         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
549         op.getFalseBlock());
550     return success();
551   }
552 };
553 
554 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
555 /// is an aggregate type (struct or array). Otherwise, converts to
556 /// `llvm.extractelement` that operates on vectors.
557 class CompositeExtractPattern
558     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
559 public:
560   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
561 
562   LogicalResult
563   matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
564                   ConversionPatternRewriter &rewriter) const override {
565     auto dstType = this->typeConverter.convertType(op.getType());
566     if (!dstType)
567       return failure();
568 
569     Type containerType = op.composite().getType();
570     if (containerType.isa<VectorType>()) {
571       Location loc = op.getLoc();
572       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
573       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
574       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
575           op, dstType, op.composite(), index);
576       return success();
577     }
578     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
579         op, dstType, op.composite(), op.indices());
580     return success();
581   }
582 };
583 
584 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
585 /// is an aggregate type (struct or array). Otherwise, converts to
586 /// `llvm.insertelement` that operates on vectors.
587 class CompositeInsertPattern
588     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
589 public:
590   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
591 
592   LogicalResult
593   matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
594                   ConversionPatternRewriter &rewriter) const override {
595     auto dstType = this->typeConverter.convertType(op.getType());
596     if (!dstType)
597       return failure();
598 
599     Type containerType = op.composite().getType();
600     if (containerType.isa<VectorType>()) {
601       Location loc = op.getLoc();
602       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
603       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
604       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
605           op, dstType, op.composite(), op.object(), index);
606       return success();
607     }
608     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
609         op, dstType, op.composite(), op.object(), op.indices());
610     return success();
611   }
612 };
613 
614 /// Converts SPIR-V operations that have straightforward LLVM equivalent
615 /// into LLVM dialect operations.
616 template <typename SPIRVOp, typename LLVMOp>
617 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
618 public:
619   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
620 
621   LogicalResult
622   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
623                   ConversionPatternRewriter &rewriter) const override {
624     auto dstType = this->typeConverter.convertType(operation.getType());
625     if (!dstType)
626       return failure();
627     rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
628                                                  operation.getAttrs());
629     return success();
630   }
631 };
632 
633 /// Converts `spv.ExecutionMode` into a global struct constant that holds
634 /// execution mode information.
635 class ExecutionModePattern
636     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
637 public:
638   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
639 
640   LogicalResult
641   matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
642                   ConversionPatternRewriter &rewriter) const override {
643     // First, create the global struct's name that would be associated with
644     // this entry point's execution mode. We set it to be:
645     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info
646     ModuleOp module = op->getParentOfType<ModuleOp>();
647     std::string moduleName;
648     if (module.getName().hasValue())
649       moduleName = "_" + module.getName().getValue().str();
650     else
651       moduleName = "";
652     std::string executionModeInfoName = llvm::formatv(
653         "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str());
654 
655     MLIRContext *context = rewriter.getContext();
656     OpBuilder::InsertionGuard guard(rewriter);
657     rewriter.setInsertionPointToStart(module.getBody());
658 
659     // Create a struct type, corresponding to the C struct below.
660     // struct {
661     //   int32_t executionMode;
662     //   int32_t values[];          // optional values
663     // };
664     auto llvmI32Type = IntegerType::get(context, 32);
665     SmallVector<Type, 2> fields;
666     fields.push_back(llvmI32Type);
667     ArrayAttr values = op.values();
668     if (!values.empty()) {
669       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
670       fields.push_back(arrayType);
671     }
672     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
673 
674     // Create `llvm.mlir.global` with initializer region containing one block.
675     auto global = rewriter.create<LLVM::GlobalOp>(
676         UnknownLoc::get(context), structType, /*isConstant=*/true,
677         LLVM::Linkage::External, executionModeInfoName, Attribute());
678     Location loc = global.getLoc();
679     Region &region = global.getInitializerRegion();
680     Block *block = rewriter.createBlock(&region);
681 
682     // Initialize the struct and set the execution mode value.
683     rewriter.setInsertionPoint(block, block->begin());
684     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
685     IntegerAttr executionModeAttr = op.execution_modeAttr();
686     Value executionMode =
687         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
688     structValue = rewriter.create<LLVM::InsertValueOp>(
689         loc, structType, structValue, executionMode,
690         ArrayAttr::get(context,
691                        {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
692 
693     // Insert extra operands if they exist into execution mode info struct.
694     for (unsigned i = 0, e = values.size(); i < e; ++i) {
695       auto attr = values.getValue()[i];
696       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
697       structValue = rewriter.create<LLVM::InsertValueOp>(
698           loc, structType, structValue, entry,
699           ArrayAttr::get(context,
700                          {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
701                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
702     }
703     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
704     rewriter.eraseOp(op);
705     return success();
706   }
707 };
708 
709 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
710 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
711 /// This difference is handled by `spv.mlir.addressof` and
712 /// `llvm.mlir.addressof`ops that both return a pointer.
713 class GlobalVariablePattern
714     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
715 public:
716   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
717 
718   LogicalResult
719   matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
720                   ConversionPatternRewriter &rewriter) const override {
721     // Currently, there is no support of initialization with a constant value in
722     // SPIR-V dialect. Specialization constants are not considered as well.
723     if (op.initializer())
724       return failure();
725 
726     auto srcType = op.type().cast<spirv::PointerType>();
727     auto dstType = typeConverter.convertType(srcType.getPointeeType());
728     if (!dstType)
729       return failure();
730 
731     // Limit conversion to the current invocation only or `StorageBuffer`
732     // required by SPIR-V runner.
733     // This is okay because multiple invocations are not supported yet.
734     auto storageClass = srcType.getStorageClass();
735     if (storageClass != spirv::StorageClass::Input &&
736         storageClass != spirv::StorageClass::Private &&
737         storageClass != spirv::StorageClass::Output &&
738         storageClass != spirv::StorageClass::StorageBuffer) {
739       return failure();
740     }
741 
742     // LLVM dialect spec: "If the global value is a constant, storing into it is
743     // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
744     // read-only.
745     bool isConstant = storageClass == spirv::StorageClass::Input;
746     // SPIR-V spec: "By default, functions and global variables are private to a
747     // module and cannot be accessed by other modules. However, a module may be
748     // written to export or import functions and global (module scope)
749     // variables.". Therefore, map 'Private' storage class to private linkage,
750     // 'Input' and 'Output' to external linkage.
751     auto linkage = storageClass == spirv::StorageClass::Private
752                        ? LLVM::Linkage::Private
753                        : LLVM::Linkage::External;
754     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
755         op, dstType, isConstant, linkage, op.sym_name(), Attribute());
756     return success();
757   }
758 };
759 
760 /// Converts SPIR-V cast ops that do not have straightforward LLVM
761 /// equivalent in LLVM dialect.
762 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
763 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
764 public:
765   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
766 
767   LogicalResult
768   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
769                   ConversionPatternRewriter &rewriter) const override {
770 
771     Type fromType = operation.operand().getType();
772     Type toType = operation.getType();
773 
774     auto dstType = this->typeConverter.convertType(toType);
775     if (!dstType)
776       return failure();
777 
778     if (getBitWidth(fromType) < getBitWidth(toType)) {
779       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
780                                                       operands);
781       return success();
782     }
783     if (getBitWidth(fromType) > getBitWidth(toType)) {
784       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
785                                                         operands);
786       return success();
787     }
788     return failure();
789   }
790 };
791 
792 class FunctionCallPattern
793     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
794 public:
795   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
796 
797   LogicalResult
798   matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
799                   ConversionPatternRewriter &rewriter) const override {
800     if (callOp.getNumResults() == 0) {
801       rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
802                                                 callOp.getAttrs());
803       return success();
804     }
805 
806     // Function returns a single result.
807     auto dstType = typeConverter.convertType(callOp.getType(0));
808     rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
809                                               callOp.getAttrs());
810     return success();
811   }
812 };
813 
814 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
815 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
816 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
817 public:
818   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
819 
820   LogicalResult
821   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
822                   ConversionPatternRewriter &rewriter) const override {
823 
824     auto dstType = this->typeConverter.convertType(operation.getType());
825     if (!dstType)
826       return failure();
827 
828     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
829         operation, dstType,
830         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
831         operation.operand1(), operation.operand2(),
832         LLVM::FMFAttr::get({}, operation.getContext()));
833     return success();
834   }
835 };
836 
837 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
838 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
839 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
840 public:
841   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
842 
843   LogicalResult
844   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
845                   ConversionPatternRewriter &rewriter) const override {
846 
847     auto dstType = this->typeConverter.convertType(operation.getType());
848     if (!dstType)
849       return failure();
850 
851     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
852         operation, dstType,
853         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
854         operation.operand1(), operation.operand2());
855     return success();
856   }
857 };
858 
859 class InverseSqrtPattern
860     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
861 public:
862   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
863 
864   LogicalResult
865   matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
866                   ConversionPatternRewriter &rewriter) const override {
867     auto srcType = op.getType();
868     auto dstType = typeConverter.convertType(srcType);
869     if (!dstType)
870       return failure();
871 
872     Location loc = op.getLoc();
873     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
874     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
875     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
876     return success();
877   }
878 };
879 
880 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
881 template <typename SPIRVop>
882 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
883 public:
884   using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
885 
886   LogicalResult
887   matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
888                   ConversionPatternRewriter &rewriter) const override {
889 
890     if (!op.memory_access().hasValue()) {
891       return replaceWithLoadOrStore(
892           op, rewriter, this->typeConverter, /*alignment=*/0,
893           /*isVolatile=*/false, /*isNonTemporal=*/false);
894     }
895     auto memoryAccess = op.memory_access().getValue();
896     switch (memoryAccess) {
897     case spirv::MemoryAccess::Aligned:
898     case spirv::MemoryAccess::None:
899     case spirv::MemoryAccess::Nontemporal:
900     case spirv::MemoryAccess::Volatile: {
901       unsigned alignment =
902           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
903       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
904       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
905       return replaceWithLoadOrStore(op, rewriter, this->typeConverter,
906                                     alignment, isVolatile, isNonTemporal);
907     }
908     default:
909       // There is no support of other memory access attributes.
910       return failure();
911     }
912   }
913 };
914 
915 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
916 template <typename SPIRVOp>
917 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
918 public:
919   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
920 
921   LogicalResult
922   matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
923                   ConversionPatternRewriter &rewriter) const override {
924 
925     auto srcType = notOp.getType();
926     auto dstType = this->typeConverter.convertType(srcType);
927     if (!dstType)
928       return failure();
929 
930     Location loc = notOp.getLoc();
931     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
932     auto mask = srcType.template isa<VectorType>()
933                     ? rewriter.create<LLVM::ConstantOp>(
934                           loc, dstType,
935                           SplatElementsAttr::get(
936                               srcType.template cast<VectorType>(), minusOne))
937                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
938     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
939                                                       notOp.operand(), mask);
940     return success();
941   }
942 };
943 
944 /// A template pattern that erases the given `SPIRVOp`.
945 template <typename SPIRVOp>
946 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
947 public:
948   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
949 
950   LogicalResult
951   matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
952                   ConversionPatternRewriter &rewriter) const override {
953     rewriter.eraseOp(op);
954     return success();
955   }
956 };
957 
958 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
959 public:
960   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
961 
962   LogicalResult
963   matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
964                   ConversionPatternRewriter &rewriter) const override {
965     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
966                                                 ArrayRef<Value>());
967     return success();
968   }
969 };
970 
971 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
972 public:
973   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
974 
975   LogicalResult
976   matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
977                   ConversionPatternRewriter &rewriter) const override {
978     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
979                                                 operands);
980     return success();
981   }
982 };
983 
984 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
985 /// reachable for conversion to succeed.
986 /// The structure of the loop in LLVM dialect will be the following:
987 ///
988 ///      +------------------------------------+
989 ///      | <code before spv.loop>             |
990 ///      | llvm.br ^header                    |
991 ///      +------------------------------------+
992 ///                           |
993 ///   +----------------+      |
994 ///   |                |      |
995 ///   |                V      V
996 ///   |  +------------------------------------+
997 ///   |  | ^header:                           |
998 ///   |  |   <header code>                    |
999 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1000 ///   |  +------------------------------------+
1001 ///   |                    |
1002 ///   |                    |----------------------+
1003 ///   |                    |                      |
1004 ///   |                    V                      |
1005 ///   |  +------------------------------------+   |
1006 ///   |  | ^body:                             |   |
1007 ///   |  |   <body code>                      |   |
1008 ///   |  |   llvm.br ^continue                |   |
1009 ///   |  +------------------------------------+   |
1010 ///   |                    |                      |
1011 ///   |                    V                      |
1012 ///   |  +------------------------------------+   |
1013 ///   |  | ^continue:                         |   |
1014 ///   |  |   <continue code>                  |   |
1015 ///   |  |   llvm.br ^header                  |   |
1016 ///   |  +------------------------------------+   |
1017 ///   |               |                           |
1018 ///   +---------------+    +----------------------+
1019 ///                        |
1020 ///                        V
1021 ///      +------------------------------------+
1022 ///      | ^exit:                             |
1023 ///      |   llvm.br ^remaining               |
1024 ///      +------------------------------------+
1025 ///                        |
1026 ///                        V
1027 ///      +------------------------------------+
1028 ///      | ^remaining:                        |
1029 ///      |   <code after spv.loop>            |
1030 ///      +------------------------------------+
1031 ///
1032 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1033 public:
1034   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1035 
1036   LogicalResult
1037   matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
1038                   ConversionPatternRewriter &rewriter) const override {
1039     // There is no support of loop control at the moment.
1040     if (loopOp.loop_control() != spirv::LoopControl::None)
1041       return failure();
1042 
1043     Location loc = loopOp.getLoc();
1044 
1045     // Split the current block after `spv.loop`. The remaining ops will be used
1046     // in `endBlock`.
1047     Block *currentBlock = rewriter.getBlock();
1048     auto position = Block::iterator(loopOp);
1049     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1050 
1051     // Remove entry block and create a branch in the current block going to the
1052     // header block.
1053     Block *entryBlock = loopOp.getEntryBlock();
1054     assert(entryBlock->getOperations().size() == 1);
1055     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1056     if (!brOp)
1057       return failure();
1058     Block *headerBlock = loopOp.getHeaderBlock();
1059     rewriter.setInsertionPointToEnd(currentBlock);
1060     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1061     rewriter.eraseBlock(entryBlock);
1062 
1063     // Branch from merge block to end block.
1064     Block *mergeBlock = loopOp.getMergeBlock();
1065     Operation *terminator = mergeBlock->getTerminator();
1066     ValueRange terminatorOperands = terminator->getOperands();
1067     rewriter.setInsertionPointToEnd(mergeBlock);
1068     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1069 
1070     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1071     rewriter.replaceOp(loopOp, endBlock->getArguments());
1072     return success();
1073   }
1074 };
1075 
1076 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1077 /// All blocks within selection should be reachable for conversion to succeed.
1078 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1079 public:
1080   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1081 
1082   LogicalResult
1083   matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1084                   ConversionPatternRewriter &rewriter) const override {
1085     // There is no support for `Flatten` or `DontFlatten` selection control at
1086     // the moment. This are just compiler hints and can be performed during the
1087     // optimization passes.
1088     if (op.selection_control() != spirv::SelectionControl::None)
1089       return failure();
1090 
1091     // `spv.selection` should have at least two blocks: one selection header
1092     // block and one merge block. If no blocks are present, or control flow
1093     // branches straight to merge block (two blocks are present), the op is
1094     // redundant and it is erased.
1095     if (op.body().getBlocks().size() <= 2) {
1096       rewriter.eraseOp(op);
1097       return success();
1098     }
1099 
1100     Location loc = op.getLoc();
1101 
1102     // Split the current block after `spv.selection`. The remaining ops will be
1103     // used in `continueBlock`.
1104     auto *currentBlock = rewriter.getInsertionBlock();
1105     rewriter.setInsertionPointAfter(op);
1106     auto position = rewriter.getInsertionPoint();
1107     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1108 
1109     // Extract conditional branch information from the header block. By SPIR-V
1110     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1111     // op. Note that `spv.Switch op` is not supported at the moment in the
1112     // SPIR-V dialect. Remove this block when finished.
1113     auto *headerBlock = op.getHeaderBlock();
1114     assert(headerBlock->getOperations().size() == 1);
1115     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1116         headerBlock->getOperations().front());
1117     if (!condBrOp)
1118       return failure();
1119     rewriter.eraseBlock(headerBlock);
1120 
1121     // Branch from merge block to continue block.
1122     auto *mergeBlock = op.getMergeBlock();
1123     Operation *terminator = mergeBlock->getTerminator();
1124     ValueRange terminatorOperands = terminator->getOperands();
1125     rewriter.setInsertionPointToEnd(mergeBlock);
1126     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1127 
1128     // Link current block to `true` and `false` blocks within the selection.
1129     Block *trueBlock = condBrOp.getTrueBlock();
1130     Block *falseBlock = condBrOp.getFalseBlock();
1131     rewriter.setInsertionPointToEnd(currentBlock);
1132     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1133                                     condBrOp.trueTargetOperands(), falseBlock,
1134                                     condBrOp.falseTargetOperands());
1135 
1136     rewriter.inlineRegionBefore(op.body(), continueBlock);
1137     rewriter.replaceOp(op, continueBlock->getArguments());
1138     return success();
1139   }
1140 };
1141 
1142 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1143 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1144 /// `Shift` is zero or sign extended to match this specification. Cases when
1145 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1146 template <typename SPIRVOp, typename LLVMOp>
1147 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1148 public:
1149   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1150 
1151   LogicalResult
1152   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1153                   ConversionPatternRewriter &rewriter) const override {
1154 
1155     auto dstType = this->typeConverter.convertType(operation.getType());
1156     if (!dstType)
1157       return failure();
1158 
1159     Type op1Type = operation.operand1().getType();
1160     Type op2Type = operation.operand2().getType();
1161 
1162     if (op1Type == op2Type) {
1163       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1164                                                    operands);
1165       return success();
1166     }
1167 
1168     Location loc = operation.getLoc();
1169     Value extended;
1170     if (isUnsignedIntegerOrVector(op2Type)) {
1171       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1172                                                         operation.operand2());
1173     } else {
1174       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1175                                                         operation.operand2());
1176     }
1177     Value result = rewriter.template create<LLVMOp>(
1178         loc, dstType, operation.operand1(), extended);
1179     rewriter.replaceOp(operation, result);
1180     return success();
1181   }
1182 };
1183 
1184 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1185 public:
1186   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1187 
1188   LogicalResult
1189   matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1190                   ConversionPatternRewriter &rewriter) const override {
1191     auto dstType = typeConverter.convertType(tanOp.getType());
1192     if (!dstType)
1193       return failure();
1194 
1195     Location loc = tanOp.getLoc();
1196     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1197     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1198     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1199     return success();
1200   }
1201 };
1202 
1203 /// Convert `spv.Tanh` to
1204 ///
1205 ///   exp(2x) - 1
1206 ///   -----------
1207 ///   exp(2x) + 1
1208 ///
1209 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1210 public:
1211   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1212 
1213   LogicalResult
1214   matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1215                   ConversionPatternRewriter &rewriter) const override {
1216     auto srcType = tanhOp.getType();
1217     auto dstType = typeConverter.convertType(srcType);
1218     if (!dstType)
1219       return failure();
1220 
1221     Location loc = tanhOp.getLoc();
1222     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1223     Value multiplied =
1224         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1225     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1226     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1227     Value numerator =
1228         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1229     Value denominator =
1230         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1231     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1232                                               denominator);
1233     return success();
1234   }
1235 };
1236 
1237 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1238 public:
1239   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1240 
1241   LogicalResult
1242   matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1243                   ConversionPatternRewriter &rewriter) const override {
1244     auto srcType = varOp.getType();
1245     // Initialization is supported for scalars and vectors only.
1246     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1247     auto init = varOp.initializer();
1248     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1249       return failure();
1250 
1251     auto dstType = typeConverter.convertType(srcType);
1252     if (!dstType)
1253       return failure();
1254 
1255     Location loc = varOp.getLoc();
1256     Value size = createI32ConstantOf(loc, rewriter, 1);
1257     if (!init) {
1258       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1259       return success();
1260     }
1261     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1262     rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1263     rewriter.replaceOp(varOp, allocated);
1264     return success();
1265   }
1266 };
1267 
1268 //===----------------------------------------------------------------------===//
1269 // FuncOp conversion
1270 //===----------------------------------------------------------------------===//
1271 
1272 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1273 public:
1274   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1275 
1276   LogicalResult
1277   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1278                   ConversionPatternRewriter &rewriter) const override {
1279 
1280     // Convert function signature. At the moment LLVMType converter is enough
1281     // for currently supported types.
1282     auto funcType = funcOp.getType();
1283     TypeConverter::SignatureConversion signatureConverter(
1284         funcType.getNumInputs());
1285     auto llvmType = typeConverter.convertFunctionSignature(
1286         funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1287     if (!llvmType)
1288       return failure();
1289 
1290     // Create a new `LLVMFuncOp`
1291     Location loc = funcOp.getLoc();
1292     StringRef name = funcOp.getName();
1293     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1294 
1295     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1296     MLIRContext *context = funcOp.getContext();
1297     switch (funcOp.function_control()) {
1298 #define DISPATCH(functionControl, llvmAttr)                                    \
1299   case functionControl:                                                        \
1300     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1301     break;
1302 
1303       DISPATCH(spirv::FunctionControl::Inline,
1304                StringAttr::get(context, "alwaysinline"));
1305       DISPATCH(spirv::FunctionControl::DontInline,
1306                StringAttr::get(context, "noinline"));
1307       DISPATCH(spirv::FunctionControl::Pure,
1308                StringAttr::get(context, "readonly"));
1309       DISPATCH(spirv::FunctionControl::Const,
1310                StringAttr::get(context, "readnone"));
1311 
1312 #undef DISPATCH
1313 
1314     // Default: if `spirv::FunctionControl::None`, then no attributes are
1315     // needed.
1316     default:
1317       break;
1318     }
1319 
1320     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1321                                 newFuncOp.end());
1322     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1323                                            &signatureConverter))) {
1324       return failure();
1325     }
1326     rewriter.eraseOp(funcOp);
1327     return success();
1328   }
1329 };
1330 
1331 //===----------------------------------------------------------------------===//
1332 // ModuleOp conversion
1333 //===----------------------------------------------------------------------===//
1334 
1335 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1336 public:
1337   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1338 
1339   LogicalResult
1340   matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1341                   ConversionPatternRewriter &rewriter) const override {
1342 
1343     auto newModuleOp =
1344         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1345     rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1346 
1347     // Remove the terminator block that was automatically added by builder
1348     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1349     rewriter.eraseOp(spvModuleOp);
1350     return success();
1351   }
1352 };
1353 
1354 class ModuleEndConversionPattern
1355     : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1356 public:
1357   using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1358 
1359   LogicalResult
1360   matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1361                   ConversionPatternRewriter &rewriter) const override {
1362 
1363     rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1364     return success();
1365   }
1366 };
1367 
1368 } // namespace
1369 
1370 //===----------------------------------------------------------------------===//
1371 // Pattern population
1372 //===----------------------------------------------------------------------===//
1373 
1374 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1375   typeConverter.addConversion([&](spirv::ArrayType type) {
1376     return convertArrayType(type, typeConverter);
1377   });
1378   typeConverter.addConversion([&](spirv::PointerType type) {
1379     return convertPointerType(type, typeConverter);
1380   });
1381   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1382     return convertRuntimeArrayType(type, typeConverter);
1383   });
1384   typeConverter.addConversion([&](spirv::StructType type) {
1385     return convertStructType(type, typeConverter);
1386   });
1387 }
1388 
1389 void mlir::populateSPIRVToLLVMConversionPatterns(
1390     MLIRContext *context, LLVMTypeConverter &typeConverter,
1391     OwningRewritePatternList &patterns) {
1392   patterns.insert<
1393       // Arithmetic ops
1394       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1395       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1396       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1397       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1398       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1399       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1400       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1401       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1402       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1403       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1404       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1405       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1406       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1407 
1408       // Bitwise ops
1409       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1410       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1411       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1412       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1413       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1414       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1415       NotPattern<spirv::NotOp>,
1416 
1417       // Cast ops
1418       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1419       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1420       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1421       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1422       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1423       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1424       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1425       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1426 
1427       // Comparison ops
1428       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1429       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1430       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1431       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1432       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1433       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1434       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1435       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1436       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1437       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1438       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1439                       LLVM::FCmpPredicate::uge>,
1440       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1441       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1442       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1443       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1444       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1445       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1446       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1447       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1448       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1449       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1450       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1451 
1452       // Constant op
1453       ConstantScalarAndVectorPattern,
1454 
1455       // Control Flow ops
1456       BranchConversionPattern, BranchConditionalConversionPattern,
1457       FunctionCallPattern, LoopPattern, SelectionPattern,
1458       ErasePattern<spirv::MergeOp>,
1459 
1460       // Entry points and execution mode are handled separately.
1461       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1462 
1463       // GLSL extended instruction set ops
1464       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1465       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1466       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1467       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1468       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1469       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1470       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1471       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1472       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1473       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1474       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1475       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1476       InverseSqrtPattern, TanPattern, TanhPattern,
1477 
1478       // Logical ops
1479       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1480       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1481       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1482       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1483       NotPattern<spirv::LogicalNotOp>,
1484 
1485       // Memory ops
1486       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1487       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1488       VariablePattern,
1489 
1490       // Miscellaneous ops
1491       CompositeExtractPattern, CompositeInsertPattern,
1492       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1493       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1494 
1495       // Shift ops
1496       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1497       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1498       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1499 
1500       // Return ops
1501       ReturnPattern, ReturnValuePattern>(context, typeConverter);
1502 }
1503 
1504 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1505     MLIRContext *context, LLVMTypeConverter &typeConverter,
1506     OwningRewritePatternList &patterns) {
1507   patterns.insert<FuncConversionPattern>(context, typeConverter);
1508 }
1509 
1510 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1511     MLIRContext *context, LLVMTypeConverter &typeConverter,
1512     OwningRewritePatternList &patterns) {
1513   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1514       context, typeConverter);
1515 }
1516 
1517 //===----------------------------------------------------------------------===//
1518 // Pre-conversion hooks
1519 //===----------------------------------------------------------------------===//
1520 
1521 /// Hook for descriptor set and binding number encoding.
1522 static constexpr StringRef kBinding = "binding";
1523 static constexpr StringRef kDescriptorSet = "descriptor_set";
1524 void mlir::encodeBindAttribute(ModuleOp module) {
1525   auto spvModules = module.getOps<spirv::ModuleOp>();
1526   for (auto spvModule : spvModules) {
1527     spvModule.walk([&](spirv::GlobalVariableOp op) {
1528       IntegerAttr descriptorSet =
1529           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1530       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1531       // For every global variable in the module, get the ones with descriptor
1532       // set and binding numbers.
1533       if (descriptorSet && binding) {
1534         // Encode these numbers into the variable's symbolic name. If the
1535         // SPIR-V module has a name, add it at the beginning.
1536         auto moduleAndName = spvModule.getName().hasValue()
1537                                  ? spvModule.getName().getValue().str() + "_" +
1538                                        op.sym_name().str()
1539                                  : op.sym_name().str();
1540         std::string name =
1541             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1542                           std::to_string(descriptorSet.getInt()),
1543                           std::to_string(binding.getInt()));
1544 
1545         // Replace all symbol uses and set the new symbol name. Finally, remove
1546         // descriptor set and binding attributes.
1547         if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1548           op.emitError("unable to replace all symbol uses for ") << name;
1549         SymbolTable::setSymbolName(op, name);
1550         op.removeAttr(kDescriptorSet);
1551         op.removeAttr(kBinding);
1552       }
1553     });
1554   }
1555 }
1556