1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38   if (type.isSignedInteger())
39     return true;
40   if (auto vecType = type.dyn_cast<VectorType>())
41     return vecType.getElementType().isSignedInteger();
42   return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47   if (type.isUnsignedInteger())
48     return true;
49   if (auto vecType = type.dyn_cast<VectorType>())
50     return vecType.getElementType().isUnsignedInteger();
51   return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57          "bitwidth is not supported for this type");
58   if (type.isIntOrFloat())
59     return type.getIntOrFloatBitWidth();
60   auto vecType = type.dyn_cast<VectorType>();
61   auto elementType = vecType.getElementType();
62   assert(elementType.isIntOrFloat() &&
63          "only integers and floats have a bitwidth");
64   return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(Type type) {
69   return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
70                                              : type)
71       .cast<IntegerType>()
72       .getWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77   if (auto vecType = type.dyn_cast<VectorType>()) {
78     auto integerType = vecType.getElementType().cast<IntegerType>();
79     return builder.getIntegerAttr(integerType, -1);
80   }
81   auto integerType = type.cast<IntegerType>();
82   return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87                                       PatternRewriter &rewriter) {
88   if (srcType.isa<VectorType>()) {
89     return rewriter.create<LLVM::ConstantOp>(
90         loc, dstType,
91         SplatElementsAttr::get(srcType.cast<ShapedType>(),
92                                minusOneIntegerAttribute(srcType, rewriter)));
93   }
94   return rewriter.create<LLVM::ConstantOp>(
95       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100                               PatternRewriter &rewriter, double value) {
101   if (auto vecType = srcType.dyn_cast<VectorType>()) {
102     auto floatType = vecType.getElementType().cast<FloatType>();
103     return rewriter.create<LLVM::ConstantOp>(
104         loc, dstType,
105         SplatElementsAttr::get(vecType,
106                                rewriter.getFloatAttr(floatType, value)));
107   }
108   auto floatType = srcType.cast<FloatType>();
109   return rewriter.create<LLVM::ConstantOp>(
110       loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 ///   - `BitFieldInsert`
115 ///   - `BitFieldSExtract`
116 ///   - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
119 static Value optionallyTruncateOrExtend(Location loc, Value value,
120                                         Type llvmType,
121                                         PatternRewriter &rewriter) {
122   auto srcType = value.getType();
123   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125                                ? getLLVMTypeBitWidth(srcType)
126                                : getBitWidth(srcType);
127 
128   if (valueBitWidth < targetBitWidth)
129     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130   // If the bit widths of `Count` and `Offset` are greater than the bit width
131   // of the target type, they are truncated. Truncation is safe since `Count`
132   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133   // both values can be expressed in 8 bits.
134   if (valueBitWidth > targetBitWidth)
135     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136   return value;
137 }
138 
139 /// Broadcasts the value to vector with `numElements` number of elements.
140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141                        LLVMTypeConverter &typeConverter,
142                        ConversionPatternRewriter &rewriter) {
143   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144   auto llvmVectorType = typeConverter.convertType(vectorType);
145   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147   for (unsigned i = 0; i < numElements; ++i) {
148     auto index = rewriter.create<LLVM::ConstantOp>(
149         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150     broadcasted = rewriter.create<LLVM::InsertElementOp>(
151         loc, llvmVectorType, broadcasted, toBroadcast, index);
152   }
153   return broadcasted;
154 }
155 
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
157 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
158                                  LLVMTypeConverter &typeConverter,
159                                  ConversionPatternRewriter &rewriter) {
160   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161     unsigned numElements = vectorType.getNumElements();
162     return broadcast(loc, value, numElements, typeConverter, rewriter);
163   }
164   return value;
165 }
166 
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 ///  - same number of elements as `Base`
172 ///  - each element has the type that is the same as the type of `Offset` or
173 ///    `Count`
174 ///  - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
177 static Value processCountOrOffset(Location loc, Value value, Type srcType,
178                                   Type dstType, LLVMTypeConverter &converter,
179                                   ConversionPatternRewriter &rewriter) {
180   Value broadcasted =
181       optionallyBroadcast(loc, value, srcType, converter, rewriter);
182   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184 
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static Optional<Type>
188 convertStructTypeWithOffset(spirv::StructType type,
189                             LLVMTypeConverter &converter) {
190   if (type != VulkanLayoutUtils::decorateType(type))
191     return llvm::None;
192 
193   auto elementsVector = llvm::to_vector<8>(
194       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195         return converter.convertType(elementType);
196       }));
197   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198                                           /*isPacked=*/false);
199 }
200 
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
202 static Type convertStructTypePacked(spirv::StructType type,
203                                     LLVMTypeConverter &converter) {
204   auto elementsVector = llvm::to_vector<8>(
205       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206         return converter.convertType(elementType);
207       }));
208   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209                                           /*isPacked=*/true);
210 }
211 
212 /// Creates LLVM dialect constant with the given value.
213 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
214                                  unsigned value) {
215   return rewriter.create<LLVM::ConstantOp>(
216       loc, IntegerType::get(rewriter.getContext(), 32),
217       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219 
220 /// Utility for `spv.Load` and `spv.Store` conversion.
221 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
222                                             ConversionPatternRewriter &rewriter,
223                                             LLVMTypeConverter &typeConverter,
224                                             unsigned alignment, bool isVolatile,
225                                             bool isNonTemporal) {
226   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227     auto dstType = typeConverter.convertType(loadOp.getType());
228     if (!dstType)
229       return failure();
230     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231         loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
232         isVolatile, isNonTemporal);
233     return success();
234   }
235   auto storeOp = cast<spirv::StoreOp>(op);
236   spirv::StoreOpAdaptor adaptor(operands);
237   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
238                                              adaptor.ptr(), alignment,
239                                              isVolatile, isNonTemporal);
240   return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Type conversion
245 //===----------------------------------------------------------------------===//
246 
247 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
248 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
249 /// when converting ops that manipulate array types.
250 static Optional<Type> convertArrayType(spirv::ArrayType type,
251                                        TypeConverter &converter) {
252   unsigned stride = type.getArrayStride();
253   Type elementType = type.getElementType();
254   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
255   if (stride != 0 &&
256       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
257     return llvm::None;
258 
259   auto llvmElementType = converter.convertType(elementType);
260   unsigned numElements = type.getNumElements();
261   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
262 }
263 
264 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
265 /// modelled at the moment.
266 static Type convertPointerType(spirv::PointerType type,
267                                TypeConverter &converter) {
268   auto pointeeType = converter.convertType(type.getPointeeType());
269   return LLVM::LLVMPointerType::get(pointeeType);
270 }
271 
272 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
273 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
274 /// no modelling of array stride at the moment.
275 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
276                                               TypeConverter &converter) {
277   if (type.getArrayStride() != 0)
278     return llvm::None;
279   auto elementType = converter.convertType(type.getElementType());
280   return LLVM::LLVMArrayType::get(elementType, 0);
281 }
282 
283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
284 /// member decorations. Also, only natural offset is supported.
285 static Optional<Type> convertStructType(spirv::StructType type,
286                                         LLVMTypeConverter &converter) {
287   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
288   type.getMemberDecorations(memberDecorations);
289   if (!memberDecorations.empty())
290     return llvm::None;
291   if (type.hasOffset())
292     return convertStructTypeWithOffset(type, converter);
293   return convertStructTypePacked(type, converter);
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // Operation conversion
298 //===----------------------------------------------------------------------===//
299 
300 namespace {
301 
302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
303 public:
304   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
305 
306   LogicalResult
307   matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto dstType = typeConverter.convertType(op.component_ptr().getType());
310     if (!dstType)
311       return failure();
312     // To use GEP we need to add a first 0 index to go through the pointer.
313     auto indices = llvm::to_vector<4>(adaptor.indices());
314     Type indexType = op.indices().front().getType();
315     auto llvmIndexType = typeConverter.convertType(indexType);
316     if (!llvmIndexType)
317       return failure();
318     Value zero = rewriter.create<LLVM::ConstantOp>(
319         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
320     indices.insert(indices.begin(), zero);
321     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
322                                              indices);
323     return success();
324   }
325 };
326 
327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
328 public:
329   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
330 
331   LogicalResult
332   matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
333                   ConversionPatternRewriter &rewriter) const override {
334     auto dstType = typeConverter.convertType(op.pointer().getType());
335     if (!dstType)
336       return failure();
337     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
338     return success();
339   }
340 };
341 
342 class BitFieldInsertPattern
343     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
344 public:
345   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
346 
347   LogicalResult
348   matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
349                   ConversionPatternRewriter &rewriter) const override {
350     auto srcType = op.getType();
351     auto dstType = typeConverter.convertType(srcType);
352     if (!dstType)
353       return failure();
354     Location loc = op.getLoc();
355 
356     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
357     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
358                                         typeConverter, rewriter);
359     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
360                                        typeConverter, rewriter);
361 
362     // Create a mask with bits set outside [Offset, Offset + Count - 1].
363     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
364     Value maskShiftedByCount =
365         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
366     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
367                                                  maskShiftedByCount, minusOne);
368     Value maskShiftedByCountAndOffset =
369         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
370     Value mask = rewriter.create<LLVM::XOrOp>(
371         loc, dstType, maskShiftedByCountAndOffset, minusOne);
372 
373     // Extract unchanged bits from the `Base`  that are outside of
374     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
375     Value baseAndMask =
376         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
377     Value insertShiftedByOffset =
378         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
379     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
380                                             insertShiftedByOffset);
381     return success();
382   }
383 };
384 
385 /// Converts SPIR-V ConstantOp with scalar or vector type.
386 class ConstantScalarAndVectorPattern
387     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
388 public:
389   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
390 
391   LogicalResult
392   matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
393                   ConversionPatternRewriter &rewriter) const override {
394     auto srcType = constOp.getType();
395     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
396       return failure();
397 
398     auto dstType = typeConverter.convertType(srcType);
399     if (!dstType)
400       return failure();
401 
402     // SPIR-V constant can be a signed/unsigned integer, which has to be
403     // casted to signless integer when converting to LLVM dialect. Removing the
404     // sign bit may have unexpected behaviour. However, it is better to handle
405     // it case-by-case, given that the purpose of the conversion is not to
406     // cover all possible corner cases.
407     if (isSignedIntegerOrVector(srcType) ||
408         isUnsignedIntegerOrVector(srcType)) {
409       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
410 
411       if (srcType.isa<VectorType>()) {
412         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
413         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
414             constOp, dstType,
415             dstElementsAttr.mapValues(
416                 signlessType, [&](const APInt &value) { return value; }));
417         return success();
418       }
419       auto srcAttr = constOp.value().cast<IntegerAttr>();
420       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
421       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
422       return success();
423     }
424     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
425         constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
426     return success();
427   }
428 };
429 
430 class BitFieldSExtractPattern
431     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
432 public:
433   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
434 
435   LogicalResult
436   matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
437                   ConversionPatternRewriter &rewriter) const override {
438     auto srcType = op.getType();
439     auto dstType = typeConverter.convertType(srcType);
440     if (!dstType)
441       return failure();
442     Location loc = op.getLoc();
443 
444     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
445     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
446                                         typeConverter, rewriter);
447     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
448                                        typeConverter, rewriter);
449 
450     // Create a constant that holds the size of the `Base`.
451     IntegerType integerType;
452     if (auto vecType = srcType.dyn_cast<VectorType>())
453       integerType = vecType.getElementType().cast<IntegerType>();
454     else
455       integerType = srcType.cast<IntegerType>();
456 
457     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
458     Value size =
459         srcType.isa<VectorType>()
460             ? rewriter.create<LLVM::ConstantOp>(
461                   loc, dstType,
462                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
463             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
464 
465     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
466     // at Offset + Count - 1 is the most significant bit now.
467     Value countPlusOffset =
468         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
469     Value amountToShiftLeft =
470         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
471     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
472         loc, dstType, op.base(), amountToShiftLeft);
473 
474     // Shift the result right, filling the bits with the sign bit.
475     Value amountToShiftRight =
476         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
477     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
478                                               amountToShiftRight);
479     return success();
480   }
481 };
482 
483 class BitFieldUExtractPattern
484     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
485 public:
486   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
487 
488   LogicalResult
489   matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
490                   ConversionPatternRewriter &rewriter) const override {
491     auto srcType = op.getType();
492     auto dstType = typeConverter.convertType(srcType);
493     if (!dstType)
494       return failure();
495     Location loc = op.getLoc();
496 
497     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
498     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
499                                         typeConverter, rewriter);
500     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
501                                        typeConverter, rewriter);
502 
503     // Create a mask with bits set at [0, Count - 1].
504     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
505     Value maskShiftedByCount =
506         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
507     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
508                                               minusOne);
509 
510     // Shift `Base` by `Offset` and apply the mask on it.
511     Value shiftedBase =
512         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
513     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
514     return success();
515   }
516 };
517 
518 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
519 public:
520   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
521 
522   LogicalResult
523   matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
524                   ConversionPatternRewriter &rewriter) const override {
525     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
526                                             branchOp.getTarget());
527     return success();
528   }
529 };
530 
531 class BranchConditionalConversionPattern
532     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
533 public:
534   using SPIRVToLLVMConversion<
535       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
536 
537   LogicalResult
538   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     // If branch weights exist, map them to 32-bit integer vector.
541     ElementsAttr branchWeights = nullptr;
542     if (auto weights = op.branch_weights()) {
543       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
544       branchWeights =
545           DenseElementsAttr::get(weightType, weights.getValue().getValue());
546     }
547 
548     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
549         op, op.condition(), op.getTrueBlockArguments(),
550         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
551         op.getFalseBlock());
552     return success();
553   }
554 };
555 
556 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
557 /// is an aggregate type (struct or array). Otherwise, converts to
558 /// `llvm.extractelement` that operates on vectors.
559 class CompositeExtractPattern
560     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
561 public:
562   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
563 
564   LogicalResult
565   matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
566                   ConversionPatternRewriter &rewriter) const override {
567     auto dstType = this->typeConverter.convertType(op.getType());
568     if (!dstType)
569       return failure();
570 
571     Type containerType = op.composite().getType();
572     if (containerType.isa<VectorType>()) {
573       Location loc = op.getLoc();
574       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
575       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
576       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
577           op, dstType, adaptor.composite(), index);
578       return success();
579     }
580     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
581         op, dstType, adaptor.composite(), op.indices());
582     return success();
583   }
584 };
585 
586 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
587 /// is an aggregate type (struct or array). Otherwise, converts to
588 /// `llvm.insertelement` that operates on vectors.
589 class CompositeInsertPattern
590     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
591 public:
592   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
593 
594   LogicalResult
595   matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
596                   ConversionPatternRewriter &rewriter) const override {
597     auto dstType = this->typeConverter.convertType(op.getType());
598     if (!dstType)
599       return failure();
600 
601     Type containerType = op.composite().getType();
602     if (containerType.isa<VectorType>()) {
603       Location loc = op.getLoc();
604       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
605       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
606       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
607           op, dstType, adaptor.composite(), adaptor.object(), index);
608       return success();
609     }
610     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
611         op, dstType, adaptor.composite(), adaptor.object(), op.indices());
612     return success();
613   }
614 };
615 
616 /// Converts SPIR-V operations that have straightforward LLVM equivalent
617 /// into LLVM dialect operations.
618 template <typename SPIRVOp, typename LLVMOp>
619 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
620 public:
621   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
622 
623   LogicalResult
624   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
625                   ConversionPatternRewriter &rewriter) const override {
626     auto dstType = this->typeConverter.convertType(operation.getType());
627     if (!dstType)
628       return failure();
629     rewriter.template replaceOpWithNewOp<LLVMOp>(
630         operation, dstType, adaptor.getOperands(), operation->getAttrs());
631     return success();
632   }
633 };
634 
635 /// Converts `spv.ExecutionMode` into a global struct constant that holds
636 /// execution mode information.
637 class ExecutionModePattern
638     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
639 public:
640   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
641 
642   LogicalResult
643   matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
644                   ConversionPatternRewriter &rewriter) const override {
645     // First, create the global struct's name that would be associated with
646     // this entry point's execution mode. We set it to be:
647     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
648     ModuleOp module = op->getParentOfType<ModuleOp>();
649     IntegerAttr executionModeAttr = op.execution_modeAttr();
650     std::string moduleName;
651     if (module.getName().hasValue())
652       moduleName = "_" + module.getName().getValue().str();
653     else
654       moduleName = "";
655     std::string executionModeInfoName =
656         llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
657                       op.fn().str(), executionModeAttr.getValue());
658 
659     MLIRContext *context = rewriter.getContext();
660     OpBuilder::InsertionGuard guard(rewriter);
661     rewriter.setInsertionPointToStart(module.getBody());
662 
663     // Create a struct type, corresponding to the C struct below.
664     // struct {
665     //   int32_t executionMode;
666     //   int32_t values[];          // optional values
667     // };
668     auto llvmI32Type = IntegerType::get(context, 32);
669     SmallVector<Type, 2> fields;
670     fields.push_back(llvmI32Type);
671     ArrayAttr values = op.values();
672     if (!values.empty()) {
673       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
674       fields.push_back(arrayType);
675     }
676     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
677 
678     // Create `llvm.mlir.global` with initializer region containing one block.
679     auto global = rewriter.create<LLVM::GlobalOp>(
680         UnknownLoc::get(context), structType, /*isConstant=*/true,
681         LLVM::Linkage::External, executionModeInfoName, Attribute(),
682         /*alignment=*/0);
683     Location loc = global.getLoc();
684     Region &region = global.getInitializerRegion();
685     Block *block = rewriter.createBlock(&region);
686 
687     // Initialize the struct and set the execution mode value.
688     rewriter.setInsertionPoint(block, block->begin());
689     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
690     Value executionMode =
691         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
692     structValue = rewriter.create<LLVM::InsertValueOp>(
693         loc, structType, structValue, executionMode,
694         ArrayAttr::get(context,
695                        {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
696 
697     // Insert extra operands if they exist into execution mode info struct.
698     for (unsigned i = 0, e = values.size(); i < e; ++i) {
699       auto attr = values.getValue()[i];
700       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
701       structValue = rewriter.create<LLVM::InsertValueOp>(
702           loc, structType, structValue, entry,
703           ArrayAttr::get(context,
704                          {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
705                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
706     }
707     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
708     rewriter.eraseOp(op);
709     return success();
710   }
711 };
712 
713 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global
714 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
715 /// This difference is handled by `spv.mlir.addressof` and
716 /// `llvm.mlir.addressof`ops that both return a pointer.
717 class GlobalVariablePattern
718     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
719 public:
720   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
721 
722   LogicalResult
723   matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
724                   ConversionPatternRewriter &rewriter) const override {
725     // Currently, there is no support of initialization with a constant value in
726     // SPIR-V dialect. Specialization constants are not considered as well.
727     if (op.initializer())
728       return failure();
729 
730     auto srcType = op.type().cast<spirv::PointerType>();
731     auto dstType = typeConverter.convertType(srcType.getPointeeType());
732     if (!dstType)
733       return failure();
734 
735     // Limit conversion to the current invocation only or `StorageBuffer`
736     // required by SPIR-V runner.
737     // This is okay because multiple invocations are not supported yet.
738     auto storageClass = srcType.getStorageClass();
739     switch (storageClass) {
740     case spirv::StorageClass::Input:
741     case spirv::StorageClass::Private:
742     case spirv::StorageClass::Output:
743     case spirv::StorageClass::StorageBuffer:
744     case spirv::StorageClass::UniformConstant:
745       break;
746     default:
747       return failure();
748     }
749 
750     // LLVM dialect spec: "If the global value is a constant, storing into it is
751     // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
752     // storage class that is read-only.
753     bool isConstant = (storageClass == spirv::StorageClass::Input) ||
754                       (storageClass == spirv::StorageClass::UniformConstant);
755     // SPIR-V spec: "By default, functions and global variables are private to a
756     // module and cannot be accessed by other modules. However, a module may be
757     // written to export or import functions and global (module scope)
758     // variables.". Therefore, map 'Private' storage class to private linkage,
759     // 'Input' and 'Output' to external linkage.
760     auto linkage = storageClass == spirv::StorageClass::Private
761                        ? LLVM::Linkage::Private
762                        : LLVM::Linkage::External;
763     auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
764         op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
765         /*alignment=*/0);
766 
767     // Attach location attribute if applicable
768     if (op.locationAttr())
769       newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
770 
771     return success();
772   }
773 };
774 
775 /// Converts SPIR-V cast ops that do not have straightforward LLVM
776 /// equivalent in LLVM dialect.
777 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
778 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
779 public:
780   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
781 
782   LogicalResult
783   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
784                   ConversionPatternRewriter &rewriter) const override {
785 
786     Type fromType = operation.operand().getType();
787     Type toType = operation.getType();
788 
789     auto dstType = this->typeConverter.convertType(toType);
790     if (!dstType)
791       return failure();
792 
793     if (getBitWidth(fromType) < getBitWidth(toType)) {
794       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
795                                                       adaptor.getOperands());
796       return success();
797     }
798     if (getBitWidth(fromType) > getBitWidth(toType)) {
799       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
800                                                         adaptor.getOperands());
801       return success();
802     }
803     return failure();
804   }
805 };
806 
807 class FunctionCallPattern
808     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
809 public:
810   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
811 
812   LogicalResult
813   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
814                   ConversionPatternRewriter &rewriter) const override {
815     if (callOp.getNumResults() == 0) {
816       rewriter.replaceOpWithNewOp<LLVM::CallOp>(
817           callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
818       return success();
819     }
820 
821     // Function returns a single result.
822     auto dstType = typeConverter.convertType(callOp.getType(0));
823     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
824         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
825     return success();
826   }
827 };
828 
829 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
830 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
831 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
832 public:
833   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
834 
835   LogicalResult
836   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
837                   ConversionPatternRewriter &rewriter) const override {
838 
839     auto dstType = this->typeConverter.convertType(operation.getType());
840     if (!dstType)
841       return failure();
842 
843     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
844         operation, dstType, predicate, operation.operand1(),
845         operation.operand2());
846     return success();
847   }
848 };
849 
850 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
851 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
852 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
853 public:
854   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
855 
856   LogicalResult
857   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
858                   ConversionPatternRewriter &rewriter) const override {
859 
860     auto dstType = this->typeConverter.convertType(operation.getType());
861     if (!dstType)
862       return failure();
863 
864     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
865         operation, dstType, predicate, operation.operand1(),
866         operation.operand2());
867     return success();
868   }
869 };
870 
871 class InverseSqrtPattern
872     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
873 public:
874   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
875 
876   LogicalResult
877   matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor,
878                   ConversionPatternRewriter &rewriter) const override {
879     auto srcType = op.getType();
880     auto dstType = typeConverter.convertType(srcType);
881     if (!dstType)
882       return failure();
883 
884     Location loc = op.getLoc();
885     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
886     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
887     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
888     return success();
889   }
890 };
891 
892 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
893 template <typename SPIRVOp>
894 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
895 public:
896   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
897 
898   LogicalResult
899   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
900                   ConversionPatternRewriter &rewriter) const override {
901     if (!op.memory_access().hasValue()) {
902       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
903                                     this->typeConverter, /*alignment=*/0,
904                                     /*isVolatile=*/false,
905                                     /*isNonTemporal=*/false);
906     }
907     auto memoryAccess = op.memory_access().getValue();
908     switch (memoryAccess) {
909     case spirv::MemoryAccess::Aligned:
910     case spirv::MemoryAccess::None:
911     case spirv::MemoryAccess::Nontemporal:
912     case spirv::MemoryAccess::Volatile: {
913       unsigned alignment =
914           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
915       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
916       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
917       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
918                                     this->typeConverter, alignment, isVolatile,
919                                     isNonTemporal);
920     }
921     default:
922       // There is no support of other memory access attributes.
923       return failure();
924     }
925   }
926 };
927 
928 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
929 template <typename SPIRVOp>
930 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
931 public:
932   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
933 
934   LogicalResult
935   matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
936                   ConversionPatternRewriter &rewriter) const override {
937     auto srcType = notOp.getType();
938     auto dstType = this->typeConverter.convertType(srcType);
939     if (!dstType)
940       return failure();
941 
942     Location loc = notOp.getLoc();
943     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
944     auto mask = srcType.template isa<VectorType>()
945                     ? rewriter.create<LLVM::ConstantOp>(
946                           loc, dstType,
947                           SplatElementsAttr::get(
948                               srcType.template cast<VectorType>(), minusOne))
949                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
950     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
951                                                       notOp.operand(), mask);
952     return success();
953   }
954 };
955 
956 /// A template pattern that erases the given `SPIRVOp`.
957 template <typename SPIRVOp>
958 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
959 public:
960   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
961 
962   LogicalResult
963   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
964                   ConversionPatternRewriter &rewriter) const override {
965     rewriter.eraseOp(op);
966     return success();
967   }
968 };
969 
970 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
971 public:
972   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
973 
974   LogicalResult
975   matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
976                   ConversionPatternRewriter &rewriter) const override {
977     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
978                                                 ArrayRef<Value>());
979     return success();
980   }
981 };
982 
983 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
984 public:
985   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
986 
987   LogicalResult
988   matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
989                   ConversionPatternRewriter &rewriter) const override {
990     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
991                                                 adaptor.getOperands());
992     return success();
993   }
994 };
995 
996 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should
997 /// be reachable for conversion to succeed. The structure of the loop in LLVM
998 /// dialect will be the following:
999 ///
1000 ///      +------------------------------------+
1001 ///      | <code before spv.mlir.loop>        |
1002 ///      | llvm.br ^header                    |
1003 ///      +------------------------------------+
1004 ///                           |
1005 ///   +----------------+      |
1006 ///   |                |      |
1007 ///   |                V      V
1008 ///   |  +------------------------------------+
1009 ///   |  | ^header:                           |
1010 ///   |  |   <header code>                    |
1011 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1012 ///   |  +------------------------------------+
1013 ///   |                    |
1014 ///   |                    |----------------------+
1015 ///   |                    |                      |
1016 ///   |                    V                      |
1017 ///   |  +------------------------------------+   |
1018 ///   |  | ^body:                             |   |
1019 ///   |  |   <body code>                      |   |
1020 ///   |  |   llvm.br ^continue                |   |
1021 ///   |  +------------------------------------+   |
1022 ///   |                    |                      |
1023 ///   |                    V                      |
1024 ///   |  +------------------------------------+   |
1025 ///   |  | ^continue:                         |   |
1026 ///   |  |   <continue code>                  |   |
1027 ///   |  |   llvm.br ^header                  |   |
1028 ///   |  +------------------------------------+   |
1029 ///   |               |                           |
1030 ///   +---------------+    +----------------------+
1031 ///                        |
1032 ///                        V
1033 ///      +------------------------------------+
1034 ///      | ^exit:                             |
1035 ///      |   llvm.br ^remaining               |
1036 ///      +------------------------------------+
1037 ///                        |
1038 ///                        V
1039 ///      +------------------------------------+
1040 ///      | ^remaining:                        |
1041 ///      |   <code after spv.mlir.loop>       |
1042 ///      +------------------------------------+
1043 ///
1044 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1045 public:
1046   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1047 
1048   LogicalResult
1049   matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1050                   ConversionPatternRewriter &rewriter) const override {
1051     // There is no support of loop control at the moment.
1052     if (loopOp.loop_control() != spirv::LoopControl::None)
1053       return failure();
1054 
1055     Location loc = loopOp.getLoc();
1056 
1057     // Split the current block after `spv.mlir.loop`. The remaining ops will be
1058     // used in `endBlock`.
1059     Block *currentBlock = rewriter.getBlock();
1060     auto position = Block::iterator(loopOp);
1061     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1062 
1063     // Remove entry block and create a branch in the current block going to the
1064     // header block.
1065     Block *entryBlock = loopOp.getEntryBlock();
1066     assert(entryBlock->getOperations().size() == 1);
1067     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1068     if (!brOp)
1069       return failure();
1070     Block *headerBlock = loopOp.getHeaderBlock();
1071     rewriter.setInsertionPointToEnd(currentBlock);
1072     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1073     rewriter.eraseBlock(entryBlock);
1074 
1075     // Branch from merge block to end block.
1076     Block *mergeBlock = loopOp.getMergeBlock();
1077     Operation *terminator = mergeBlock->getTerminator();
1078     ValueRange terminatorOperands = terminator->getOperands();
1079     rewriter.setInsertionPointToEnd(mergeBlock);
1080     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1081 
1082     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1083     rewriter.replaceOp(loopOp, endBlock->getArguments());
1084     return success();
1085   }
1086 };
1087 
1088 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header
1089 /// block. All blocks within selection should be reachable for conversion to
1090 /// succeed.
1091 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1092 public:
1093   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1094 
1095   LogicalResult
1096   matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1097                   ConversionPatternRewriter &rewriter) const override {
1098     // There is no support for `Flatten` or `DontFlatten` selection control at
1099     // the moment. This are just compiler hints and can be performed during the
1100     // optimization passes.
1101     if (op.selection_control() != spirv::SelectionControl::None)
1102       return failure();
1103 
1104     // `spv.mlir.selection` should have at least two blocks: one selection
1105     // header block and one merge block. If no blocks are present, or control
1106     // flow branches straight to merge block (two blocks are present), the op is
1107     // redundant and it is erased.
1108     if (op.body().getBlocks().size() <= 2) {
1109       rewriter.eraseOp(op);
1110       return success();
1111     }
1112 
1113     Location loc = op.getLoc();
1114 
1115     // Split the current block after `spv.mlir.selection`. The remaining ops
1116     // will be used in `continueBlock`.
1117     auto *currentBlock = rewriter.getInsertionBlock();
1118     rewriter.setInsertionPointAfter(op);
1119     auto position = rewriter.getInsertionPoint();
1120     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1121 
1122     // Extract conditional branch information from the header block. By SPIR-V
1123     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1124     // op. Note that `spv.Switch op` is not supported at the moment in the
1125     // SPIR-V dialect. Remove this block when finished.
1126     auto *headerBlock = op.getHeaderBlock();
1127     assert(headerBlock->getOperations().size() == 1);
1128     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1129         headerBlock->getOperations().front());
1130     if (!condBrOp)
1131       return failure();
1132     rewriter.eraseBlock(headerBlock);
1133 
1134     // Branch from merge block to continue block.
1135     auto *mergeBlock = op.getMergeBlock();
1136     Operation *terminator = mergeBlock->getTerminator();
1137     ValueRange terminatorOperands = terminator->getOperands();
1138     rewriter.setInsertionPointToEnd(mergeBlock);
1139     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1140 
1141     // Link current block to `true` and `false` blocks within the selection.
1142     Block *trueBlock = condBrOp.getTrueBlock();
1143     Block *falseBlock = condBrOp.getFalseBlock();
1144     rewriter.setInsertionPointToEnd(currentBlock);
1145     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1146                                     condBrOp.trueTargetOperands(), falseBlock,
1147                                     condBrOp.falseTargetOperands());
1148 
1149     rewriter.inlineRegionBefore(op.body(), continueBlock);
1150     rewriter.replaceOp(op, continueBlock->getArguments());
1151     return success();
1152   }
1153 };
1154 
1155 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1156 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1157 /// `Shift` is zero or sign extended to match this specification. Cases when
1158 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1159 template <typename SPIRVOp, typename LLVMOp>
1160 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1161 public:
1162   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1163 
1164   LogicalResult
1165   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1166                   ConversionPatternRewriter &rewriter) const override {
1167 
1168     auto dstType = this->typeConverter.convertType(operation.getType());
1169     if (!dstType)
1170       return failure();
1171 
1172     Type op1Type = operation.operand1().getType();
1173     Type op2Type = operation.operand2().getType();
1174 
1175     if (op1Type == op2Type) {
1176       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1177                                                    adaptor.getOperands());
1178       return success();
1179     }
1180 
1181     Location loc = operation.getLoc();
1182     Value extended;
1183     if (isUnsignedIntegerOrVector(op2Type)) {
1184       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1185                                                         adaptor.operand2());
1186     } else {
1187       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1188                                                         adaptor.operand2());
1189     }
1190     Value result = rewriter.template create<LLVMOp>(
1191         loc, dstType, adaptor.operand1(), extended);
1192     rewriter.replaceOp(operation, result);
1193     return success();
1194   }
1195 };
1196 
1197 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1198 public:
1199   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1200 
1201   LogicalResult
1202   matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor,
1203                   ConversionPatternRewriter &rewriter) const override {
1204     auto dstType = typeConverter.convertType(tanOp.getType());
1205     if (!dstType)
1206       return failure();
1207 
1208     Location loc = tanOp.getLoc();
1209     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1210     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1211     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1212     return success();
1213   }
1214 };
1215 
1216 /// Convert `spv.Tanh` to
1217 ///
1218 ///   exp(2x) - 1
1219 ///   -----------
1220 ///   exp(2x) + 1
1221 ///
1222 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1223 public:
1224   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1225 
1226   LogicalResult
1227   matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor,
1228                   ConversionPatternRewriter &rewriter) const override {
1229     auto srcType = tanhOp.getType();
1230     auto dstType = typeConverter.convertType(srcType);
1231     if (!dstType)
1232       return failure();
1233 
1234     Location loc = tanhOp.getLoc();
1235     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1236     Value multiplied =
1237         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1238     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1239     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1240     Value numerator =
1241         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1242     Value denominator =
1243         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1244     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1245                                               denominator);
1246     return success();
1247   }
1248 };
1249 
1250 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1251 public:
1252   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1253 
1254   LogicalResult
1255   matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1256                   ConversionPatternRewriter &rewriter) const override {
1257     auto srcType = varOp.getType();
1258     // Initialization is supported for scalars and vectors only.
1259     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1260     auto init = varOp.initializer();
1261     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1262       return failure();
1263 
1264     auto dstType = typeConverter.convertType(srcType);
1265     if (!dstType)
1266       return failure();
1267 
1268     Location loc = varOp.getLoc();
1269     Value size = createI32ConstantOf(loc, rewriter, 1);
1270     if (!init) {
1271       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1272       return success();
1273     }
1274     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1275     rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
1276     rewriter.replaceOp(varOp, allocated);
1277     return success();
1278   }
1279 };
1280 
1281 //===----------------------------------------------------------------------===//
1282 // FuncOp conversion
1283 //===----------------------------------------------------------------------===//
1284 
1285 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1286 public:
1287   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1288 
1289   LogicalResult
1290   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1291                   ConversionPatternRewriter &rewriter) const override {
1292 
1293     // Convert function signature. At the moment LLVMType converter is enough
1294     // for currently supported types.
1295     auto funcType = funcOp.getType();
1296     TypeConverter::SignatureConversion signatureConverter(
1297         funcType.getNumInputs());
1298     auto llvmType = typeConverter.convertFunctionSignature(
1299         funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1300     if (!llvmType)
1301       return failure();
1302 
1303     // Create a new `LLVMFuncOp`
1304     Location loc = funcOp.getLoc();
1305     StringRef name = funcOp.getName();
1306     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1307 
1308     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1309     MLIRContext *context = funcOp.getContext();
1310     switch (funcOp.function_control()) {
1311 #define DISPATCH(functionControl, llvmAttr)                                    \
1312   case functionControl:                                                        \
1313     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1314     break;
1315 
1316       DISPATCH(spirv::FunctionControl::Inline,
1317                StringAttr::get(context, "alwaysinline"));
1318       DISPATCH(spirv::FunctionControl::DontInline,
1319                StringAttr::get(context, "noinline"));
1320       DISPATCH(spirv::FunctionControl::Pure,
1321                StringAttr::get(context, "readonly"));
1322       DISPATCH(spirv::FunctionControl::Const,
1323                StringAttr::get(context, "readnone"));
1324 
1325 #undef DISPATCH
1326 
1327     // Default: if `spirv::FunctionControl::None`, then no attributes are
1328     // needed.
1329     default:
1330       break;
1331     }
1332 
1333     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1334                                 newFuncOp.end());
1335     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1336                                            &signatureConverter))) {
1337       return failure();
1338     }
1339     rewriter.eraseOp(funcOp);
1340     return success();
1341   }
1342 };
1343 
1344 //===----------------------------------------------------------------------===//
1345 // ModuleOp conversion
1346 //===----------------------------------------------------------------------===//
1347 
1348 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1349 public:
1350   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1351 
1352   LogicalResult
1353   matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1354                   ConversionPatternRewriter &rewriter) const override {
1355 
1356     auto newModuleOp =
1357         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1358     rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1359 
1360     // Remove the terminator block that was automatically added by builder
1361     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1362     rewriter.eraseOp(spvModuleOp);
1363     return success();
1364   }
1365 };
1366 
1367 //===----------------------------------------------------------------------===//
1368 // VectorShuffleOp conversion
1369 //===----------------------------------------------------------------------===//
1370 
1371 class VectorShufflePattern
1372     : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1373 public:
1374   using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1375   LogicalResult
1376   matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1377                   ConversionPatternRewriter &rewriter) const override {
1378     Location loc = op.getLoc();
1379     auto components = adaptor.components();
1380     auto vector1 = adaptor.vector1();
1381     auto vector2 = adaptor.vector2();
1382     int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1383     int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1384     if (vector1Size == vector2Size) {
1385       rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
1386                                                          components);
1387       return success();
1388     }
1389 
1390     auto dstType = typeConverter.convertType(op.getType());
1391     auto scalarType = dstType.cast<VectorType>().getElementType();
1392     auto componentsArray = components.getValue();
1393     auto context = rewriter.getContext();
1394     auto llvmI32Type = IntegerType::get(context, 32);
1395     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1396     for (unsigned i = 0; i < componentsArray.size(); i++) {
1397       if (componentsArray[i].isa<IntegerAttr>())
1398         op.emitError("unable to support non-constant component");
1399 
1400       int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1401       if (indexVal == -1)
1402         continue;
1403 
1404       int offsetVal = 0;
1405       Value baseVector = vector1;
1406       if (indexVal >= vector1Size) {
1407         offsetVal = vector1Size;
1408         baseVector = vector2;
1409       }
1410 
1411       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1412           loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1413       Value index = rewriter.create<LLVM::ConstantOp>(
1414           loc, llvmI32Type,
1415           rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1416 
1417       auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1418           loc, scalarType, baseVector, index);
1419       targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1420                                                         extractOp, dstIndex);
1421     }
1422     rewriter.replaceOp(op, targetOp);
1423     return success();
1424   }
1425 };
1426 } // namespace
1427 
1428 //===----------------------------------------------------------------------===//
1429 // Pattern population
1430 //===----------------------------------------------------------------------===//
1431 
1432 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1433   typeConverter.addConversion([&](spirv::ArrayType type) {
1434     return convertArrayType(type, typeConverter);
1435   });
1436   typeConverter.addConversion([&](spirv::PointerType type) {
1437     return convertPointerType(type, typeConverter);
1438   });
1439   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1440     return convertRuntimeArrayType(type, typeConverter);
1441   });
1442   typeConverter.addConversion([&](spirv::StructType type) {
1443     return convertStructType(type, typeConverter);
1444   });
1445 }
1446 
1447 void mlir::populateSPIRVToLLVMConversionPatterns(
1448     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1449   patterns.add<
1450       // Arithmetic ops
1451       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1452       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1453       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1454       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1455       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1456       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1457       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1458       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1459       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1460       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1461       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1462       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1463       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1464 
1465       // Bitwise ops
1466       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1467       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1468       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1469       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1470       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1471       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1472       NotPattern<spirv::NotOp>,
1473 
1474       // Cast ops
1475       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1476       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1477       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1478       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1479       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1480       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1481       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1482       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1483 
1484       // Comparison ops
1485       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1486       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1487       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1488       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1489       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1490       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1491       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1492       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1493       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1494       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1495       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1496                       LLVM::FCmpPredicate::uge>,
1497       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1498       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1499       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1500       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1501       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1502       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1503       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1504       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1505       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1506       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1507       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1508 
1509       // Constant op
1510       ConstantScalarAndVectorPattern,
1511 
1512       // Control Flow ops
1513       BranchConversionPattern, BranchConditionalConversionPattern,
1514       FunctionCallPattern, LoopPattern, SelectionPattern,
1515       ErasePattern<spirv::MergeOp>,
1516 
1517       // Entry points and execution mode are handled separately.
1518       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1519 
1520       // GLSL extended instruction set ops
1521       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1522       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1523       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1524       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1525       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1526       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1527       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1528       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1529       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1530       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1531       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1532       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1533       InverseSqrtPattern, TanPattern, TanhPattern,
1534 
1535       // Logical ops
1536       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1537       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1538       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1539       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1540       NotPattern<spirv::LogicalNotOp>,
1541 
1542       // Memory ops
1543       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1544       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1545       VariablePattern,
1546 
1547       // Miscellaneous ops
1548       CompositeExtractPattern, CompositeInsertPattern,
1549       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1550       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1551       VectorShufflePattern,
1552 
1553       // Shift ops
1554       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1555       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1556       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1557 
1558       // Return ops
1559       ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1560 }
1561 
1562 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1563     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1564   patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1565 }
1566 
1567 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1568     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1569   patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1570 }
1571 
1572 //===----------------------------------------------------------------------===//
1573 // Pre-conversion hooks
1574 //===----------------------------------------------------------------------===//
1575 
1576 /// Hook for descriptor set and binding number encoding.
1577 static constexpr StringRef kBinding = "binding";
1578 static constexpr StringRef kDescriptorSet = "descriptor_set";
1579 void mlir::encodeBindAttribute(ModuleOp module) {
1580   auto spvModules = module.getOps<spirv::ModuleOp>();
1581   for (auto spvModule : spvModules) {
1582     spvModule.walk([&](spirv::GlobalVariableOp op) {
1583       IntegerAttr descriptorSet =
1584           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1585       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1586       // For every global variable in the module, get the ones with descriptor
1587       // set and binding numbers.
1588       if (descriptorSet && binding) {
1589         // Encode these numbers into the variable's symbolic name. If the
1590         // SPIR-V module has a name, add it at the beginning.
1591         auto moduleAndName = spvModule.getName().hasValue()
1592                                  ? spvModule.getName().getValue().str() + "_" +
1593                                        op.sym_name().str()
1594                                  : op.sym_name().str();
1595         std::string name =
1596             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1597                           std::to_string(descriptorSet.getInt()),
1598                           std::to_string(binding.getInt()));
1599         auto nameAttr = StringAttr::get(op->getContext(), name);
1600 
1601         // Replace all symbol uses and set the new symbol name. Finally, remove
1602         // descriptor set and binding attributes.
1603         if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1604           op.emitError("unable to replace all symbol uses for ") << name;
1605         SymbolTable::setSymbolName(op, nameAttr);
1606         op->removeAttr(kDescriptorSet);
1607         op->removeAttr(kBinding);
1608       }
1609     });
1610   }
1611 }
1612