1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by one rank at front.
38 static VectorType reducedVectorTypeFront(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
41 }
42 
43 // Helper to reduce vector type by *all* but one rank at back.
44 static VectorType reducedVectorTypeBack(VectorType tp) {
45   assert((tp.getRank() > 1) && "unlowerable vector type");
46   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
47 }
48 
49 // Helper that picks the proper sequence for inserting.
50 static Value insertOne(ConversionPatternRewriter &rewriter,
51                        LLVMTypeConverter &typeConverter, Location loc,
52                        Value val1, Value val2, Type llvmType, int64_t rank,
53                        int64_t pos) {
54   if (rank == 1) {
55     auto idxType = rewriter.getIndexType();
56     auto constant = rewriter.create<LLVM::ConstantOp>(
57         loc, typeConverter.convertType(idxType),
58         rewriter.getIntegerAttr(idxType, pos));
59     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
60                                                   constant);
61   }
62   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
63                                               rewriter.getI64ArrayAttr(pos));
64 }
65 
66 // Helper that picks the proper sequence for inserting.
67 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
68                        Value into, int64_t offset) {
69   auto vectorType = into.getType().cast<VectorType>();
70   if (vectorType.getRank() > 1)
71     return rewriter.create<InsertOp>(loc, from, into, offset);
72   return rewriter.create<vector::InsertElementOp>(
73       loc, vectorType, from, into,
74       rewriter.create<ConstantIndexOp>(loc, offset));
75 }
76 
77 // Helper that picks the proper sequence for extracting.
78 static Value extractOne(ConversionPatternRewriter &rewriter,
79                         LLVMTypeConverter &typeConverter, Location loc,
80                         Value val, Type llvmType, int64_t rank, int64_t pos) {
81   if (rank == 1) {
82     auto idxType = rewriter.getIndexType();
83     auto constant = rewriter.create<LLVM::ConstantOp>(
84         loc, typeConverter.convertType(idxType),
85         rewriter.getIntegerAttr(idxType, pos));
86     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
87                                                    constant);
88   }
89   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
90                                                rewriter.getI64ArrayAttr(pos));
91 }
92 
93 // Helper that picks the proper sequence for extracting.
94 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
95                         int64_t offset) {
96   auto vectorType = vector.getType().cast<VectorType>();
97   if (vectorType.getRank() > 1)
98     return rewriter.create<ExtractOp>(loc, vector, offset);
99   return rewriter.create<vector::ExtractElementOp>(
100       loc, vectorType.getElementType(), vector,
101       rewriter.create<ConstantIndexOp>(loc, offset));
102 }
103 
104 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
105 // TODO: Better support for attribute subtype forwarding + slicing.
106 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
107                                               unsigned dropFront = 0,
108                                               unsigned dropBack = 0) {
109   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
110   auto range = arrayAttr.getAsRange<IntegerAttr>();
111   SmallVector<int64_t, 4> res;
112   res.reserve(arrayAttr.size() - dropFront - dropBack);
113   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
114        it != eit; ++it)
115     res.push_back((*it).getValue().getSExtValue());
116   return res;
117 }
118 
119 // Helper that returns data layout alignment of an operation with memref.
120 template <typename T>
121 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
122                                  unsigned &align) {
123   Type elementTy =
124       typeConverter.convertType(op.getMemRefType().getElementType());
125   if (!elementTy)
126     return failure();
127 
128   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
129   align = dataLayout.getPrefTypeAlignment(
130       elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
131   return success();
132 }
133 
134 // Helper that returns vector of pointers given a base and an index vector.
135 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
136                              LLVMTypeConverter &typeConverter, Location loc,
137                              Value memref, Value indices, MemRefType memRefType,
138                              VectorType vType, Type iType, Value &ptrs) {
139   // Inspect stride and offset structure.
140   //
141   // TODO: flat memory only for now, generalize
142   //
143   int64_t offset;
144   SmallVector<int64_t, 4> strides;
145   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
146   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
147       offset != 0 || memRefType.getMemorySpace() != 0)
148     return failure();
149 
150   // Base pointer.
151   MemRefDescriptor memRefDescriptor(memref);
152   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
153 
154   // Create a vector of pointers from base and indices.
155   //
156   // TODO: this step serializes the address computations unfortunately,
157   //       ideally we would like to add splat(base) + index_vector
158   //       in SIMD form, but this does not match well with current
159   //       constraints of the standard and vector dialect....
160   //
161   int64_t size = vType.getDimSize(0);
162   auto pType = memRefDescriptor.getElementType();
163   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
164   auto idxType = typeConverter.convertType(iType);
165   ptrs = rewriter.create<LLVM::UndefOp>(loc, ptrsType);
166   for (int64_t i = 0; i < size; i++) {
167     Value off =
168         extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i);
169     Value ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base, off);
170     ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i);
171   }
172   return success();
173 }
174 
175 static LogicalResult
176 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
177                                  LLVMTypeConverter &typeConverter, Location loc,
178                                  TransferReadOp xferOp,
179                                  ArrayRef<Value> operands, Value dataPtr) {
180   unsigned align;
181   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
182     return failure();
183   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
184   return success();
185 }
186 
187 static LogicalResult
188 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
189                             LLVMTypeConverter &typeConverter, Location loc,
190                             TransferReadOp xferOp, ArrayRef<Value> operands,
191                             Value dataPtr, Value mask) {
192   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
193   VectorType fillType = xferOp.getVectorType();
194   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
195   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
196 
197   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
198   if (!vecTy)
199     return failure();
200 
201   unsigned align;
202   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
203     return failure();
204 
205   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
207       rewriter.getI32IntegerAttr(align));
208   return success();
209 }
210 
211 static LogicalResult
212 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
213                                  LLVMTypeConverter &typeConverter, Location loc,
214                                  TransferWriteOp xferOp,
215                                  ArrayRef<Value> operands, Value dataPtr) {
216   unsigned align;
217   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
218     return failure();
219   auto adaptor = TransferWriteOpAdaptor(operands);
220   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
221                                              align);
222   return success();
223 }
224 
225 static LogicalResult
226 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
227                             LLVMTypeConverter &typeConverter, Location loc,
228                             TransferWriteOp xferOp, ArrayRef<Value> operands,
229                             Value dataPtr, Value mask) {
230   unsigned align;
231   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
232     return failure();
233 
234   auto adaptor = TransferWriteOpAdaptor(operands);
235   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
236       xferOp, adaptor.vector(), dataPtr, mask,
237       rewriter.getI32IntegerAttr(align));
238   return success();
239 }
240 
241 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
242                                                   ArrayRef<Value> operands) {
243   return TransferReadOpAdaptor(operands);
244 }
245 
246 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
247                                                    ArrayRef<Value> operands) {
248   return TransferWriteOpAdaptor(operands);
249 }
250 
251 namespace {
252 
253 /// Conversion pattern for a vector.matrix_multiply.
254 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
255 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
256 public:
257   explicit VectorMatmulOpConversion(MLIRContext *context,
258                                     LLVMTypeConverter &typeConverter)
259       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
260                              typeConverter) {}
261 
262   LogicalResult
263   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
264                   ConversionPatternRewriter &rewriter) const override {
265     auto matmulOp = cast<vector::MatmulOp>(op);
266     auto adaptor = vector::MatmulOpAdaptor(operands);
267     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
268         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
269         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
270         matmulOp.rhs_columns());
271     return success();
272   }
273 };
274 
275 /// Conversion pattern for a vector.flat_transpose.
276 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
277 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
278 public:
279   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
280                                            LLVMTypeConverter &typeConverter)
281       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
282                              context, typeConverter) {}
283 
284   LogicalResult
285   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
286                   ConversionPatternRewriter &rewriter) const override {
287     auto transOp = cast<vector::FlatTransposeOp>(op);
288     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
289     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
290         transOp, typeConverter.convertType(transOp.res().getType()),
291         adaptor.matrix(), transOp.rows(), transOp.columns());
292     return success();
293   }
294 };
295 
296 /// Conversion pattern for a vector.gather.
297 class VectorGatherOpConversion : public ConvertToLLVMPattern {
298 public:
299   explicit VectorGatherOpConversion(MLIRContext *context,
300                                     LLVMTypeConverter &typeConverter)
301       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
302                              typeConverter) {}
303 
304   LogicalResult
305   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto loc = op->getLoc();
308     auto gather = cast<vector::GatherOp>(op);
309     auto adaptor = vector::GatherOpAdaptor(operands);
310 
311     // Resolve alignment.
312     unsigned align;
313     if (failed(getMemRefAlignment(typeConverter, gather, align)))
314       return failure();
315 
316     // Get index ptrs.
317     VectorType vType = gather.getResultVectorType();
318     Type iType = gather.getIndicesVectorType().getElementType();
319     Value ptrs;
320     if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
321                               adaptor.indices(), gather.getMemRefType(), vType,
322                               iType, ptrs)))
323       return failure();
324 
325     // Replace with the gather intrinsic.
326     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
327                                                           : adaptor.pass_thru();
328     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
329         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
330         rewriter.getI32IntegerAttr(align));
331     return success();
332   }
333 };
334 
335 /// Conversion pattern for a vector.scatter.
336 class VectorScatterOpConversion : public ConvertToLLVMPattern {
337 public:
338   explicit VectorScatterOpConversion(MLIRContext *context,
339                                      LLVMTypeConverter &typeConverter)
340       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
341                              typeConverter) {}
342 
343   LogicalResult
344   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const override {
346     auto loc = op->getLoc();
347     auto scatter = cast<vector::ScatterOp>(op);
348     auto adaptor = vector::ScatterOpAdaptor(operands);
349 
350     // Resolve alignment.
351     unsigned align;
352     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
353       return failure();
354 
355     // Get index ptrs.
356     VectorType vType = scatter.getValueVectorType();
357     Type iType = scatter.getIndicesVectorType().getElementType();
358     Value ptrs;
359     if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
360                               adaptor.indices(), scatter.getMemRefType(), vType,
361                               iType, ptrs)))
362       return failure();
363 
364     // Replace with the scatter intrinsic.
365     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
366         scatter, adaptor.value(), ptrs, adaptor.mask(),
367         rewriter.getI32IntegerAttr(align));
368     return success();
369   }
370 };
371 
372 /// Conversion pattern for all vector reductions.
373 class VectorReductionOpConversion : public ConvertToLLVMPattern {
374 public:
375   explicit VectorReductionOpConversion(MLIRContext *context,
376                                        LLVMTypeConverter &typeConverter,
377                                        bool reassociateFP)
378       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
379                              typeConverter),
380         reassociateFPReductions(reassociateFP) {}
381 
382   LogicalResult
383   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
384                   ConversionPatternRewriter &rewriter) const override {
385     auto reductionOp = cast<vector::ReductionOp>(op);
386     auto kind = reductionOp.kind();
387     Type eltType = reductionOp.dest().getType();
388     Type llvmType = typeConverter.convertType(eltType);
389     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
390       // Integer reductions: add/mul/min/max/and/or/xor.
391       if (kind == "add")
392         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
393             op, llvmType, operands[0]);
394       else if (kind == "mul")
395         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
396             op, llvmType, operands[0]);
397       else if (kind == "min")
398         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
399             op, llvmType, operands[0]);
400       else if (kind == "max")
401         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
402             op, llvmType, operands[0]);
403       else if (kind == "and")
404         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
405             op, llvmType, operands[0]);
406       else if (kind == "or")
407         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
408             op, llvmType, operands[0]);
409       else if (kind == "xor")
410         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
411             op, llvmType, operands[0]);
412       else
413         return failure();
414       return success();
415 
416     } else if (eltType.isF32() || eltType.isF64()) {
417       // Floating-point reductions: add/mul/min/max
418       if (kind == "add") {
419         // Optional accumulator (or zero).
420         Value acc = operands.size() > 1 ? operands[1]
421                                         : rewriter.create<LLVM::ConstantOp>(
422                                               op->getLoc(), llvmType,
423                                               rewriter.getZeroAttr(eltType));
424         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
425             op, llvmType, acc, operands[0],
426             rewriter.getBoolAttr(reassociateFPReductions));
427       } else if (kind == "mul") {
428         // Optional accumulator (or one).
429         Value acc = operands.size() > 1
430                         ? operands[1]
431                         : rewriter.create<LLVM::ConstantOp>(
432                               op->getLoc(), llvmType,
433                               rewriter.getFloatAttr(eltType, 1.0));
434         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
435             op, llvmType, acc, operands[0],
436             rewriter.getBoolAttr(reassociateFPReductions));
437       } else if (kind == "min")
438         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
439             op, llvmType, operands[0]);
440       else if (kind == "max")
441         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
442             op, llvmType, operands[0]);
443       else
444         return failure();
445       return success();
446     }
447     return failure();
448   }
449 
450 private:
451   const bool reassociateFPReductions;
452 };
453 
454 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
455 public:
456   explicit VectorShuffleOpConversion(MLIRContext *context,
457                                      LLVMTypeConverter &typeConverter)
458       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
459                              typeConverter) {}
460 
461   LogicalResult
462   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto loc = op->getLoc();
465     auto adaptor = vector::ShuffleOpAdaptor(operands);
466     auto shuffleOp = cast<vector::ShuffleOp>(op);
467     auto v1Type = shuffleOp.getV1VectorType();
468     auto v2Type = shuffleOp.getV2VectorType();
469     auto vectorType = shuffleOp.getVectorType();
470     Type llvmType = typeConverter.convertType(vectorType);
471     auto maskArrayAttr = shuffleOp.mask();
472 
473     // Bail if result type cannot be lowered.
474     if (!llvmType)
475       return failure();
476 
477     // Get rank and dimension sizes.
478     int64_t rank = vectorType.getRank();
479     assert(v1Type.getRank() == rank);
480     assert(v2Type.getRank() == rank);
481     int64_t v1Dim = v1Type.getDimSize(0);
482 
483     // For rank 1, where both operands have *exactly* the same vector type,
484     // there is direct shuffle support in LLVM. Use it!
485     if (rank == 1 && v1Type == v2Type) {
486       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
487           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
488       rewriter.replaceOp(op, shuffle);
489       return success();
490     }
491 
492     // For all other cases, insert the individual values individually.
493     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
494     int64_t insPos = 0;
495     for (auto en : llvm::enumerate(maskArrayAttr)) {
496       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
497       Value value = adaptor.v1();
498       if (extPos >= v1Dim) {
499         extPos -= v1Dim;
500         value = adaptor.v2();
501       }
502       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
503                                  rank, extPos);
504       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
505                          llvmType, rank, insPos++);
506     }
507     rewriter.replaceOp(op, insert);
508     return success();
509   }
510 };
511 
512 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
513 public:
514   explicit VectorExtractElementOpConversion(MLIRContext *context,
515                                             LLVMTypeConverter &typeConverter)
516       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
517                              context, typeConverter) {}
518 
519   LogicalResult
520   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
521                   ConversionPatternRewriter &rewriter) const override {
522     auto adaptor = vector::ExtractElementOpAdaptor(operands);
523     auto extractEltOp = cast<vector::ExtractElementOp>(op);
524     auto vectorType = extractEltOp.getVectorType();
525     auto llvmType = typeConverter.convertType(vectorType.getElementType());
526 
527     // Bail if result type cannot be lowered.
528     if (!llvmType)
529       return failure();
530 
531     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
532         op, llvmType, adaptor.vector(), adaptor.position());
533     return success();
534   }
535 };
536 
537 class VectorExtractOpConversion : public ConvertToLLVMPattern {
538 public:
539   explicit VectorExtractOpConversion(MLIRContext *context,
540                                      LLVMTypeConverter &typeConverter)
541       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
542                              typeConverter) {}
543 
544   LogicalResult
545   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
546                   ConversionPatternRewriter &rewriter) const override {
547     auto loc = op->getLoc();
548     auto adaptor = vector::ExtractOpAdaptor(operands);
549     auto extractOp = cast<vector::ExtractOp>(op);
550     auto vectorType = extractOp.getVectorType();
551     auto resultType = extractOp.getResult().getType();
552     auto llvmResultType = typeConverter.convertType(resultType);
553     auto positionArrayAttr = extractOp.position();
554 
555     // Bail if result type cannot be lowered.
556     if (!llvmResultType)
557       return failure();
558 
559     // One-shot extraction of vector from array (only requires extractvalue).
560     if (resultType.isa<VectorType>()) {
561       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
562           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
563       rewriter.replaceOp(op, extracted);
564       return success();
565     }
566 
567     // Potential extraction of 1-D vector from array.
568     auto *context = op->getContext();
569     Value extracted = adaptor.vector();
570     auto positionAttrs = positionArrayAttr.getValue();
571     if (positionAttrs.size() > 1) {
572       auto oneDVectorType = reducedVectorTypeBack(vectorType);
573       auto nMinusOnePositionAttrs =
574           ArrayAttr::get(positionAttrs.drop_back(), context);
575       extracted = rewriter.create<LLVM::ExtractValueOp>(
576           loc, typeConverter.convertType(oneDVectorType), extracted,
577           nMinusOnePositionAttrs);
578     }
579 
580     // Remaining extraction of element from 1-D LLVM vector
581     auto position = positionAttrs.back().cast<IntegerAttr>();
582     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
583     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
584     extracted =
585         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
586     rewriter.replaceOp(op, extracted);
587 
588     return success();
589   }
590 };
591 
592 /// Conversion pattern that turns a vector.fma on a 1-D vector
593 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
594 /// This does not match vectors of n >= 2 rank.
595 ///
596 /// Example:
597 /// ```
598 ///  vector.fma %a, %a, %a : vector<8xf32>
599 /// ```
600 /// is converted to:
601 /// ```
602 ///  llvm.intr.fmuladd %va, %va, %va:
603 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
604 ///    -> !llvm<"<8 x float>">
605 /// ```
606 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
607 public:
608   explicit VectorFMAOp1DConversion(MLIRContext *context,
609                                    LLVMTypeConverter &typeConverter)
610       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
611                              typeConverter) {}
612 
613   LogicalResult
614   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
615                   ConversionPatternRewriter &rewriter) const override {
616     auto adaptor = vector::FMAOpAdaptor(operands);
617     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
618     VectorType vType = fmaOp.getVectorType();
619     if (vType.getRank() != 1)
620       return failure();
621     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
622                                                  adaptor.rhs(), adaptor.acc());
623     return success();
624   }
625 };
626 
627 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
628 public:
629   explicit VectorInsertElementOpConversion(MLIRContext *context,
630                                            LLVMTypeConverter &typeConverter)
631       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
632                              context, typeConverter) {}
633 
634   LogicalResult
635   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
636                   ConversionPatternRewriter &rewriter) const override {
637     auto adaptor = vector::InsertElementOpAdaptor(operands);
638     auto insertEltOp = cast<vector::InsertElementOp>(op);
639     auto vectorType = insertEltOp.getDestVectorType();
640     auto llvmType = typeConverter.convertType(vectorType);
641 
642     // Bail if result type cannot be lowered.
643     if (!llvmType)
644       return failure();
645 
646     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
647         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
648     return success();
649   }
650 };
651 
652 class VectorInsertOpConversion : public ConvertToLLVMPattern {
653 public:
654   explicit VectorInsertOpConversion(MLIRContext *context,
655                                     LLVMTypeConverter &typeConverter)
656       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
657                              typeConverter) {}
658 
659   LogicalResult
660   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
661                   ConversionPatternRewriter &rewriter) const override {
662     auto loc = op->getLoc();
663     auto adaptor = vector::InsertOpAdaptor(operands);
664     auto insertOp = cast<vector::InsertOp>(op);
665     auto sourceType = insertOp.getSourceType();
666     auto destVectorType = insertOp.getDestVectorType();
667     auto llvmResultType = typeConverter.convertType(destVectorType);
668     auto positionArrayAttr = insertOp.position();
669 
670     // Bail if result type cannot be lowered.
671     if (!llvmResultType)
672       return failure();
673 
674     // One-shot insertion of a vector into an array (only requires insertvalue).
675     if (sourceType.isa<VectorType>()) {
676       Value inserted = rewriter.create<LLVM::InsertValueOp>(
677           loc, llvmResultType, adaptor.dest(), adaptor.source(),
678           positionArrayAttr);
679       rewriter.replaceOp(op, inserted);
680       return success();
681     }
682 
683     // Potential extraction of 1-D vector from array.
684     auto *context = op->getContext();
685     Value extracted = adaptor.dest();
686     auto positionAttrs = positionArrayAttr.getValue();
687     auto position = positionAttrs.back().cast<IntegerAttr>();
688     auto oneDVectorType = destVectorType;
689     if (positionAttrs.size() > 1) {
690       oneDVectorType = reducedVectorTypeBack(destVectorType);
691       auto nMinusOnePositionAttrs =
692           ArrayAttr::get(positionAttrs.drop_back(), context);
693       extracted = rewriter.create<LLVM::ExtractValueOp>(
694           loc, typeConverter.convertType(oneDVectorType), extracted,
695           nMinusOnePositionAttrs);
696     }
697 
698     // Insertion of an element into a 1-D LLVM vector.
699     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
700     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
701     Value inserted = rewriter.create<LLVM::InsertElementOp>(
702         loc, typeConverter.convertType(oneDVectorType), extracted,
703         adaptor.source(), constant);
704 
705     // Potential insertion of resulting 1-D vector into array.
706     if (positionAttrs.size() > 1) {
707       auto nMinusOnePositionAttrs =
708           ArrayAttr::get(positionAttrs.drop_back(), context);
709       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
710                                                       adaptor.dest(), inserted,
711                                                       nMinusOnePositionAttrs);
712     }
713 
714     rewriter.replaceOp(op, inserted);
715     return success();
716   }
717 };
718 
719 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
720 ///
721 /// Example:
722 /// ```
723 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
724 /// ```
725 /// is rewritten into:
726 /// ```
727 ///  %r = splat %f0: vector<2x4xf32>
728 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
729 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
730 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
731 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
732 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
733 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
734 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
735 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
736 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
737 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
738 ///  // %r3 holds the final value.
739 /// ```
740 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
741 public:
742   using OpRewritePattern<FMAOp>::OpRewritePattern;
743 
744   LogicalResult matchAndRewrite(FMAOp op,
745                                 PatternRewriter &rewriter) const override {
746     auto vType = op.getVectorType();
747     if (vType.getRank() < 2)
748       return failure();
749 
750     auto loc = op.getLoc();
751     auto elemType = vType.getElementType();
752     Value zero = rewriter.create<ConstantOp>(loc, elemType,
753                                              rewriter.getZeroAttr(elemType));
754     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
755     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
756       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
757       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
758       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
759       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
760       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
761     }
762     rewriter.replaceOp(op, desc);
763     return success();
764   }
765 };
766 
767 // When ranks are different, InsertStridedSlice needs to extract a properly
768 // ranked vector from the destination vector into which to insert. This pattern
769 // only takes care of this part and forwards the rest of the conversion to
770 // another pattern that converts InsertStridedSlice for operands of the same
771 // rank.
772 //
773 // RewritePattern for InsertStridedSliceOp where source and destination vectors
774 // have different ranks. In this case:
775 //   1. the proper subvector is extracted from the destination vector
776 //   2. a new InsertStridedSlice op is created to insert the source in the
777 //   destination subvector
778 //   3. the destination subvector is inserted back in the proper place
779 //   4. the op is replaced by the result of step 3.
780 // The new InsertStridedSlice from step 2. will be picked up by a
781 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
782 class VectorInsertStridedSliceOpDifferentRankRewritePattern
783     : public OpRewritePattern<InsertStridedSliceOp> {
784 public:
785   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
786 
787   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
788                                 PatternRewriter &rewriter) const override {
789     auto srcType = op.getSourceVectorType();
790     auto dstType = op.getDestVectorType();
791 
792     if (op.offsets().getValue().empty())
793       return failure();
794 
795     auto loc = op.getLoc();
796     int64_t rankDiff = dstType.getRank() - srcType.getRank();
797     assert(rankDiff >= 0);
798     if (rankDiff == 0)
799       return failure();
800 
801     int64_t rankRest = dstType.getRank() - rankDiff;
802     // Extract / insert the subvector of matching rank and InsertStridedSlice
803     // on it.
804     Value extracted =
805         rewriter.create<ExtractOp>(loc, op.dest(),
806                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
807                                                   /*dropFront=*/rankRest));
808     // A different pattern will kick in for InsertStridedSlice with matching
809     // ranks.
810     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
811         loc, op.source(), extracted,
812         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
813         getI64SubArray(op.strides(), /*dropFront=*/0));
814     rewriter.replaceOpWithNewOp<InsertOp>(
815         op, stridedSliceInnerOp.getResult(), op.dest(),
816         getI64SubArray(op.offsets(), /*dropFront=*/0,
817                        /*dropFront=*/rankRest));
818     return success();
819   }
820 };
821 
822 // RewritePattern for InsertStridedSliceOp where source and destination vectors
823 // have the same rank. In this case, we reduce
824 //   1. the proper subvector is extracted from the destination vector
825 //   2. a new InsertStridedSlice op is created to insert the source in the
826 //   destination subvector
827 //   3. the destination subvector is inserted back in the proper place
828 //   4. the op is replaced by the result of step 3.
829 // The new InsertStridedSlice from step 2. will be picked up by a
830 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
831 class VectorInsertStridedSliceOpSameRankRewritePattern
832     : public OpRewritePattern<InsertStridedSliceOp> {
833 public:
834   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
835 
836   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
837                                 PatternRewriter &rewriter) const override {
838     auto srcType = op.getSourceVectorType();
839     auto dstType = op.getDestVectorType();
840 
841     if (op.offsets().getValue().empty())
842       return failure();
843 
844     int64_t rankDiff = dstType.getRank() - srcType.getRank();
845     assert(rankDiff >= 0);
846     if (rankDiff != 0)
847       return failure();
848 
849     if (srcType == dstType) {
850       rewriter.replaceOp(op, op.source());
851       return success();
852     }
853 
854     int64_t offset =
855         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
856     int64_t size = srcType.getShape().front();
857     int64_t stride =
858         op.strides().getValue().front().cast<IntegerAttr>().getInt();
859 
860     auto loc = op.getLoc();
861     Value res = op.dest();
862     // For each slice of the source vector along the most major dimension.
863     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
864          off += stride, ++idx) {
865       // 1. extract the proper subvector (or element) from source
866       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
867       if (extractedSource.getType().isa<VectorType>()) {
868         // 2. If we have a vector, extract the proper subvector from destination
869         // Otherwise we are at the element level and no need to recurse.
870         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
871         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
872         // smaller rank.
873         extractedSource = rewriter.create<InsertStridedSliceOp>(
874             loc, extractedSource, extractedDest,
875             getI64SubArray(op.offsets(), /* dropFront=*/1),
876             getI64SubArray(op.strides(), /* dropFront=*/1));
877       }
878       // 4. Insert the extractedSource into the res vector.
879       res = insertOne(rewriter, loc, extractedSource, res, off);
880     }
881 
882     rewriter.replaceOp(op, res);
883     return success();
884   }
885   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
886   /// bounded as the rank is strictly decreasing.
887   bool hasBoundedRewriteRecursion() const final { return true; }
888 };
889 
890 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
891 public:
892   explicit VectorTypeCastOpConversion(MLIRContext *context,
893                                       LLVMTypeConverter &typeConverter)
894       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
895                              typeConverter) {}
896 
897   LogicalResult
898   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
899                   ConversionPatternRewriter &rewriter) const override {
900     auto loc = op->getLoc();
901     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
902     MemRefType sourceMemRefType =
903         castOp.getOperand().getType().cast<MemRefType>();
904     MemRefType targetMemRefType =
905         castOp.getResult().getType().cast<MemRefType>();
906 
907     // Only static shape casts supported atm.
908     if (!sourceMemRefType.hasStaticShape() ||
909         !targetMemRefType.hasStaticShape())
910       return failure();
911 
912     auto llvmSourceDescriptorTy =
913         operands[0].getType().dyn_cast<LLVM::LLVMType>();
914     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
915       return failure();
916     MemRefDescriptor sourceMemRef(operands[0]);
917 
918     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
919                                       .dyn_cast_or_null<LLVM::LLVMType>();
920     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
921       return failure();
922 
923     int64_t offset;
924     SmallVector<int64_t, 4> strides;
925     auto successStrides =
926         getStridesAndOffset(sourceMemRefType, strides, offset);
927     bool isContiguous = (strides.back() == 1);
928     if (isContiguous) {
929       auto sizes = sourceMemRefType.getShape();
930       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
931         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
932           isContiguous = false;
933           break;
934         }
935       }
936     }
937     // Only contiguous source tensors supported atm.
938     if (failed(successStrides) || !isContiguous)
939       return failure();
940 
941     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
942 
943     // Create descriptor.
944     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
945     Type llvmTargetElementTy = desc.getElementType();
946     // Set allocated ptr.
947     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
948     allocated =
949         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
950     desc.setAllocatedPtr(rewriter, loc, allocated);
951     // Set aligned ptr.
952     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
953     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
954     desc.setAlignedPtr(rewriter, loc, ptr);
955     // Fill offset 0.
956     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
957     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
958     desc.setOffset(rewriter, loc, zero);
959 
960     // Fill size and stride descriptors in memref.
961     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
962       int64_t index = indexedSize.index();
963       auto sizeAttr =
964           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
965       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
966       desc.setSize(rewriter, loc, index, size);
967       auto strideAttr =
968           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
969       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
970       desc.setStride(rewriter, loc, index, stride);
971     }
972 
973     rewriter.replaceOp(op, {desc});
974     return success();
975   }
976 };
977 
978 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
979 /// sequence of:
980 /// 1. Bitcast or addrspacecast to vector form.
981 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
982 /// 3. Create a mask where offsetVector is compared against memref upper bound.
983 /// 4. Rewrite op as a masked read or write.
984 template <typename ConcreteOp>
985 class VectorTransferConversion : public ConvertToLLVMPattern {
986 public:
987   explicit VectorTransferConversion(MLIRContext *context,
988                                     LLVMTypeConverter &typeConv)
989       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
990                              typeConv) {}
991 
992   LogicalResult
993   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
994                   ConversionPatternRewriter &rewriter) const override {
995     auto xferOp = cast<ConcreteOp>(op);
996     auto adaptor = getTransferOpAdapter(xferOp, operands);
997 
998     if (xferOp.getVectorType().getRank() > 1 ||
999         llvm::size(xferOp.indices()) == 0)
1000       return failure();
1001     if (xferOp.permutation_map() !=
1002         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1003                                        xferOp.getVectorType().getRank(),
1004                                        op->getContext()))
1005       return failure();
1006 
1007     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1008 
1009     Location loc = op->getLoc();
1010     Type i64Type = rewriter.getIntegerType(64);
1011     MemRefType memRefType = xferOp.getMemRefType();
1012 
1013     // 1. Get the source/dst address as an LLVM vector pointer.
1014     //    The vector pointer would always be on address space 0, therefore
1015     //    addrspacecast shall be used when source/dst memrefs are not on
1016     //    address space 0.
1017     // TODO: support alignment when possible.
1018     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1019                                adaptor.indices(), rewriter, getModule());
1020     auto vecTy =
1021         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1022     Value vectorDataPtr;
1023     if (memRefType.getMemorySpace() == 0)
1024       vectorDataPtr =
1025           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1026     else
1027       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1028           loc, vecTy.getPointerTo(), dataPtr);
1029 
1030     if (!xferOp.isMaskedDim(0))
1031       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1032                                               xferOp, operands, vectorDataPtr);
1033 
1034     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1035     unsigned vecWidth = vecTy.getVectorNumElements();
1036     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
1037     SmallVector<int64_t, 8> indices;
1038     indices.reserve(vecWidth);
1039     for (unsigned i = 0; i < vecWidth; ++i)
1040       indices.push_back(i);
1041     Value linearIndices = rewriter.create<ConstantOp>(
1042         loc, vectorCmpType,
1043         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
1044     linearIndices = rewriter.create<LLVM::DialectCastOp>(
1045         loc, toLLVMTy(vectorCmpType), linearIndices);
1046 
1047     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1048     // TODO: when the leaf transfer rank is k > 1 we need the last
1049     // `k` dimensions here.
1050     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1051     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
1052     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
1053     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1054     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1055 
1056     // 4. Let dim the memref dimension, compute the vector comparison mask:
1057     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1058     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1059     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
1060     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1061     Value mask =
1062         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1063     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1064                                                 mask);
1065 
1066     // 5. Rewrite as a masked read / write.
1067     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1068                                        operands, vectorDataPtr, mask);
1069   }
1070 };
1071 
1072 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1073 public:
1074   explicit VectorPrintOpConversion(MLIRContext *context,
1075                                    LLVMTypeConverter &typeConverter)
1076       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1077                              typeConverter) {}
1078 
1079   // Proof-of-concept lowering implementation that relies on a small
1080   // runtime support library, which only needs to provide a few
1081   // printing methods (single value for all data types, opening/closing
1082   // bracket, comma, newline). The lowering fully unrolls a vector
1083   // in terms of these elementary printing operations. The advantage
1084   // of this approach is that the library can remain unaware of all
1085   // low-level implementation details of vectors while still supporting
1086   // output of any shaped and dimensioned vector. Due to full unrolling,
1087   // this approach is less suited for very large vectors though.
1088   //
1089   // TODO: rely solely on libc in future? something else?
1090   //
1091   LogicalResult
1092   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1093                   ConversionPatternRewriter &rewriter) const override {
1094     auto printOp = cast<vector::PrintOp>(op);
1095     auto adaptor = vector::PrintOpAdaptor(operands);
1096     Type printType = printOp.getPrintType();
1097 
1098     if (typeConverter.convertType(printType) == nullptr)
1099       return failure();
1100 
1101     // Make sure element type has runtime support (currently just Float/Double).
1102     VectorType vectorType = printType.dyn_cast<VectorType>();
1103     Type eltType = vectorType ? vectorType.getElementType() : printType;
1104     int64_t rank = vectorType ? vectorType.getRank() : 0;
1105     Operation *printer;
1106     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1107       printer = getPrintI32(op);
1108     else if (eltType.isSignlessInteger(64))
1109       printer = getPrintI64(op);
1110     else if (eltType.isF32())
1111       printer = getPrintFloat(op);
1112     else if (eltType.isF64())
1113       printer = getPrintDouble(op);
1114     else
1115       return failure();
1116 
1117     // Unroll vector into elementary print calls.
1118     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1119     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1120     rewriter.eraseOp(op);
1121     return success();
1122   }
1123 
1124 private:
1125   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1126                  Value value, VectorType vectorType, Operation *printer,
1127                  int64_t rank) const {
1128     Location loc = op->getLoc();
1129     if (rank == 0) {
1130       if (value.getType() ==
1131           LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
1132         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1133         // This avoids the need for a print_i1 method with an unclear ABI.
1134         auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
1135         auto trueVal = rewriter.create<ConstantOp>(
1136             loc, i32Type, rewriter.getI32IntegerAttr(1));
1137         auto falseVal = rewriter.create<ConstantOp>(
1138             loc, i32Type, rewriter.getI32IntegerAttr(0));
1139         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1140       }
1141       emitCall(rewriter, loc, printer, value);
1142       return;
1143     }
1144 
1145     emitCall(rewriter, loc, getPrintOpen(op));
1146     Operation *printComma = getPrintComma(op);
1147     int64_t dim = vectorType.getDimSize(0);
1148     for (int64_t d = 0; d < dim; ++d) {
1149       auto reducedType =
1150           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1151       auto llvmType = typeConverter.convertType(
1152           rank > 1 ? reducedType : vectorType.getElementType());
1153       Value nestedVal =
1154           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1155       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1156       if (d != dim - 1)
1157         emitCall(rewriter, loc, printComma);
1158     }
1159     emitCall(rewriter, loc, getPrintClose(op));
1160   }
1161 
1162   // Helper to emit a call.
1163   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1164                        Operation *ref, ValueRange params = ValueRange()) {
1165     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1166                                   rewriter.getSymbolRefAttr(ref), params);
1167   }
1168 
1169   // Helper for printer method declaration (first hit) and lookup.
1170   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1171                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1172     auto module = op->getParentOfType<ModuleOp>();
1173     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1174     if (func)
1175       return func;
1176     OpBuilder moduleBuilder(module.getBodyRegion());
1177     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1178         op->getLoc(), name,
1179         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1180                                       params, /*isVarArg=*/false));
1181   }
1182 
1183   // Helpers for method names.
1184   Operation *getPrintI32(Operation *op) const {
1185     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1186     return getPrint(op, dialect, "print_i32",
1187                     LLVM::LLVMType::getInt32Ty(dialect));
1188   }
1189   Operation *getPrintI64(Operation *op) const {
1190     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1191     return getPrint(op, dialect, "print_i64",
1192                     LLVM::LLVMType::getInt64Ty(dialect));
1193   }
1194   Operation *getPrintFloat(Operation *op) const {
1195     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1196     return getPrint(op, dialect, "print_f32",
1197                     LLVM::LLVMType::getFloatTy(dialect));
1198   }
1199   Operation *getPrintDouble(Operation *op) const {
1200     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1201     return getPrint(op, dialect, "print_f64",
1202                     LLVM::LLVMType::getDoubleTy(dialect));
1203   }
1204   Operation *getPrintOpen(Operation *op) const {
1205     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1206   }
1207   Operation *getPrintClose(Operation *op) const {
1208     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1209   }
1210   Operation *getPrintComma(Operation *op) const {
1211     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1212   }
1213   Operation *getPrintNewline(Operation *op) const {
1214     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1215   }
1216 };
1217 
1218 /// Progressive lowering of ExtractStridedSliceOp to either:
1219 ///   1. extractelement + insertelement for the 1-D case
1220 ///   2. extract + optional strided_slice + insert for the n-D case.
1221 class VectorStridedSliceOpConversion
1222     : public OpRewritePattern<ExtractStridedSliceOp> {
1223 public:
1224   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1225 
1226   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1227                                 PatternRewriter &rewriter) const override {
1228     auto dstType = op.getResult().getType().cast<VectorType>();
1229 
1230     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1231 
1232     int64_t offset =
1233         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1234     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1235     int64_t stride =
1236         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1237 
1238     auto loc = op.getLoc();
1239     auto elemType = dstType.getElementType();
1240     assert(elemType.isSignlessIntOrIndexOrFloat());
1241     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1242                                              rewriter.getZeroAttr(elemType));
1243     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1244     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1245          off += stride, ++idx) {
1246       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1247       if (op.offsets().getValue().size() > 1) {
1248         extracted = rewriter.create<ExtractStridedSliceOp>(
1249             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1250             getI64SubArray(op.sizes(), /* dropFront=*/1),
1251             getI64SubArray(op.strides(), /* dropFront=*/1));
1252       }
1253       res = insertOne(rewriter, loc, extracted, res, idx);
1254     }
1255     rewriter.replaceOp(op, {res});
1256     return success();
1257   }
1258   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1259   /// bounded as the rank is strictly decreasing.
1260   bool hasBoundedRewriteRecursion() const final { return true; }
1261 };
1262 
1263 } // namespace
1264 
1265 /// Populate the given list with patterns that convert from Vector to LLVM.
1266 void mlir::populateVectorToLLVMConversionPatterns(
1267     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1268     bool reassociateFPReductions) {
1269   MLIRContext *ctx = converter.getDialect()->getContext();
1270   // clang-format off
1271   patterns.insert<VectorFMAOpNDRewritePattern,
1272                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1273                   VectorInsertStridedSliceOpSameRankRewritePattern,
1274                   VectorStridedSliceOpConversion>(ctx);
1275   patterns.insert<VectorReductionOpConversion>(
1276       ctx, converter, reassociateFPReductions);
1277   patterns
1278       .insert<VectorShuffleOpConversion,
1279               VectorExtractElementOpConversion,
1280               VectorExtractOpConversion,
1281               VectorFMAOp1DConversion,
1282               VectorInsertElementOpConversion,
1283               VectorInsertOpConversion,
1284               VectorPrintOpConversion,
1285               VectorTransferConversion<TransferReadOp>,
1286               VectorTransferConversion<TransferWriteOp>,
1287               VectorTypeCastOpConversion,
1288               VectorGatherOpConversion,
1289               VectorScatterOpConversion>(ctx, converter);
1290   // clang-format on
1291 }
1292 
1293 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1294     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1295   MLIRContext *ctx = converter.getDialect()->getContext();
1296   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1297   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1298 }
1299 
1300 namespace {
1301 struct LowerVectorToLLVMPass
1302     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1303   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1304     this->reassociateFPReductions = options.reassociateFPReductions;
1305   }
1306   void runOnOperation() override;
1307 };
1308 } // namespace
1309 
1310 void LowerVectorToLLVMPass::runOnOperation() {
1311   // Perform progressive lowering of operations on slices and
1312   // all contraction operations. Also applies folding and DCE.
1313   {
1314     OwningRewritePatternList patterns;
1315     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1316     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1317     populateVectorContractLoweringPatterns(patterns, &getContext());
1318     applyPatternsAndFoldGreedily(getOperation(), patterns);
1319   }
1320 
1321   // Convert to the LLVM IR dialect.
1322   LLVMTypeConverter converter(&getContext());
1323   OwningRewritePatternList patterns;
1324   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1325   populateVectorToLLVMConversionPatterns(converter, patterns,
1326                                          reassociateFPReductions);
1327   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1328   populateStdToLLVMConversionPatterns(converter, patterns);
1329 
1330   LLVMConversionTarget target(getContext());
1331   if (failed(applyPartialConversion(getOperation(), target, patterns))) {
1332     signalPassFailure();
1333   }
1334 }
1335 
1336 std::unique_ptr<OperationPass<ModuleOp>>
1337 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1338   return std::make_unique<LowerVectorToLLVMPass>(options);
1339 }
1340