1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "spirv-to-llvm-pattern"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Returns true if the given type is a signed integer or vector type.
isSignedIntegerOrVector(Type type)36 static bool isSignedIntegerOrVector(Type type) {
37   if (type.isSignedInteger())
38     return true;
39   if (auto vecType = type.dyn_cast<VectorType>())
40     return vecType.getElementType().isSignedInteger();
41   return false;
42 }
43 
44 /// Returns true if the given type is an unsigned integer or vector type
isUnsignedIntegerOrVector(Type type)45 static bool isUnsignedIntegerOrVector(Type type) {
46   if (type.isUnsignedInteger())
47     return true;
48   if (auto vecType = type.dyn_cast<VectorType>())
49     return vecType.getElementType().isUnsignedInteger();
50   return false;
51 }
52 
53 /// Returns the bit width of integer, float or vector of float or integer values
getBitWidth(Type type)54 static unsigned getBitWidth(Type type) {
55   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
56          "bitwidth is not supported for this type");
57   if (type.isIntOrFloat())
58     return type.getIntOrFloatBitWidth();
59   auto vecType = type.dyn_cast<VectorType>();
60   auto elementType = vecType.getElementType();
61   assert(elementType.isIntOrFloat() &&
62          "only integers and floats have a bitwidth");
63   return elementType.getIntOrFloatBitWidth();
64 }
65 
66 /// Returns the bit width of LLVMType integer or vector.
getLLVMTypeBitWidth(Type type)67 static unsigned getLLVMTypeBitWidth(Type type) {
68   return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
69                                              : type)
70       .cast<IntegerType>()
71       .getWidth();
72 }
73 
74 /// Creates `IntegerAttribute` with all bits set for given type
minusOneIntegerAttribute(Type type,Builder builder)75 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
76   if (auto vecType = type.dyn_cast<VectorType>()) {
77     auto integerType = vecType.getElementType().cast<IntegerType>();
78     return builder.getIntegerAttr(integerType, -1);
79   }
80   auto integerType = type.cast<IntegerType>();
81   return builder.getIntegerAttr(integerType, -1);
82 }
83 
84 /// Creates `llvm.mlir.constant` with all bits set for the given type.
createConstantAllBitsSet(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter)85 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
86                                       PatternRewriter &rewriter) {
87   if (srcType.isa<VectorType>()) {
88     return rewriter.create<LLVM::ConstantOp>(
89         loc, dstType,
90         SplatElementsAttr::get(srcType.cast<ShapedType>(),
91                                minusOneIntegerAttribute(srcType, rewriter)));
92   }
93   return rewriter.create<LLVM::ConstantOp>(
94       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
95 }
96 
97 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
createFPConstant(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter,double value)98 static Value createFPConstant(Location loc, Type srcType, Type dstType,
99                               PatternRewriter &rewriter, double value) {
100   if (auto vecType = srcType.dyn_cast<VectorType>()) {
101     auto floatType = vecType.getElementType().cast<FloatType>();
102     return rewriter.create<LLVM::ConstantOp>(
103         loc, dstType,
104         SplatElementsAttr::get(vecType,
105                                rewriter.getFloatAttr(floatType, value)));
106   }
107   auto floatType = srcType.cast<FloatType>();
108   return rewriter.create<LLVM::ConstantOp>(
109       loc, dstType, rewriter.getFloatAttr(floatType, value));
110 }
111 
112 /// Utility function for bitfield ops:
113 ///   - `BitFieldInsert`
114 ///   - `BitFieldSExtract`
115 ///   - `BitFieldUExtract`
116 /// Truncates or extends the value. If the bitwidth of the value is the same as
117 /// `llvmType` bitwidth, the value remains unchanged.
optionallyTruncateOrExtend(Location loc,Value value,Type llvmType,PatternRewriter & rewriter)118 static Value optionallyTruncateOrExtend(Location loc, Value value,
119                                         Type llvmType,
120                                         PatternRewriter &rewriter) {
121   auto srcType = value.getType();
122   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
123   unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
124                                ? getLLVMTypeBitWidth(srcType)
125                                : getBitWidth(srcType);
126 
127   if (valueBitWidth < targetBitWidth)
128     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129   // If the bit widths of `Count` and `Offset` are greater than the bit width
130   // of the target type, they are truncated. Truncation is safe since `Count`
131   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132   // both values can be expressed in 8 bits.
133   if (valueBitWidth > targetBitWidth)
134     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135   return value;
136 }
137 
138 /// Broadcasts the value to vector with `numElements` number of elements.
broadcast(Location loc,Value toBroadcast,unsigned numElements,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140                        LLVMTypeConverter &typeConverter,
141                        ConversionPatternRewriter &rewriter) {
142   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143   auto llvmVectorType = typeConverter.convertType(vectorType);
144   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146   for (unsigned i = 0; i < numElements; ++i) {
147     auto index = rewriter.create<LLVM::ConstantOp>(
148         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149     broadcasted = rewriter.create<LLVM::InsertElementOp>(
150         loc, llvmVectorType, broadcasted, toBroadcast, index);
151   }
152   return broadcasted;
153 }
154 
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
optionallyBroadcast(Location loc,Value value,Type srcType,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)156 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
157                                  LLVMTypeConverter &typeConverter,
158                                  ConversionPatternRewriter &rewriter) {
159   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160     unsigned numElements = vectorType.getNumElements();
161     return broadcast(loc, value, numElements, typeConverter, rewriter);
162   }
163   return value;
164 }
165 
166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 ///  - same number of elements as `Base`
171 ///  - each element has the type that is the same as the type of `Offset` or
172 ///    `Count`
173 ///  - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
processCountOrOffset(Location loc,Value value,Type srcType,Type dstType,LLVMTypeConverter & converter,ConversionPatternRewriter & rewriter)176 static Value processCountOrOffset(Location loc, Value value, Type srcType,
177                                   Type dstType, LLVMTypeConverter &converter,
178                                   ConversionPatternRewriter &rewriter) {
179   Value broadcasted =
180       optionallyBroadcast(loc, value, srcType, converter, rewriter);
181   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183 
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
convertStructTypeWithOffset(spirv::StructType type,LLVMTypeConverter & converter)187 convertStructTypeWithOffset(spirv::StructType type,
188                             LLVMTypeConverter &converter) {
189   if (type != VulkanLayoutUtils::decorateType(type))
190     return llvm::None;
191 
192   auto elementsVector = llvm::to_vector<8>(
193       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194         return converter.convertType(elementType);
195       }));
196   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
197                                           /*isPacked=*/false);
198 }
199 
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
convertStructTypePacked(spirv::StructType type,LLVMTypeConverter & converter)201 static Type convertStructTypePacked(spirv::StructType type,
202                                     LLVMTypeConverter &converter) {
203   auto elementsVector = llvm::to_vector<8>(
204       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205         return converter.convertType(elementType);
206       }));
207   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
208                                           /*isPacked=*/true);
209 }
210 
211 /// Creates LLVM dialect constant with the given value.
createI32ConstantOf(Location loc,PatternRewriter & rewriter,unsigned value)212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
213                                  unsigned value) {
214   return rewriter.create<LLVM::ConstantOp>(
215       loc, IntegerType::get(rewriter.getContext(), 32),
216       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218 
219 /// Utility for `spv.Load` and `spv.Store` conversion.
replaceWithLoadOrStore(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,unsigned alignment,bool isVolatile,bool isNonTemporal)220 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
221                                             ConversionPatternRewriter &rewriter,
222                                             LLVMTypeConverter &typeConverter,
223                                             unsigned alignment, bool isVolatile,
224                                             bool isNonTemporal) {
225   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226     auto dstType = typeConverter.convertType(loadOp.getType());
227     if (!dstType)
228       return failure();
229     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230         loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
231         isVolatile, isNonTemporal);
232     return success();
233   }
234   auto storeOp = cast<spirv::StoreOp>(op);
235   spirv::StoreOpAdaptor adaptor(operands);
236   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
237                                              adaptor.ptr(), alignment,
238                                              isVolatile, isNonTemporal);
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Type conversion
244 //===----------------------------------------------------------------------===//
245 
246 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
247 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
248 /// when converting ops that manipulate array types.
convertArrayType(spirv::ArrayType type,TypeConverter & converter)249 static Optional<Type> convertArrayType(spirv::ArrayType type,
250                                        TypeConverter &converter) {
251   unsigned stride = type.getArrayStride();
252   Type elementType = type.getElementType();
253   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
254   if (stride != 0 && !(sizeInBytes && *sizeInBytes == stride))
255     return llvm::None;
256 
257   auto llvmElementType = converter.convertType(elementType);
258   unsigned numElements = type.getNumElements();
259   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
260 }
261 
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
convertPointerType(spirv::PointerType type,TypeConverter & converter)264 static Type convertPointerType(spirv::PointerType type,
265                                TypeConverter &converter) {
266   auto pointeeType = converter.convertType(type.getPointeeType());
267   return LLVM::LLVMPointerType::get(pointeeType);
268 }
269 
270 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
271 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
272 /// no modelling of array stride at the moment.
convertRuntimeArrayType(spirv::RuntimeArrayType type,TypeConverter & converter)273 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
274                                               TypeConverter &converter) {
275   if (type.getArrayStride() != 0)
276     return llvm::None;
277   auto elementType = converter.convertType(type.getElementType());
278   return LLVM::LLVMArrayType::get(elementType, 0);
279 }
280 
281 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
282 /// member decorations. Also, only natural offset is supported.
convertStructType(spirv::StructType type,LLVMTypeConverter & converter)283 static Optional<Type> convertStructType(spirv::StructType type,
284                                         LLVMTypeConverter &converter) {
285   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
286   type.getMemberDecorations(memberDecorations);
287   if (!memberDecorations.empty())
288     return llvm::None;
289   if (type.hasOffset())
290     return convertStructTypeWithOffset(type, converter);
291   return convertStructTypePacked(type, converter);
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Operation conversion
296 //===----------------------------------------------------------------------===//
297 
298 namespace {
299 
300 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
301 public:
302   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
303 
304   LogicalResult
matchAndRewrite(spirv::AccessChainOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const305   matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto dstType = typeConverter.convertType(op.component_ptr().getType());
308     if (!dstType)
309       return failure();
310     // To use GEP we need to add a first 0 index to go through the pointer.
311     auto indices = llvm::to_vector<4>(adaptor.indices());
312     Type indexType = op.indices().front().getType();
313     auto llvmIndexType = typeConverter.convertType(indexType);
314     if (!llvmIndexType)
315       return failure();
316     Value zero = rewriter.create<LLVM::ConstantOp>(
317         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
318     indices.insert(indices.begin(), zero);
319     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
320                                              indices);
321     return success();
322   }
323 };
324 
325 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
326 public:
327   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
328 
329   LogicalResult
matchAndRewrite(spirv::AddressOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const330   matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto dstType = typeConverter.convertType(op.pointer().getType());
333     if (!dstType)
334       return failure();
335     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
336     return success();
337   }
338 };
339 
340 class BitFieldInsertPattern
341     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
342 public:
343   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
344 
345   LogicalResult
matchAndRewrite(spirv::BitFieldInsertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const346   matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
347                   ConversionPatternRewriter &rewriter) const override {
348     auto srcType = op.getType();
349     auto dstType = typeConverter.convertType(srcType);
350     if (!dstType)
351       return failure();
352     Location loc = op.getLoc();
353 
354     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
355     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
356                                         typeConverter, rewriter);
357     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
358                                        typeConverter, rewriter);
359 
360     // Create a mask with bits set outside [Offset, Offset + Count - 1].
361     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
362     Value maskShiftedByCount =
363         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
364     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
365                                                  maskShiftedByCount, minusOne);
366     Value maskShiftedByCountAndOffset =
367         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
368     Value mask = rewriter.create<LLVM::XOrOp>(
369         loc, dstType, maskShiftedByCountAndOffset, minusOne);
370 
371     // Extract unchanged bits from the `Base`  that are outside of
372     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
373     Value baseAndMask =
374         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
375     Value insertShiftedByOffset =
376         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
377     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
378                                             insertShiftedByOffset);
379     return success();
380   }
381 };
382 
383 /// Converts SPIR-V ConstantOp with scalar or vector type.
384 class ConstantScalarAndVectorPattern
385     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
386 public:
387   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
388 
389   LogicalResult
matchAndRewrite(spirv::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const390   matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
391                   ConversionPatternRewriter &rewriter) const override {
392     auto srcType = constOp.getType();
393     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
394       return failure();
395 
396     auto dstType = typeConverter.convertType(srcType);
397     if (!dstType)
398       return failure();
399 
400     // SPIR-V constant can be a signed/unsigned integer, which has to be
401     // casted to signless integer when converting to LLVM dialect. Removing the
402     // sign bit may have unexpected behaviour. However, it is better to handle
403     // it case-by-case, given that the purpose of the conversion is not to
404     // cover all possible corner cases.
405     if (isSignedIntegerOrVector(srcType) ||
406         isUnsignedIntegerOrVector(srcType)) {
407       auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
408 
409       if (srcType.isa<VectorType>()) {
410         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
411         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
412             constOp, dstType,
413             dstElementsAttr.mapValues(
414                 signlessType, [&](const APInt &value) { return value; }));
415         return success();
416       }
417       auto srcAttr = constOp.value().cast<IntegerAttr>();
418       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
419       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
420       return success();
421     }
422     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
423         constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
424     return success();
425   }
426 };
427 
428 class BitFieldSExtractPattern
429     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
430 public:
431   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
432 
433   LogicalResult
matchAndRewrite(spirv::BitFieldSExtractOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const434   matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
435                   ConversionPatternRewriter &rewriter) const override {
436     auto srcType = op.getType();
437     auto dstType = typeConverter.convertType(srcType);
438     if (!dstType)
439       return failure();
440     Location loc = op.getLoc();
441 
442     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
443     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
444                                         typeConverter, rewriter);
445     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
446                                        typeConverter, rewriter);
447 
448     // Create a constant that holds the size of the `Base`.
449     IntegerType integerType;
450     if (auto vecType = srcType.dyn_cast<VectorType>())
451       integerType = vecType.getElementType().cast<IntegerType>();
452     else
453       integerType = srcType.cast<IntegerType>();
454 
455     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
456     Value size =
457         srcType.isa<VectorType>()
458             ? rewriter.create<LLVM::ConstantOp>(
459                   loc, dstType,
460                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
461             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
462 
463     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
464     // at Offset + Count - 1 is the most significant bit now.
465     Value countPlusOffset =
466         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
467     Value amountToShiftLeft =
468         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
469     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
470         loc, dstType, op.base(), amountToShiftLeft);
471 
472     // Shift the result right, filling the bits with the sign bit.
473     Value amountToShiftRight =
474         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
475     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
476                                               amountToShiftRight);
477     return success();
478   }
479 };
480 
481 class BitFieldUExtractPattern
482     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
483 public:
484   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
485 
486   LogicalResult
matchAndRewrite(spirv::BitFieldUExtractOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const487   matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto srcType = op.getType();
490     auto dstType = typeConverter.convertType(srcType);
491     if (!dstType)
492       return failure();
493     Location loc = op.getLoc();
494 
495     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
496     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
497                                         typeConverter, rewriter);
498     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
499                                        typeConverter, rewriter);
500 
501     // Create a mask with bits set at [0, Count - 1].
502     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
503     Value maskShiftedByCount =
504         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
505     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
506                                               minusOne);
507 
508     // Shift `Base` by `Offset` and apply the mask on it.
509     Value shiftedBase =
510         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
511     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
512     return success();
513   }
514 };
515 
516 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
517 public:
518   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
519 
520   LogicalResult
matchAndRewrite(spirv::BranchOp branchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const521   matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
522                   ConversionPatternRewriter &rewriter) const override {
523     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
524                                             branchOp.getTarget());
525     return success();
526   }
527 };
528 
529 class BranchConditionalConversionPattern
530     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
531 public:
532   using SPIRVToLLVMConversion<
533       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
534 
535   LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const536   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
537                   ConversionPatternRewriter &rewriter) const override {
538     // If branch weights exist, map them to 32-bit integer vector.
539     ElementsAttr branchWeights = nullptr;
540     if (auto weights = op.branch_weights()) {
541       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
542       branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
543     }
544 
545     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
546         op, op.condition(), op.getTrueBlockArguments(),
547         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
548         op.getFalseBlock());
549     return success();
550   }
551 };
552 
553 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
554 /// is an aggregate type (struct or array). Otherwise, converts to
555 /// `llvm.extractelement` that operates on vectors.
556 class CompositeExtractPattern
557     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
558 public:
559   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
560 
561   LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const562   matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
563                   ConversionPatternRewriter &rewriter) const override {
564     auto dstType = this->typeConverter.convertType(op.getType());
565     if (!dstType)
566       return failure();
567 
568     Type containerType = op.composite().getType();
569     if (containerType.isa<VectorType>()) {
570       Location loc = op.getLoc();
571       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
572       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
573       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
574           op, dstType, adaptor.composite(), index);
575       return success();
576     }
577     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
578         op, dstType, adaptor.composite(), op.indices());
579     return success();
580   }
581 };
582 
583 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
584 /// is an aggregate type (struct or array). Otherwise, converts to
585 /// `llvm.insertelement` that operates on vectors.
586 class CompositeInsertPattern
587     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
588 public:
589   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
590 
591   LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const592   matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
593                   ConversionPatternRewriter &rewriter) const override {
594     auto dstType = this->typeConverter.convertType(op.getType());
595     if (!dstType)
596       return failure();
597 
598     Type containerType = op.composite().getType();
599     if (containerType.isa<VectorType>()) {
600       Location loc = op.getLoc();
601       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
602       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
603       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
604           op, dstType, adaptor.composite(), adaptor.object(), index);
605       return success();
606     }
607     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
608         op, dstType, adaptor.composite(), adaptor.object(), op.indices());
609     return success();
610   }
611 };
612 
613 /// Converts SPIR-V operations that have straightforward LLVM equivalent
614 /// into LLVM dialect operations.
615 template <typename SPIRVOp, typename LLVMOp>
616 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
617 public:
618   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
619 
620   LogicalResult
matchAndRewrite(SPIRVOp operation,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const621   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
622                   ConversionPatternRewriter &rewriter) const override {
623     auto dstType = this->typeConverter.convertType(operation.getType());
624     if (!dstType)
625       return failure();
626     rewriter.template replaceOpWithNewOp<LLVMOp>(
627         operation, dstType, adaptor.getOperands(), operation->getAttrs());
628     return success();
629   }
630 };
631 
632 /// Converts `spv.ExecutionMode` into a global struct constant that holds
633 /// execution mode information.
634 class ExecutionModePattern
635     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
636 public:
637   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
638 
639   LogicalResult
matchAndRewrite(spirv::ExecutionModeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const640   matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
641                   ConversionPatternRewriter &rewriter) const override {
642     // First, create the global struct's name that would be associated with
643     // this entry point's execution mode. We set it to be:
644     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
645     ModuleOp module = op->getParentOfType<ModuleOp>();
646     IntegerAttr executionModeAttr = op.execution_modeAttr();
647     std::string moduleName;
648     if (module.getName().has_value())
649       moduleName = "_" + module.getName().value().str();
650     else
651       moduleName = "";
652     std::string executionModeInfoName =
653         llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
654                       op.fn().str(), executionModeAttr.getValue());
655 
656     MLIRContext *context = rewriter.getContext();
657     OpBuilder::InsertionGuard guard(rewriter);
658     rewriter.setInsertionPointToStart(module.getBody());
659 
660     // Create a struct type, corresponding to the C struct below.
661     // struct {
662     //   int32_t executionMode;
663     //   int32_t values[];          // optional values
664     // };
665     auto llvmI32Type = IntegerType::get(context, 32);
666     SmallVector<Type, 2> fields;
667     fields.push_back(llvmI32Type);
668     ArrayAttr values = op.values();
669     if (!values.empty()) {
670       auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
671       fields.push_back(arrayType);
672     }
673     auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
674 
675     // Create `llvm.mlir.global` with initializer region containing one block.
676     auto global = rewriter.create<LLVM::GlobalOp>(
677         UnknownLoc::get(context), structType, /*isConstant=*/true,
678         LLVM::Linkage::External, executionModeInfoName, Attribute(),
679         /*alignment=*/0);
680     Location loc = global.getLoc();
681     Region &region = global.getInitializerRegion();
682     Block *block = rewriter.createBlock(&region);
683 
684     // Initialize the struct and set the execution mode value.
685     rewriter.setInsertionPoint(block, block->begin());
686     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
687     Value executionMode =
688         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
689     structValue = rewriter.create<LLVM::InsertValueOp>(
690         loc, structType, structValue, executionMode,
691         ArrayAttr::get(context,
692                        {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
693 
694     // Insert extra operands if they exist into execution mode info struct.
695     for (unsigned i = 0, e = values.size(); i < e; ++i) {
696       auto attr = values.getValue()[i];
697       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
698       structValue = rewriter.create<LLVM::InsertValueOp>(
699           loc, structType, structValue, entry,
700           ArrayAttr::get(context,
701                          {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
702                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
703     }
704     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
705     rewriter.eraseOp(op);
706     return success();
707   }
708 };
709 
710 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global
711 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
712 /// This difference is handled by `spv.mlir.addressof` and
713 /// `llvm.mlir.addressof`ops that both return a pointer.
714 class GlobalVariablePattern
715     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
716 public:
717   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
718 
719   LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const720   matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
721                   ConversionPatternRewriter &rewriter) const override {
722     // Currently, there is no support of initialization with a constant value in
723     // SPIR-V dialect. Specialization constants are not considered as well.
724     if (op.initializer())
725       return failure();
726 
727     auto srcType = op.type().cast<spirv::PointerType>();
728     auto dstType = typeConverter.convertType(srcType.getPointeeType());
729     if (!dstType)
730       return failure();
731 
732     // Limit conversion to the current invocation only or `StorageBuffer`
733     // required by SPIR-V runner.
734     // This is okay because multiple invocations are not supported yet.
735     auto storageClass = srcType.getStorageClass();
736     switch (storageClass) {
737     case spirv::StorageClass::Input:
738     case spirv::StorageClass::Private:
739     case spirv::StorageClass::Output:
740     case spirv::StorageClass::StorageBuffer:
741     case spirv::StorageClass::UniformConstant:
742       break;
743     default:
744       return failure();
745     }
746 
747     // LLVM dialect spec: "If the global value is a constant, storing into it is
748     // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
749     // storage class that is read-only.
750     bool isConstant = (storageClass == spirv::StorageClass::Input) ||
751                       (storageClass == spirv::StorageClass::UniformConstant);
752     // SPIR-V spec: "By default, functions and global variables are private to a
753     // module and cannot be accessed by other modules. However, a module may be
754     // written to export or import functions and global (module scope)
755     // variables.". Therefore, map 'Private' storage class to private linkage,
756     // 'Input' and 'Output' to external linkage.
757     auto linkage = storageClass == spirv::StorageClass::Private
758                        ? LLVM::Linkage::Private
759                        : LLVM::Linkage::External;
760     auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
761         op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
762         /*alignment=*/0);
763 
764     // Attach location attribute if applicable
765     if (op.locationAttr())
766       newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
767 
768     return success();
769   }
770 };
771 
772 /// Converts SPIR-V cast ops that do not have straightforward LLVM
773 /// equivalent in LLVM dialect.
774 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
775 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
776 public:
777   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
778 
779   LogicalResult
matchAndRewrite(SPIRVOp operation,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const780   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
781                   ConversionPatternRewriter &rewriter) const override {
782 
783     Type fromType = operation.operand().getType();
784     Type toType = operation.getType();
785 
786     auto dstType = this->typeConverter.convertType(toType);
787     if (!dstType)
788       return failure();
789 
790     if (getBitWidth(fromType) < getBitWidth(toType)) {
791       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
792                                                       adaptor.getOperands());
793       return success();
794     }
795     if (getBitWidth(fromType) > getBitWidth(toType)) {
796       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
797                                                         adaptor.getOperands());
798       return success();
799     }
800     return failure();
801   }
802 };
803 
804 class FunctionCallPattern
805     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
806 public:
807   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
808 
809   LogicalResult
matchAndRewrite(spirv::FunctionCallOp callOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const810   matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
811                   ConversionPatternRewriter &rewriter) const override {
812     if (callOp.getNumResults() == 0) {
813       rewriter.replaceOpWithNewOp<LLVM::CallOp>(
814           callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
815       return success();
816     }
817 
818     // Function returns a single result.
819     auto dstType = typeConverter.convertType(callOp.getType(0));
820     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
821         callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
822     return success();
823   }
824 };
825 
826 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
827 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
828 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
829 public:
830   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
831 
832   LogicalResult
matchAndRewrite(SPIRVOp operation,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const833   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
834                   ConversionPatternRewriter &rewriter) const override {
835 
836     auto dstType = this->typeConverter.convertType(operation.getType());
837     if (!dstType)
838       return failure();
839 
840     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
841         operation, dstType, predicate, operation.operand1(),
842         operation.operand2());
843     return success();
844   }
845 };
846 
847 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
848 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
849 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
850 public:
851   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
852 
853   LogicalResult
matchAndRewrite(SPIRVOp operation,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const854   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
855                   ConversionPatternRewriter &rewriter) const override {
856 
857     auto dstType = this->typeConverter.convertType(operation.getType());
858     if (!dstType)
859       return failure();
860 
861     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
862         operation, dstType, predicate, operation.operand1(),
863         operation.operand2());
864     return success();
865   }
866 };
867 
868 class InverseSqrtPattern
869     : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
870 public:
871   using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
872 
873   LogicalResult
matchAndRewrite(spirv::GLInverseSqrtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const874   matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
875                   ConversionPatternRewriter &rewriter) const override {
876     auto srcType = op.getType();
877     auto dstType = typeConverter.convertType(srcType);
878     if (!dstType)
879       return failure();
880 
881     Location loc = op.getLoc();
882     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
883     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
884     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
885     return success();
886   }
887 };
888 
889 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
890 template <typename SPIRVOp>
891 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
892 public:
893   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
894 
895   LogicalResult
matchAndRewrite(SPIRVOp op,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const896   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
897                   ConversionPatternRewriter &rewriter) const override {
898     if (!op.memory_access()) {
899       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
900                                     this->typeConverter, /*alignment=*/0,
901                                     /*isVolatile=*/false,
902                                     /*isNonTemporal=*/false);
903     }
904     auto memoryAccess = *op.memory_access();
905     switch (memoryAccess) {
906     case spirv::MemoryAccess::Aligned:
907     case spirv::MemoryAccess::None:
908     case spirv::MemoryAccess::Nontemporal:
909     case spirv::MemoryAccess::Volatile: {
910       unsigned alignment =
911           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
912       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
913       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
914       return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
915                                     this->typeConverter, alignment, isVolatile,
916                                     isNonTemporal);
917     }
918     default:
919       // There is no support of other memory access attributes.
920       return failure();
921     }
922   }
923 };
924 
925 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
926 template <typename SPIRVOp>
927 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
928 public:
929   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
930 
931   LogicalResult
matchAndRewrite(SPIRVOp notOp,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const932   matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
933                   ConversionPatternRewriter &rewriter) const override {
934     auto srcType = notOp.getType();
935     auto dstType = this->typeConverter.convertType(srcType);
936     if (!dstType)
937       return failure();
938 
939     Location loc = notOp.getLoc();
940     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
941     auto mask = srcType.template isa<VectorType>()
942                     ? rewriter.create<LLVM::ConstantOp>(
943                           loc, dstType,
944                           SplatElementsAttr::get(
945                               srcType.template cast<VectorType>(), minusOne))
946                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
947     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
948                                                       notOp.operand(), mask);
949     return success();
950   }
951 };
952 
953 /// A template pattern that erases the given `SPIRVOp`.
954 template <typename SPIRVOp>
955 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
956 public:
957   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
958 
959   LogicalResult
matchAndRewrite(SPIRVOp op,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const960   matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
961                   ConversionPatternRewriter &rewriter) const override {
962     rewriter.eraseOp(op);
963     return success();
964   }
965 };
966 
967 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
968 public:
969   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
970 
971   LogicalResult
matchAndRewrite(spirv::ReturnOp returnOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const972   matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
973                   ConversionPatternRewriter &rewriter) const override {
974     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
975                                                 ArrayRef<Value>());
976     return success();
977   }
978 };
979 
980 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
981 public:
982   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
983 
984   LogicalResult
matchAndRewrite(spirv::ReturnValueOp returnValueOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const985   matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
986                   ConversionPatternRewriter &rewriter) const override {
987     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
988                                                 adaptor.getOperands());
989     return success();
990   }
991 };
992 
993 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should
994 /// be reachable for conversion to succeed. The structure of the loop in LLVM
995 /// dialect will be the following:
996 ///
997 ///      +------------------------------------+
998 ///      | <code before spv.mlir.loop>        |
999 ///      | llvm.br ^header                    |
1000 ///      +------------------------------------+
1001 ///                           |
1002 ///   +----------------+      |
1003 ///   |                |      |
1004 ///   |                V      V
1005 ///   |  +------------------------------------+
1006 ///   |  | ^header:                           |
1007 ///   |  |   <header code>                    |
1008 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1009 ///   |  +------------------------------------+
1010 ///   |                    |
1011 ///   |                    |----------------------+
1012 ///   |                    |                      |
1013 ///   |                    V                      |
1014 ///   |  +------------------------------------+   |
1015 ///   |  | ^body:                             |   |
1016 ///   |  |   <body code>                      |   |
1017 ///   |  |   llvm.br ^continue                |   |
1018 ///   |  +------------------------------------+   |
1019 ///   |                    |                      |
1020 ///   |                    V                      |
1021 ///   |  +------------------------------------+   |
1022 ///   |  | ^continue:                         |   |
1023 ///   |  |   <continue code>                  |   |
1024 ///   |  |   llvm.br ^header                  |   |
1025 ///   |  +------------------------------------+   |
1026 ///   |               |                           |
1027 ///   +---------------+    +----------------------+
1028 ///                        |
1029 ///                        V
1030 ///      +------------------------------------+
1031 ///      | ^exit:                             |
1032 ///      |   llvm.br ^remaining               |
1033 ///      +------------------------------------+
1034 ///                        |
1035 ///                        V
1036 ///      +------------------------------------+
1037 ///      | ^remaining:                        |
1038 ///      |   <code after spv.mlir.loop>       |
1039 ///      +------------------------------------+
1040 ///
1041 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1042 public:
1043   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1044 
1045   LogicalResult
matchAndRewrite(spirv::LoopOp loopOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1046   matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1047                   ConversionPatternRewriter &rewriter) const override {
1048     // There is no support of loop control at the moment.
1049     if (loopOp.loop_control() != spirv::LoopControl::None)
1050       return failure();
1051 
1052     Location loc = loopOp.getLoc();
1053 
1054     // Split the current block after `spv.mlir.loop`. The remaining ops will be
1055     // used in `endBlock`.
1056     Block *currentBlock = rewriter.getBlock();
1057     auto position = Block::iterator(loopOp);
1058     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1059 
1060     // Remove entry block and create a branch in the current block going to the
1061     // header block.
1062     Block *entryBlock = loopOp.getEntryBlock();
1063     assert(entryBlock->getOperations().size() == 1);
1064     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1065     if (!brOp)
1066       return failure();
1067     Block *headerBlock = loopOp.getHeaderBlock();
1068     rewriter.setInsertionPointToEnd(currentBlock);
1069     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1070     rewriter.eraseBlock(entryBlock);
1071 
1072     // Branch from merge block to end block.
1073     Block *mergeBlock = loopOp.getMergeBlock();
1074     Operation *terminator = mergeBlock->getTerminator();
1075     ValueRange terminatorOperands = terminator->getOperands();
1076     rewriter.setInsertionPointToEnd(mergeBlock);
1077     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1078 
1079     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1080     rewriter.replaceOp(loopOp, endBlock->getArguments());
1081     return success();
1082   }
1083 };
1084 
1085 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header
1086 /// block. All blocks within selection should be reachable for conversion to
1087 /// succeed.
1088 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1089 public:
1090   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1091 
1092   LogicalResult
matchAndRewrite(spirv::SelectionOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1093   matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1094                   ConversionPatternRewriter &rewriter) const override {
1095     // There is no support for `Flatten` or `DontFlatten` selection control at
1096     // the moment. This are just compiler hints and can be performed during the
1097     // optimization passes.
1098     if (op.selection_control() != spirv::SelectionControl::None)
1099       return failure();
1100 
1101     // `spv.mlir.selection` should have at least two blocks: one selection
1102     // header block and one merge block. If no blocks are present, or control
1103     // flow branches straight to merge block (two blocks are present), the op is
1104     // redundant and it is erased.
1105     if (op.body().getBlocks().size() <= 2) {
1106       rewriter.eraseOp(op);
1107       return success();
1108     }
1109 
1110     Location loc = op.getLoc();
1111 
1112     // Split the current block after `spv.mlir.selection`. The remaining ops
1113     // will be used in `continueBlock`.
1114     auto *currentBlock = rewriter.getInsertionBlock();
1115     rewriter.setInsertionPointAfter(op);
1116     auto position = rewriter.getInsertionPoint();
1117     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1118 
1119     // Extract conditional branch information from the header block. By SPIR-V
1120     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1121     // op. Note that `spv.Switch op` is not supported at the moment in the
1122     // SPIR-V dialect. Remove this block when finished.
1123     auto *headerBlock = op.getHeaderBlock();
1124     assert(headerBlock->getOperations().size() == 1);
1125     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1126         headerBlock->getOperations().front());
1127     if (!condBrOp)
1128       return failure();
1129     rewriter.eraseBlock(headerBlock);
1130 
1131     // Branch from merge block to continue block.
1132     auto *mergeBlock = op.getMergeBlock();
1133     Operation *terminator = mergeBlock->getTerminator();
1134     ValueRange terminatorOperands = terminator->getOperands();
1135     rewriter.setInsertionPointToEnd(mergeBlock);
1136     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1137 
1138     // Link current block to `true` and `false` blocks within the selection.
1139     Block *trueBlock = condBrOp.getTrueBlock();
1140     Block *falseBlock = condBrOp.getFalseBlock();
1141     rewriter.setInsertionPointToEnd(currentBlock);
1142     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1143                                     condBrOp.trueTargetOperands(), falseBlock,
1144                                     condBrOp.falseTargetOperands());
1145 
1146     rewriter.inlineRegionBefore(op.body(), continueBlock);
1147     rewriter.replaceOp(op, continueBlock->getArguments());
1148     return success();
1149   }
1150 };
1151 
1152 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1153 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1154 /// `Shift` is zero or sign extended to match this specification. Cases when
1155 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1156 template <typename SPIRVOp, typename LLVMOp>
1157 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1158 public:
1159   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1160 
1161   LogicalResult
matchAndRewrite(SPIRVOp operation,typename SPIRVOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1162   matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1163                   ConversionPatternRewriter &rewriter) const override {
1164 
1165     auto dstType = this->typeConverter.convertType(operation.getType());
1166     if (!dstType)
1167       return failure();
1168 
1169     Type op1Type = operation.operand1().getType();
1170     Type op2Type = operation.operand2().getType();
1171 
1172     if (op1Type == op2Type) {
1173       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1174                                                    adaptor.getOperands());
1175       return success();
1176     }
1177 
1178     Location loc = operation.getLoc();
1179     Value extended;
1180     if (isUnsignedIntegerOrVector(op2Type)) {
1181       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1182                                                         adaptor.operand2());
1183     } else {
1184       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1185                                                         adaptor.operand2());
1186     }
1187     Value result = rewriter.template create<LLVMOp>(
1188         loc, dstType, adaptor.operand1(), extended);
1189     rewriter.replaceOp(operation, result);
1190     return success();
1191   }
1192 };
1193 
1194 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1195 public:
1196   using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1197 
1198   LogicalResult
matchAndRewrite(spirv::GLTanOp tanOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1199   matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1200                   ConversionPatternRewriter &rewriter) const override {
1201     auto dstType = typeConverter.convertType(tanOp.getType());
1202     if (!dstType)
1203       return failure();
1204 
1205     Location loc = tanOp.getLoc();
1206     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1207     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1208     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1209     return success();
1210   }
1211 };
1212 
1213 /// Convert `spv.Tanh` to
1214 ///
1215 ///   exp(2x) - 1
1216 ///   -----------
1217 ///   exp(2x) + 1
1218 ///
1219 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1220 public:
1221   using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1222 
1223   LogicalResult
matchAndRewrite(spirv::GLTanhOp tanhOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1224   matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1225                   ConversionPatternRewriter &rewriter) const override {
1226     auto srcType = tanhOp.getType();
1227     auto dstType = typeConverter.convertType(srcType);
1228     if (!dstType)
1229       return failure();
1230 
1231     Location loc = tanhOp.getLoc();
1232     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1233     Value multiplied =
1234         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1235     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1236     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1237     Value numerator =
1238         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1239     Value denominator =
1240         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1241     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1242                                               denominator);
1243     return success();
1244   }
1245 };
1246 
1247 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1248 public:
1249   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1250 
1251   LogicalResult
matchAndRewrite(spirv::VariableOp varOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1252   matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1253                   ConversionPatternRewriter &rewriter) const override {
1254     auto srcType = varOp.getType();
1255     // Initialization is supported for scalars and vectors only.
1256     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1257     auto init = varOp.initializer();
1258     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1259       return failure();
1260 
1261     auto dstType = typeConverter.convertType(srcType);
1262     if (!dstType)
1263       return failure();
1264 
1265     Location loc = varOp.getLoc();
1266     Value size = createI32ConstantOf(loc, rewriter, 1);
1267     if (!init) {
1268       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1269       return success();
1270     }
1271     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1272     rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
1273     rewriter.replaceOp(varOp, allocated);
1274     return success();
1275   }
1276 };
1277 
1278 //===----------------------------------------------------------------------===//
1279 // FuncOp conversion
1280 //===----------------------------------------------------------------------===//
1281 
1282 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1283 public:
1284   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1285 
1286   LogicalResult
matchAndRewrite(spirv::FuncOp funcOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1287   matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1288                   ConversionPatternRewriter &rewriter) const override {
1289 
1290     // Convert function signature. At the moment LLVMType converter is enough
1291     // for currently supported types.
1292     auto funcType = funcOp.getFunctionType();
1293     TypeConverter::SignatureConversion signatureConverter(
1294         funcType.getNumInputs());
1295     auto llvmType = typeConverter.convertFunctionSignature(
1296         funcType, /*isVariadic=*/false, signatureConverter);
1297     if (!llvmType)
1298       return failure();
1299 
1300     // Create a new `LLVMFuncOp`
1301     Location loc = funcOp.getLoc();
1302     StringRef name = funcOp.getName();
1303     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1304 
1305     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1306     MLIRContext *context = funcOp.getContext();
1307     switch (funcOp.function_control()) {
1308 #define DISPATCH(functionControl, llvmAttr)                                    \
1309   case functionControl:                                                        \
1310     newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
1311     break;
1312 
1313       DISPATCH(spirv::FunctionControl::Inline,
1314                StringAttr::get(context, "alwaysinline"));
1315       DISPATCH(spirv::FunctionControl::DontInline,
1316                StringAttr::get(context, "noinline"));
1317       DISPATCH(spirv::FunctionControl::Pure,
1318                StringAttr::get(context, "readonly"));
1319       DISPATCH(spirv::FunctionControl::Const,
1320                StringAttr::get(context, "readnone"));
1321 
1322 #undef DISPATCH
1323 
1324     // Default: if `spirv::FunctionControl::None`, then no attributes are
1325     // needed.
1326     default:
1327       break;
1328     }
1329 
1330     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1331                                 newFuncOp.end());
1332     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1333                                            &signatureConverter))) {
1334       return failure();
1335     }
1336     rewriter.eraseOp(funcOp);
1337     return success();
1338   }
1339 };
1340 
1341 //===----------------------------------------------------------------------===//
1342 // ModuleOp conversion
1343 //===----------------------------------------------------------------------===//
1344 
1345 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1346 public:
1347   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1348 
1349   LogicalResult
matchAndRewrite(spirv::ModuleOp spvModuleOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1350   matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1351                   ConversionPatternRewriter &rewriter) const override {
1352 
1353     auto newModuleOp =
1354         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1355     rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1356 
1357     // Remove the terminator block that was automatically added by builder
1358     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1359     rewriter.eraseOp(spvModuleOp);
1360     return success();
1361   }
1362 };
1363 
1364 //===----------------------------------------------------------------------===//
1365 // VectorShuffleOp conversion
1366 //===----------------------------------------------------------------------===//
1367 
1368 class VectorShufflePattern
1369     : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1370 public:
1371   using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1372   LogicalResult
matchAndRewrite(spirv::VectorShuffleOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1373   matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1374                   ConversionPatternRewriter &rewriter) const override {
1375     Location loc = op.getLoc();
1376     auto components = adaptor.components();
1377     auto vector1 = adaptor.vector1();
1378     auto vector2 = adaptor.vector2();
1379     int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1380     int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1381     if (vector1Size == vector2Size) {
1382       rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
1383                                                          components);
1384       return success();
1385     }
1386 
1387     auto dstType = typeConverter.convertType(op.getType());
1388     auto scalarType = dstType.cast<VectorType>().getElementType();
1389     auto componentsArray = components.getValue();
1390     auto *context = rewriter.getContext();
1391     auto llvmI32Type = IntegerType::get(context, 32);
1392     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1393     for (unsigned i = 0; i < componentsArray.size(); i++) {
1394       if (componentsArray[i].isa<IntegerAttr>())
1395         op.emitError("unable to support non-constant component");
1396 
1397       int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1398       if (indexVal == -1)
1399         continue;
1400 
1401       int offsetVal = 0;
1402       Value baseVector = vector1;
1403       if (indexVal >= vector1Size) {
1404         offsetVal = vector1Size;
1405         baseVector = vector2;
1406       }
1407 
1408       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1409           loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1410       Value index = rewriter.create<LLVM::ConstantOp>(
1411           loc, llvmI32Type,
1412           rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1413 
1414       auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1415           loc, scalarType, baseVector, index);
1416       targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1417                                                         extractOp, dstIndex);
1418     }
1419     rewriter.replaceOp(op, targetOp);
1420     return success();
1421   }
1422 };
1423 } // namespace
1424 
1425 //===----------------------------------------------------------------------===//
1426 // Pattern population
1427 //===----------------------------------------------------------------------===//
1428 
populateSPIRVToLLVMTypeConversion(LLVMTypeConverter & typeConverter)1429 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1430   typeConverter.addConversion([&](spirv::ArrayType type) {
1431     return convertArrayType(type, typeConverter);
1432   });
1433   typeConverter.addConversion([&](spirv::PointerType type) {
1434     return convertPointerType(type, typeConverter);
1435   });
1436   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1437     return convertRuntimeArrayType(type, typeConverter);
1438   });
1439   typeConverter.addConversion([&](spirv::StructType type) {
1440     return convertStructType(type, typeConverter);
1441   });
1442 }
1443 
populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter & typeConverter,RewritePatternSet & patterns)1444 void mlir::populateSPIRVToLLVMConversionPatterns(
1445     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1446   patterns.add<
1447       // Arithmetic ops
1448       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1449       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1450       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1451       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1452       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1453       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1454       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1455       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1456       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1457       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1458       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1459       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1460       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1461 
1462       // Bitwise ops
1463       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1464       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1465       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1466       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1467       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1468       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1469       NotPattern<spirv::NotOp>,
1470 
1471       // Cast ops
1472       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1473       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1474       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1475       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1476       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1477       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1478       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1479       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1480 
1481       // Comparison ops
1482       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1483       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1484       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1485       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1486       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1487       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1488       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1489       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1490       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1491       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1492       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1493                       LLVM::FCmpPredicate::uge>,
1494       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1495       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1496       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1497       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1498       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1499       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1500       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1501       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1502       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1503       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1504       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1505 
1506       // Constant op
1507       ConstantScalarAndVectorPattern,
1508 
1509       // Control Flow ops
1510       BranchConversionPattern, BranchConditionalConversionPattern,
1511       FunctionCallPattern, LoopPattern, SelectionPattern,
1512       ErasePattern<spirv::MergeOp>,
1513 
1514       // Entry points and execution mode are handled separately.
1515       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1516 
1517       // GLSL extended instruction set ops
1518       DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1519       DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1520       DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1521       DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1522       DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1523       DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1524       DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1525       DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1526       DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1527       DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1528       DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1529       DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1530       InverseSqrtPattern, TanPattern, TanhPattern,
1531 
1532       // Logical ops
1533       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1534       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1535       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1536       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1537       NotPattern<spirv::LogicalNotOp>,
1538 
1539       // Memory ops
1540       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1541       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1542       VariablePattern,
1543 
1544       // Miscellaneous ops
1545       CompositeExtractPattern, CompositeInsertPattern,
1546       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1547       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1548       VectorShufflePattern,
1549 
1550       // Shift ops
1551       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1552       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1553       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1554 
1555       // Return ops
1556       ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1557 }
1558 
populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter & typeConverter,RewritePatternSet & patterns)1559 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1560     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1561   patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1562 }
1563 
populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter & typeConverter,RewritePatternSet & patterns)1564 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1565     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1566   patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1567 }
1568 
1569 //===----------------------------------------------------------------------===//
1570 // Pre-conversion hooks
1571 //===----------------------------------------------------------------------===//
1572 
1573 /// Hook for descriptor set and binding number encoding.
1574 static constexpr StringRef kBinding = "binding";
1575 static constexpr StringRef kDescriptorSet = "descriptor_set";
encodeBindAttribute(ModuleOp module)1576 void mlir::encodeBindAttribute(ModuleOp module) {
1577   auto spvModules = module.getOps<spirv::ModuleOp>();
1578   for (auto spvModule : spvModules) {
1579     spvModule.walk([&](spirv::GlobalVariableOp op) {
1580       IntegerAttr descriptorSet =
1581           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1582       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1583       // For every global variable in the module, get the ones with descriptor
1584       // set and binding numbers.
1585       if (descriptorSet && binding) {
1586         // Encode these numbers into the variable's symbolic name. If the
1587         // SPIR-V module has a name, add it at the beginning.
1588         auto moduleAndName =
1589             spvModule.getName().has_value()
1590                 ? spvModule.getName().value().str() + "_" + op.sym_name().str()
1591                 : op.sym_name().str();
1592         std::string name =
1593             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1594                           std::to_string(descriptorSet.getInt()),
1595                           std::to_string(binding.getInt()));
1596         auto nameAttr = StringAttr::get(op->getContext(), name);
1597 
1598         // Replace all symbol uses and set the new symbol name. Finally, remove
1599         // descriptor set and binding attributes.
1600         if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1601           op.emitError("unable to replace all symbol uses for ") << name;
1602         SymbolTable::setSymbolName(op, nameAttr);
1603         op->removeAttr(kDescriptorSet);
1604         op->removeAttr(kBinding);
1605       }
1606     });
1607   }
1608 }
1609