1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by one rank at front.
38 static VectorType reducedVectorTypeFront(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
41 }
42 
43 // Helper to reduce vector type by *all* but one rank at back.
44 static VectorType reducedVectorTypeBack(VectorType tp) {
45   assert((tp.getRank() > 1) && "unlowerable vector type");
46   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
47 }
48 
49 // Helper that picks the proper sequence for inserting.
50 static Value insertOne(ConversionPatternRewriter &rewriter,
51                        LLVMTypeConverter &typeConverter, Location loc,
52                        Value val1, Value val2, Type llvmType, int64_t rank,
53                        int64_t pos) {
54   if (rank == 1) {
55     auto idxType = rewriter.getIndexType();
56     auto constant = rewriter.create<LLVM::ConstantOp>(
57         loc, typeConverter.convertType(idxType),
58         rewriter.getIntegerAttr(idxType, pos));
59     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
60                                                   constant);
61   }
62   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
63                                               rewriter.getI64ArrayAttr(pos));
64 }
65 
66 // Helper that picks the proper sequence for inserting.
67 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
68                        Value into, int64_t offset) {
69   auto vectorType = into.getType().cast<VectorType>();
70   if (vectorType.getRank() > 1)
71     return rewriter.create<InsertOp>(loc, from, into, offset);
72   return rewriter.create<vector::InsertElementOp>(
73       loc, vectorType, from, into,
74       rewriter.create<ConstantIndexOp>(loc, offset));
75 }
76 
77 // Helper that picks the proper sequence for extracting.
78 static Value extractOne(ConversionPatternRewriter &rewriter,
79                         LLVMTypeConverter &typeConverter, Location loc,
80                         Value val, Type llvmType, int64_t rank, int64_t pos) {
81   if (rank == 1) {
82     auto idxType = rewriter.getIndexType();
83     auto constant = rewriter.create<LLVM::ConstantOp>(
84         loc, typeConverter.convertType(idxType),
85         rewriter.getIntegerAttr(idxType, pos));
86     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
87                                                    constant);
88   }
89   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
90                                                rewriter.getI64ArrayAttr(pos));
91 }
92 
93 // Helper that picks the proper sequence for extracting.
94 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
95                         int64_t offset) {
96   auto vectorType = vector.getType().cast<VectorType>();
97   if (vectorType.getRank() > 1)
98     return rewriter.create<ExtractOp>(loc, vector, offset);
99   return rewriter.create<vector::ExtractElementOp>(
100       loc, vectorType.getElementType(), vector,
101       rewriter.create<ConstantIndexOp>(loc, offset));
102 }
103 
104 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
105 // TODO: Better support for attribute subtype forwarding + slicing.
106 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
107                                               unsigned dropFront = 0,
108                                               unsigned dropBack = 0) {
109   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
110   auto range = arrayAttr.getAsRange<IntegerAttr>();
111   SmallVector<int64_t, 4> res;
112   res.reserve(arrayAttr.size() - dropFront - dropBack);
113   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
114        it != eit; ++it)
115     res.push_back((*it).getValue().getSExtValue());
116   return res;
117 }
118 
119 // Helper that returns data layout alignment of an operation with memref.
120 template <typename T>
121 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
122                                  unsigned &align) {
123   Type elementTy =
124       typeConverter.convertType(op.getMemRefType().getElementType());
125   if (!elementTy)
126     return failure();
127 
128   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
129   align = dataLayout.getPrefTypeAlignment(
130       LLVM::convertLLVMType(elementTy.cast<LLVM::LLVMType>()));
131   return success();
132 }
133 
134 // Helper that returns vector of pointers given a base and an index vector.
135 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
136                              LLVMTypeConverter &typeConverter, Location loc,
137                              Value memref, Value indices, MemRefType memRefType,
138                              VectorType vType, Type iType, Value &ptrs) {
139   // Inspect stride and offset structure.
140   //
141   // TODO: flat memory only for now, generalize
142   //
143   int64_t offset;
144   SmallVector<int64_t, 4> strides;
145   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
146   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
147       offset != 0 || memRefType.getMemorySpace() != 0)
148     return failure();
149 
150   // Create a vector of pointers from base and indices.
151   MemRefDescriptor memRefDescriptor(memref);
152   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
153   int64_t size = vType.getDimSize(0);
154   auto pType = memRefDescriptor.getElementType();
155   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
156   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
157   return success();
158 }
159 
160 static LogicalResult
161 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
162                                  LLVMTypeConverter &typeConverter, Location loc,
163                                  TransferReadOp xferOp,
164                                  ArrayRef<Value> operands, Value dataPtr) {
165   unsigned align;
166   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
167     return failure();
168   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
169   return success();
170 }
171 
172 static LogicalResult
173 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
174                             LLVMTypeConverter &typeConverter, Location loc,
175                             TransferReadOp xferOp, ArrayRef<Value> operands,
176                             Value dataPtr, Value mask) {
177   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
178   VectorType fillType = xferOp.getVectorType();
179   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
180   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
181 
182   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
183   if (!vecTy)
184     return failure();
185 
186   unsigned align;
187   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
188     return failure();
189 
190   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
191       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
192       rewriter.getI32IntegerAttr(align));
193   return success();
194 }
195 
196 static LogicalResult
197 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
198                                  LLVMTypeConverter &typeConverter, Location loc,
199                                  TransferWriteOp xferOp,
200                                  ArrayRef<Value> operands, Value dataPtr) {
201   unsigned align;
202   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
203     return failure();
204   auto adaptor = TransferWriteOpAdaptor(operands);
205   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
206                                              align);
207   return success();
208 }
209 
210 static LogicalResult
211 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
212                             LLVMTypeConverter &typeConverter, Location loc,
213                             TransferWriteOp xferOp, ArrayRef<Value> operands,
214                             Value dataPtr, Value mask) {
215   unsigned align;
216   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
217     return failure();
218 
219   auto adaptor = TransferWriteOpAdaptor(operands);
220   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
221       xferOp, adaptor.vector(), dataPtr, mask,
222       rewriter.getI32IntegerAttr(align));
223   return success();
224 }
225 
226 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
227                                                   ArrayRef<Value> operands) {
228   return TransferReadOpAdaptor(operands);
229 }
230 
231 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
232                                                    ArrayRef<Value> operands) {
233   return TransferWriteOpAdaptor(operands);
234 }
235 
236 namespace {
237 
238 /// Conversion pattern for a vector.matrix_multiply.
239 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
240 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
241 public:
242   explicit VectorMatmulOpConversion(MLIRContext *context,
243                                     LLVMTypeConverter &typeConverter)
244       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
245                              typeConverter) {}
246 
247   LogicalResult
248   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
249                   ConversionPatternRewriter &rewriter) const override {
250     auto matmulOp = cast<vector::MatmulOp>(op);
251     auto adaptor = vector::MatmulOpAdaptor(operands);
252     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
253         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
254         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
255         matmulOp.rhs_columns());
256     return success();
257   }
258 };
259 
260 /// Conversion pattern for a vector.flat_transpose.
261 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
262 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
263 public:
264   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
265                                            LLVMTypeConverter &typeConverter)
266       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
267                              context, typeConverter) {}
268 
269   LogicalResult
270   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
271                   ConversionPatternRewriter &rewriter) const override {
272     auto transOp = cast<vector::FlatTransposeOp>(op);
273     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
274     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
275         transOp, typeConverter.convertType(transOp.res().getType()),
276         adaptor.matrix(), transOp.rows(), transOp.columns());
277     return success();
278   }
279 };
280 
281 /// Conversion pattern for a vector.gather.
282 class VectorGatherOpConversion : public ConvertToLLVMPattern {
283 public:
284   explicit VectorGatherOpConversion(MLIRContext *context,
285                                     LLVMTypeConverter &typeConverter)
286       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
287                              typeConverter) {}
288 
289   LogicalResult
290   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const override {
292     auto loc = op->getLoc();
293     auto gather = cast<vector::GatherOp>(op);
294     auto adaptor = vector::GatherOpAdaptor(operands);
295 
296     // Resolve alignment.
297     unsigned align;
298     if (failed(getMemRefAlignment(typeConverter, gather, align)))
299       return failure();
300 
301     // Get index ptrs.
302     VectorType vType = gather.getResultVectorType();
303     Type iType = gather.getIndicesVectorType().getElementType();
304     Value ptrs;
305     if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
306                               adaptor.indices(), gather.getMemRefType(), vType,
307                               iType, ptrs)))
308       return failure();
309 
310     // Replace with the gather intrinsic.
311     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
312                                                           : adaptor.pass_thru();
313     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
314         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
315         rewriter.getI32IntegerAttr(align));
316     return success();
317   }
318 };
319 
320 /// Conversion pattern for a vector.scatter.
321 class VectorScatterOpConversion : public ConvertToLLVMPattern {
322 public:
323   explicit VectorScatterOpConversion(MLIRContext *context,
324                                      LLVMTypeConverter &typeConverter)
325       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
326                              typeConverter) {}
327 
328   LogicalResult
329   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const override {
331     auto loc = op->getLoc();
332     auto scatter = cast<vector::ScatterOp>(op);
333     auto adaptor = vector::ScatterOpAdaptor(operands);
334 
335     // Resolve alignment.
336     unsigned align;
337     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
338       return failure();
339 
340     // Get index ptrs.
341     VectorType vType = scatter.getValueVectorType();
342     Type iType = scatter.getIndicesVectorType().getElementType();
343     Value ptrs;
344     if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
345                               adaptor.indices(), scatter.getMemRefType(), vType,
346                               iType, ptrs)))
347       return failure();
348 
349     // Replace with the scatter intrinsic.
350     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
351         scatter, adaptor.value(), ptrs, adaptor.mask(),
352         rewriter.getI32IntegerAttr(align));
353     return success();
354   }
355 };
356 
357 /// Conversion pattern for all vector reductions.
358 class VectorReductionOpConversion : public ConvertToLLVMPattern {
359 public:
360   explicit VectorReductionOpConversion(MLIRContext *context,
361                                        LLVMTypeConverter &typeConverter,
362                                        bool reassociateFP)
363       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
364                              typeConverter),
365         reassociateFPReductions(reassociateFP) {}
366 
367   LogicalResult
368   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
369                   ConversionPatternRewriter &rewriter) const override {
370     auto reductionOp = cast<vector::ReductionOp>(op);
371     auto kind = reductionOp.kind();
372     Type eltType = reductionOp.dest().getType();
373     Type llvmType = typeConverter.convertType(eltType);
374     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
375       // Integer reductions: add/mul/min/max/and/or/xor.
376       if (kind == "add")
377         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
378             op, llvmType, operands[0]);
379       else if (kind == "mul")
380         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
381             op, llvmType, operands[0]);
382       else if (kind == "min")
383         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
384             op, llvmType, operands[0]);
385       else if (kind == "max")
386         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
387             op, llvmType, operands[0]);
388       else if (kind == "and")
389         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
390             op, llvmType, operands[0]);
391       else if (kind == "or")
392         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
393             op, llvmType, operands[0]);
394       else if (kind == "xor")
395         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
396             op, llvmType, operands[0]);
397       else
398         return failure();
399       return success();
400 
401     } else if (eltType.isF32() || eltType.isF64()) {
402       // Floating-point reductions: add/mul/min/max
403       if (kind == "add") {
404         // Optional accumulator (or zero).
405         Value acc = operands.size() > 1 ? operands[1]
406                                         : rewriter.create<LLVM::ConstantOp>(
407                                               op->getLoc(), llvmType,
408                                               rewriter.getZeroAttr(eltType));
409         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
410             op, llvmType, acc, operands[0],
411             rewriter.getBoolAttr(reassociateFPReductions));
412       } else if (kind == "mul") {
413         // Optional accumulator (or one).
414         Value acc = operands.size() > 1
415                         ? operands[1]
416                         : rewriter.create<LLVM::ConstantOp>(
417                               op->getLoc(), llvmType,
418                               rewriter.getFloatAttr(eltType, 1.0));
419         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
420             op, llvmType, acc, operands[0],
421             rewriter.getBoolAttr(reassociateFPReductions));
422       } else if (kind == "min")
423         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
424             op, llvmType, operands[0]);
425       else if (kind == "max")
426         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
427             op, llvmType, operands[0]);
428       else
429         return failure();
430       return success();
431     }
432     return failure();
433   }
434 
435 private:
436   const bool reassociateFPReductions;
437 };
438 
439 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
440 public:
441   explicit VectorShuffleOpConversion(MLIRContext *context,
442                                      LLVMTypeConverter &typeConverter)
443       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
444                              typeConverter) {}
445 
446   LogicalResult
447   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
448                   ConversionPatternRewriter &rewriter) const override {
449     auto loc = op->getLoc();
450     auto adaptor = vector::ShuffleOpAdaptor(operands);
451     auto shuffleOp = cast<vector::ShuffleOp>(op);
452     auto v1Type = shuffleOp.getV1VectorType();
453     auto v2Type = shuffleOp.getV2VectorType();
454     auto vectorType = shuffleOp.getVectorType();
455     Type llvmType = typeConverter.convertType(vectorType);
456     auto maskArrayAttr = shuffleOp.mask();
457 
458     // Bail if result type cannot be lowered.
459     if (!llvmType)
460       return failure();
461 
462     // Get rank and dimension sizes.
463     int64_t rank = vectorType.getRank();
464     assert(v1Type.getRank() == rank);
465     assert(v2Type.getRank() == rank);
466     int64_t v1Dim = v1Type.getDimSize(0);
467 
468     // For rank 1, where both operands have *exactly* the same vector type,
469     // there is direct shuffle support in LLVM. Use it!
470     if (rank == 1 && v1Type == v2Type) {
471       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
472           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
473       rewriter.replaceOp(op, shuffle);
474       return success();
475     }
476 
477     // For all other cases, insert the individual values individually.
478     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
479     int64_t insPos = 0;
480     for (auto en : llvm::enumerate(maskArrayAttr)) {
481       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
482       Value value = adaptor.v1();
483       if (extPos >= v1Dim) {
484         extPos -= v1Dim;
485         value = adaptor.v2();
486       }
487       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
488                                  rank, extPos);
489       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
490                          llvmType, rank, insPos++);
491     }
492     rewriter.replaceOp(op, insert);
493     return success();
494   }
495 };
496 
497 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
498 public:
499   explicit VectorExtractElementOpConversion(MLIRContext *context,
500                                             LLVMTypeConverter &typeConverter)
501       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
502                              context, typeConverter) {}
503 
504   LogicalResult
505   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
506                   ConversionPatternRewriter &rewriter) const override {
507     auto adaptor = vector::ExtractElementOpAdaptor(operands);
508     auto extractEltOp = cast<vector::ExtractElementOp>(op);
509     auto vectorType = extractEltOp.getVectorType();
510     auto llvmType = typeConverter.convertType(vectorType.getElementType());
511 
512     // Bail if result type cannot be lowered.
513     if (!llvmType)
514       return failure();
515 
516     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
517         op, llvmType, adaptor.vector(), adaptor.position());
518     return success();
519   }
520 };
521 
522 class VectorExtractOpConversion : public ConvertToLLVMPattern {
523 public:
524   explicit VectorExtractOpConversion(MLIRContext *context,
525                                      LLVMTypeConverter &typeConverter)
526       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
527                              typeConverter) {}
528 
529   LogicalResult
530   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
531                   ConversionPatternRewriter &rewriter) const override {
532     auto loc = op->getLoc();
533     auto adaptor = vector::ExtractOpAdaptor(operands);
534     auto extractOp = cast<vector::ExtractOp>(op);
535     auto vectorType = extractOp.getVectorType();
536     auto resultType = extractOp.getResult().getType();
537     auto llvmResultType = typeConverter.convertType(resultType);
538     auto positionArrayAttr = extractOp.position();
539 
540     // Bail if result type cannot be lowered.
541     if (!llvmResultType)
542       return failure();
543 
544     // One-shot extraction of vector from array (only requires extractvalue).
545     if (resultType.isa<VectorType>()) {
546       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
547           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
548       rewriter.replaceOp(op, extracted);
549       return success();
550     }
551 
552     // Potential extraction of 1-D vector from array.
553     auto *context = op->getContext();
554     Value extracted = adaptor.vector();
555     auto positionAttrs = positionArrayAttr.getValue();
556     if (positionAttrs.size() > 1) {
557       auto oneDVectorType = reducedVectorTypeBack(vectorType);
558       auto nMinusOnePositionAttrs =
559           ArrayAttr::get(positionAttrs.drop_back(), context);
560       extracted = rewriter.create<LLVM::ExtractValueOp>(
561           loc, typeConverter.convertType(oneDVectorType), extracted,
562           nMinusOnePositionAttrs);
563     }
564 
565     // Remaining extraction of element from 1-D LLVM vector
566     auto position = positionAttrs.back().cast<IntegerAttr>();
567     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
568     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
569     extracted =
570         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
571     rewriter.replaceOp(op, extracted);
572 
573     return success();
574   }
575 };
576 
577 /// Conversion pattern that turns a vector.fma on a 1-D vector
578 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
579 /// This does not match vectors of n >= 2 rank.
580 ///
581 /// Example:
582 /// ```
583 ///  vector.fma %a, %a, %a : vector<8xf32>
584 /// ```
585 /// is converted to:
586 /// ```
587 ///  llvm.intr.fmuladd %va, %va, %va:
588 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
589 ///    -> !llvm<"<8 x float>">
590 /// ```
591 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
592 public:
593   explicit VectorFMAOp1DConversion(MLIRContext *context,
594                                    LLVMTypeConverter &typeConverter)
595       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
596                              typeConverter) {}
597 
598   LogicalResult
599   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
600                   ConversionPatternRewriter &rewriter) const override {
601     auto adaptor = vector::FMAOpAdaptor(operands);
602     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
603     VectorType vType = fmaOp.getVectorType();
604     if (vType.getRank() != 1)
605       return failure();
606     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
607                                                  adaptor.rhs(), adaptor.acc());
608     return success();
609   }
610 };
611 
612 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
613 public:
614   explicit VectorInsertElementOpConversion(MLIRContext *context,
615                                            LLVMTypeConverter &typeConverter)
616       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
617                              context, typeConverter) {}
618 
619   LogicalResult
620   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
621                   ConversionPatternRewriter &rewriter) const override {
622     auto adaptor = vector::InsertElementOpAdaptor(operands);
623     auto insertEltOp = cast<vector::InsertElementOp>(op);
624     auto vectorType = insertEltOp.getDestVectorType();
625     auto llvmType = typeConverter.convertType(vectorType);
626 
627     // Bail if result type cannot be lowered.
628     if (!llvmType)
629       return failure();
630 
631     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
632         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
633     return success();
634   }
635 };
636 
637 class VectorInsertOpConversion : public ConvertToLLVMPattern {
638 public:
639   explicit VectorInsertOpConversion(MLIRContext *context,
640                                     LLVMTypeConverter &typeConverter)
641       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
642                              typeConverter) {}
643 
644   LogicalResult
645   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     auto loc = op->getLoc();
648     auto adaptor = vector::InsertOpAdaptor(operands);
649     auto insertOp = cast<vector::InsertOp>(op);
650     auto sourceType = insertOp.getSourceType();
651     auto destVectorType = insertOp.getDestVectorType();
652     auto llvmResultType = typeConverter.convertType(destVectorType);
653     auto positionArrayAttr = insertOp.position();
654 
655     // Bail if result type cannot be lowered.
656     if (!llvmResultType)
657       return failure();
658 
659     // One-shot insertion of a vector into an array (only requires insertvalue).
660     if (sourceType.isa<VectorType>()) {
661       Value inserted = rewriter.create<LLVM::InsertValueOp>(
662           loc, llvmResultType, adaptor.dest(), adaptor.source(),
663           positionArrayAttr);
664       rewriter.replaceOp(op, inserted);
665       return success();
666     }
667 
668     // Potential extraction of 1-D vector from array.
669     auto *context = op->getContext();
670     Value extracted = adaptor.dest();
671     auto positionAttrs = positionArrayAttr.getValue();
672     auto position = positionAttrs.back().cast<IntegerAttr>();
673     auto oneDVectorType = destVectorType;
674     if (positionAttrs.size() > 1) {
675       oneDVectorType = reducedVectorTypeBack(destVectorType);
676       auto nMinusOnePositionAttrs =
677           ArrayAttr::get(positionAttrs.drop_back(), context);
678       extracted = rewriter.create<LLVM::ExtractValueOp>(
679           loc, typeConverter.convertType(oneDVectorType), extracted,
680           nMinusOnePositionAttrs);
681     }
682 
683     // Insertion of an element into a 1-D LLVM vector.
684     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
685     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
686     Value inserted = rewriter.create<LLVM::InsertElementOp>(
687         loc, typeConverter.convertType(oneDVectorType), extracted,
688         adaptor.source(), constant);
689 
690     // Potential insertion of resulting 1-D vector into array.
691     if (positionAttrs.size() > 1) {
692       auto nMinusOnePositionAttrs =
693           ArrayAttr::get(positionAttrs.drop_back(), context);
694       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
695                                                       adaptor.dest(), inserted,
696                                                       nMinusOnePositionAttrs);
697     }
698 
699     rewriter.replaceOp(op, inserted);
700     return success();
701   }
702 };
703 
704 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
705 ///
706 /// Example:
707 /// ```
708 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
709 /// ```
710 /// is rewritten into:
711 /// ```
712 ///  %r = splat %f0: vector<2x4xf32>
713 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
714 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
715 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
716 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
717 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
718 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
719 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
720 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
721 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
722 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
723 ///  // %r3 holds the final value.
724 /// ```
725 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
726 public:
727   using OpRewritePattern<FMAOp>::OpRewritePattern;
728 
729   LogicalResult matchAndRewrite(FMAOp op,
730                                 PatternRewriter &rewriter) const override {
731     auto vType = op.getVectorType();
732     if (vType.getRank() < 2)
733       return failure();
734 
735     auto loc = op.getLoc();
736     auto elemType = vType.getElementType();
737     Value zero = rewriter.create<ConstantOp>(loc, elemType,
738                                              rewriter.getZeroAttr(elemType));
739     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
740     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
741       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
742       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
743       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
744       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
745       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
746     }
747     rewriter.replaceOp(op, desc);
748     return success();
749   }
750 };
751 
752 // When ranks are different, InsertStridedSlice needs to extract a properly
753 // ranked vector from the destination vector into which to insert. This pattern
754 // only takes care of this part and forwards the rest of the conversion to
755 // another pattern that converts InsertStridedSlice for operands of the same
756 // rank.
757 //
758 // RewritePattern for InsertStridedSliceOp where source and destination vectors
759 // have different ranks. In this case:
760 //   1. the proper subvector is extracted from the destination vector
761 //   2. a new InsertStridedSlice op is created to insert the source in the
762 //   destination subvector
763 //   3. the destination subvector is inserted back in the proper place
764 //   4. the op is replaced by the result of step 3.
765 // The new InsertStridedSlice from step 2. will be picked up by a
766 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
767 class VectorInsertStridedSliceOpDifferentRankRewritePattern
768     : public OpRewritePattern<InsertStridedSliceOp> {
769 public:
770   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
771 
772   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
773                                 PatternRewriter &rewriter) const override {
774     auto srcType = op.getSourceVectorType();
775     auto dstType = op.getDestVectorType();
776 
777     if (op.offsets().getValue().empty())
778       return failure();
779 
780     auto loc = op.getLoc();
781     int64_t rankDiff = dstType.getRank() - srcType.getRank();
782     assert(rankDiff >= 0);
783     if (rankDiff == 0)
784       return failure();
785 
786     int64_t rankRest = dstType.getRank() - rankDiff;
787     // Extract / insert the subvector of matching rank and InsertStridedSlice
788     // on it.
789     Value extracted =
790         rewriter.create<ExtractOp>(loc, op.dest(),
791                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
792                                                   /*dropFront=*/rankRest));
793     // A different pattern will kick in for InsertStridedSlice with matching
794     // ranks.
795     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
796         loc, op.source(), extracted,
797         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
798         getI64SubArray(op.strides(), /*dropFront=*/0));
799     rewriter.replaceOpWithNewOp<InsertOp>(
800         op, stridedSliceInnerOp.getResult(), op.dest(),
801         getI64SubArray(op.offsets(), /*dropFront=*/0,
802                        /*dropFront=*/rankRest));
803     return success();
804   }
805 };
806 
807 // RewritePattern for InsertStridedSliceOp where source and destination vectors
808 // have the same rank. In this case, we reduce
809 //   1. the proper subvector is extracted from the destination vector
810 //   2. a new InsertStridedSlice op is created to insert the source in the
811 //   destination subvector
812 //   3. the destination subvector is inserted back in the proper place
813 //   4. the op is replaced by the result of step 3.
814 // The new InsertStridedSlice from step 2. will be picked up by a
815 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
816 class VectorInsertStridedSliceOpSameRankRewritePattern
817     : public OpRewritePattern<InsertStridedSliceOp> {
818 public:
819   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
820 
821   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
822                                 PatternRewriter &rewriter) const override {
823     auto srcType = op.getSourceVectorType();
824     auto dstType = op.getDestVectorType();
825 
826     if (op.offsets().getValue().empty())
827       return failure();
828 
829     int64_t rankDiff = dstType.getRank() - srcType.getRank();
830     assert(rankDiff >= 0);
831     if (rankDiff != 0)
832       return failure();
833 
834     if (srcType == dstType) {
835       rewriter.replaceOp(op, op.source());
836       return success();
837     }
838 
839     int64_t offset =
840         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
841     int64_t size = srcType.getShape().front();
842     int64_t stride =
843         op.strides().getValue().front().cast<IntegerAttr>().getInt();
844 
845     auto loc = op.getLoc();
846     Value res = op.dest();
847     // For each slice of the source vector along the most major dimension.
848     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
849          off += stride, ++idx) {
850       // 1. extract the proper subvector (or element) from source
851       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
852       if (extractedSource.getType().isa<VectorType>()) {
853         // 2. If we have a vector, extract the proper subvector from destination
854         // Otherwise we are at the element level and no need to recurse.
855         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
856         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
857         // smaller rank.
858         extractedSource = rewriter.create<InsertStridedSliceOp>(
859             loc, extractedSource, extractedDest,
860             getI64SubArray(op.offsets(), /* dropFront=*/1),
861             getI64SubArray(op.strides(), /* dropFront=*/1));
862       }
863       // 4. Insert the extractedSource into the res vector.
864       res = insertOne(rewriter, loc, extractedSource, res, off);
865     }
866 
867     rewriter.replaceOp(op, res);
868     return success();
869   }
870   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
871   /// bounded as the rank is strictly decreasing.
872   bool hasBoundedRewriteRecursion() const final { return true; }
873 };
874 
875 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
876 public:
877   explicit VectorTypeCastOpConversion(MLIRContext *context,
878                                       LLVMTypeConverter &typeConverter)
879       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
880                              typeConverter) {}
881 
882   LogicalResult
883   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
884                   ConversionPatternRewriter &rewriter) const override {
885     auto loc = op->getLoc();
886     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
887     MemRefType sourceMemRefType =
888         castOp.getOperand().getType().cast<MemRefType>();
889     MemRefType targetMemRefType =
890         castOp.getResult().getType().cast<MemRefType>();
891 
892     // Only static shape casts supported atm.
893     if (!sourceMemRefType.hasStaticShape() ||
894         !targetMemRefType.hasStaticShape())
895       return failure();
896 
897     auto llvmSourceDescriptorTy =
898         operands[0].getType().dyn_cast<LLVM::LLVMType>();
899     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
900       return failure();
901     MemRefDescriptor sourceMemRef(operands[0]);
902 
903     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
904                                       .dyn_cast_or_null<LLVM::LLVMType>();
905     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
906       return failure();
907 
908     int64_t offset;
909     SmallVector<int64_t, 4> strides;
910     auto successStrides =
911         getStridesAndOffset(sourceMemRefType, strides, offset);
912     bool isContiguous = (strides.back() == 1);
913     if (isContiguous) {
914       auto sizes = sourceMemRefType.getShape();
915       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
916         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
917           isContiguous = false;
918           break;
919         }
920       }
921     }
922     // Only contiguous source tensors supported atm.
923     if (failed(successStrides) || !isContiguous)
924       return failure();
925 
926     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
927 
928     // Create descriptor.
929     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
930     Type llvmTargetElementTy = desc.getElementType();
931     // Set allocated ptr.
932     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
933     allocated =
934         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
935     desc.setAllocatedPtr(rewriter, loc, allocated);
936     // Set aligned ptr.
937     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
938     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
939     desc.setAlignedPtr(rewriter, loc, ptr);
940     // Fill offset 0.
941     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
942     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
943     desc.setOffset(rewriter, loc, zero);
944 
945     // Fill size and stride descriptors in memref.
946     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
947       int64_t index = indexedSize.index();
948       auto sizeAttr =
949           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
950       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
951       desc.setSize(rewriter, loc, index, size);
952       auto strideAttr =
953           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
954       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
955       desc.setStride(rewriter, loc, index, stride);
956     }
957 
958     rewriter.replaceOp(op, {desc});
959     return success();
960   }
961 };
962 
963 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
964 /// sequence of:
965 /// 1. Bitcast or addrspacecast to vector form.
966 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
967 /// 3. Create a mask where offsetVector is compared against memref upper bound.
968 /// 4. Rewrite op as a masked read or write.
969 template <typename ConcreteOp>
970 class VectorTransferConversion : public ConvertToLLVMPattern {
971 public:
972   explicit VectorTransferConversion(MLIRContext *context,
973                                     LLVMTypeConverter &typeConv)
974       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
975                              typeConv) {}
976 
977   LogicalResult
978   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
979                   ConversionPatternRewriter &rewriter) const override {
980     auto xferOp = cast<ConcreteOp>(op);
981     auto adaptor = getTransferOpAdapter(xferOp, operands);
982 
983     if (xferOp.getVectorType().getRank() > 1 ||
984         llvm::size(xferOp.indices()) == 0)
985       return failure();
986     if (xferOp.permutation_map() !=
987         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
988                                        xferOp.getVectorType().getRank(),
989                                        op->getContext()))
990       return failure();
991 
992     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
993 
994     Location loc = op->getLoc();
995     Type i64Type = rewriter.getIntegerType(64);
996     MemRefType memRefType = xferOp.getMemRefType();
997 
998     // 1. Get the source/dst address as an LLVM vector pointer.
999     //    The vector pointer would always be on address space 0, therefore
1000     //    addrspacecast shall be used when source/dst memrefs are not on
1001     //    address space 0.
1002     // TODO: support alignment when possible.
1003     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1004                                adaptor.indices(), rewriter, getModule());
1005     auto vecTy =
1006         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1007     Value vectorDataPtr;
1008     if (memRefType.getMemorySpace() == 0)
1009       vectorDataPtr =
1010           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1011     else
1012       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1013           loc, vecTy.getPointerTo(), dataPtr);
1014 
1015     if (!xferOp.isMaskedDim(0))
1016       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1017                                               xferOp, operands, vectorDataPtr);
1018 
1019     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1020     unsigned vecWidth = vecTy.getVectorNumElements();
1021     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
1022     SmallVector<int64_t, 8> indices;
1023     indices.reserve(vecWidth);
1024     for (unsigned i = 0; i < vecWidth; ++i)
1025       indices.push_back(i);
1026     Value linearIndices = rewriter.create<ConstantOp>(
1027         loc, vectorCmpType,
1028         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
1029     linearIndices = rewriter.create<LLVM::DialectCastOp>(
1030         loc, toLLVMTy(vectorCmpType), linearIndices);
1031 
1032     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1033     // TODO: when the leaf transfer rank is k > 1 we need the last
1034     // `k` dimensions here.
1035     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1036     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
1037     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
1038     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1039     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1040 
1041     // 4. Let dim the memref dimension, compute the vector comparison mask:
1042     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1043     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1044     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
1045     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1046     Value mask =
1047         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1048     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1049                                                 mask);
1050 
1051     // 5. Rewrite as a masked read / write.
1052     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1053                                        operands, vectorDataPtr, mask);
1054   }
1055 };
1056 
1057 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1058 public:
1059   explicit VectorPrintOpConversion(MLIRContext *context,
1060                                    LLVMTypeConverter &typeConverter)
1061       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1062                              typeConverter) {}
1063 
1064   // Proof-of-concept lowering implementation that relies on a small
1065   // runtime support library, which only needs to provide a few
1066   // printing methods (single value for all data types, opening/closing
1067   // bracket, comma, newline). The lowering fully unrolls a vector
1068   // in terms of these elementary printing operations. The advantage
1069   // of this approach is that the library can remain unaware of all
1070   // low-level implementation details of vectors while still supporting
1071   // output of any shaped and dimensioned vector. Due to full unrolling,
1072   // this approach is less suited for very large vectors though.
1073   //
1074   // TODO: rely solely on libc in future? something else?
1075   //
1076   LogicalResult
1077   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1078                   ConversionPatternRewriter &rewriter) const override {
1079     auto printOp = cast<vector::PrintOp>(op);
1080     auto adaptor = vector::PrintOpAdaptor(operands);
1081     Type printType = printOp.getPrintType();
1082 
1083     if (typeConverter.convertType(printType) == nullptr)
1084       return failure();
1085 
1086     // Make sure element type has runtime support (currently just Float/Double).
1087     VectorType vectorType = printType.dyn_cast<VectorType>();
1088     Type eltType = vectorType ? vectorType.getElementType() : printType;
1089     int64_t rank = vectorType ? vectorType.getRank() : 0;
1090     Operation *printer;
1091     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1092       printer = getPrintI32(op);
1093     else if (eltType.isSignlessInteger(64))
1094       printer = getPrintI64(op);
1095     else if (eltType.isF32())
1096       printer = getPrintFloat(op);
1097     else if (eltType.isF64())
1098       printer = getPrintDouble(op);
1099     else
1100       return failure();
1101 
1102     // Unroll vector into elementary print calls.
1103     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1104     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1105     rewriter.eraseOp(op);
1106     return success();
1107   }
1108 
1109 private:
1110   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1111                  Value value, VectorType vectorType, Operation *printer,
1112                  int64_t rank) const {
1113     Location loc = op->getLoc();
1114     if (rank == 0) {
1115       if (value.getType() ==
1116           LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) {
1117         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1118         // This avoids the need for a print_i1 method with an unclear ABI.
1119         auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect());
1120         auto trueVal = rewriter.create<ConstantOp>(
1121             loc, i32Type, rewriter.getI32IntegerAttr(1));
1122         auto falseVal = rewriter.create<ConstantOp>(
1123             loc, i32Type, rewriter.getI32IntegerAttr(0));
1124         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1125       }
1126       emitCall(rewriter, loc, printer, value);
1127       return;
1128     }
1129 
1130     emitCall(rewriter, loc, getPrintOpen(op));
1131     Operation *printComma = getPrintComma(op);
1132     int64_t dim = vectorType.getDimSize(0);
1133     for (int64_t d = 0; d < dim; ++d) {
1134       auto reducedType =
1135           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1136       auto llvmType = typeConverter.convertType(
1137           rank > 1 ? reducedType : vectorType.getElementType());
1138       Value nestedVal =
1139           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1140       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1141       if (d != dim - 1)
1142         emitCall(rewriter, loc, printComma);
1143     }
1144     emitCall(rewriter, loc, getPrintClose(op));
1145   }
1146 
1147   // Helper to emit a call.
1148   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1149                        Operation *ref, ValueRange params = ValueRange()) {
1150     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1151                                   rewriter.getSymbolRefAttr(ref), params);
1152   }
1153 
1154   // Helper for printer method declaration (first hit) and lookup.
1155   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1156                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1157     auto module = op->getParentOfType<ModuleOp>();
1158     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1159     if (func)
1160       return func;
1161     OpBuilder moduleBuilder(module.getBodyRegion());
1162     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1163         op->getLoc(), name,
1164         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1165                                       params, /*isVarArg=*/false));
1166   }
1167 
1168   // Helpers for method names.
1169   Operation *getPrintI32(Operation *op) const {
1170     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1171     return getPrint(op, dialect, "print_i32",
1172                     LLVM::LLVMType::getInt32Ty(dialect));
1173   }
1174   Operation *getPrintI64(Operation *op) const {
1175     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1176     return getPrint(op, dialect, "print_i64",
1177                     LLVM::LLVMType::getInt64Ty(dialect));
1178   }
1179   Operation *getPrintFloat(Operation *op) const {
1180     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1181     return getPrint(op, dialect, "print_f32",
1182                     LLVM::LLVMType::getFloatTy(dialect));
1183   }
1184   Operation *getPrintDouble(Operation *op) const {
1185     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1186     return getPrint(op, dialect, "print_f64",
1187                     LLVM::LLVMType::getDoubleTy(dialect));
1188   }
1189   Operation *getPrintOpen(Operation *op) const {
1190     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1191   }
1192   Operation *getPrintClose(Operation *op) const {
1193     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1194   }
1195   Operation *getPrintComma(Operation *op) const {
1196     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1197   }
1198   Operation *getPrintNewline(Operation *op) const {
1199     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1200   }
1201 };
1202 
1203 /// Progressive lowering of ExtractStridedSliceOp to either:
1204 ///   1. extractelement + insertelement for the 1-D case
1205 ///   2. extract + optional strided_slice + insert for the n-D case.
1206 class VectorStridedSliceOpConversion
1207     : public OpRewritePattern<ExtractStridedSliceOp> {
1208 public:
1209   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1210 
1211   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1212                                 PatternRewriter &rewriter) const override {
1213     auto dstType = op.getResult().getType().cast<VectorType>();
1214 
1215     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1216 
1217     int64_t offset =
1218         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1219     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1220     int64_t stride =
1221         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1222 
1223     auto loc = op.getLoc();
1224     auto elemType = dstType.getElementType();
1225     assert(elemType.isSignlessIntOrIndexOrFloat());
1226     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1227                                              rewriter.getZeroAttr(elemType));
1228     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1229     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1230          off += stride, ++idx) {
1231       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1232       if (op.offsets().getValue().size() > 1) {
1233         extracted = rewriter.create<ExtractStridedSliceOp>(
1234             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1235             getI64SubArray(op.sizes(), /* dropFront=*/1),
1236             getI64SubArray(op.strides(), /* dropFront=*/1));
1237       }
1238       res = insertOne(rewriter, loc, extracted, res, idx);
1239     }
1240     rewriter.replaceOp(op, {res});
1241     return success();
1242   }
1243   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1244   /// bounded as the rank is strictly decreasing.
1245   bool hasBoundedRewriteRecursion() const final { return true; }
1246 };
1247 
1248 } // namespace
1249 
1250 /// Populate the given list with patterns that convert from Vector to LLVM.
1251 void mlir::populateVectorToLLVMConversionPatterns(
1252     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1253     bool reassociateFPReductions) {
1254   MLIRContext *ctx = converter.getDialect()->getContext();
1255   // clang-format off
1256   patterns.insert<VectorFMAOpNDRewritePattern,
1257                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1258                   VectorInsertStridedSliceOpSameRankRewritePattern,
1259                   VectorStridedSliceOpConversion>(ctx);
1260   patterns.insert<VectorReductionOpConversion>(
1261       ctx, converter, reassociateFPReductions);
1262   patterns
1263       .insert<VectorShuffleOpConversion,
1264               VectorExtractElementOpConversion,
1265               VectorExtractOpConversion,
1266               VectorFMAOp1DConversion,
1267               VectorInsertElementOpConversion,
1268               VectorInsertOpConversion,
1269               VectorPrintOpConversion,
1270               VectorTransferConversion<TransferReadOp>,
1271               VectorTransferConversion<TransferWriteOp>,
1272               VectorTypeCastOpConversion,
1273               VectorGatherOpConversion,
1274               VectorScatterOpConversion>(ctx, converter);
1275   // clang-format on
1276 }
1277 
1278 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1279     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1280   MLIRContext *ctx = converter.getDialect()->getContext();
1281   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1282   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1283 }
1284 
1285 namespace {
1286 struct LowerVectorToLLVMPass
1287     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1288   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1289     this->reassociateFPReductions = options.reassociateFPReductions;
1290   }
1291   void runOnOperation() override;
1292 };
1293 } // namespace
1294 
1295 void LowerVectorToLLVMPass::runOnOperation() {
1296   // Perform progressive lowering of operations on slices and
1297   // all contraction operations. Also applies folding and DCE.
1298   {
1299     OwningRewritePatternList patterns;
1300     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1301     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1302     populateVectorContractLoweringPatterns(patterns, &getContext());
1303     applyPatternsAndFoldGreedily(getOperation(), patterns);
1304   }
1305 
1306   // Convert to the LLVM IR dialect.
1307   LLVMTypeConverter converter(&getContext());
1308   OwningRewritePatternList patterns;
1309   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1310   populateVectorToLLVMConversionPatterns(converter, patterns,
1311                                          reassociateFPReductions);
1312   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1313   populateStdToLLVMConversionPatterns(converter, patterns);
1314 
1315   LLVMConversionTarget target(getContext());
1316   if (failed(applyPartialConversion(getOperation(), target, patterns))) {
1317     signalPassFailure();
1318   }
1319 }
1320 
1321 std::unique_ptr<OperationPass<ModuleOp>>
1322 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1323   return std::make_unique<LowerVectorToLLVMPass>(options);
1324 }
1325