1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Target/LLVMIR/TypeTranslation.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/Passes.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/ErrorHandling.h"
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 
38 // Helper to reduce vector type by one rank at front.
39 static VectorType reducedVectorTypeFront(VectorType tp) {
40   assert((tp.getRank() > 1) && "unlowerable vector type");
41   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
42 }
43 
44 // Helper to reduce vector type by *all* but one rank at back.
45 static VectorType reducedVectorTypeBack(VectorType tp) {
46   assert((tp.getRank() > 1) && "unlowerable vector type");
47   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
48 }
49 
50 // Helper that picks the proper sequence for inserting.
51 static Value insertOne(ConversionPatternRewriter &rewriter,
52                        LLVMTypeConverter &typeConverter, Location loc,
53                        Value val1, Value val2, Type llvmType, int64_t rank,
54                        int64_t pos) {
55   if (rank == 1) {
56     auto idxType = rewriter.getIndexType();
57     auto constant = rewriter.create<LLVM::ConstantOp>(
58         loc, typeConverter.convertType(idxType),
59         rewriter.getIntegerAttr(idxType, pos));
60     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
61                                                   constant);
62   }
63   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
64                                               rewriter.getI64ArrayAttr(pos));
65 }
66 
67 // Helper that picks the proper sequence for inserting.
68 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
69                        Value into, int64_t offset) {
70   auto vectorType = into.getType().cast<VectorType>();
71   if (vectorType.getRank() > 1)
72     return rewriter.create<InsertOp>(loc, from, into, offset);
73   return rewriter.create<vector::InsertElementOp>(
74       loc, vectorType, from, into,
75       rewriter.create<ConstantIndexOp>(loc, offset));
76 }
77 
78 // Helper that picks the proper sequence for extracting.
79 static Value extractOne(ConversionPatternRewriter &rewriter,
80                         LLVMTypeConverter &typeConverter, Location loc,
81                         Value val, Type llvmType, int64_t rank, int64_t pos) {
82   if (rank == 1) {
83     auto idxType = rewriter.getIndexType();
84     auto constant = rewriter.create<LLVM::ConstantOp>(
85         loc, typeConverter.convertType(idxType),
86         rewriter.getIntegerAttr(idxType, pos));
87     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
88                                                    constant);
89   }
90   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
91                                                rewriter.getI64ArrayAttr(pos));
92 }
93 
94 // Helper that picks the proper sequence for extracting.
95 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
96                         int64_t offset) {
97   auto vectorType = vector.getType().cast<VectorType>();
98   if (vectorType.getRank() > 1)
99     return rewriter.create<ExtractOp>(loc, vector, offset);
100   return rewriter.create<vector::ExtractElementOp>(
101       loc, vectorType.getElementType(), vector,
102       rewriter.create<ConstantIndexOp>(loc, offset));
103 }
104 
105 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
106 // TODO: Better support for attribute subtype forwarding + slicing.
107 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
108                                               unsigned dropFront = 0,
109                                               unsigned dropBack = 0) {
110   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
111   auto range = arrayAttr.getAsRange<IntegerAttr>();
112   SmallVector<int64_t, 4> res;
113   res.reserve(arrayAttr.size() - dropFront - dropBack);
114   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
115        it != eit; ++it)
116     res.push_back((*it).getValue().getSExtValue());
117   return res;
118 }
119 
120 // Helper that returns data layout alignment of an operation with memref.
121 template <typename T>
122 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
123                                  unsigned &align) {
124   Type elementTy =
125       typeConverter.convertType(op.getMemRefType().getElementType());
126   if (!elementTy)
127     return failure();
128 
129   // TODO: this should use the MLIR data layout when it becomes available and
130   // stop depending on translation.
131   llvm::LLVMContext llvmContext;
132   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
133               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
134                                      typeConverter.getDataLayout());
135   return success();
136 }
137 
138 // Helper that returns the base address of a memref.
139 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
140                              Value memref, MemRefType memRefType, Value &base) {
141   // Inspect stride and offset structure.
142   //
143   // TODO: flat memory only for now, generalize
144   //
145   int64_t offset;
146   SmallVector<int64_t, 4> strides;
147   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
148   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
149       offset != 0 || memRefType.getMemorySpace() != 0)
150     return failure();
151   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
152   return success();
153 }
154 
155 // Helper that returns a pointer given a memref base.
156 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
157                                 Location loc, Value memref,
158                                 MemRefType memRefType, Value &ptr) {
159   Value base;
160   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
161     return failure();
162   auto pType = MemRefDescriptor(memref).getElementType();
163   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
164   return success();
165 }
166 
167 // Helper that returns a bit-casted pointer given a memref base.
168 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
169                                 Location loc, Value memref,
170                                 MemRefType memRefType, Type type, Value &ptr) {
171   Value base;
172   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
173     return failure();
174   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
175   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
176   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
177   return success();
178 }
179 
180 // Helper that returns vector of pointers given a memref base and an index
181 // vector.
182 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
183                                     Location loc, Value memref, Value indices,
184                                     MemRefType memRefType, VectorType vType,
185                                     Type iType, Value &ptrs) {
186   Value base;
187   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
188     return failure();
189   auto pType = MemRefDescriptor(memref).getElementType();
190   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
191   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
192   return success();
193 }
194 
195 static LogicalResult
196 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
197                                  LLVMTypeConverter &typeConverter, Location loc,
198                                  TransferReadOp xferOp,
199                                  ArrayRef<Value> operands, Value dataPtr) {
200   unsigned align;
201   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
202     return failure();
203   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
204   return success();
205 }
206 
207 static LogicalResult
208 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
209                             LLVMTypeConverter &typeConverter, Location loc,
210                             TransferReadOp xferOp, ArrayRef<Value> operands,
211                             Value dataPtr, Value mask) {
212   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
213   VectorType fillType = xferOp.getVectorType();
214   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
215   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
216 
217   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
218   if (!vecTy)
219     return failure();
220 
221   unsigned align;
222   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
223     return failure();
224 
225   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
226       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
227       rewriter.getI32IntegerAttr(align));
228   return success();
229 }
230 
231 static LogicalResult
232 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
233                                  LLVMTypeConverter &typeConverter, Location loc,
234                                  TransferWriteOp xferOp,
235                                  ArrayRef<Value> operands, Value dataPtr) {
236   unsigned align;
237   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
238     return failure();
239   auto adaptor = TransferWriteOpAdaptor(operands);
240   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
241                                              align);
242   return success();
243 }
244 
245 static LogicalResult
246 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
247                             LLVMTypeConverter &typeConverter, Location loc,
248                             TransferWriteOp xferOp, ArrayRef<Value> operands,
249                             Value dataPtr, Value mask) {
250   unsigned align;
251   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
252     return failure();
253 
254   auto adaptor = TransferWriteOpAdaptor(operands);
255   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
256       xferOp, adaptor.vector(), dataPtr, mask,
257       rewriter.getI32IntegerAttr(align));
258   return success();
259 }
260 
261 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
262                                                   ArrayRef<Value> operands) {
263   return TransferReadOpAdaptor(operands);
264 }
265 
266 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
267                                                    ArrayRef<Value> operands) {
268   return TransferWriteOpAdaptor(operands);
269 }
270 
271 namespace {
272 
273 /// Conversion pattern for a vector.matrix_multiply.
274 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
275 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
276 public:
277   explicit VectorMatmulOpConversion(MLIRContext *context,
278                                     LLVMTypeConverter &typeConverter)
279       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
280                              typeConverter) {}
281 
282   LogicalResult
283   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
284                   ConversionPatternRewriter &rewriter) const override {
285     auto matmulOp = cast<vector::MatmulOp>(op);
286     auto adaptor = vector::MatmulOpAdaptor(operands);
287     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
288         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
289         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
290         matmulOp.rhs_columns());
291     return success();
292   }
293 };
294 
295 /// Conversion pattern for a vector.flat_transpose.
296 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
297 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
298 public:
299   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
300                                            LLVMTypeConverter &typeConverter)
301       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
302                              context, typeConverter) {}
303 
304   LogicalResult
305   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
306                   ConversionPatternRewriter &rewriter) const override {
307     auto transOp = cast<vector::FlatTransposeOp>(op);
308     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
309     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
310         transOp, typeConverter.convertType(transOp.res().getType()),
311         adaptor.matrix(), transOp.rows(), transOp.columns());
312     return success();
313   }
314 };
315 
316 /// Conversion pattern for a vector.maskedload.
317 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
318 public:
319   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
320                                         LLVMTypeConverter &typeConverter)
321       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
322                              typeConverter) {}
323 
324   LogicalResult
325   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
326                   ConversionPatternRewriter &rewriter) const override {
327     auto loc = op->getLoc();
328     auto load = cast<vector::MaskedLoadOp>(op);
329     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
330 
331     // Resolve alignment.
332     unsigned align;
333     if (failed(getMemRefAlignment(typeConverter, load, align)))
334       return failure();
335 
336     auto vtype = typeConverter.convertType(load.getResultVectorType());
337     Value ptr;
338     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
339                           vtype, ptr)))
340       return failure();
341 
342     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
343         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
344         rewriter.getI32IntegerAttr(align));
345     return success();
346   }
347 };
348 
349 /// Conversion pattern for a vector.maskedstore.
350 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
351 public:
352   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
353                                          LLVMTypeConverter &typeConverter)
354       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
355                              typeConverter) {}
356 
357   LogicalResult
358   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
359                   ConversionPatternRewriter &rewriter) const override {
360     auto loc = op->getLoc();
361     auto store = cast<vector::MaskedStoreOp>(op);
362     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
363 
364     // Resolve alignment.
365     unsigned align;
366     if (failed(getMemRefAlignment(typeConverter, store, align)))
367       return failure();
368 
369     auto vtype = typeConverter.convertType(store.getValueVectorType());
370     Value ptr;
371     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
372                           vtype, ptr)))
373       return failure();
374 
375     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
376         store, adaptor.value(), ptr, adaptor.mask(),
377         rewriter.getI32IntegerAttr(align));
378     return success();
379   }
380 };
381 
382 /// Conversion pattern for a vector.gather.
383 class VectorGatherOpConversion : public ConvertToLLVMPattern {
384 public:
385   explicit VectorGatherOpConversion(MLIRContext *context,
386                                     LLVMTypeConverter &typeConverter)
387       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
388                              typeConverter) {}
389 
390   LogicalResult
391   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
392                   ConversionPatternRewriter &rewriter) const override {
393     auto loc = op->getLoc();
394     auto gather = cast<vector::GatherOp>(op);
395     auto adaptor = vector::GatherOpAdaptor(operands);
396 
397     // Resolve alignment.
398     unsigned align;
399     if (failed(getMemRefAlignment(typeConverter, gather, align)))
400       return failure();
401 
402     // Get index ptrs.
403     VectorType vType = gather.getResultVectorType();
404     Type iType = gather.getIndicesVectorType().getElementType();
405     Value ptrs;
406     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
407                               gather.getMemRefType(), vType, iType, ptrs)))
408       return failure();
409 
410     // Replace with the gather intrinsic.
411     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
412                                                           : adaptor.pass_thru();
413     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
414         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
415         rewriter.getI32IntegerAttr(align));
416     return success();
417   }
418 };
419 
420 /// Conversion pattern for a vector.scatter.
421 class VectorScatterOpConversion : public ConvertToLLVMPattern {
422 public:
423   explicit VectorScatterOpConversion(MLIRContext *context,
424                                      LLVMTypeConverter &typeConverter)
425       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
426                              typeConverter) {}
427 
428   LogicalResult
429   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
430                   ConversionPatternRewriter &rewriter) const override {
431     auto loc = op->getLoc();
432     auto scatter = cast<vector::ScatterOp>(op);
433     auto adaptor = vector::ScatterOpAdaptor(operands);
434 
435     // Resolve alignment.
436     unsigned align;
437     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
438       return failure();
439 
440     // Get index ptrs.
441     VectorType vType = scatter.getValueVectorType();
442     Type iType = scatter.getIndicesVectorType().getElementType();
443     Value ptrs;
444     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
445                               scatter.getMemRefType(), vType, iType, ptrs)))
446       return failure();
447 
448     // Replace with the scatter intrinsic.
449     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
450         scatter, adaptor.value(), ptrs, adaptor.mask(),
451         rewriter.getI32IntegerAttr(align));
452     return success();
453   }
454 };
455 
456 /// Conversion pattern for a vector.expandload.
457 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
458 public:
459   explicit VectorExpandLoadOpConversion(MLIRContext *context,
460                                         LLVMTypeConverter &typeConverter)
461       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
462                              typeConverter) {}
463 
464   LogicalResult
465   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
466                   ConversionPatternRewriter &rewriter) const override {
467     auto loc = op->getLoc();
468     auto expand = cast<vector::ExpandLoadOp>(op);
469     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
470 
471     Value ptr;
472     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
473                           ptr)))
474       return failure();
475 
476     auto vType = expand.getResultVectorType();
477     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
478         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
479         adaptor.pass_thru());
480     return success();
481   }
482 };
483 
484 /// Conversion pattern for a vector.compressstore.
485 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
486 public:
487   explicit VectorCompressStoreOpConversion(MLIRContext *context,
488                                            LLVMTypeConverter &typeConverter)
489       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
490                              context, typeConverter) {}
491 
492   LogicalResult
493   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
494                   ConversionPatternRewriter &rewriter) const override {
495     auto loc = op->getLoc();
496     auto compress = cast<vector::CompressStoreOp>(op);
497     auto adaptor = vector::CompressStoreOpAdaptor(operands);
498 
499     Value ptr;
500     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
501                           compress.getMemRefType(), ptr)))
502       return failure();
503 
504     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
505         op, adaptor.value(), ptr, adaptor.mask());
506     return success();
507   }
508 };
509 
510 /// Conversion pattern for all vector reductions.
511 class VectorReductionOpConversion : public ConvertToLLVMPattern {
512 public:
513   explicit VectorReductionOpConversion(MLIRContext *context,
514                                        LLVMTypeConverter &typeConverter,
515                                        bool reassociateFP)
516       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
517                              typeConverter),
518         reassociateFPReductions(reassociateFP) {}
519 
520   LogicalResult
521   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
522                   ConversionPatternRewriter &rewriter) const override {
523     auto reductionOp = cast<vector::ReductionOp>(op);
524     auto kind = reductionOp.kind();
525     Type eltType = reductionOp.dest().getType();
526     Type llvmType = typeConverter.convertType(eltType);
527     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
528       // Integer reductions: add/mul/min/max/and/or/xor.
529       if (kind == "add")
530         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
531             op, llvmType, operands[0]);
532       else if (kind == "mul")
533         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
534             op, llvmType, operands[0]);
535       else if (kind == "min")
536         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
537             op, llvmType, operands[0]);
538       else if (kind == "max")
539         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
540             op, llvmType, operands[0]);
541       else if (kind == "and")
542         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
543             op, llvmType, operands[0]);
544       else if (kind == "or")
545         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
546             op, llvmType, operands[0]);
547       else if (kind == "xor")
548         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
549             op, llvmType, operands[0]);
550       else
551         return failure();
552       return success();
553 
554     } else if (eltType.isF32() || eltType.isF64()) {
555       // Floating-point reductions: add/mul/min/max
556       if (kind == "add") {
557         // Optional accumulator (or zero).
558         Value acc = operands.size() > 1 ? operands[1]
559                                         : rewriter.create<LLVM::ConstantOp>(
560                                               op->getLoc(), llvmType,
561                                               rewriter.getZeroAttr(eltType));
562         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
563             op, llvmType, acc, operands[0],
564             rewriter.getBoolAttr(reassociateFPReductions));
565       } else if (kind == "mul") {
566         // Optional accumulator (or one).
567         Value acc = operands.size() > 1
568                         ? operands[1]
569                         : rewriter.create<LLVM::ConstantOp>(
570                               op->getLoc(), llvmType,
571                               rewriter.getFloatAttr(eltType, 1.0));
572         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
573             op, llvmType, acc, operands[0],
574             rewriter.getBoolAttr(reassociateFPReductions));
575       } else if (kind == "min")
576         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
577             op, llvmType, operands[0]);
578       else if (kind == "max")
579         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
580             op, llvmType, operands[0]);
581       else
582         return failure();
583       return success();
584     }
585     return failure();
586   }
587 
588 private:
589   const bool reassociateFPReductions;
590 };
591 
592 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
593 public:
594   explicit VectorShuffleOpConversion(MLIRContext *context,
595                                      LLVMTypeConverter &typeConverter)
596       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
597                              typeConverter) {}
598 
599   LogicalResult
600   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
601                   ConversionPatternRewriter &rewriter) const override {
602     auto loc = op->getLoc();
603     auto adaptor = vector::ShuffleOpAdaptor(operands);
604     auto shuffleOp = cast<vector::ShuffleOp>(op);
605     auto v1Type = shuffleOp.getV1VectorType();
606     auto v2Type = shuffleOp.getV2VectorType();
607     auto vectorType = shuffleOp.getVectorType();
608     Type llvmType = typeConverter.convertType(vectorType);
609     auto maskArrayAttr = shuffleOp.mask();
610 
611     // Bail if result type cannot be lowered.
612     if (!llvmType)
613       return failure();
614 
615     // Get rank and dimension sizes.
616     int64_t rank = vectorType.getRank();
617     assert(v1Type.getRank() == rank);
618     assert(v2Type.getRank() == rank);
619     int64_t v1Dim = v1Type.getDimSize(0);
620 
621     // For rank 1, where both operands have *exactly* the same vector type,
622     // there is direct shuffle support in LLVM. Use it!
623     if (rank == 1 && v1Type == v2Type) {
624       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
625           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
626       rewriter.replaceOp(op, shuffle);
627       return success();
628     }
629 
630     // For all other cases, insert the individual values individually.
631     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
632     int64_t insPos = 0;
633     for (auto en : llvm::enumerate(maskArrayAttr)) {
634       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
635       Value value = adaptor.v1();
636       if (extPos >= v1Dim) {
637         extPos -= v1Dim;
638         value = adaptor.v2();
639       }
640       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
641                                  rank, extPos);
642       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
643                          llvmType, rank, insPos++);
644     }
645     rewriter.replaceOp(op, insert);
646     return success();
647   }
648 };
649 
650 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
651 public:
652   explicit VectorExtractElementOpConversion(MLIRContext *context,
653                                             LLVMTypeConverter &typeConverter)
654       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
655                              context, typeConverter) {}
656 
657   LogicalResult
658   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
659                   ConversionPatternRewriter &rewriter) const override {
660     auto adaptor = vector::ExtractElementOpAdaptor(operands);
661     auto extractEltOp = cast<vector::ExtractElementOp>(op);
662     auto vectorType = extractEltOp.getVectorType();
663     auto llvmType = typeConverter.convertType(vectorType.getElementType());
664 
665     // Bail if result type cannot be lowered.
666     if (!llvmType)
667       return failure();
668 
669     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
670         op, llvmType, adaptor.vector(), adaptor.position());
671     return success();
672   }
673 };
674 
675 class VectorExtractOpConversion : public ConvertToLLVMPattern {
676 public:
677   explicit VectorExtractOpConversion(MLIRContext *context,
678                                      LLVMTypeConverter &typeConverter)
679       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
680                              typeConverter) {}
681 
682   LogicalResult
683   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
684                   ConversionPatternRewriter &rewriter) const override {
685     auto loc = op->getLoc();
686     auto adaptor = vector::ExtractOpAdaptor(operands);
687     auto extractOp = cast<vector::ExtractOp>(op);
688     auto vectorType = extractOp.getVectorType();
689     auto resultType = extractOp.getResult().getType();
690     auto llvmResultType = typeConverter.convertType(resultType);
691     auto positionArrayAttr = extractOp.position();
692 
693     // Bail if result type cannot be lowered.
694     if (!llvmResultType)
695       return failure();
696 
697     // One-shot extraction of vector from array (only requires extractvalue).
698     if (resultType.isa<VectorType>()) {
699       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
700           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
701       rewriter.replaceOp(op, extracted);
702       return success();
703     }
704 
705     // Potential extraction of 1-D vector from array.
706     auto *context = op->getContext();
707     Value extracted = adaptor.vector();
708     auto positionAttrs = positionArrayAttr.getValue();
709     if (positionAttrs.size() > 1) {
710       auto oneDVectorType = reducedVectorTypeBack(vectorType);
711       auto nMinusOnePositionAttrs =
712           ArrayAttr::get(positionAttrs.drop_back(), context);
713       extracted = rewriter.create<LLVM::ExtractValueOp>(
714           loc, typeConverter.convertType(oneDVectorType), extracted,
715           nMinusOnePositionAttrs);
716     }
717 
718     // Remaining extraction of element from 1-D LLVM vector
719     auto position = positionAttrs.back().cast<IntegerAttr>();
720     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
721     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
722     extracted =
723         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
724     rewriter.replaceOp(op, extracted);
725 
726     return success();
727   }
728 };
729 
730 /// Conversion pattern that turns a vector.fma on a 1-D vector
731 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
732 /// This does not match vectors of n >= 2 rank.
733 ///
734 /// Example:
735 /// ```
736 ///  vector.fma %a, %a, %a : vector<8xf32>
737 /// ```
738 /// is converted to:
739 /// ```
740 ///  llvm.intr.fmuladd %va, %va, %va:
741 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
742 ///    -> !llvm<"<8 x float>">
743 /// ```
744 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
745 public:
746   explicit VectorFMAOp1DConversion(MLIRContext *context,
747                                    LLVMTypeConverter &typeConverter)
748       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
749                              typeConverter) {}
750 
751   LogicalResult
752   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
753                   ConversionPatternRewriter &rewriter) const override {
754     auto adaptor = vector::FMAOpAdaptor(operands);
755     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
756     VectorType vType = fmaOp.getVectorType();
757     if (vType.getRank() != 1)
758       return failure();
759     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
760                                                  adaptor.rhs(), adaptor.acc());
761     return success();
762   }
763 };
764 
765 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
766 public:
767   explicit VectorInsertElementOpConversion(MLIRContext *context,
768                                            LLVMTypeConverter &typeConverter)
769       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
770                              context, typeConverter) {}
771 
772   LogicalResult
773   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
774                   ConversionPatternRewriter &rewriter) const override {
775     auto adaptor = vector::InsertElementOpAdaptor(operands);
776     auto insertEltOp = cast<vector::InsertElementOp>(op);
777     auto vectorType = insertEltOp.getDestVectorType();
778     auto llvmType = typeConverter.convertType(vectorType);
779 
780     // Bail if result type cannot be lowered.
781     if (!llvmType)
782       return failure();
783 
784     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
785         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
786     return success();
787   }
788 };
789 
790 class VectorInsertOpConversion : public ConvertToLLVMPattern {
791 public:
792   explicit VectorInsertOpConversion(MLIRContext *context,
793                                     LLVMTypeConverter &typeConverter)
794       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
795                              typeConverter) {}
796 
797   LogicalResult
798   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
799                   ConversionPatternRewriter &rewriter) const override {
800     auto loc = op->getLoc();
801     auto adaptor = vector::InsertOpAdaptor(operands);
802     auto insertOp = cast<vector::InsertOp>(op);
803     auto sourceType = insertOp.getSourceType();
804     auto destVectorType = insertOp.getDestVectorType();
805     auto llvmResultType = typeConverter.convertType(destVectorType);
806     auto positionArrayAttr = insertOp.position();
807 
808     // Bail if result type cannot be lowered.
809     if (!llvmResultType)
810       return failure();
811 
812     // One-shot insertion of a vector into an array (only requires insertvalue).
813     if (sourceType.isa<VectorType>()) {
814       Value inserted = rewriter.create<LLVM::InsertValueOp>(
815           loc, llvmResultType, adaptor.dest(), adaptor.source(),
816           positionArrayAttr);
817       rewriter.replaceOp(op, inserted);
818       return success();
819     }
820 
821     // Potential extraction of 1-D vector from array.
822     auto *context = op->getContext();
823     Value extracted = adaptor.dest();
824     auto positionAttrs = positionArrayAttr.getValue();
825     auto position = positionAttrs.back().cast<IntegerAttr>();
826     auto oneDVectorType = destVectorType;
827     if (positionAttrs.size() > 1) {
828       oneDVectorType = reducedVectorTypeBack(destVectorType);
829       auto nMinusOnePositionAttrs =
830           ArrayAttr::get(positionAttrs.drop_back(), context);
831       extracted = rewriter.create<LLVM::ExtractValueOp>(
832           loc, typeConverter.convertType(oneDVectorType), extracted,
833           nMinusOnePositionAttrs);
834     }
835 
836     // Insertion of an element into a 1-D LLVM vector.
837     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
838     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
839     Value inserted = rewriter.create<LLVM::InsertElementOp>(
840         loc, typeConverter.convertType(oneDVectorType), extracted,
841         adaptor.source(), constant);
842 
843     // Potential insertion of resulting 1-D vector into array.
844     if (positionAttrs.size() > 1) {
845       auto nMinusOnePositionAttrs =
846           ArrayAttr::get(positionAttrs.drop_back(), context);
847       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
848                                                       adaptor.dest(), inserted,
849                                                       nMinusOnePositionAttrs);
850     }
851 
852     rewriter.replaceOp(op, inserted);
853     return success();
854   }
855 };
856 
857 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
858 ///
859 /// Example:
860 /// ```
861 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
862 /// ```
863 /// is rewritten into:
864 /// ```
865 ///  %r = splat %f0: vector<2x4xf32>
866 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
867 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
868 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
869 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
870 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
871 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
872 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
873 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
874 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
875 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
876 ///  // %r3 holds the final value.
877 /// ```
878 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
879 public:
880   using OpRewritePattern<FMAOp>::OpRewritePattern;
881 
882   LogicalResult matchAndRewrite(FMAOp op,
883                                 PatternRewriter &rewriter) const override {
884     auto vType = op.getVectorType();
885     if (vType.getRank() < 2)
886       return failure();
887 
888     auto loc = op.getLoc();
889     auto elemType = vType.getElementType();
890     Value zero = rewriter.create<ConstantOp>(loc, elemType,
891                                              rewriter.getZeroAttr(elemType));
892     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
893     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
894       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
895       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
896       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
897       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
898       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
899     }
900     rewriter.replaceOp(op, desc);
901     return success();
902   }
903 };
904 
905 // When ranks are different, InsertStridedSlice needs to extract a properly
906 // ranked vector from the destination vector into which to insert. This pattern
907 // only takes care of this part and forwards the rest of the conversion to
908 // another pattern that converts InsertStridedSlice for operands of the same
909 // rank.
910 //
911 // RewritePattern for InsertStridedSliceOp where source and destination vectors
912 // have different ranks. In this case:
913 //   1. the proper subvector is extracted from the destination vector
914 //   2. a new InsertStridedSlice op is created to insert the source in the
915 //   destination subvector
916 //   3. the destination subvector is inserted back in the proper place
917 //   4. the op is replaced by the result of step 3.
918 // The new InsertStridedSlice from step 2. will be picked up by a
919 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
920 class VectorInsertStridedSliceOpDifferentRankRewritePattern
921     : public OpRewritePattern<InsertStridedSliceOp> {
922 public:
923   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
924 
925   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
926                                 PatternRewriter &rewriter) const override {
927     auto srcType = op.getSourceVectorType();
928     auto dstType = op.getDestVectorType();
929 
930     if (op.offsets().getValue().empty())
931       return failure();
932 
933     auto loc = op.getLoc();
934     int64_t rankDiff = dstType.getRank() - srcType.getRank();
935     assert(rankDiff >= 0);
936     if (rankDiff == 0)
937       return failure();
938 
939     int64_t rankRest = dstType.getRank() - rankDiff;
940     // Extract / insert the subvector of matching rank and InsertStridedSlice
941     // on it.
942     Value extracted =
943         rewriter.create<ExtractOp>(loc, op.dest(),
944                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
945                                                   /*dropFront=*/rankRest));
946     // A different pattern will kick in for InsertStridedSlice with matching
947     // ranks.
948     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
949         loc, op.source(), extracted,
950         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
951         getI64SubArray(op.strides(), /*dropFront=*/0));
952     rewriter.replaceOpWithNewOp<InsertOp>(
953         op, stridedSliceInnerOp.getResult(), op.dest(),
954         getI64SubArray(op.offsets(), /*dropFront=*/0,
955                        /*dropFront=*/rankRest));
956     return success();
957   }
958 };
959 
960 // RewritePattern for InsertStridedSliceOp where source and destination vectors
961 // have the same rank. In this case, we reduce
962 //   1. the proper subvector is extracted from the destination vector
963 //   2. a new InsertStridedSlice op is created to insert the source in the
964 //   destination subvector
965 //   3. the destination subvector is inserted back in the proper place
966 //   4. the op is replaced by the result of step 3.
967 // The new InsertStridedSlice from step 2. will be picked up by a
968 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
969 class VectorInsertStridedSliceOpSameRankRewritePattern
970     : public OpRewritePattern<InsertStridedSliceOp> {
971 public:
972   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
973 
974   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
975                                 PatternRewriter &rewriter) const override {
976     auto srcType = op.getSourceVectorType();
977     auto dstType = op.getDestVectorType();
978 
979     if (op.offsets().getValue().empty())
980       return failure();
981 
982     int64_t rankDiff = dstType.getRank() - srcType.getRank();
983     assert(rankDiff >= 0);
984     if (rankDiff != 0)
985       return failure();
986 
987     if (srcType == dstType) {
988       rewriter.replaceOp(op, op.source());
989       return success();
990     }
991 
992     int64_t offset =
993         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
994     int64_t size = srcType.getShape().front();
995     int64_t stride =
996         op.strides().getValue().front().cast<IntegerAttr>().getInt();
997 
998     auto loc = op.getLoc();
999     Value res = op.dest();
1000     // For each slice of the source vector along the most major dimension.
1001     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1002          off += stride, ++idx) {
1003       // 1. extract the proper subvector (or element) from source
1004       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1005       if (extractedSource.getType().isa<VectorType>()) {
1006         // 2. If we have a vector, extract the proper subvector from destination
1007         // Otherwise we are at the element level and no need to recurse.
1008         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1009         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1010         // smaller rank.
1011         extractedSource = rewriter.create<InsertStridedSliceOp>(
1012             loc, extractedSource, extractedDest,
1013             getI64SubArray(op.offsets(), /* dropFront=*/1),
1014             getI64SubArray(op.strides(), /* dropFront=*/1));
1015       }
1016       // 4. Insert the extractedSource into the res vector.
1017       res = insertOne(rewriter, loc, extractedSource, res, off);
1018     }
1019 
1020     rewriter.replaceOp(op, res);
1021     return success();
1022   }
1023   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
1024   /// bounded as the rank is strictly decreasing.
1025   bool hasBoundedRewriteRecursion() const final { return true; }
1026 };
1027 
1028 /// Returns true if the memory underlying `memRefType` has a contiguous layout.
1029 /// Strides are written to `strides`.
1030 static bool isContiguous(MemRefType memRefType,
1031                          SmallVectorImpl<int64_t> &strides) {
1032   int64_t offset;
1033   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
1034   bool isContiguous = (strides.back() == 1);
1035   if (isContiguous) {
1036     auto sizes = memRefType.getShape();
1037     for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1038       if (strides[index] != strides[index + 1] * sizes[index + 1]) {
1039         isContiguous = false;
1040         break;
1041       }
1042     }
1043   }
1044   return succeeded(successStrides) && isContiguous;
1045 }
1046 
1047 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1048 public:
1049   explicit VectorTypeCastOpConversion(MLIRContext *context,
1050                                       LLVMTypeConverter &typeConverter)
1051       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1052                              typeConverter) {}
1053 
1054   LogicalResult
1055   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1056                   ConversionPatternRewriter &rewriter) const override {
1057     auto loc = op->getLoc();
1058     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1059     MemRefType sourceMemRefType =
1060         castOp.getOperand().getType().cast<MemRefType>();
1061     MemRefType targetMemRefType =
1062         castOp.getResult().getType().cast<MemRefType>();
1063 
1064     // Only static shape casts supported atm.
1065     if (!sourceMemRefType.hasStaticShape() ||
1066         !targetMemRefType.hasStaticShape())
1067       return failure();
1068 
1069     auto llvmSourceDescriptorTy =
1070         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1071     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1072       return failure();
1073     MemRefDescriptor sourceMemRef(operands[0]);
1074 
1075     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
1076                                       .dyn_cast_or_null<LLVM::LLVMType>();
1077     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1078       return failure();
1079 
1080     // Only contiguous source tensors supported atm.
1081     SmallVector<int64_t, 4> strides;
1082     if (!isContiguous(sourceMemRefType, strides))
1083       return failure();
1084 
1085     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1086 
1087     // Create descriptor.
1088     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1089     Type llvmTargetElementTy = desc.getElementType();
1090     // Set allocated ptr.
1091     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1092     allocated =
1093         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1094     desc.setAllocatedPtr(rewriter, loc, allocated);
1095     // Set aligned ptr.
1096     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1097     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1098     desc.setAlignedPtr(rewriter, loc, ptr);
1099     // Fill offset 0.
1100     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1101     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1102     desc.setOffset(rewriter, loc, zero);
1103 
1104     // Fill size and stride descriptors in memref.
1105     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1106       int64_t index = indexedSize.index();
1107       auto sizeAttr =
1108           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1109       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1110       desc.setSize(rewriter, loc, index, size);
1111       auto strideAttr =
1112           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
1113       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1114       desc.setStride(rewriter, loc, index, stride);
1115     }
1116 
1117     rewriter.replaceOp(op, {desc});
1118     return success();
1119   }
1120 };
1121 
1122 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1123 /// sequence of:
1124 /// 1. Bitcast or addrspacecast to vector form.
1125 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1126 /// 3. Create a mask where offsetVector is compared against memref upper bound.
1127 /// 4. Rewrite op as a masked read or write.
1128 template <typename ConcreteOp>
1129 class VectorTransferConversion : public ConvertToLLVMPattern {
1130 public:
1131   explicit VectorTransferConversion(MLIRContext *context,
1132                                     LLVMTypeConverter &typeConv)
1133       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
1134                              typeConv) {}
1135 
1136   LogicalResult
1137   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1138                   ConversionPatternRewriter &rewriter) const override {
1139     auto xferOp = cast<ConcreteOp>(op);
1140     auto adaptor = getTransferOpAdapter(xferOp, operands);
1141 
1142     if (xferOp.getVectorType().getRank() > 1 ||
1143         llvm::size(xferOp.indices()) == 0)
1144       return failure();
1145     if (xferOp.permutation_map() !=
1146         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1147                                        xferOp.getVectorType().getRank(),
1148                                        op->getContext()))
1149       return failure();
1150     // Only contiguous source tensors supported atm.
1151     SmallVector<int64_t, 4> strides;
1152     if (!isContiguous(xferOp.getMemRefType(), strides))
1153       return failure();
1154 
1155     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1156 
1157     Location loc = op->getLoc();
1158     Type i64Type = rewriter.getIntegerType(64);
1159     MemRefType memRefType = xferOp.getMemRefType();
1160 
1161     if (auto memrefVectorElementType =
1162             memRefType.getElementType().dyn_cast<VectorType>()) {
1163       // Memref has vector element type.
1164       if (memrefVectorElementType.getElementType() !=
1165           xferOp.getVectorType().getElementType())
1166         return failure();
1167 #ifndef NDEBUG
1168       // Check that memref vector type is a suffix of 'vectorType.
1169       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1170       unsigned resultVecRank = xferOp.getVectorType().getRank();
1171       assert(memrefVecEltRank <= resultVecRank);
1172       // TODO: Move this to isSuffix in Vector/Utils.h.
1173       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1174       auto memrefVecEltShape = memrefVectorElementType.getShape();
1175       auto resultVecShape = xferOp.getVectorType().getShape();
1176       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1177         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1178                "memref vector element shape should match suffix of vector "
1179                "result shape.");
1180 #endif // ifndef NDEBUG
1181     }
1182 
1183     // 1. Get the source/dst address as an LLVM vector pointer.
1184     //    The vector pointer would always be on address space 0, therefore
1185     //    addrspacecast shall be used when source/dst memrefs are not on
1186     //    address space 0.
1187     // TODO: support alignment when possible.
1188     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1189                                adaptor.indices(), rewriter);
1190     auto vecTy =
1191         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1192     Value vectorDataPtr;
1193     if (memRefType.getMemorySpace() == 0)
1194       vectorDataPtr =
1195           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1196     else
1197       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1198           loc, vecTy.getPointerTo(), dataPtr);
1199 
1200     if (!xferOp.isMaskedDim(0))
1201       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1202                                               xferOp, operands, vectorDataPtr);
1203 
1204     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1205     unsigned vecWidth = vecTy.getVectorNumElements();
1206     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
1207     SmallVector<int64_t, 8> indices;
1208     indices.reserve(vecWidth);
1209     for (unsigned i = 0; i < vecWidth; ++i)
1210       indices.push_back(i);
1211     Value linearIndices = rewriter.create<ConstantOp>(
1212         loc, vectorCmpType,
1213         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
1214     linearIndices = rewriter.create<LLVM::DialectCastOp>(
1215         loc, toLLVMTy(vectorCmpType), linearIndices);
1216 
1217     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1218     // TODO: when the leaf transfer rank is k > 1 we need the last
1219     // `k` dimensions here.
1220     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1221     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
1222     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
1223     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1224     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1225 
1226     // 4. Let dim the memref dimension, compute the vector comparison mask:
1227     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1228     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1229     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
1230     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1231     Value mask =
1232         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1233     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1234                                                 mask);
1235 
1236     // 5. Rewrite as a masked read / write.
1237     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1238                                        operands, vectorDataPtr, mask);
1239   }
1240 };
1241 
1242 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1243 public:
1244   explicit VectorPrintOpConversion(MLIRContext *context,
1245                                    LLVMTypeConverter &typeConverter)
1246       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1247                              typeConverter) {}
1248 
1249   // Proof-of-concept lowering implementation that relies on a small
1250   // runtime support library, which only needs to provide a few
1251   // printing methods (single value for all data types, opening/closing
1252   // bracket, comma, newline). The lowering fully unrolls a vector
1253   // in terms of these elementary printing operations. The advantage
1254   // of this approach is that the library can remain unaware of all
1255   // low-level implementation details of vectors while still supporting
1256   // output of any shaped and dimensioned vector. Due to full unrolling,
1257   // this approach is less suited for very large vectors though.
1258   //
1259   // TODO: rely solely on libc in future? something else?
1260   //
1261   LogicalResult
1262   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1263                   ConversionPatternRewriter &rewriter) const override {
1264     auto printOp = cast<vector::PrintOp>(op);
1265     auto adaptor = vector::PrintOpAdaptor(operands);
1266     Type printType = printOp.getPrintType();
1267 
1268     if (typeConverter.convertType(printType) == nullptr)
1269       return failure();
1270 
1271     // Make sure element type has runtime support (currently just Float/Double).
1272     VectorType vectorType = printType.dyn_cast<VectorType>();
1273     Type eltType = vectorType ? vectorType.getElementType() : printType;
1274     int64_t rank = vectorType ? vectorType.getRank() : 0;
1275     Operation *printer;
1276     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1277       printer = getPrintI32(op);
1278     else if (eltType.isSignlessInteger(64))
1279       printer = getPrintI64(op);
1280     else if (eltType.isF32())
1281       printer = getPrintFloat(op);
1282     else if (eltType.isF64())
1283       printer = getPrintDouble(op);
1284     else
1285       return failure();
1286 
1287     // Unroll vector into elementary print calls.
1288     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1289     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1290     rewriter.eraseOp(op);
1291     return success();
1292   }
1293 
1294 private:
1295   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1296                  Value value, VectorType vectorType, Operation *printer,
1297                  int64_t rank) const {
1298     Location loc = op->getLoc();
1299     if (rank == 0) {
1300       if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
1301         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1302         // This avoids the need for a print_i1 method with an unclear ABI.
1303         auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
1304         auto trueVal = rewriter.create<ConstantOp>(
1305             loc, i32Type, rewriter.getI32IntegerAttr(1));
1306         auto falseVal = rewriter.create<ConstantOp>(
1307             loc, i32Type, rewriter.getI32IntegerAttr(0));
1308         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1309       }
1310       emitCall(rewriter, loc, printer, value);
1311       return;
1312     }
1313 
1314     emitCall(rewriter, loc, getPrintOpen(op));
1315     Operation *printComma = getPrintComma(op);
1316     int64_t dim = vectorType.getDimSize(0);
1317     for (int64_t d = 0; d < dim; ++d) {
1318       auto reducedType =
1319           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1320       auto llvmType = typeConverter.convertType(
1321           rank > 1 ? reducedType : vectorType.getElementType());
1322       Value nestedVal =
1323           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1324       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1325       if (d != dim - 1)
1326         emitCall(rewriter, loc, printComma);
1327     }
1328     emitCall(rewriter, loc, getPrintClose(op));
1329   }
1330 
1331   // Helper to emit a call.
1332   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1333                        Operation *ref, ValueRange params = ValueRange()) {
1334     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1335                                   rewriter.getSymbolRefAttr(ref), params);
1336   }
1337 
1338   // Helper for printer method declaration (first hit) and lookup.
1339   static Operation *getPrint(Operation *op, StringRef name,
1340                              ArrayRef<LLVM::LLVMType> params) {
1341     auto module = op->getParentOfType<ModuleOp>();
1342     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1343     if (func)
1344       return func;
1345     OpBuilder moduleBuilder(module.getBodyRegion());
1346     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1347         op->getLoc(), name,
1348         LLVM::LLVMType::getFunctionTy(
1349             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1350             /*isVarArg=*/false));
1351   }
1352 
1353   // Helpers for method names.
1354   Operation *getPrintI32(Operation *op) const {
1355     return getPrint(op, "print_i32",
1356                     LLVM::LLVMType::getInt32Ty(op->getContext()));
1357   }
1358   Operation *getPrintI64(Operation *op) const {
1359     return getPrint(op, "print_i64",
1360                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1361   }
1362   Operation *getPrintFloat(Operation *op) const {
1363     return getPrint(op, "print_f32",
1364                     LLVM::LLVMType::getFloatTy(op->getContext()));
1365   }
1366   Operation *getPrintDouble(Operation *op) const {
1367     return getPrint(op, "print_f64",
1368                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1369   }
1370   Operation *getPrintOpen(Operation *op) const {
1371     return getPrint(op, "print_open", {});
1372   }
1373   Operation *getPrintClose(Operation *op) const {
1374     return getPrint(op, "print_close", {});
1375   }
1376   Operation *getPrintComma(Operation *op) const {
1377     return getPrint(op, "print_comma", {});
1378   }
1379   Operation *getPrintNewline(Operation *op) const {
1380     return getPrint(op, "print_newline", {});
1381   }
1382 };
1383 
1384 /// Progressive lowering of ExtractStridedSliceOp to either:
1385 ///   1. express single offset extract as a direct shuffle.
1386 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1387 class VectorExtractStridedSliceOpConversion
1388     : public OpRewritePattern<ExtractStridedSliceOp> {
1389 public:
1390   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1391 
1392   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1393                                 PatternRewriter &rewriter) const override {
1394     auto dstType = op.getResult().getType().cast<VectorType>();
1395 
1396     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1397 
1398     int64_t offset =
1399         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1400     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1401     int64_t stride =
1402         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1403 
1404     auto loc = op.getLoc();
1405     auto elemType = dstType.getElementType();
1406     assert(elemType.isSignlessIntOrIndexOrFloat());
1407 
1408     // Single offset can be more efficiently shuffled.
1409     if (op.offsets().getValue().size() == 1) {
1410       SmallVector<int64_t, 4> offsets;
1411       offsets.reserve(size);
1412       for (int64_t off = offset, e = offset + size * stride; off < e;
1413            off += stride)
1414         offsets.push_back(off);
1415       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1416                                              op.vector(),
1417                                              rewriter.getI64ArrayAttr(offsets));
1418       return success();
1419     }
1420 
1421     // Extract/insert on a lower ranked extract strided slice op.
1422     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1423                                              rewriter.getZeroAttr(elemType));
1424     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1425     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1426          off += stride, ++idx) {
1427       Value one = extractOne(rewriter, loc, op.vector(), off);
1428       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1429           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1430           getI64SubArray(op.sizes(), /* dropFront=*/1),
1431           getI64SubArray(op.strides(), /* dropFront=*/1));
1432       res = insertOne(rewriter, loc, extracted, res, idx);
1433     }
1434     rewriter.replaceOp(op, res);
1435     return success();
1436   }
1437   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1438   /// bounded as the rank is strictly decreasing.
1439   bool hasBoundedRewriteRecursion() const final { return true; }
1440 };
1441 
1442 } // namespace
1443 
1444 /// Populate the given list with patterns that convert from Vector to LLVM.
1445 void mlir::populateVectorToLLVMConversionPatterns(
1446     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1447     bool reassociateFPReductions) {
1448   MLIRContext *ctx = converter.getDialect()->getContext();
1449   // clang-format off
1450   patterns.insert<VectorFMAOpNDRewritePattern,
1451                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1452                   VectorInsertStridedSliceOpSameRankRewritePattern,
1453                   VectorExtractStridedSliceOpConversion>(ctx);
1454   patterns.insert<VectorReductionOpConversion>(
1455       ctx, converter, reassociateFPReductions);
1456   patterns
1457       .insert<VectorShuffleOpConversion,
1458               VectorExtractElementOpConversion,
1459               VectorExtractOpConversion,
1460               VectorFMAOp1DConversion,
1461               VectorInsertElementOpConversion,
1462               VectorInsertOpConversion,
1463               VectorPrintOpConversion,
1464               VectorTransferConversion<TransferReadOp>,
1465               VectorTransferConversion<TransferWriteOp>,
1466               VectorTypeCastOpConversion,
1467               VectorMaskedLoadOpConversion,
1468               VectorMaskedStoreOpConversion,
1469               VectorGatherOpConversion,
1470               VectorScatterOpConversion,
1471               VectorExpandLoadOpConversion,
1472               VectorCompressStoreOpConversion>(ctx, converter);
1473   // clang-format on
1474 }
1475 
1476 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1477     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1478   MLIRContext *ctx = converter.getDialect()->getContext();
1479   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1480   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1481 }
1482 
1483 namespace {
1484 struct LowerVectorToLLVMPass
1485     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1486   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1487     this->reassociateFPReductions = options.reassociateFPReductions;
1488   }
1489   void runOnOperation() override;
1490 };
1491 } // namespace
1492 
1493 void LowerVectorToLLVMPass::runOnOperation() {
1494   // Perform progressive lowering of operations on slices and
1495   // all contraction operations. Also applies folding and DCE.
1496   {
1497     OwningRewritePatternList patterns;
1498     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1499     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1500     populateVectorContractLoweringPatterns(patterns, &getContext());
1501     applyPatternsAndFoldGreedily(getOperation(), patterns);
1502   }
1503 
1504   // Convert to the LLVM IR dialect.
1505   LLVMTypeConverter converter(&getContext());
1506   OwningRewritePatternList patterns;
1507   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1508   populateVectorToLLVMConversionPatterns(converter, patterns,
1509                                          reassociateFPReductions);
1510   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1511   populateStdToLLVMConversionPatterns(converter, patterns);
1512 
1513   LLVMConversionTarget target(getContext());
1514   if (failed(applyPartialConversion(getOperation(), target, patterns))) {
1515     signalPassFailure();
1516   }
1517 }
1518 
1519 std::unique_ptr<OperationPass<ModuleOp>>
1520 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1521   return std::make_unique<LowerVectorToLLVMPass>(options);
1522 }
1523