1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
30 }
31 
32 // Helper to reduce vector type by *all* but one rank at back.
33 static VectorType reducedVectorTypeBack(VectorType tp) {
34   assert((tp.getRank() > 1) && "unlowerable vector type");
35   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
36 }
37 
38 // Helper that picks the proper sequence for inserting.
39 static Value insertOne(ConversionPatternRewriter &rewriter,
40                        LLVMTypeConverter &typeConverter, Location loc,
41                        Value val1, Value val2, Type llvmType, int64_t rank,
42                        int64_t pos) {
43   assert(rank > 0 && "0-D vector corner case should have been handled already");
44   if (rank == 1) {
45     auto idxType = rewriter.getIndexType();
46     auto constant = rewriter.create<LLVM::ConstantOp>(
47         loc, typeConverter.convertType(idxType),
48         rewriter.getIntegerAttr(idxType, pos));
49     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50                                                   constant);
51   }
52   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
53                                               rewriter.getI64ArrayAttr(pos));
54 }
55 
56 // Helper that picks the proper sequence for extracting.
57 static Value extractOne(ConversionPatternRewriter &rewriter,
58                         LLVMTypeConverter &typeConverter, Location loc,
59                         Value val, Type llvmType, int64_t rank, int64_t pos) {
60   if (rank <= 1) {
61     auto idxType = rewriter.getIndexType();
62     auto constant = rewriter.create<LLVM::ConstantOp>(
63         loc, typeConverter.convertType(idxType),
64         rewriter.getIntegerAttr(idxType, pos));
65     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
66                                                    constant);
67   }
68   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
69                                                rewriter.getI64ArrayAttr(pos));
70 }
71 
72 // Helper that returns data layout alignment of a memref.
73 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
74                                  MemRefType memrefType, unsigned &align) {
75   Type elementTy = typeConverter.convertType(memrefType.getElementType());
76   if (!elementTy)
77     return failure();
78 
79   // TODO: this should use the MLIR data layout when it becomes available and
80   // stop depending on translation.
81   llvm::LLVMContext llvmContext;
82   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
83               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
84   return success();
85 }
86 
87 // Add an index vector component to a base pointer. This almost always succeeds
88 // unless the last stride is non-unit or the memory space is not zero.
89 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
90                                     Location loc, Value memref, Value base,
91                                     Value index, MemRefType memRefType,
92                                     VectorType vType, Value &ptrs) {
93   int64_t offset;
94   SmallVector<int64_t, 4> strides;
95   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
96   if (failed(successStrides) || strides.back() != 1 ||
97       memRefType.getMemorySpaceAsInt() != 0)
98     return failure();
99   auto pType = MemRefDescriptor(memref).getElementPtrType();
100   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
101   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
102   return success();
103 }
104 
105 // Casts a strided element pointer to a vector pointer.  The vector pointer
106 // will be in the same address space as the incoming memref type.
107 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
108                          Value ptr, MemRefType memRefType, Type vt) {
109   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
110   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
111 }
112 
113 namespace {
114 
115 /// Conversion pattern for a vector.bitcast.
116 class VectorBitCastOpConversion
117     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
118 public:
119   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
120 
121   LogicalResult
122   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
123                   ConversionPatternRewriter &rewriter) const override {
124     // Only 0-D and 1-D vectors can be lowered to LLVM.
125     VectorType resultTy = bitCastOp.getResultVectorType();
126     if (resultTy.getRank() > 1)
127       return failure();
128     Type newResultTy = typeConverter->convertType(resultTy);
129     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
130                                                  adaptor.getOperands()[0]);
131     return success();
132   }
133 };
134 
135 /// Conversion pattern for a vector.matrix_multiply.
136 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
137 class VectorMatmulOpConversion
138     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
139 public:
140   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
141 
142   LogicalResult
143   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
144                   ConversionPatternRewriter &rewriter) const override {
145     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
146         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
147         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
148         matmulOp.lhs_columns(), matmulOp.rhs_columns());
149     return success();
150   }
151 };
152 
153 /// Conversion pattern for a vector.flat_transpose.
154 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
155 class VectorFlatTransposeOpConversion
156     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
157 public:
158   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
159 
160   LogicalResult
161   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
162                   ConversionPatternRewriter &rewriter) const override {
163     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
164         transOp, typeConverter->convertType(transOp.res().getType()),
165         adaptor.matrix(), transOp.rows(), transOp.columns());
166     return success();
167   }
168 };
169 
170 /// Overloaded utility that replaces a vector.load, vector.store,
171 /// vector.maskedload and vector.maskedstore with their respective LLVM
172 /// couterparts.
173 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
174                                  vector::LoadOpAdaptor adaptor,
175                                  VectorType vectorTy, Value ptr, unsigned align,
176                                  ConversionPatternRewriter &rewriter) {
177   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
178 }
179 
180 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
181                                  vector::MaskedLoadOpAdaptor adaptor,
182                                  VectorType vectorTy, Value ptr, unsigned align,
183                                  ConversionPatternRewriter &rewriter) {
184   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
185       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
186 }
187 
188 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
189                                  vector::StoreOpAdaptor adaptor,
190                                  VectorType vectorTy, Value ptr, unsigned align,
191                                  ConversionPatternRewriter &rewriter) {
192   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
193                                              ptr, align);
194 }
195 
196 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
197                                  vector::MaskedStoreOpAdaptor adaptor,
198                                  VectorType vectorTy, Value ptr, unsigned align,
199                                  ConversionPatternRewriter &rewriter) {
200   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
201       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
202 }
203 
204 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
205 /// vector.maskedstore.
206 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
207 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
208 public:
209   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
210 
211   LogicalResult
212   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
213                   typename LoadOrStoreOp::Adaptor adaptor,
214                   ConversionPatternRewriter &rewriter) const override {
215     // Only 1-D vectors can be lowered to LLVM.
216     VectorType vectorTy = loadOrStoreOp.getVectorType();
217     if (vectorTy.getRank() > 1)
218       return failure();
219 
220     auto loc = loadOrStoreOp->getLoc();
221     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
222 
223     // Resolve alignment.
224     unsigned align;
225     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
226       return failure();
227 
228     // Resolve address.
229     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
230                      .template cast<VectorType>();
231     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
232                                                adaptor.indices(), rewriter);
233     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
234 
235     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
236     return success();
237   }
238 };
239 
240 /// Conversion pattern for a vector.gather.
241 class VectorGatherOpConversion
242     : public ConvertOpToLLVMPattern<vector::GatherOp> {
243 public:
244   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
245 
246   LogicalResult
247   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
248                   ConversionPatternRewriter &rewriter) const override {
249     auto loc = gather->getLoc();
250     MemRefType memRefType = gather.getMemRefType();
251 
252     // Resolve alignment.
253     unsigned align;
254     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
255       return failure();
256 
257     // Resolve address.
258     Value ptrs;
259     VectorType vType = gather.getVectorType();
260     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
261                                      adaptor.indices(), rewriter);
262     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
263                               adaptor.index_vec(), memRefType, vType, ptrs)))
264       return failure();
265 
266     // Replace with the gather intrinsic.
267     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
268         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
269         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
270     return success();
271   }
272 };
273 
274 /// Conversion pattern for a vector.scatter.
275 class VectorScatterOpConversion
276     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
277 public:
278   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
279 
280   LogicalResult
281   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
282                   ConversionPatternRewriter &rewriter) const override {
283     auto loc = scatter->getLoc();
284     MemRefType memRefType = scatter.getMemRefType();
285 
286     // Resolve alignment.
287     unsigned align;
288     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
289       return failure();
290 
291     // Resolve address.
292     Value ptrs;
293     VectorType vType = scatter.getVectorType();
294     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
295                                      adaptor.indices(), rewriter);
296     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
297                               adaptor.index_vec(), memRefType, vType, ptrs)))
298       return failure();
299 
300     // Replace with the scatter intrinsic.
301     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
302         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
303         rewriter.getI32IntegerAttr(align));
304     return success();
305   }
306 };
307 
308 /// Conversion pattern for a vector.expandload.
309 class VectorExpandLoadOpConversion
310     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
311 public:
312   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
313 
314   LogicalResult
315   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
316                   ConversionPatternRewriter &rewriter) const override {
317     auto loc = expand->getLoc();
318     MemRefType memRefType = expand.getMemRefType();
319 
320     // Resolve address.
321     auto vtype = typeConverter->convertType(expand.getVectorType());
322     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
323                                      adaptor.indices(), rewriter);
324 
325     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
326         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
327     return success();
328   }
329 };
330 
331 /// Conversion pattern for a vector.compressstore.
332 class VectorCompressStoreOpConversion
333     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
334 public:
335   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
336 
337   LogicalResult
338   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     auto loc = compress->getLoc();
341     MemRefType memRefType = compress.getMemRefType();
342 
343     // Resolve address.
344     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
345                                      adaptor.indices(), rewriter);
346 
347     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
348         compress, adaptor.valueToStore(), ptr, adaptor.mask());
349     return success();
350   }
351 };
352 
353 /// Conversion pattern for all vector reductions.
354 class VectorReductionOpConversion
355     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
356 public:
357   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
358                                        bool reassociateFPRed)
359       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
360         reassociateFPReductions(reassociateFPRed) {}
361 
362   LogicalResult
363   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
364                   ConversionPatternRewriter &rewriter) const override {
365     auto kind = reductionOp.kind();
366     Type eltType = reductionOp.dest().getType();
367     Type llvmType = typeConverter->convertType(eltType);
368     Value operand = adaptor.getOperands()[0];
369     if (eltType.isIntOrIndex()) {
370       // Integer reductions: add/mul/min/max/and/or/xor.
371       if (kind == "add")
372         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
373                                                              llvmType, operand);
374       else if (kind == "mul")
375         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
376                                                              llvmType, operand);
377       else if (kind == "minui")
378         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
379             reductionOp, llvmType, operand);
380       else if (kind == "minsi")
381         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
382             reductionOp, llvmType, operand);
383       else if (kind == "maxui")
384         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
385             reductionOp, llvmType, operand);
386       else if (kind == "maxsi")
387         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
388             reductionOp, llvmType, operand);
389       else if (kind == "and")
390         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
391                                                              llvmType, operand);
392       else if (kind == "or")
393         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
394                                                             llvmType, operand);
395       else if (kind == "xor")
396         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
397                                                              llvmType, operand);
398       else
399         return failure();
400       return success();
401     }
402 
403     if (!eltType.isa<FloatType>())
404       return failure();
405 
406     // Floating-point reductions: add/mul/min/max
407     if (kind == "add") {
408       // Optional accumulator (or zero).
409       Value acc = adaptor.getOperands().size() > 1
410                       ? adaptor.getOperands()[1]
411                       : rewriter.create<LLVM::ConstantOp>(
412                             reductionOp->getLoc(), llvmType,
413                             rewriter.getZeroAttr(eltType));
414       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
415           reductionOp, llvmType, acc, operand,
416           rewriter.getBoolAttr(reassociateFPReductions));
417     } else if (kind == "mul") {
418       // Optional accumulator (or one).
419       Value acc = adaptor.getOperands().size() > 1
420                       ? adaptor.getOperands()[1]
421                       : rewriter.create<LLVM::ConstantOp>(
422                             reductionOp->getLoc(), llvmType,
423                             rewriter.getFloatAttr(eltType, 1.0));
424       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
425           reductionOp, llvmType, acc, operand,
426           rewriter.getBoolAttr(reassociateFPReductions));
427     } else if (kind == "minf")
428       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
429       // NaNs/-0.0/+0.0 in the same way.
430       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
431                                                             llvmType, operand);
432     else if (kind == "maxf")
433       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
434       // NaNs/-0.0/+0.0 in the same way.
435       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
436                                                             llvmType, operand);
437     else
438       return failure();
439     return success();
440   }
441 
442 private:
443   const bool reassociateFPReductions;
444 };
445 
446 class VectorShuffleOpConversion
447     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
448 public:
449   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
450 
451   LogicalResult
452   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
453                   ConversionPatternRewriter &rewriter) const override {
454     auto loc = shuffleOp->getLoc();
455     auto v1Type = shuffleOp.getV1VectorType();
456     auto v2Type = shuffleOp.getV2VectorType();
457     auto vectorType = shuffleOp.getVectorType();
458     Type llvmType = typeConverter->convertType(vectorType);
459     auto maskArrayAttr = shuffleOp.mask();
460 
461     // Bail if result type cannot be lowered.
462     if (!llvmType)
463       return failure();
464 
465     // Get rank and dimension sizes.
466     int64_t rank = vectorType.getRank();
467     assert(v1Type.getRank() == rank);
468     assert(v2Type.getRank() == rank);
469     int64_t v1Dim = v1Type.getDimSize(0);
470 
471     // For rank 1, where both operands have *exactly* the same vector type,
472     // there is direct shuffle support in LLVM. Use it!
473     if (rank == 1 && v1Type == v2Type) {
474       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
475           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
476       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
477       return success();
478     }
479 
480     // For all other cases, insert the individual values individually.
481     Type eltType;
482     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
483       eltType = arrayType.getElementType();
484     else
485       eltType = llvmType.cast<VectorType>().getElementType();
486     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
487     int64_t insPos = 0;
488     for (auto en : llvm::enumerate(maskArrayAttr)) {
489       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
490       Value value = adaptor.v1();
491       if (extPos >= v1Dim) {
492         extPos -= v1Dim;
493         value = adaptor.v2();
494       }
495       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
496                                  eltType, rank, extPos);
497       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
498                          llvmType, rank, insPos++);
499     }
500     rewriter.replaceOp(shuffleOp, insert);
501     return success();
502   }
503 };
504 
505 class VectorExtractElementOpConversion
506     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
507 public:
508   using ConvertOpToLLVMPattern<
509       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
510 
511   LogicalResult
512   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
513                   ConversionPatternRewriter &rewriter) const override {
514     auto vectorType = extractEltOp.getVectorType();
515     auto llvmType = typeConverter->convertType(vectorType.getElementType());
516 
517     // Bail if result type cannot be lowered.
518     if (!llvmType)
519       return failure();
520 
521     if (vectorType.getRank() == 0) {
522       Location loc = extractEltOp.getLoc();
523       auto idxType = rewriter.getIndexType();
524       auto zero = rewriter.create<LLVM::ConstantOp>(
525           loc, typeConverter->convertType(idxType),
526           rewriter.getIntegerAttr(idxType, 0));
527       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
528           extractEltOp, llvmType, adaptor.vector(), zero);
529       return success();
530     }
531 
532     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
533         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
534     return success();
535   }
536 };
537 
538 class VectorExtractOpConversion
539     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
540 public:
541   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
542 
543   LogicalResult
544   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
545                   ConversionPatternRewriter &rewriter) const override {
546     auto loc = extractOp->getLoc();
547     auto vectorType = extractOp.getVectorType();
548     auto resultType = extractOp.getResult().getType();
549     auto llvmResultType = typeConverter->convertType(resultType);
550     auto positionArrayAttr = extractOp.position();
551 
552     // Bail if result type cannot be lowered.
553     if (!llvmResultType)
554       return failure();
555 
556     // Extract entire vector. Should be handled by folder, but just to be safe.
557     if (positionArrayAttr.empty()) {
558       rewriter.replaceOp(extractOp, adaptor.vector());
559       return success();
560     }
561 
562     // One-shot extraction of vector from array (only requires extractvalue).
563     if (resultType.isa<VectorType>()) {
564       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
565           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
566       rewriter.replaceOp(extractOp, extracted);
567       return success();
568     }
569 
570     // Potential extraction of 1-D vector from array.
571     auto *context = extractOp->getContext();
572     Value extracted = adaptor.vector();
573     auto positionAttrs = positionArrayAttr.getValue();
574     if (positionAttrs.size() > 1) {
575       auto oneDVectorType = reducedVectorTypeBack(vectorType);
576       auto nMinusOnePositionAttrs =
577           ArrayAttr::get(context, positionAttrs.drop_back());
578       extracted = rewriter.create<LLVM::ExtractValueOp>(
579           loc, typeConverter->convertType(oneDVectorType), extracted,
580           nMinusOnePositionAttrs);
581     }
582 
583     // Remaining extraction of element from 1-D LLVM vector
584     auto position = positionAttrs.back().cast<IntegerAttr>();
585     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
586     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
587     extracted =
588         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
589     rewriter.replaceOp(extractOp, extracted);
590 
591     return success();
592   }
593 };
594 
595 /// Conversion pattern that turns a vector.fma on a 1-D vector
596 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
597 /// This does not match vectors of n >= 2 rank.
598 ///
599 /// Example:
600 /// ```
601 ///  vector.fma %a, %a, %a : vector<8xf32>
602 /// ```
603 /// is converted to:
604 /// ```
605 ///  llvm.intr.fmuladd %va, %va, %va:
606 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
607 ///    -> !llvm."<8 x f32>">
608 /// ```
609 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
610 public:
611   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
612 
613   LogicalResult
614   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
615                   ConversionPatternRewriter &rewriter) const override {
616     VectorType vType = fmaOp.getVectorType();
617     if (vType.getRank() != 1)
618       return failure();
619     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
620                                                  adaptor.rhs(), adaptor.acc());
621     return success();
622   }
623 };
624 
625 class VectorInsertElementOpConversion
626     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
627 public:
628   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
629 
630   LogicalResult
631   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
632                   ConversionPatternRewriter &rewriter) const override {
633     auto vectorType = insertEltOp.getDestVectorType();
634     auto llvmType = typeConverter->convertType(vectorType);
635 
636     // Bail if result type cannot be lowered.
637     if (!llvmType)
638       return failure();
639 
640     if (vectorType.getRank() == 0) {
641       Location loc = insertEltOp.getLoc();
642       auto idxType = rewriter.getIndexType();
643       auto zero = rewriter.create<LLVM::ConstantOp>(
644           loc, typeConverter->convertType(idxType),
645           rewriter.getIntegerAttr(idxType, 0));
646       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
647           insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
648       return success();
649     }
650 
651     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
652         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
653         adaptor.position());
654     return success();
655   }
656 };
657 
658 class VectorInsertOpConversion
659     : public ConvertOpToLLVMPattern<vector::InsertOp> {
660 public:
661   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
662 
663   LogicalResult
664   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
665                   ConversionPatternRewriter &rewriter) const override {
666     auto loc = insertOp->getLoc();
667     auto sourceType = insertOp.getSourceType();
668     auto destVectorType = insertOp.getDestVectorType();
669     auto llvmResultType = typeConverter->convertType(destVectorType);
670     auto positionArrayAttr = insertOp.position();
671 
672     // Bail if result type cannot be lowered.
673     if (!llvmResultType)
674       return failure();
675 
676     // Overwrite entire vector with value. Should be handled by folder, but
677     // just to be safe.
678     if (positionArrayAttr.empty()) {
679       rewriter.replaceOp(insertOp, adaptor.source());
680       return success();
681     }
682 
683     // One-shot insertion of a vector into an array (only requires insertvalue).
684     if (sourceType.isa<VectorType>()) {
685       Value inserted = rewriter.create<LLVM::InsertValueOp>(
686           loc, llvmResultType, adaptor.dest(), adaptor.source(),
687           positionArrayAttr);
688       rewriter.replaceOp(insertOp, inserted);
689       return success();
690     }
691 
692     // Potential extraction of 1-D vector from array.
693     auto *context = insertOp->getContext();
694     Value extracted = adaptor.dest();
695     auto positionAttrs = positionArrayAttr.getValue();
696     auto position = positionAttrs.back().cast<IntegerAttr>();
697     auto oneDVectorType = destVectorType;
698     if (positionAttrs.size() > 1) {
699       oneDVectorType = reducedVectorTypeBack(destVectorType);
700       auto nMinusOnePositionAttrs =
701           ArrayAttr::get(context, positionAttrs.drop_back());
702       extracted = rewriter.create<LLVM::ExtractValueOp>(
703           loc, typeConverter->convertType(oneDVectorType), extracted,
704           nMinusOnePositionAttrs);
705     }
706 
707     // Insertion of an element into a 1-D LLVM vector.
708     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
709     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
710     Value inserted = rewriter.create<LLVM::InsertElementOp>(
711         loc, typeConverter->convertType(oneDVectorType), extracted,
712         adaptor.source(), constant);
713 
714     // Potential insertion of resulting 1-D vector into array.
715     if (positionAttrs.size() > 1) {
716       auto nMinusOnePositionAttrs =
717           ArrayAttr::get(context, positionAttrs.drop_back());
718       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
719                                                       adaptor.dest(), inserted,
720                                                       nMinusOnePositionAttrs);
721     }
722 
723     rewriter.replaceOp(insertOp, inserted);
724     return success();
725   }
726 };
727 
728 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
729 ///
730 /// Example:
731 /// ```
732 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
733 /// ```
734 /// is rewritten into:
735 /// ```
736 ///  %r = splat %f0: vector<2x4xf32>
737 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
738 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
739 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
740 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
741 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
742 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
743 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
744 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
745 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
746 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
747 ///  // %r3 holds the final value.
748 /// ```
749 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
750 public:
751   using OpRewritePattern<FMAOp>::OpRewritePattern;
752 
753   void initialize() {
754     // This pattern recursively unpacks one dimension at a time. The recursion
755     // bounded as the rank is strictly decreasing.
756     setHasBoundedRewriteRecursion();
757   }
758 
759   LogicalResult matchAndRewrite(FMAOp op,
760                                 PatternRewriter &rewriter) const override {
761     auto vType = op.getVectorType();
762     if (vType.getRank() < 2)
763       return failure();
764 
765     auto loc = op.getLoc();
766     auto elemType = vType.getElementType();
767     Value zero = rewriter.create<arith::ConstantOp>(
768         loc, elemType, rewriter.getZeroAttr(elemType));
769     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
770     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
771       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
772       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
773       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
774       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
775       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
776     }
777     rewriter.replaceOp(op, desc);
778     return success();
779   }
780 };
781 
782 /// Returns the strides if the memory underlying `memRefType` has a contiguous
783 /// static layout.
784 static llvm::Optional<SmallVector<int64_t, 4>>
785 computeContiguousStrides(MemRefType memRefType) {
786   int64_t offset;
787   SmallVector<int64_t, 4> strides;
788   if (failed(getStridesAndOffset(memRefType, strides, offset)))
789     return None;
790   if (!strides.empty() && strides.back() != 1)
791     return None;
792   // If no layout or identity layout, this is contiguous by definition.
793   if (memRefType.getLayout().isIdentity())
794     return strides;
795 
796   // Otherwise, we must determine contiguity form shapes. This can only ever
797   // work in static cases because MemRefType is underspecified to represent
798   // contiguous dynamic shapes in other ways than with just empty/identity
799   // layout.
800   auto sizes = memRefType.getShape();
801   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
802     if (ShapedType::isDynamic(sizes[index + 1]) ||
803         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
804         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
805       return None;
806     if (strides[index] != strides[index + 1] * sizes[index + 1])
807       return None;
808   }
809   return strides;
810 }
811 
812 class VectorTypeCastOpConversion
813     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
814 public:
815   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
816 
817   LogicalResult
818   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
819                   ConversionPatternRewriter &rewriter) const override {
820     auto loc = castOp->getLoc();
821     MemRefType sourceMemRefType =
822         castOp.getOperand().getType().cast<MemRefType>();
823     MemRefType targetMemRefType = castOp.getType();
824 
825     // Only static shape casts supported atm.
826     if (!sourceMemRefType.hasStaticShape() ||
827         !targetMemRefType.hasStaticShape())
828       return failure();
829 
830     auto llvmSourceDescriptorTy =
831         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
832     if (!llvmSourceDescriptorTy)
833       return failure();
834     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
835 
836     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
837                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
838     if (!llvmTargetDescriptorTy)
839       return failure();
840 
841     // Only contiguous source buffers supported atm.
842     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
843     if (!sourceStrides)
844       return failure();
845     auto targetStrides = computeContiguousStrides(targetMemRefType);
846     if (!targetStrides)
847       return failure();
848     // Only support static strides for now, regardless of contiguity.
849     if (llvm::any_of(*targetStrides, [](int64_t stride) {
850           return ShapedType::isDynamicStrideOrOffset(stride);
851         }))
852       return failure();
853 
854     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
855 
856     // Create descriptor.
857     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
858     Type llvmTargetElementTy = desc.getElementPtrType();
859     // Set allocated ptr.
860     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
861     allocated =
862         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
863     desc.setAllocatedPtr(rewriter, loc, allocated);
864     // Set aligned ptr.
865     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
866     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
867     desc.setAlignedPtr(rewriter, loc, ptr);
868     // Fill offset 0.
869     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
870     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
871     desc.setOffset(rewriter, loc, zero);
872 
873     // Fill size and stride descriptors in memref.
874     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
875       int64_t index = indexedSize.index();
876       auto sizeAttr =
877           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
878       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
879       desc.setSize(rewriter, loc, index, size);
880       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
881                                                 (*targetStrides)[index]);
882       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
883       desc.setStride(rewriter, loc, index, stride);
884     }
885 
886     rewriter.replaceOp(castOp, {desc});
887     return success();
888   }
889 };
890 
891 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
892 public:
893   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
894 
895   // Proof-of-concept lowering implementation that relies on a small
896   // runtime support library, which only needs to provide a few
897   // printing methods (single value for all data types, opening/closing
898   // bracket, comma, newline). The lowering fully unrolls a vector
899   // in terms of these elementary printing operations. The advantage
900   // of this approach is that the library can remain unaware of all
901   // low-level implementation details of vectors while still supporting
902   // output of any shaped and dimensioned vector. Due to full unrolling,
903   // this approach is less suited for very large vectors though.
904   //
905   // TODO: rely solely on libc in future? something else?
906   //
907   LogicalResult
908   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
909                   ConversionPatternRewriter &rewriter) const override {
910     Type printType = printOp.getPrintType();
911 
912     if (typeConverter->convertType(printType) == nullptr)
913       return failure();
914 
915     // Make sure element type has runtime support.
916     PrintConversion conversion = PrintConversion::None;
917     VectorType vectorType = printType.dyn_cast<VectorType>();
918     Type eltType = vectorType ? vectorType.getElementType() : printType;
919     Operation *printer;
920     if (eltType.isF32()) {
921       printer =
922           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
923     } else if (eltType.isF64()) {
924       printer =
925           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
926     } else if (eltType.isIndex()) {
927       printer =
928           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
929     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
930       // Integers need a zero or sign extension on the operand
931       // (depending on the source type) as well as a signed or
932       // unsigned print method. Up to 64-bit is supported.
933       unsigned width = intTy.getWidth();
934       if (intTy.isUnsigned()) {
935         if (width <= 64) {
936           if (width < 64)
937             conversion = PrintConversion::ZeroExt64;
938           printer = LLVM::lookupOrCreatePrintU64Fn(
939               printOp->getParentOfType<ModuleOp>());
940         } else {
941           return failure();
942         }
943       } else {
944         assert(intTy.isSignless() || intTy.isSigned());
945         if (width <= 64) {
946           // Note that we *always* zero extend booleans (1-bit integers),
947           // so that true/false is printed as 1/0 rather than -1/0.
948           if (width == 1)
949             conversion = PrintConversion::ZeroExt64;
950           else if (width < 64)
951             conversion = PrintConversion::SignExt64;
952           printer = LLVM::lookupOrCreatePrintI64Fn(
953               printOp->getParentOfType<ModuleOp>());
954         } else {
955           return failure();
956         }
957       }
958     } else {
959       return failure();
960     }
961 
962     // Unroll vector into elementary print calls.
963     int64_t rank = vectorType ? vectorType.getRank() : 0;
964     Type type = vectorType ? vectorType : eltType;
965     emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
966               conversion);
967     emitCall(rewriter, printOp->getLoc(),
968              LLVM::lookupOrCreatePrintNewlineFn(
969                  printOp->getParentOfType<ModuleOp>()));
970     rewriter.eraseOp(printOp);
971     return success();
972   }
973 
974 private:
975   enum class PrintConversion {
976     // clang-format off
977     None,
978     ZeroExt64,
979     SignExt64
980     // clang-format on
981   };
982 
983   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
984                  Value value, Type type, Operation *printer, int64_t rank,
985                  PrintConversion conversion) const {
986     VectorType vectorType = type.dyn_cast<VectorType>();
987     Location loc = op->getLoc();
988     if (!vectorType) {
989       assert(rank == 0 && "The scalar case expects rank == 0");
990       switch (conversion) {
991       case PrintConversion::ZeroExt64:
992         value = rewriter.create<arith::ExtUIOp>(
993             loc, value, IntegerType::get(rewriter.getContext(), 64));
994         break;
995       case PrintConversion::SignExt64:
996         value = rewriter.create<arith::ExtSIOp>(
997             loc, value, IntegerType::get(rewriter.getContext(), 64));
998         break;
999       case PrintConversion::None:
1000         break;
1001       }
1002       emitCall(rewriter, loc, printer, value);
1003       return;
1004     }
1005 
1006     emitCall(rewriter, loc,
1007              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1008     Operation *printComma =
1009         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1010 
1011     if (rank <= 1) {
1012       auto reducedType = vectorType.getElementType();
1013       auto llvmType = typeConverter->convertType(reducedType);
1014       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1015       for (int64_t d = 0; d < dim; ++d) {
1016         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1017                                      llvmType, /*rank=*/0, /*pos=*/d);
1018         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1019                   conversion);
1020         if (d != dim - 1)
1021           emitCall(rewriter, loc, printComma);
1022       }
1023       emitCall(
1024           rewriter, loc,
1025           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1026       return;
1027     }
1028 
1029     int64_t dim = vectorType.getDimSize(0);
1030     for (int64_t d = 0; d < dim; ++d) {
1031       auto reducedType = reducedVectorTypeFront(vectorType);
1032       auto llvmType = typeConverter->convertType(reducedType);
1033       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1034                                    llvmType, rank, d);
1035       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1036                 conversion);
1037       if (d != dim - 1)
1038         emitCall(rewriter, loc, printComma);
1039     }
1040     emitCall(rewriter, loc,
1041              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1042   }
1043 
1044   // Helper to emit a call.
1045   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1046                        Operation *ref, ValueRange params = ValueRange()) {
1047     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1048                                   params);
1049   }
1050 };
1051 
1052 } // namespace
1053 
1054 /// Populate the given list with patterns that convert from Vector to LLVM.
1055 void mlir::populateVectorToLLVMConversionPatterns(
1056     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1057     bool reassociateFPReductions) {
1058   MLIRContext *ctx = converter.getDialect()->getContext();
1059   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1060   populateVectorInsertExtractStridedSliceTransforms(patterns);
1061   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1062   patterns
1063       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1064            VectorExtractElementOpConversion, VectorExtractOpConversion,
1065            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1066            VectorInsertOpConversion, VectorPrintOpConversion,
1067            VectorTypeCastOpConversion,
1068            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1069            VectorLoadStoreConversion<vector::MaskedLoadOp,
1070                                      vector::MaskedLoadOpAdaptor>,
1071            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1072            VectorLoadStoreConversion<vector::MaskedStoreOp,
1073                                      vector::MaskedStoreOpAdaptor>,
1074            VectorGatherOpConversion, VectorScatterOpConversion,
1075            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1076           converter);
1077   // Transfer ops with rank > 1 are handled by VectorToSCF.
1078   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1079 }
1080 
1081 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1082     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1083   patterns.add<VectorMatmulOpConversion>(converter);
1084   patterns.add<VectorFlatTransposeOpConversion>(converter);
1085 }
1086