1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
30 }
31 
32 // Helper to reduce vector type by *all* but one rank at back.
33 static VectorType reducedVectorTypeBack(VectorType tp) {
34   assert((tp.getRank() > 1) && "unlowerable vector type");
35   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
36 }
37 
38 // Helper that picks the proper sequence for inserting.
39 static Value insertOne(ConversionPatternRewriter &rewriter,
40                        LLVMTypeConverter &typeConverter, Location loc,
41                        Value val1, Value val2, Type llvmType, int64_t rank,
42                        int64_t pos) {
43   if (rank == 1) {
44     auto idxType = rewriter.getIndexType();
45     auto constant = rewriter.create<LLVM::ConstantOp>(
46         loc, typeConverter.convertType(idxType),
47         rewriter.getIntegerAttr(idxType, pos));
48     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
49                                                   constant);
50   }
51   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
52                                               rewriter.getI64ArrayAttr(pos));
53 }
54 
55 // Helper that picks the proper sequence for extracting.
56 static Value extractOne(ConversionPatternRewriter &rewriter,
57                         LLVMTypeConverter &typeConverter, Location loc,
58                         Value val, Type llvmType, int64_t rank, int64_t pos) {
59   if (rank == 1) {
60     auto idxType = rewriter.getIndexType();
61     auto constant = rewriter.create<LLVM::ConstantOp>(
62         loc, typeConverter.convertType(idxType),
63         rewriter.getIntegerAttr(idxType, pos));
64     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
65                                                    constant);
66   }
67   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
68                                                rewriter.getI64ArrayAttr(pos));
69 }
70 
71 // Helper that returns data layout alignment of a memref.
72 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
73                                  MemRefType memrefType, unsigned &align) {
74   Type elementTy = typeConverter.convertType(memrefType.getElementType());
75   if (!elementTy)
76     return failure();
77 
78   // TODO: this should use the MLIR data layout when it becomes available and
79   // stop depending on translation.
80   llvm::LLVMContext llvmContext;
81   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
82               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
83   return success();
84 }
85 
86 // Return the minimal alignment value that satisfies all the AssumeAlignment
87 // uses of `value`. If no such uses exist, return 1.
88 static unsigned getAssumedAlignment(Value value) {
89   unsigned align = 1;
90   for (auto &u : value.getUses()) {
91     Operation *owner = u.getOwner();
92     if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
93       align = mlir::lcm(align, op.alignment());
94   }
95   return align;
96 }
97 
98 // Helper that returns data layout alignment of a memref associated with a
99 // load, store, scatter, or gather op, including additional information from
100 // assume_alignment calls on the source of the transfer
101 template <class OpAdaptor>
102 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
103                                    OpAdaptor op, unsigned &align) {
104   if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
105     return failure();
106   align = std::max(align, getAssumedAlignment(op.base()));
107   return success();
108 }
109 
110 // Add an index vector component to a base pointer. This almost always succeeds
111 // unless the last stride is non-unit or the memory space is not zero.
112 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
113                                     Location loc, Value memref, Value base,
114                                     Value index, MemRefType memRefType,
115                                     VectorType vType, Value &ptrs) {
116   int64_t offset;
117   SmallVector<int64_t, 4> strides;
118   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
119   if (failed(successStrides) || strides.back() != 1 ||
120       memRefType.getMemorySpaceAsInt() != 0)
121     return failure();
122   auto pType = MemRefDescriptor(memref).getElementPtrType();
123   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
124   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
125   return success();
126 }
127 
128 // Casts a strided element pointer to a vector pointer.  The vector pointer
129 // will be in the same address space as the incoming memref type.
130 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
131                          Value ptr, MemRefType memRefType, Type vt) {
132   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
133   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
134 }
135 
136 namespace {
137 
138 /// Conversion pattern for a vector.bitcast.
139 class VectorBitCastOpConversion
140     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
141 public:
142   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
143 
144   LogicalResult
145   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     // Only 1-D vectors can be lowered to LLVM.
148     VectorType resultTy = bitCastOp.getType();
149     if (resultTy.getRank() != 1)
150       return failure();
151     Type newResultTy = typeConverter->convertType(resultTy);
152     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
153                                                  adaptor.getOperands()[0]);
154     return success();
155   }
156 };
157 
158 /// Conversion pattern for a vector.matrix_multiply.
159 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
160 class VectorMatmulOpConversion
161     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
162 public:
163   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
164 
165   LogicalResult
166   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
167                   ConversionPatternRewriter &rewriter) const override {
168     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
169         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
170         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
171         matmulOp.lhs_columns(), matmulOp.rhs_columns());
172     return success();
173   }
174 };
175 
176 /// Conversion pattern for a vector.flat_transpose.
177 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
178 class VectorFlatTransposeOpConversion
179     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
180 public:
181   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
182 
183   LogicalResult
184   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override {
186     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
187         transOp, typeConverter->convertType(transOp.res().getType()),
188         adaptor.matrix(), transOp.rows(), transOp.columns());
189     return success();
190   }
191 };
192 
193 /// Overloaded utility that replaces a vector.load, vector.store,
194 /// vector.maskedload and vector.maskedstore with their respective LLVM
195 /// couterparts.
196 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
197                                  vector::LoadOpAdaptor adaptor,
198                                  VectorType vectorTy, Value ptr, unsigned align,
199                                  ConversionPatternRewriter &rewriter) {
200   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
201 }
202 
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204                                  vector::MaskedLoadOpAdaptor adaptor,
205                                  VectorType vectorTy, Value ptr, unsigned align,
206                                  ConversionPatternRewriter &rewriter) {
207   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
208       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
209 }
210 
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212                                  vector::StoreOpAdaptor adaptor,
213                                  VectorType vectorTy, Value ptr, unsigned align,
214                                  ConversionPatternRewriter &rewriter) {
215   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
216                                              ptr, align);
217 }
218 
219 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
220                                  vector::MaskedStoreOpAdaptor adaptor,
221                                  VectorType vectorTy, Value ptr, unsigned align,
222                                  ConversionPatternRewriter &rewriter) {
223   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
224       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
225 }
226 
227 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
228 /// vector.maskedstore.
229 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
230 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
231 public:
232   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
233 
234   LogicalResult
235   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
236                   typename LoadOrStoreOp::Adaptor adaptor,
237                   ConversionPatternRewriter &rewriter) const override {
238     // Only 1-D vectors can be lowered to LLVM.
239     VectorType vectorTy = loadOrStoreOp.getVectorType();
240     if (vectorTy.getRank() > 1)
241       return failure();
242 
243     auto loc = loadOrStoreOp->getLoc();
244     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
245 
246     // Resolve alignment.
247     unsigned align;
248     if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
249                                     align)))
250       return failure();
251 
252     // Resolve address.
253     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
254                      .template cast<VectorType>();
255     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
256                                                adaptor.indices(), rewriter);
257     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
258 
259     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
260     return success();
261   }
262 };
263 
264 /// Conversion pattern for a vector.gather.
265 class VectorGatherOpConversion
266     : public ConvertOpToLLVMPattern<vector::GatherOp> {
267 public:
268   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
269 
270   LogicalResult
271   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
272                   ConversionPatternRewriter &rewriter) const override {
273     auto loc = gather->getLoc();
274     MemRefType memRefType = gather.getMemRefType();
275 
276     // Resolve alignment.
277     unsigned align;
278     if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
279       return failure();
280 
281     // Resolve address.
282     Value ptrs;
283     VectorType vType = gather.getVectorType();
284     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
285                                      adaptor.indices(), rewriter);
286     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
287                               adaptor.index_vec(), memRefType, vType, ptrs)))
288       return failure();
289 
290     // Replace with the gather intrinsic.
291     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
292         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
293         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
294     return success();
295   }
296 };
297 
298 /// Conversion pattern for a vector.scatter.
299 class VectorScatterOpConversion
300     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
301 public:
302   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
303 
304   LogicalResult
305   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto loc = scatter->getLoc();
308     MemRefType memRefType = scatter.getMemRefType();
309 
310     // Resolve alignment.
311     unsigned align;
312     if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
313       return failure();
314 
315     // Resolve address.
316     Value ptrs;
317     VectorType vType = scatter.getVectorType();
318     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
319                                      adaptor.indices(), rewriter);
320     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
321                               adaptor.index_vec(), memRefType, vType, ptrs)))
322       return failure();
323 
324     // Replace with the scatter intrinsic.
325     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
326         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
327         rewriter.getI32IntegerAttr(align));
328     return success();
329   }
330 };
331 
332 /// Conversion pattern for a vector.expandload.
333 class VectorExpandLoadOpConversion
334     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
335 public:
336   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
337 
338   LogicalResult
339   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
340                   ConversionPatternRewriter &rewriter) const override {
341     auto loc = expand->getLoc();
342     MemRefType memRefType = expand.getMemRefType();
343 
344     // Resolve address.
345     auto vtype = typeConverter->convertType(expand.getVectorType());
346     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
347                                      adaptor.indices(), rewriter);
348 
349     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
350         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
351     return success();
352   }
353 };
354 
355 /// Conversion pattern for a vector.compressstore.
356 class VectorCompressStoreOpConversion
357     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
358 public:
359   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
360 
361   LogicalResult
362   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
363                   ConversionPatternRewriter &rewriter) const override {
364     auto loc = compress->getLoc();
365     MemRefType memRefType = compress.getMemRefType();
366 
367     // Resolve address.
368     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
369                                      adaptor.indices(), rewriter);
370 
371     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
372         compress, adaptor.valueToStore(), ptr, adaptor.mask());
373     return success();
374   }
375 };
376 
377 /// Conversion pattern for all vector reductions.
378 class VectorReductionOpConversion
379     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
380 public:
381   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
382                                        bool reassociateFPRed)
383       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
384         reassociateFPReductions(reassociateFPRed) {}
385 
386   LogicalResult
387   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
388                   ConversionPatternRewriter &rewriter) const override {
389     auto kind = reductionOp.kind();
390     Type eltType = reductionOp.dest().getType();
391     Type llvmType = typeConverter->convertType(eltType);
392     Value operand = adaptor.getOperands()[0];
393     if (eltType.isIntOrIndex()) {
394       // Integer reductions: add/mul/min/max/and/or/xor.
395       if (kind == "add")
396         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
397                                                              llvmType, operand);
398       else if (kind == "mul")
399         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
400                                                              llvmType, operand);
401       else if (kind == "minui")
402         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
403             reductionOp, llvmType, operand);
404       else if (kind == "minsi")
405         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
406             reductionOp, llvmType, operand);
407       else if (kind == "maxui")
408         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
409             reductionOp, llvmType, operand);
410       else if (kind == "maxsi")
411         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
412             reductionOp, llvmType, operand);
413       else if (kind == "and")
414         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
415                                                              llvmType, operand);
416       else if (kind == "or")
417         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
418                                                             llvmType, operand);
419       else if (kind == "xor")
420         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
421                                                              llvmType, operand);
422       else
423         return failure();
424       return success();
425     }
426 
427     if (!eltType.isa<FloatType>())
428       return failure();
429 
430     // Floating-point reductions: add/mul/min/max
431     if (kind == "add") {
432       // Optional accumulator (or zero).
433       Value acc = adaptor.getOperands().size() > 1
434                       ? adaptor.getOperands()[1]
435                       : rewriter.create<LLVM::ConstantOp>(
436                             reductionOp->getLoc(), llvmType,
437                             rewriter.getZeroAttr(eltType));
438       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
439           reductionOp, llvmType, acc, operand,
440           rewriter.getBoolAttr(reassociateFPReductions));
441     } else if (kind == "mul") {
442       // Optional accumulator (or one).
443       Value acc = adaptor.getOperands().size() > 1
444                       ? adaptor.getOperands()[1]
445                       : rewriter.create<LLVM::ConstantOp>(
446                             reductionOp->getLoc(), llvmType,
447                             rewriter.getFloatAttr(eltType, 1.0));
448       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
449           reductionOp, llvmType, acc, operand,
450           rewriter.getBoolAttr(reassociateFPReductions));
451     } else if (kind == "minf")
452       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
453       // NaNs/-0.0/+0.0 in the same way.
454       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
455                                                             llvmType, operand);
456     else if (kind == "maxf")
457       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
458       // NaNs/-0.0/+0.0 in the same way.
459       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
460                                                             llvmType, operand);
461     else
462       return failure();
463     return success();
464   }
465 
466 private:
467   const bool reassociateFPReductions;
468 };
469 
470 class VectorShuffleOpConversion
471     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
472 public:
473   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
474 
475   LogicalResult
476   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
477                   ConversionPatternRewriter &rewriter) const override {
478     auto loc = shuffleOp->getLoc();
479     auto v1Type = shuffleOp.getV1VectorType();
480     auto v2Type = shuffleOp.getV2VectorType();
481     auto vectorType = shuffleOp.getVectorType();
482     Type llvmType = typeConverter->convertType(vectorType);
483     auto maskArrayAttr = shuffleOp.mask();
484 
485     // Bail if result type cannot be lowered.
486     if (!llvmType)
487       return failure();
488 
489     // Get rank and dimension sizes.
490     int64_t rank = vectorType.getRank();
491     assert(v1Type.getRank() == rank);
492     assert(v2Type.getRank() == rank);
493     int64_t v1Dim = v1Type.getDimSize(0);
494 
495     // For rank 1, where both operands have *exactly* the same vector type,
496     // there is direct shuffle support in LLVM. Use it!
497     if (rank == 1 && v1Type == v2Type) {
498       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
499           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
500       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
501       return success();
502     }
503 
504     // For all other cases, insert the individual values individually.
505     Type eltType;
506     llvm::errs() << llvmType << "\n";
507     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
508       eltType = arrayType.getElementType();
509     else
510       eltType = llvmType.cast<VectorType>().getElementType();
511     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
512     int64_t insPos = 0;
513     for (auto en : llvm::enumerate(maskArrayAttr)) {
514       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
515       Value value = adaptor.v1();
516       if (extPos >= v1Dim) {
517         extPos -= v1Dim;
518         value = adaptor.v2();
519       }
520       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
521                                  eltType, rank, extPos);
522       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
523                          llvmType, rank, insPos++);
524     }
525     rewriter.replaceOp(shuffleOp, insert);
526     return success();
527   }
528 };
529 
530 class VectorExtractElementOpConversion
531     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
532 public:
533   using ConvertOpToLLVMPattern<
534       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
535 
536   LogicalResult
537   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
538                   ConversionPatternRewriter &rewriter) const override {
539     auto vectorType = extractEltOp.getVectorType();
540     auto llvmType = typeConverter->convertType(vectorType.getElementType());
541 
542     // Bail if result type cannot be lowered.
543     if (!llvmType)
544       return failure();
545 
546     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
547         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
548     return success();
549   }
550 };
551 
552 class VectorExtractOpConversion
553     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
554 public:
555   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
556 
557   LogicalResult
558   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
559                   ConversionPatternRewriter &rewriter) const override {
560     auto loc = extractOp->getLoc();
561     auto vectorType = extractOp.getVectorType();
562     auto resultType = extractOp.getResult().getType();
563     auto llvmResultType = typeConverter->convertType(resultType);
564     auto positionArrayAttr = extractOp.position();
565 
566     // Bail if result type cannot be lowered.
567     if (!llvmResultType)
568       return failure();
569 
570     // Extract entire vector. Should be handled by folder, but just to be safe.
571     if (positionArrayAttr.empty()) {
572       rewriter.replaceOp(extractOp, adaptor.vector());
573       return success();
574     }
575 
576     // One-shot extraction of vector from array (only requires extractvalue).
577     if (resultType.isa<VectorType>()) {
578       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
579           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
580       rewriter.replaceOp(extractOp, extracted);
581       return success();
582     }
583 
584     // Potential extraction of 1-D vector from array.
585     auto *context = extractOp->getContext();
586     Value extracted = adaptor.vector();
587     auto positionAttrs = positionArrayAttr.getValue();
588     if (positionAttrs.size() > 1) {
589       auto oneDVectorType = reducedVectorTypeBack(vectorType);
590       auto nMinusOnePositionAttrs =
591           ArrayAttr::get(context, positionAttrs.drop_back());
592       extracted = rewriter.create<LLVM::ExtractValueOp>(
593           loc, typeConverter->convertType(oneDVectorType), extracted,
594           nMinusOnePositionAttrs);
595     }
596 
597     // Remaining extraction of element from 1-D LLVM vector
598     auto position = positionAttrs.back().cast<IntegerAttr>();
599     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
600     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
601     extracted =
602         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
603     rewriter.replaceOp(extractOp, extracted);
604 
605     return success();
606   }
607 };
608 
609 /// Conversion pattern that turns a vector.fma on a 1-D vector
610 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
611 /// This does not match vectors of n >= 2 rank.
612 ///
613 /// Example:
614 /// ```
615 ///  vector.fma %a, %a, %a : vector<8xf32>
616 /// ```
617 /// is converted to:
618 /// ```
619 ///  llvm.intr.fmuladd %va, %va, %va:
620 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
621 ///    -> !llvm."<8 x f32>">
622 /// ```
623 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
624 public:
625   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
626 
627   LogicalResult
628   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
629                   ConversionPatternRewriter &rewriter) const override {
630     VectorType vType = fmaOp.getVectorType();
631     if (vType.getRank() != 1)
632       return failure();
633     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
634                                                  adaptor.rhs(), adaptor.acc());
635     return success();
636   }
637 };
638 
639 class VectorInsertElementOpConversion
640     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
641 public:
642   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
643 
644   LogicalResult
645   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
646                   ConversionPatternRewriter &rewriter) const override {
647     auto vectorType = insertEltOp.getDestVectorType();
648     auto llvmType = typeConverter->convertType(vectorType);
649 
650     // Bail if result type cannot be lowered.
651     if (!llvmType)
652       return failure();
653 
654     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
655         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
656         adaptor.position());
657     return success();
658   }
659 };
660 
661 class VectorInsertOpConversion
662     : public ConvertOpToLLVMPattern<vector::InsertOp> {
663 public:
664   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
665 
666   LogicalResult
667   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
668                   ConversionPatternRewriter &rewriter) const override {
669     auto loc = insertOp->getLoc();
670     auto sourceType = insertOp.getSourceType();
671     auto destVectorType = insertOp.getDestVectorType();
672     auto llvmResultType = typeConverter->convertType(destVectorType);
673     auto positionArrayAttr = insertOp.position();
674 
675     // Bail if result type cannot be lowered.
676     if (!llvmResultType)
677       return failure();
678 
679     // Overwrite entire vector with value. Should be handled by folder, but
680     // just to be safe.
681     if (positionArrayAttr.empty()) {
682       rewriter.replaceOp(insertOp, adaptor.source());
683       return success();
684     }
685 
686     // One-shot insertion of a vector into an array (only requires insertvalue).
687     if (sourceType.isa<VectorType>()) {
688       Value inserted = rewriter.create<LLVM::InsertValueOp>(
689           loc, llvmResultType, adaptor.dest(), adaptor.source(),
690           positionArrayAttr);
691       rewriter.replaceOp(insertOp, inserted);
692       return success();
693     }
694 
695     // Potential extraction of 1-D vector from array.
696     auto *context = insertOp->getContext();
697     Value extracted = adaptor.dest();
698     auto positionAttrs = positionArrayAttr.getValue();
699     auto position = positionAttrs.back().cast<IntegerAttr>();
700     auto oneDVectorType = destVectorType;
701     if (positionAttrs.size() > 1) {
702       oneDVectorType = reducedVectorTypeBack(destVectorType);
703       auto nMinusOnePositionAttrs =
704           ArrayAttr::get(context, positionAttrs.drop_back());
705       extracted = rewriter.create<LLVM::ExtractValueOp>(
706           loc, typeConverter->convertType(oneDVectorType), extracted,
707           nMinusOnePositionAttrs);
708     }
709 
710     // Insertion of an element into a 1-D LLVM vector.
711     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
712     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
713     Value inserted = rewriter.create<LLVM::InsertElementOp>(
714         loc, typeConverter->convertType(oneDVectorType), extracted,
715         adaptor.source(), constant);
716 
717     // Potential insertion of resulting 1-D vector into array.
718     if (positionAttrs.size() > 1) {
719       auto nMinusOnePositionAttrs =
720           ArrayAttr::get(context, positionAttrs.drop_back());
721       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
722                                                       adaptor.dest(), inserted,
723                                                       nMinusOnePositionAttrs);
724     }
725 
726     rewriter.replaceOp(insertOp, inserted);
727     return success();
728   }
729 };
730 
731 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
732 ///
733 /// Example:
734 /// ```
735 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
736 /// ```
737 /// is rewritten into:
738 /// ```
739 ///  %r = splat %f0: vector<2x4xf32>
740 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
741 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
742 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
743 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
744 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
745 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
746 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
747 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
748 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
749 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
750 ///  // %r3 holds the final value.
751 /// ```
752 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
753 public:
754   using OpRewritePattern<FMAOp>::OpRewritePattern;
755 
756   LogicalResult matchAndRewrite(FMAOp op,
757                                 PatternRewriter &rewriter) const override {
758     auto vType = op.getVectorType();
759     if (vType.getRank() < 2)
760       return failure();
761 
762     auto loc = op.getLoc();
763     auto elemType = vType.getElementType();
764     Value zero = rewriter.create<arith::ConstantOp>(
765         loc, elemType, rewriter.getZeroAttr(elemType));
766     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
767     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
768       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
769       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
770       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
771       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
772       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
773     }
774     rewriter.replaceOp(op, desc);
775     return success();
776   }
777 };
778 
779 /// Returns the strides if the memory underlying `memRefType` has a contiguous
780 /// static layout.
781 static llvm::Optional<SmallVector<int64_t, 4>>
782 computeContiguousStrides(MemRefType memRefType) {
783   int64_t offset;
784   SmallVector<int64_t, 4> strides;
785   if (failed(getStridesAndOffset(memRefType, strides, offset)))
786     return None;
787   if (!strides.empty() && strides.back() != 1)
788     return None;
789   // If no layout or identity layout, this is contiguous by definition.
790   if (memRefType.getLayout().isIdentity())
791     return strides;
792 
793   // Otherwise, we must determine contiguity form shapes. This can only ever
794   // work in static cases because MemRefType is underspecified to represent
795   // contiguous dynamic shapes in other ways than with just empty/identity
796   // layout.
797   auto sizes = memRefType.getShape();
798   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
799     if (ShapedType::isDynamic(sizes[index + 1]) ||
800         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
801         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
802       return None;
803     if (strides[index] != strides[index + 1] * sizes[index + 1])
804       return None;
805   }
806   return strides;
807 }
808 
809 class VectorTypeCastOpConversion
810     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
811 public:
812   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
813 
814   LogicalResult
815   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
816                   ConversionPatternRewriter &rewriter) const override {
817     auto loc = castOp->getLoc();
818     MemRefType sourceMemRefType =
819         castOp.getOperand().getType().cast<MemRefType>();
820     MemRefType targetMemRefType = castOp.getType();
821 
822     // Only static shape casts supported atm.
823     if (!sourceMemRefType.hasStaticShape() ||
824         !targetMemRefType.hasStaticShape())
825       return failure();
826 
827     auto llvmSourceDescriptorTy =
828         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
829     if (!llvmSourceDescriptorTy)
830       return failure();
831     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
832 
833     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
834                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
835     if (!llvmTargetDescriptorTy)
836       return failure();
837 
838     // Only contiguous source buffers supported atm.
839     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
840     if (!sourceStrides)
841       return failure();
842     auto targetStrides = computeContiguousStrides(targetMemRefType);
843     if (!targetStrides)
844       return failure();
845     // Only support static strides for now, regardless of contiguity.
846     if (llvm::any_of(*targetStrides, [](int64_t stride) {
847           return ShapedType::isDynamicStrideOrOffset(stride);
848         }))
849       return failure();
850 
851     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
852 
853     // Create descriptor.
854     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
855     Type llvmTargetElementTy = desc.getElementPtrType();
856     // Set allocated ptr.
857     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
858     allocated =
859         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
860     desc.setAllocatedPtr(rewriter, loc, allocated);
861     // Set aligned ptr.
862     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
863     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
864     desc.setAlignedPtr(rewriter, loc, ptr);
865     // Fill offset 0.
866     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
867     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
868     desc.setOffset(rewriter, loc, zero);
869 
870     // Fill size and stride descriptors in memref.
871     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
872       int64_t index = indexedSize.index();
873       auto sizeAttr =
874           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
875       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
876       desc.setSize(rewriter, loc, index, size);
877       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
878                                                 (*targetStrides)[index]);
879       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
880       desc.setStride(rewriter, loc, index, stride);
881     }
882 
883     rewriter.replaceOp(castOp, {desc});
884     return success();
885   }
886 };
887 
888 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
889 public:
890   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
891 
892   // Proof-of-concept lowering implementation that relies on a small
893   // runtime support library, which only needs to provide a few
894   // printing methods (single value for all data types, opening/closing
895   // bracket, comma, newline). The lowering fully unrolls a vector
896   // in terms of these elementary printing operations. The advantage
897   // of this approach is that the library can remain unaware of all
898   // low-level implementation details of vectors while still supporting
899   // output of any shaped and dimensioned vector. Due to full unrolling,
900   // this approach is less suited for very large vectors though.
901   //
902   // TODO: rely solely on libc in future? something else?
903   //
904   LogicalResult
905   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
906                   ConversionPatternRewriter &rewriter) const override {
907     Type printType = printOp.getPrintType();
908 
909     if (typeConverter->convertType(printType) == nullptr)
910       return failure();
911 
912     // Make sure element type has runtime support.
913     PrintConversion conversion = PrintConversion::None;
914     VectorType vectorType = printType.dyn_cast<VectorType>();
915     Type eltType = vectorType ? vectorType.getElementType() : printType;
916     Operation *printer;
917     if (eltType.isF32()) {
918       printer =
919           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
920     } else if (eltType.isF64()) {
921       printer =
922           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
923     } else if (eltType.isIndex()) {
924       printer =
925           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
926     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
927       // Integers need a zero or sign extension on the operand
928       // (depending on the source type) as well as a signed or
929       // unsigned print method. Up to 64-bit is supported.
930       unsigned width = intTy.getWidth();
931       if (intTy.isUnsigned()) {
932         if (width <= 64) {
933           if (width < 64)
934             conversion = PrintConversion::ZeroExt64;
935           printer = LLVM::lookupOrCreatePrintU64Fn(
936               printOp->getParentOfType<ModuleOp>());
937         } else {
938           return failure();
939         }
940       } else {
941         assert(intTy.isSignless() || intTy.isSigned());
942         if (width <= 64) {
943           // Note that we *always* zero extend booleans (1-bit integers),
944           // so that true/false is printed as 1/0 rather than -1/0.
945           if (width == 1)
946             conversion = PrintConversion::ZeroExt64;
947           else if (width < 64)
948             conversion = PrintConversion::SignExt64;
949           printer = LLVM::lookupOrCreatePrintI64Fn(
950               printOp->getParentOfType<ModuleOp>());
951         } else {
952           return failure();
953         }
954       }
955     } else {
956       return failure();
957     }
958 
959     // Unroll vector into elementary print calls.
960     int64_t rank = vectorType ? vectorType.getRank() : 0;
961     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
962               conversion);
963     emitCall(rewriter, printOp->getLoc(),
964              LLVM::lookupOrCreatePrintNewlineFn(
965                  printOp->getParentOfType<ModuleOp>()));
966     rewriter.eraseOp(printOp);
967     return success();
968   }
969 
970 private:
971   enum class PrintConversion {
972     // clang-format off
973     None,
974     ZeroExt64,
975     SignExt64
976     // clang-format on
977   };
978 
979   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
980                  Value value, VectorType vectorType, Operation *printer,
981                  int64_t rank, PrintConversion conversion) const {
982     Location loc = op->getLoc();
983     if (rank == 0) {
984       switch (conversion) {
985       case PrintConversion::ZeroExt64:
986         value = rewriter.create<arith::ExtUIOp>(
987             loc, value, IntegerType::get(rewriter.getContext(), 64));
988         break;
989       case PrintConversion::SignExt64:
990         value = rewriter.create<arith::ExtSIOp>(
991             loc, value, IntegerType::get(rewriter.getContext(), 64));
992         break;
993       case PrintConversion::None:
994         break;
995       }
996       emitCall(rewriter, loc, printer, value);
997       return;
998     }
999 
1000     emitCall(rewriter, loc,
1001              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1002     Operation *printComma =
1003         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1004     int64_t dim = vectorType.getDimSize(0);
1005     for (int64_t d = 0; d < dim; ++d) {
1006       auto reducedType =
1007           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1008       auto llvmType = typeConverter->convertType(
1009           rank > 1 ? reducedType : vectorType.getElementType());
1010       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1011                                    llvmType, rank, d);
1012       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1013                 conversion);
1014       if (d != dim - 1)
1015         emitCall(rewriter, loc, printComma);
1016     }
1017     emitCall(rewriter, loc,
1018              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1019   }
1020 
1021   // Helper to emit a call.
1022   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1023                        Operation *ref, ValueRange params = ValueRange()) {
1024     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1025                                   params);
1026   }
1027 };
1028 
1029 } // namespace
1030 
1031 /// Populate the given list with patterns that convert from Vector to LLVM.
1032 void mlir::populateVectorToLLVMConversionPatterns(
1033     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1034     bool reassociateFPReductions) {
1035   MLIRContext *ctx = converter.getDialect()->getContext();
1036   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1037   populateVectorInsertExtractStridedSliceTransforms(patterns);
1038   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1039   patterns
1040       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1041            VectorExtractElementOpConversion, VectorExtractOpConversion,
1042            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1043            VectorInsertOpConversion, VectorPrintOpConversion,
1044            VectorTypeCastOpConversion,
1045            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1046            VectorLoadStoreConversion<vector::MaskedLoadOp,
1047                                      vector::MaskedLoadOpAdaptor>,
1048            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1049            VectorLoadStoreConversion<vector::MaskedStoreOp,
1050                                      vector::MaskedStoreOpAdaptor>,
1051            VectorGatherOpConversion, VectorScatterOpConversion,
1052            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1053           converter);
1054   // Transfer ops with rank > 1 are handled by VectorToSCF.
1055   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1056 }
1057 
1058 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1059     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1060   patterns.add<VectorMatmulOpConversion>(converter);
1061   patterns.add<VectorFlatTransposeOpConversion>(converter);
1062 }
1063