1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "spirv-to-llvm-pattern"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns true if the given type is a signed integer or vector type.
36 static bool isSignedIntegerOrVector(Type type) {
37   if (type.isSignedInteger())
38     return true;
39   if (auto vecType = type.dyn_cast<VectorType>())
40     return vecType.getElementType().isSignedInteger();
41   return false;
42 }
43 
44 /// Returns true if the given type is an unsigned integer or vector type
45 static bool isUnsignedIntegerOrVector(Type type) {
46   if (type.isUnsignedInteger())
47     return true;
48   if (auto vecType = type.dyn_cast<VectorType>())
49     return vecType.getElementType().isUnsignedInteger();
50   return false;
51 }
52 
53 /// Returns the bit width of integer, float or vector of float or integer values
54 static unsigned getBitWidth(Type type) {
55   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
56          "bitwidth is not supported for this type");
57   if (type.isIntOrFloat())
58     return type.getIntOrFloatBitWidth();
59   auto vecType = type.dyn_cast<VectorType>();
60   auto elementType = vecType.getElementType();
61   assert(elementType.isIntOrFloat() &&
62          "only integers and floats have a bitwidth");
63   return elementType.getIntOrFloatBitWidth();
64 }
65 
66 /// Returns the bit width of LLVMType integer or vector.
67 static unsigned getLLVMTypeBitWidth(Type type) {
68   return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
69                                              : type)
70       .cast<IntegerType>()
71       .getWidth();
72 }
73 
74 /// Creates `IntegerAttribute` with all bits set for given type
75 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
76   if (auto vecType = type.dyn_cast<VectorType>()) {
77     auto integerType = vecType.getElementType().cast<IntegerType>();
78     return builder.getIntegerAttr(integerType, -1);
79   }
80   auto integerType = type.cast<IntegerType>();
81   return builder.getIntegerAttr(integerType, -1);
82 }
83 
84 /// Creates `llvm.mlir.constant` with all bits set for the given type.
85 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
86                                       PatternRewriter &rewriter) {
87   if (srcType.isa<VectorType>()) {
88     return rewriter.create<LLVM::ConstantOp>(
89         loc, dstType,
90         SplatElementsAttr::get(srcType.cast<ShapedType>(),
91                                minusOneIntegerAttribute(srcType, rewriter)));
92   }
93   return rewriter.create<LLVM::ConstantOp>(
94       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
95 }
96 
97 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
98 static Value createFPConstant(Location loc, Type srcType, Type dstType,
99                               PatternRewriter &rewriter, double value) {
100   if (auto vecType = srcType.dyn_cast<VectorType>()) {
101     auto floatType = vecType.getElementType().cast<FloatType>();
102     return rewriter.create<LLVM::ConstantOp>(
103         loc, dstType,
104         SplatElementsAttr::get(vecType,
105                                rewriter.getFloatAttr(floatType, value)));
106   }
107   auto floatType = srcType.cast<FloatType>();
108   return rewriter.create<LLVM::ConstantOp>(
109       loc, dstType, rewriter.getFloatAttr(floatType, value));
110 }
111 
112 /// Utility function for bitfield ops:
113 ///   - `BitFieldInsert`
114 ///   - `BitFieldSExtract`
115 ///   - `BitFieldUExtract`
116 /// Truncates or extends the value. If the bitwidth of the value is the same as
117 /// `llvmType` bitwidth, the value remains unchanged.
118 static Value optionallyTruncateOrExtend(Location loc, Value value,
119                                         Type llvmType,
120                                         PatternRewriter &rewriter) {
121   auto srcType = value.getType();
122   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
123   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
124                                ? getLLVMTypeBitWidth(srcType)
125                                : getBitWidth(srcType);
126 
127   if (valueBitWidth < targetBitWidth)
128     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129   // If the bit widths of `Count` and `Offset` are greater than the bit width
130   // of the target type, they are truncated. Truncation is safe since `Count`
131   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132   // both values can be expressed in 8 bits.
133   if (valueBitWidth > targetBitWidth)
134     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135   return value;
136 }
137 
138 /// Broadcasts the value to vector with `numElements` number of elements.
139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140                        LLVMTypeConverter &typeConverter,
141                        ConversionPatternRewriter &rewriter) {
142   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143   auto llvmVectorType = typeConverter.convertType(vectorType);
144   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146   for (unsigned i = 0; i < numElements; ++i) {
147     auto index = rewriter.create<LLVM::ConstantOp>(
148         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149     broadcasted = rewriter.create<LLVM::InsertElementOp>(
150         loc, llvmVectorType, broadcasted, toBroadcast, index);
151   }
152   return broadcasted;
153 }
154 
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
156 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
157                                  LLVMTypeConverter &typeConverter,
158                                  ConversionPatternRewriter &rewriter) {
159   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160     unsigned numElements = vectorType.getNumElements();
161     return broadcast(loc, value, numElements, typeConverter, rewriter);
162   }
163   return value;
164 }
165 
166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 ///  - same number of elements as `Base`
171 ///  - each element has the type that is the same as the type of `Offset` or
172 ///    `Count`
173 ///  - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
176 static Value processCountOrOffset(Location loc, Value value, Type srcType,
177                                   Type dstType, LLVMTypeConverter &converter,
178                                   ConversionPatternRewriter &rewriter) {
179   Value broadcasted =
180       optionallyBroadcast(loc, value, srcType, converter, rewriter);
181   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183 
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
187 convertStructTypeWithOffset(spirv::StructType type,
188                             LLVMTypeConverter &converter) {
189   if (type != VulkanLayoutUtils::decorateType(type))
190     return llvm::None;
191 
192   auto elementsVector = llvm::to_vector<8>(
193       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194         return converter.convertType(elementType);
195       }));
196   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
197                                           /*isPacked=*/false);
198 }
199 
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
201 static Type convertStructTypePacked(spirv::StructType type,
202                                     LLVMTypeConverter &converter) {
203   auto elementsVector = llvm::to_vector<8>(
204       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205         return converter.convertType(elementType);
206       }));
207   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208                                           /*isPacked=*/true);
209 }
210 
211 /// Creates LLVM dialect constant with the given value.
212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
213                                  unsigned value) {
214   return rewriter.create<LLVM::ConstantOp>(
215       loc, IntegerType::get(rewriter.getContext(), 32),
216       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218 
219 /// Utility for `spv.Load` and `spv.Store` conversion.
220 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
221                                             ConversionPatternRewriter &rewriter,
222                                             LLVMTypeConverter &typeConverter,
223                                             unsigned alignment, bool isVolatile,
224                                             bool isNonTemporal) {
225   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226     auto dstType = typeConverter.convertType(loadOp.getType());
227     if (!dstType)
228       return failure();
229     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230         loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
231         isVolatile, isNonTemporal);
232     return success();
233   }
234   auto storeOp = cast<spirv::StoreOp>(op);
235   spirv::StoreOpAdaptor adaptor(operands);
236   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
237                                              adaptor.ptr(), alignment,
238                                              isVolatile, isNonTemporal);
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Type conversion
244 //===----------------------------------------------------------------------===//
245 
246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
248 /// when converting ops that manipulate array types.
249 static Optional<Type> convertArrayType(spirv::ArrayType type,
250                                        TypeConverter &converter) {
251   unsigned stride = type.getArrayStride();
252   Type elementType = type.getElementType();
253   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
254   if (stride != 0 &&
255       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
256     return llvm::None;
257 
258   auto llvmElementType = converter.convertType(elementType);
259   unsigned numElements = type.getNumElements();
260   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
261 }
262 
263 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
264 /// modelled at the moment.
265 static Type convertPointerType(spirv::PointerType type,
266                                TypeConverter &converter) {
267   auto pointeeType = converter.convertType(type.getPointeeType());
268   return LLVM::LLVMPointerType::get(pointeeType);
269 }
270 
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
274 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
275                                               TypeConverter &converter) {
276   if (type.getArrayStride() != 0)
277     return llvm::None;
278   auto elementType = converter.convertType(type.getElementType());
279   return LLVM::LLVMArrayType::get(elementType, 0);
280 }
281 
282 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
283 /// member decorations. Also, only natural offset is supported.
284 static Optional<Type> convertStructType(spirv::StructType type,
285                                         LLVMTypeConverter &converter) {
286   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
287   type.getMemberDecorations(memberDecorations);
288   if (!memberDecorations.empty())
289     return llvm::None;
290   if (type.hasOffset())
291     return convertStructTypeWithOffset(type, converter);
292   return convertStructTypePacked(type, converter);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Operation conversion
297 //===----------------------------------------------------------------------===//
298 
299 namespace {
300 
301 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
302 public:
303   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
304 
305   LogicalResult
306   matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
307                   ConversionPatternRewriter &rewriter) const override {
308     auto dstType = typeConverter.convertType(op.component_ptr().getType());
309     if (!dstType)
310       return failure();
311     // To use GEP we need to add a first 0 index to go through the pointer.
312     auto indices = llvm::to_vector<4>(adaptor.indices());
313     Type indexType = op.indices().front().getType();
314     auto llvmIndexType = typeConverter.convertType(indexType);
315     if (!llvmIndexType)
316       return failure();
317     Value zero = rewriter.create<LLVM::ConstantOp>(
318         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
319     indices.insert(indices.begin(), zero);
320     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
321                                              indices);
322     return success();
323   }
324 };
325 
326 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
327 public:
328   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
329 
330   LogicalResult
331   matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
332                   ConversionPatternRewriter &rewriter) const override {
333     auto dstType = typeConverter.convertType(op.pointer().getType());
334     if (!dstType)
335       return failure();
336     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
337     return success();
338   }
339 };
340 
341 class BitFieldInsertPattern
342     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
343 public:
344   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
345 
346   LogicalResult
347   matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
348                   ConversionPatternRewriter &rewriter) const override {
349     auto srcType = op.getType();
350     auto dstType = typeConverter.convertType(srcType);
351     if (!dstType)
352       return failure();
353     Location loc = op.getLoc();
354 
355     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
356     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
357                                         typeConverter, rewriter);
358     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
359                                        typeConverter, rewriter);
360 
361     // Create a mask with bits set outside [Offset, Offset + Count - 1].
362     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
363     Value maskShiftedByCount =
364         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
365     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
366                                                  maskShiftedByCount, minusOne);
367     Value maskShiftedByCountAndOffset =
368         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
369     Value mask = rewriter.create<LLVM::XOrOp>(
370         loc, dstType, maskShiftedByCountAndOffset, minusOne);
371 
372     // Extract unchanged bits from the `Base`  that are outside of
373     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
374     Value baseAndMask =
375         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
376     Value insertShiftedByOffset =
377         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
378     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
379                                             insertShiftedByOffset);
380     return success();
381   }
382 };
383 
384 /// Converts SPIR-V ConstantOp with scalar or vector type.
385 class ConstantScalarAndVectorPattern
386     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
387 public:
388   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
389 
390   LogicalResult
391   matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     auto srcType = constOp.getType();
394     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
395       return failure();
396 
397     auto dstType = typeConverter.convertType(srcType);
398     if (!dstType)
399       return failure();
400 
401     // SPIR-V constant can be a signed/unsigned integer, which has to be
402     // casted to signless integer when converting to LLVM dialect. Removing the
403     // sign bit may have unexpected behaviour. However, it is better to handle
404     // it case-by-case, given that the purpose of the conversion is not to
405     // cover all possible corner cases.
406     if (isSignedIntegerOrVector(srcType) ||
407         isUnsignedIntegerOrVector(srcType)) {
408       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
409 
410       if (srcType.isa<VectorType>()) {
411         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
412         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
413             constOp, dstType,
414             dstElementsAttr.mapValues(
415                 signlessType, [&](const APInt &value) { return value; }));
416         return success();
417       }
418       auto srcAttr = constOp.value().cast<IntegerAttr>();
419       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
420       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
421       return success();
422     }
423     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
424         constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
425     return success();
426   }
427 };
428 
429 class BitFieldSExtractPattern
430     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
431 public:
432   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
433 
434   LogicalResult
435   matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
436                   ConversionPatternRewriter &rewriter) const override {
437     auto srcType = op.getType();
438     auto dstType = typeConverter.convertType(srcType);
439     if (!dstType)
440       return failure();
441     Location loc = op.getLoc();
442 
443     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
444     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
445                                         typeConverter, rewriter);
446     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
447                                        typeConverter, rewriter);
448 
449     // Create a constant that holds the size of the `Base`.
450     IntegerType integerType;
451     if (auto vecType = srcType.dyn_cast<VectorType>())
452       integerType = vecType.getElementType().cast<IntegerType>();
453     else
454       integerType = srcType.cast<IntegerType>();
455 
456     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
457     Value size =
458         srcType.isa<VectorType>()
459             ? rewriter.create<LLVM::ConstantOp>(
460                   loc, dstType,
461                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
462             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
463 
464     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
465     // at Offset + Count - 1 is the most significant bit now.
466     Value countPlusOffset =
467         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
468     Value amountToShiftLeft =
469         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
470     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
471         loc, dstType, op.base(), amountToShiftLeft);
472 
473     // Shift the result right, filling the bits with the sign bit.
474     Value amountToShiftRight =
475         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
476     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
477                                               amountToShiftRight);
478     return success();
479   }
480 };
481 
482 class BitFieldUExtractPattern
483     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
484 public:
485   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
486 
487   LogicalResult
488   matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
489                   ConversionPatternRewriter &rewriter) const override {
490     auto srcType = op.getType();
491     auto dstType = typeConverter.convertType(srcType);
492     if (!dstType)
493       return failure();
494     Location loc = op.getLoc();
495 
496     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
497     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
498                                         typeConverter, rewriter);
499     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
500                                        typeConverter, rewriter);
501 
502     // Create a mask with bits set at [0, Count - 1].
503     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
504     Value maskShiftedByCount =
505         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
506     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
507                                               minusOne);
508 
509     // Shift `Base` by `Offset` and apply the mask on it.
510     Value shiftedBase =
511         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
512     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
513     return success();
514   }
515 };
516 
517 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
518 public:
519   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
520 
521   LogicalResult
522   matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
523                   ConversionPatternRewriter &rewriter) const override {
524     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
525                                             branchOp.getTarget());
526     return success();
527   }
528 };
529 
530 class BranchConditionalConversionPattern
531     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
532 public:
533   using SPIRVToLLVMConversion<
534       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
535 
536   LogicalResult
537   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
538                   ConversionPatternRewriter &rewriter) const override {
539     // If branch weights exist, map them to 32-bit integer vector.
540     ElementsAttr branchWeights = nullptr;
541     if (auto weights = op.branch_weights()) {
542       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
543       branchWeights =
544           DenseElementsAttr::get(weightType, weights.getValue().getValue());
545     }
546 
547     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
548         op, op.condition(), op.getTrueBlockArguments(),
549         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
550         op.getFalseBlock());
551     return success();
552   }
553 };
554 
555 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
556 /// is an aggregate type (struct or array). Otherwise, converts to
557 /// `llvm.extractelement` that operates on vectors.
558 class CompositeExtractPattern
559     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
560 public:
561   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
562 
563   LogicalResult
564   matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
565                   ConversionPatternRewriter &rewriter) const override {
566     auto dstType = this->typeConverter.convertType(op.getType());
567     if (!dstType)
568       return failure();
569 
570     Type containerType = op.composite().getType();
571     if (containerType.isa<VectorType>()) {
572       Location loc = op.getLoc();
573       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
574       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
575       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
576           op, dstType, adaptor.composite(), index);
577       return success();
578     }
579     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
580         op, dstType, adaptor.composite(), op.indices());
581     return success();
582   }
583 };
584 
585 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
586 /// is an aggregate type (struct or array). Otherwise, converts to
587 /// `llvm.insertelement` that operates on vectors.
588 class CompositeInsertPattern
589     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
590 public:
591   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
592 
593   LogicalResult
594   matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
595                   ConversionPatternRewriter &rewriter) const override {
596     auto dstType = this->typeConverter.convertType(op.getType());
597     if (!dstType)
598       return failure();
599 
600     Type containerType = op.composite().getType();
601     if (containerType.isa<VectorType>()) {
602       Location loc = op.getLoc();
603       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
604       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
605       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
606           op, dstType, adaptor.composite(), adaptor.object(), index);
607       return success();
608     }
609     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
610         op, dstType, adaptor.composite(), adaptor.object(), op.indices());
611     return success();
612   }
613 };
614 
615 /// Converts SPIR-V operations that have straightforward LLVM equivalent
616 /// into LLVM dialect operations.
617 template <typename SPIRVOp, typename LLVMOp>
618 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
619 public:
620   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
621 
622   LogicalResult
623   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
624                   ConversionPatternRewriter &rewriter) const override {
625     auto dstType = this->typeConverter.convertType(operation.getType());
626     if (!dstType)
627       return failure();
628     rewriter.template replaceOpWithNewOp<LLVMOp>(
629         operation, dstType, adaptor.getOperands(), operation->getAttrs());
630     return success();
631   }
632 };
633 
634 /// Converts `spv.ExecutionMode` into a global struct constant that holds
635 /// execution mode information.
636 class ExecutionModePattern
637     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
638 public:
639   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
640 
641   LogicalResult
642   matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
643                   ConversionPatternRewriter &rewriter) const override {
644     // First, create the global struct's name that would be associated with
645     // this entry point's execution mode. We set it to be:
646     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
647     ModuleOp module = op->getParentOfType<ModuleOp>();
648     IntegerAttr executionModeAttr = op.execution_modeAttr();
649     std::string moduleName;
650     if (module.getName().hasValue())
651       moduleName = "_" + module.getName().getValue().str();
652     else
653       moduleName = "";
654     std::string executionModeInfoName =
655         llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
656                       op.fn().str(), executionModeAttr.getValue());
657 
658     MLIRContext *context = rewriter.getContext();
659     OpBuilder::InsertionGuard guard(rewriter);
660     rewriter.setInsertionPointToStart(module.getBody());
661 
662     // Create a struct type, corresponding to the C struct below.
663     // struct {
664     //   int32_t executionMode;
665     //   int32_t values[];          // optional values
666     // };
667     auto llvmI32Type = IntegerType::get(context, 32);
668     SmallVector<Type, 2> fields;
669     fields.push_back(llvmI32Type);
670     ArrayAttr values = op.values();
671     if (!values.empty()) {
672       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
673       fields.push_back(arrayType);
674     }
675     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
676 
677     // Create `llvm.mlir.global` with initializer region containing one block.
678     auto global = rewriter.create<LLVM::GlobalOp>(
679         UnknownLoc::get(context), structType, /*isConstant=*/true,
680         LLVM::Linkage::External, executionModeInfoName, Attribute(),
681         /*alignment=*/0);
682     Location loc = global.getLoc();
683     Region &region = global.getInitializerRegion();
684     Block *block = rewriter.createBlock(&region);
685 
686     // Initialize the struct and set the execution mode value.
687     rewriter.setInsertionPoint(block, block->begin());
688     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
689     Value executionMode =
690         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
691     structValue = rewriter.create<LLVM::InsertValueOp>(
692         loc, structType, structValue, executionMode,
693         ArrayAttr::get(context,
694                        {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
695 
696     // Insert extra operands if they exist into execution mode info struct.
697     for (unsigned i = 0, e = values.size(); i < e; ++i) {
698       auto attr = values.getValue()[i];
699       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
700       structValue = rewriter.create<LLVM::InsertValueOp>(
701           loc, structType, structValue, entry,
702           ArrayAttr::get(context,
703                          {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
704                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
705     }
706     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
707     rewriter.eraseOp(op);
708     return success();
709   }
710 };
711 
712 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global
713 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
714 /// This difference is handled by `spv.mlir.addressof` and
715 /// `llvm.mlir.addressof`ops that both return a pointer.
716 class GlobalVariablePattern
717     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
718 public:
719   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
720 
721   LogicalResult
722   matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
723                   ConversionPatternRewriter &rewriter) const override {
724     // Currently, there is no support of initialization with a constant value in
725     // SPIR-V dialect. Specialization constants are not considered as well.
726     if (op.initializer())
727       return failure();
728 
729     auto srcType = op.type().cast<spirv::PointerType>();
730     auto dstType = typeConverter.convertType(srcType.getPointeeType());
731     if (!dstType)
732       return failure();
733 
734     // Limit conversion to the current invocation only or `StorageBuffer`
735     // required by SPIR-V runner.
736     // This is okay because multiple invocations are not supported yet.
737     auto storageClass = srcType.getStorageClass();
738     switch (storageClass) {
739     case spirv::StorageClass::Input:
740     case spirv::StorageClass::Private:
741     case spirv::StorageClass::Output:
742     case spirv::StorageClass::StorageBuffer:
743     case spirv::StorageClass::UniformConstant:
744       break;
745     default:
746       return failure();
747     }
748 
749     // LLVM dialect spec: "If the global value is a constant, storing into it is
750     // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
751     // storage class that is read-only.
752     bool isConstant = (storageClass == spirv::StorageClass::Input) ||
753                       (storageClass == spirv::StorageClass::UniformConstant);
754     // SPIR-V spec: "By default, functions and global variables are private to a
755     // module and cannot be accessed by other modules. However, a module may be
756     // written to export or import functions and global (module scope)
757     // variables.". Therefore, map 'Private' storage class to private linkage,
758     // 'Input' and 'Output' to external linkage.
759     auto linkage = storageClass == spirv::StorageClass::Private
760                        ? LLVM::Linkage::Private
761                        : LLVM::Linkage::External;
762     auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
763         op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
764         /*alignment=*/0);
765 
766     // Attach location attribute if applicable
767     if (op.locationAttr())
768       newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
769 
770     return success();
771   }
772 };
773 
774 /// Converts SPIR-V cast ops that do not have straightforward LLVM
775 /// equivalent in LLVM dialect.
776 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
777 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
778 public:
779   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
780 
781   LogicalResult
782   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
783                   ConversionPatternRewriter &rewriter) const override {
784 
785     Type fromType = operation.operand().getType();
786     Type toType = operation.getType();
787 
788     auto dstType = this->typeConverter.convertType(toType);
789     if (!dstType)
790       return failure();
791 
792     if (getBitWidth(fromType) < getBitWidth(toType)) {
793       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
794                                                       adaptor.getOperands());
795       return success();
796     }
797     if (getBitWidth(fromType) > getBitWidth(toType)) {
798       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
799                                                         adaptor.getOperands());
800       return success();
801     }
802     return failure();
803   }
804 };
805 
806 class FunctionCallPattern
807     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
808 public:
809   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
810 
811   LogicalResult
812   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
813                   ConversionPatternRewriter &rewriter) const override {
814     if (callOp.getNumResults() == 0) {
815       rewriter.replaceOpWithNewOp<LLVM::CallOp>(
816           callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
817       return success();
818     }
819 
820     // Function returns a single result.
821     auto dstType = typeConverter.convertType(callOp.getType(0));
822     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
823         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
824     return success();
825   }
826 };
827 
828 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
829 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
830 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
831 public:
832   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
833 
834   LogicalResult
835   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
836                   ConversionPatternRewriter &rewriter) const override {
837 
838     auto dstType = this->typeConverter.convertType(operation.getType());
839     if (!dstType)
840       return failure();
841 
842     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
843         operation, dstType, predicate, operation.operand1(),
844         operation.operand2());
845     return success();
846   }
847 };
848 
849 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
850 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
851 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
852 public:
853   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
854 
855   LogicalResult
856   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
857                   ConversionPatternRewriter &rewriter) const override {
858 
859     auto dstType = this->typeConverter.convertType(operation.getType());
860     if (!dstType)
861       return failure();
862 
863     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
864         operation, dstType, predicate, operation.operand1(),
865         operation.operand2());
866     return success();
867   }
868 };
869 
870 class InverseSqrtPattern
871     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
872 public:
873   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
874 
875   LogicalResult
876   matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor,
877                   ConversionPatternRewriter &rewriter) const override {
878     auto srcType = op.getType();
879     auto dstType = typeConverter.convertType(srcType);
880     if (!dstType)
881       return failure();
882 
883     Location loc = op.getLoc();
884     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
885     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
886     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
887     return success();
888   }
889 };
890 
891 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
892 template <typename SPIRVOp>
893 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
894 public:
895   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
896 
897   LogicalResult
898   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
899                   ConversionPatternRewriter &rewriter) const override {
900     if (!op.memory_access().hasValue()) {
901       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
902                                     this->typeConverter, /*alignment=*/0,
903                                     /*isVolatile=*/false,
904                                     /*isNonTemporal=*/false);
905     }
906     auto memoryAccess = op.memory_access().getValue();
907     switch (memoryAccess) {
908     case spirv::MemoryAccess::Aligned:
909     case spirv::MemoryAccess::None:
910     case spirv::MemoryAccess::Nontemporal:
911     case spirv::MemoryAccess::Volatile: {
912       unsigned alignment =
913           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
914       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
915       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
916       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
917                                     this->typeConverter, alignment, isVolatile,
918                                     isNonTemporal);
919     }
920     default:
921       // There is no support of other memory access attributes.
922       return failure();
923     }
924   }
925 };
926 
927 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
928 template <typename SPIRVOp>
929 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
930 public:
931   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
932 
933   LogicalResult
934   matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
935                   ConversionPatternRewriter &rewriter) const override {
936     auto srcType = notOp.getType();
937     auto dstType = this->typeConverter.convertType(srcType);
938     if (!dstType)
939       return failure();
940 
941     Location loc = notOp.getLoc();
942     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
943     auto mask = srcType.template isa<VectorType>()
944                     ? rewriter.create<LLVM::ConstantOp>(
945                           loc, dstType,
946                           SplatElementsAttr::get(
947                               srcType.template cast<VectorType>(), minusOne))
948                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
949     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
950                                                       notOp.operand(), mask);
951     return success();
952   }
953 };
954 
955 /// A template pattern that erases the given `SPIRVOp`.
956 template <typename SPIRVOp>
957 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
958 public:
959   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
960 
961   LogicalResult
962   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
963                   ConversionPatternRewriter &rewriter) const override {
964     rewriter.eraseOp(op);
965     return success();
966   }
967 };
968 
969 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
970 public:
971   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
972 
973   LogicalResult
974   matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
975                   ConversionPatternRewriter &rewriter) const override {
976     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
977                                                 ArrayRef<Value>());
978     return success();
979   }
980 };
981 
982 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
983 public:
984   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
985 
986   LogicalResult
987   matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
988                   ConversionPatternRewriter &rewriter) const override {
989     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
990                                                 adaptor.getOperands());
991     return success();
992   }
993 };
994 
995 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should
996 /// be reachable for conversion to succeed. The structure of the loop in LLVM
997 /// dialect will be the following:
998 ///
999 ///      +------------------------------------+
1000 ///      | <code before spv.mlir.loop>        |
1001 ///      | llvm.br ^header                    |
1002 ///      +------------------------------------+
1003 ///                           |
1004 ///   +----------------+      |
1005 ///   |                |      |
1006 ///   |                V      V
1007 ///   |  +------------------------------------+
1008 ///   |  | ^header:                           |
1009 ///   |  |   <header code>                    |
1010 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1011 ///   |  +------------------------------------+
1012 ///   |                    |
1013 ///   |                    |----------------------+
1014 ///   |                    |                      |
1015 ///   |                    V                      |
1016 ///   |  +------------------------------------+   |
1017 ///   |  | ^body:                             |   |
1018 ///   |  |   <body code>                      |   |
1019 ///   |  |   llvm.br ^continue                |   |
1020 ///   |  +------------------------------------+   |
1021 ///   |                    |                      |
1022 ///   |                    V                      |
1023 ///   |  +------------------------------------+   |
1024 ///   |  | ^continue:                         |   |
1025 ///   |  |   <continue code>                  |   |
1026 ///   |  |   llvm.br ^header                  |   |
1027 ///   |  +------------------------------------+   |
1028 ///   |               |                           |
1029 ///   +---------------+    +----------------------+
1030 ///                        |
1031 ///                        V
1032 ///      +------------------------------------+
1033 ///      | ^exit:                             |
1034 ///      |   llvm.br ^remaining               |
1035 ///      +------------------------------------+
1036 ///                        |
1037 ///                        V
1038 ///      +------------------------------------+
1039 ///      | ^remaining:                        |
1040 ///      |   <code after spv.mlir.loop>       |
1041 ///      +------------------------------------+
1042 ///
1043 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1044 public:
1045   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1046 
1047   LogicalResult
1048   matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1049                   ConversionPatternRewriter &rewriter) const override {
1050     // There is no support of loop control at the moment.
1051     if (loopOp.loop_control() != spirv::LoopControl::None)
1052       return failure();
1053 
1054     Location loc = loopOp.getLoc();
1055 
1056     // Split the current block after `spv.mlir.loop`. The remaining ops will be
1057     // used in `endBlock`.
1058     Block *currentBlock = rewriter.getBlock();
1059     auto position = Block::iterator(loopOp);
1060     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1061 
1062     // Remove entry block and create a branch in the current block going to the
1063     // header block.
1064     Block *entryBlock = loopOp.getEntryBlock();
1065     assert(entryBlock->getOperations().size() == 1);
1066     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1067     if (!brOp)
1068       return failure();
1069     Block *headerBlock = loopOp.getHeaderBlock();
1070     rewriter.setInsertionPointToEnd(currentBlock);
1071     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1072     rewriter.eraseBlock(entryBlock);
1073 
1074     // Branch from merge block to end block.
1075     Block *mergeBlock = loopOp.getMergeBlock();
1076     Operation *terminator = mergeBlock->getTerminator();
1077     ValueRange terminatorOperands = terminator->getOperands();
1078     rewriter.setInsertionPointToEnd(mergeBlock);
1079     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1080 
1081     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1082     rewriter.replaceOp(loopOp, endBlock->getArguments());
1083     return success();
1084   }
1085 };
1086 
1087 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header
1088 /// block. All blocks within selection should be reachable for conversion to
1089 /// succeed.
1090 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1091 public:
1092   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1093 
1094   LogicalResult
1095   matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1096                   ConversionPatternRewriter &rewriter) const override {
1097     // There is no support for `Flatten` or `DontFlatten` selection control at
1098     // the moment. This are just compiler hints and can be performed during the
1099     // optimization passes.
1100     if (op.selection_control() != spirv::SelectionControl::None)
1101       return failure();
1102 
1103     // `spv.mlir.selection` should have at least two blocks: one selection
1104     // header block and one merge block. If no blocks are present, or control
1105     // flow branches straight to merge block (two blocks are present), the op is
1106     // redundant and it is erased.
1107     if (op.body().getBlocks().size() <= 2) {
1108       rewriter.eraseOp(op);
1109       return success();
1110     }
1111 
1112     Location loc = op.getLoc();
1113 
1114     // Split the current block after `spv.mlir.selection`. The remaining ops
1115     // will be used in `continueBlock`.
1116     auto *currentBlock = rewriter.getInsertionBlock();
1117     rewriter.setInsertionPointAfter(op);
1118     auto position = rewriter.getInsertionPoint();
1119     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1120 
1121     // Extract conditional branch information from the header block. By SPIR-V
1122     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1123     // op. Note that `spv.Switch op` is not supported at the moment in the
1124     // SPIR-V dialect. Remove this block when finished.
1125     auto *headerBlock = op.getHeaderBlock();
1126     assert(headerBlock->getOperations().size() == 1);
1127     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1128         headerBlock->getOperations().front());
1129     if (!condBrOp)
1130       return failure();
1131     rewriter.eraseBlock(headerBlock);
1132 
1133     // Branch from merge block to continue block.
1134     auto *mergeBlock = op.getMergeBlock();
1135     Operation *terminator = mergeBlock->getTerminator();
1136     ValueRange terminatorOperands = terminator->getOperands();
1137     rewriter.setInsertionPointToEnd(mergeBlock);
1138     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1139 
1140     // Link current block to `true` and `false` blocks within the selection.
1141     Block *trueBlock = condBrOp.getTrueBlock();
1142     Block *falseBlock = condBrOp.getFalseBlock();
1143     rewriter.setInsertionPointToEnd(currentBlock);
1144     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1145                                     condBrOp.trueTargetOperands(), falseBlock,
1146                                     condBrOp.falseTargetOperands());
1147 
1148     rewriter.inlineRegionBefore(op.body(), continueBlock);
1149     rewriter.replaceOp(op, continueBlock->getArguments());
1150     return success();
1151   }
1152 };
1153 
1154 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1155 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1156 /// `Shift` is zero or sign extended to match this specification. Cases when
1157 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1158 template <typename SPIRVOp, typename LLVMOp>
1159 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1160 public:
1161   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1162 
1163   LogicalResult
1164   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1165                   ConversionPatternRewriter &rewriter) const override {
1166 
1167     auto dstType = this->typeConverter.convertType(operation.getType());
1168     if (!dstType)
1169       return failure();
1170 
1171     Type op1Type = operation.operand1().getType();
1172     Type op2Type = operation.operand2().getType();
1173 
1174     if (op1Type == op2Type) {
1175       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1176                                                    adaptor.getOperands());
1177       return success();
1178     }
1179 
1180     Location loc = operation.getLoc();
1181     Value extended;
1182     if (isUnsignedIntegerOrVector(op2Type)) {
1183       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1184                                                         adaptor.operand2());
1185     } else {
1186       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1187                                                         adaptor.operand2());
1188     }
1189     Value result = rewriter.template create<LLVMOp>(
1190         loc, dstType, adaptor.operand1(), extended);
1191     rewriter.replaceOp(operation, result);
1192     return success();
1193   }
1194 };
1195 
1196 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1197 public:
1198   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1199 
1200   LogicalResult
1201   matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor,
1202                   ConversionPatternRewriter &rewriter) const override {
1203     auto dstType = typeConverter.convertType(tanOp.getType());
1204     if (!dstType)
1205       return failure();
1206 
1207     Location loc = tanOp.getLoc();
1208     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1209     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1210     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1211     return success();
1212   }
1213 };
1214 
1215 /// Convert `spv.Tanh` to
1216 ///
1217 ///   exp(2x) - 1
1218 ///   -----------
1219 ///   exp(2x) + 1
1220 ///
1221 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1222 public:
1223   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1224 
1225   LogicalResult
1226   matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor,
1227                   ConversionPatternRewriter &rewriter) const override {
1228     auto srcType = tanhOp.getType();
1229     auto dstType = typeConverter.convertType(srcType);
1230     if (!dstType)
1231       return failure();
1232 
1233     Location loc = tanhOp.getLoc();
1234     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1235     Value multiplied =
1236         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1237     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1238     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1239     Value numerator =
1240         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1241     Value denominator =
1242         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1243     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1244                                               denominator);
1245     return success();
1246   }
1247 };
1248 
1249 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1250 public:
1251   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1252 
1253   LogicalResult
1254   matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1255                   ConversionPatternRewriter &rewriter) const override {
1256     auto srcType = varOp.getType();
1257     // Initialization is supported for scalars and vectors only.
1258     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1259     auto init = varOp.initializer();
1260     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1261       return failure();
1262 
1263     auto dstType = typeConverter.convertType(srcType);
1264     if (!dstType)
1265       return failure();
1266 
1267     Location loc = varOp.getLoc();
1268     Value size = createI32ConstantOf(loc, rewriter, 1);
1269     if (!init) {
1270       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1271       return success();
1272     }
1273     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1274     rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
1275     rewriter.replaceOp(varOp, allocated);
1276     return success();
1277   }
1278 };
1279 
1280 //===----------------------------------------------------------------------===//
1281 // FuncOp conversion
1282 //===----------------------------------------------------------------------===//
1283 
1284 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1285 public:
1286   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1287 
1288   LogicalResult
1289   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1290                   ConversionPatternRewriter &rewriter) const override {
1291 
1292     // Convert function signature. At the moment LLVMType converter is enough
1293     // for currently supported types.
1294     auto funcType = funcOp.getFunctionType();
1295     TypeConverter::SignatureConversion signatureConverter(
1296         funcType.getNumInputs());
1297     auto llvmType = typeConverter.convertFunctionSignature(
1298         funcType, /*isVariadic=*/false, signatureConverter);
1299     if (!llvmType)
1300       return failure();
1301 
1302     // Create a new `LLVMFuncOp`
1303     Location loc = funcOp.getLoc();
1304     StringRef name = funcOp.getName();
1305     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1306 
1307     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1308     MLIRContext *context = funcOp.getContext();
1309     switch (funcOp.function_control()) {
1310 #define DISPATCH(functionControl, llvmAttr)                                    \
1311   case functionControl:                                                        \
1312     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1313     break;
1314 
1315       DISPATCH(spirv::FunctionControl::Inline,
1316                StringAttr::get(context, "alwaysinline"));
1317       DISPATCH(spirv::FunctionControl::DontInline,
1318                StringAttr::get(context, "noinline"));
1319       DISPATCH(spirv::FunctionControl::Pure,
1320                StringAttr::get(context, "readonly"));
1321       DISPATCH(spirv::FunctionControl::Const,
1322                StringAttr::get(context, "readnone"));
1323 
1324 #undef DISPATCH
1325 
1326     // Default: if `spirv::FunctionControl::None`, then no attributes are
1327     // needed.
1328     default:
1329       break;
1330     }
1331 
1332     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1333                                 newFuncOp.end());
1334     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1335                                            &signatureConverter))) {
1336       return failure();
1337     }
1338     rewriter.eraseOp(funcOp);
1339     return success();
1340   }
1341 };
1342 
1343 //===----------------------------------------------------------------------===//
1344 // ModuleOp conversion
1345 //===----------------------------------------------------------------------===//
1346 
1347 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1348 public:
1349   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1350 
1351   LogicalResult
1352   matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1353                   ConversionPatternRewriter &rewriter) const override {
1354 
1355     auto newModuleOp =
1356         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1357     rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1358 
1359     // Remove the terminator block that was automatically added by builder
1360     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1361     rewriter.eraseOp(spvModuleOp);
1362     return success();
1363   }
1364 };
1365 
1366 //===----------------------------------------------------------------------===//
1367 // VectorShuffleOp conversion
1368 //===----------------------------------------------------------------------===//
1369 
1370 class VectorShufflePattern
1371     : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1372 public:
1373   using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1374   LogicalResult
1375   matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1376                   ConversionPatternRewriter &rewriter) const override {
1377     Location loc = op.getLoc();
1378     auto components = adaptor.components();
1379     auto vector1 = adaptor.vector1();
1380     auto vector2 = adaptor.vector2();
1381     int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1382     int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1383     if (vector1Size == vector2Size) {
1384       rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
1385                                                          components);
1386       return success();
1387     }
1388 
1389     auto dstType = typeConverter.convertType(op.getType());
1390     auto scalarType = dstType.cast<VectorType>().getElementType();
1391     auto componentsArray = components.getValue();
1392     auto *context = rewriter.getContext();
1393     auto llvmI32Type = IntegerType::get(context, 32);
1394     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1395     for (unsigned i = 0; i < componentsArray.size(); i++) {
1396       if (componentsArray[i].isa<IntegerAttr>())
1397         op.emitError("unable to support non-constant component");
1398 
1399       int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1400       if (indexVal == -1)
1401         continue;
1402 
1403       int offsetVal = 0;
1404       Value baseVector = vector1;
1405       if (indexVal >= vector1Size) {
1406         offsetVal = vector1Size;
1407         baseVector = vector2;
1408       }
1409 
1410       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1411           loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1412       Value index = rewriter.create<LLVM::ConstantOp>(
1413           loc, llvmI32Type,
1414           rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1415 
1416       auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1417           loc, scalarType, baseVector, index);
1418       targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1419                                                         extractOp, dstIndex);
1420     }
1421     rewriter.replaceOp(op, targetOp);
1422     return success();
1423   }
1424 };
1425 } // namespace
1426 
1427 //===----------------------------------------------------------------------===//
1428 // Pattern population
1429 //===----------------------------------------------------------------------===//
1430 
1431 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1432   typeConverter.addConversion([&](spirv::ArrayType type) {
1433     return convertArrayType(type, typeConverter);
1434   });
1435   typeConverter.addConversion([&](spirv::PointerType type) {
1436     return convertPointerType(type, typeConverter);
1437   });
1438   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1439     return convertRuntimeArrayType(type, typeConverter);
1440   });
1441   typeConverter.addConversion([&](spirv::StructType type) {
1442     return convertStructType(type, typeConverter);
1443   });
1444 }
1445 
1446 void mlir::populateSPIRVToLLVMConversionPatterns(
1447     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1448   patterns.add<
1449       // Arithmetic ops
1450       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1451       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1452       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1453       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1454       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1455       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1456       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1457       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1458       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1459       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1460       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1461       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1462       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1463 
1464       // Bitwise ops
1465       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1466       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1467       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1468       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1469       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1470       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1471       NotPattern<spirv::NotOp>,
1472 
1473       // Cast ops
1474       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1475       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1476       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1477       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1478       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1479       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1480       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1481       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1482 
1483       // Comparison ops
1484       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1485       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1486       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1487       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1488       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1489       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1490       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1491       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1492       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1493       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1494       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1495                       LLVM::FCmpPredicate::uge>,
1496       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1497       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1498       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1499       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1500       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1501       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1502       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1503       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1504       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1505       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1506       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1507 
1508       // Constant op
1509       ConstantScalarAndVectorPattern,
1510 
1511       // Control Flow ops
1512       BranchConversionPattern, BranchConditionalConversionPattern,
1513       FunctionCallPattern, LoopPattern, SelectionPattern,
1514       ErasePattern<spirv::MergeOp>,
1515 
1516       // Entry points and execution mode are handled separately.
1517       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1518 
1519       // GLSL extended instruction set ops
1520       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1521       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1522       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1523       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1524       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1525       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1526       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1527       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1528       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1529       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1530       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1531       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1532       InverseSqrtPattern, TanPattern, TanhPattern,
1533 
1534       // Logical ops
1535       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1536       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1537       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1538       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1539       NotPattern<spirv::LogicalNotOp>,
1540 
1541       // Memory ops
1542       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1543       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1544       VariablePattern,
1545 
1546       // Miscellaneous ops
1547       CompositeExtractPattern, CompositeInsertPattern,
1548       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1549       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1550       VectorShufflePattern,
1551 
1552       // Shift ops
1553       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1554       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1555       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1556 
1557       // Return ops
1558       ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1559 }
1560 
1561 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1562     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1563   patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1564 }
1565 
1566 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1567     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1568   patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // Pre-conversion hooks
1573 //===----------------------------------------------------------------------===//
1574 
1575 /// Hook for descriptor set and binding number encoding.
1576 static constexpr StringRef kBinding = "binding";
1577 static constexpr StringRef kDescriptorSet = "descriptor_set";
1578 void mlir::encodeBindAttribute(ModuleOp module) {
1579   auto spvModules = module.getOps<spirv::ModuleOp>();
1580   for (auto spvModule : spvModules) {
1581     spvModule.walk([&](spirv::GlobalVariableOp op) {
1582       IntegerAttr descriptorSet =
1583           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1584       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1585       // For every global variable in the module, get the ones with descriptor
1586       // set and binding numbers.
1587       if (descriptorSet && binding) {
1588         // Encode these numbers into the variable's symbolic name. If the
1589         // SPIR-V module has a name, add it at the beginning.
1590         auto moduleAndName = spvModule.getName().hasValue()
1591                                  ? spvModule.getName().getValue().str() + "_" +
1592                                        op.sym_name().str()
1593                                  : op.sym_name().str();
1594         std::string name =
1595             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1596                           std::to_string(descriptorSet.getInt()),
1597                           std::to_string(binding.getInt()));
1598         auto nameAttr = StringAttr::get(op->getContext(), name);
1599 
1600         // Replace all symbol uses and set the new symbol name. Finally, remove
1601         // descriptor set and binding attributes.
1602         if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1603           op.emitError("unable to replace all symbol uses for ") << name;
1604         SymbolTable::setSymbolName(op, nameAttr);
1605         op->removeAttr(kDescriptorSet);
1606         op->removeAttr(kBinding);
1607       }
1608     });
1609   }
1610 }
1611