1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorRewritePatterns.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 
27 // Helper to reduce vector type by one rank at front.
28 static VectorType reducedVectorTypeFront(VectorType tp) {
29   assert((tp.getRank() > 1) && "unlowerable vector type");
30   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
31 }
32 
33 // Helper to reduce vector type by *all* but one rank at back.
34 static VectorType reducedVectorTypeBack(VectorType tp) {
35   assert((tp.getRank() > 1) && "unlowerable vector type");
36   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
37 }
38 
39 // Helper that picks the proper sequence for inserting.
40 static Value insertOne(ConversionPatternRewriter &rewriter,
41                        LLVMTypeConverter &typeConverter, Location loc,
42                        Value val1, Value val2, Type llvmType, int64_t rank,
43                        int64_t pos) {
44   if (rank == 1) {
45     auto idxType = rewriter.getIndexType();
46     auto constant = rewriter.create<LLVM::ConstantOp>(
47         loc, typeConverter.convertType(idxType),
48         rewriter.getIntegerAttr(idxType, pos));
49     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50                                                   constant);
51   }
52   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
53                                               rewriter.getI64ArrayAttr(pos));
54 }
55 
56 // Helper that picks the proper sequence for extracting.
57 static Value extractOne(ConversionPatternRewriter &rewriter,
58                         LLVMTypeConverter &typeConverter, Location loc,
59                         Value val, Type llvmType, int64_t rank, int64_t pos) {
60   if (rank == 1) {
61     auto idxType = rewriter.getIndexType();
62     auto constant = rewriter.create<LLVM::ConstantOp>(
63         loc, typeConverter.convertType(idxType),
64         rewriter.getIntegerAttr(idxType, pos));
65     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
66                                                    constant);
67   }
68   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
69                                                rewriter.getI64ArrayAttr(pos));
70 }
71 
72 // Helper that returns data layout alignment of a memref.
73 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
74                                  MemRefType memrefType, unsigned &align) {
75   Type elementTy = typeConverter.convertType(memrefType.getElementType());
76   if (!elementTy)
77     return failure();
78 
79   // TODO: this should use the MLIR data layout when it becomes available and
80   // stop depending on translation.
81   llvm::LLVMContext llvmContext;
82   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
83               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
84   return success();
85 }
86 
87 // Return the minimal alignment value that satisfies all the AssumeAlignment
88 // uses of `value`. If no such uses exist, return 1.
89 static unsigned getAssumedAlignment(Value value) {
90   unsigned align = 1;
91   for (auto &u : value.getUses()) {
92     Operation *owner = u.getOwner();
93     if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
94       align = mlir::lcm(align, op.alignment());
95   }
96   return align;
97 }
98 
99 // Helper that returns data layout alignment of a memref associated with a
100 // load, store, scatter, or gather op, including additional information from
101 // assume_alignment calls on the source of the transfer
102 template <class OpAdaptor>
103 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
104                                    OpAdaptor op, unsigned &align) {
105   if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
106     return failure();
107   align = std::max(align, getAssumedAlignment(op.base()));
108   return success();
109 }
110 
111 // Add an index vector component to a base pointer. This almost always succeeds
112 // unless the last stride is non-unit or the memory space is not zero.
113 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
114                                     Location loc, Value memref, Value base,
115                                     Value index, MemRefType memRefType,
116                                     VectorType vType, Value &ptrs) {
117   int64_t offset;
118   SmallVector<int64_t, 4> strides;
119   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
120   if (failed(successStrides) || strides.back() != 1 ||
121       memRefType.getMemorySpaceAsInt() != 0)
122     return failure();
123   auto pType = MemRefDescriptor(memref).getElementPtrType();
124   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
125   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
126   return success();
127 }
128 
129 // Casts a strided element pointer to a vector pointer.  The vector pointer
130 // will be in the same address space as the incoming memref type.
131 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
132                          Value ptr, MemRefType memRefType, Type vt) {
133   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
134   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
135 }
136 
137 namespace {
138 
139 /// Conversion pattern for a vector.bitcast.
140 class VectorBitCastOpConversion
141     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
142 public:
143   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
144 
145   LogicalResult
146   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
147                   ConversionPatternRewriter &rewriter) const override {
148     // Only 1-D vectors can be lowered to LLVM.
149     VectorType resultTy = bitCastOp.getType();
150     if (resultTy.getRank() != 1)
151       return failure();
152     Type newResultTy = typeConverter->convertType(resultTy);
153     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
154                                                  adaptor.getOperands()[0]);
155     return success();
156   }
157 };
158 
159 /// Conversion pattern for a vector.matrix_multiply.
160 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
161 class VectorMatmulOpConversion
162     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
163 public:
164   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
165 
166   LogicalResult
167   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override {
169     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
170         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
171         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
172         matmulOp.lhs_columns(), matmulOp.rhs_columns());
173     return success();
174   }
175 };
176 
177 /// Conversion pattern for a vector.flat_transpose.
178 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
179 class VectorFlatTransposeOpConversion
180     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
181 public:
182   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
183 
184   LogicalResult
185   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
186                   ConversionPatternRewriter &rewriter) const override {
187     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
188         transOp, typeConverter->convertType(transOp.res().getType()),
189         adaptor.matrix(), transOp.rows(), transOp.columns());
190     return success();
191   }
192 };
193 
194 /// Overloaded utility that replaces a vector.load, vector.store,
195 /// vector.maskedload and vector.maskedstore with their respective LLVM
196 /// couterparts.
197 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
198                                  vector::LoadOpAdaptor adaptor,
199                                  VectorType vectorTy, Value ptr, unsigned align,
200                                  ConversionPatternRewriter &rewriter) {
201   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
202 }
203 
204 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
205                                  vector::MaskedLoadOpAdaptor adaptor,
206                                  VectorType vectorTy, Value ptr, unsigned align,
207                                  ConversionPatternRewriter &rewriter) {
208   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
209       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
210 }
211 
212 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
213                                  vector::StoreOpAdaptor adaptor,
214                                  VectorType vectorTy, Value ptr, unsigned align,
215                                  ConversionPatternRewriter &rewriter) {
216   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
217                                              ptr, align);
218 }
219 
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221                                  vector::MaskedStoreOpAdaptor adaptor,
222                                  VectorType vectorTy, Value ptr, unsigned align,
223                                  ConversionPatternRewriter &rewriter) {
224   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
225       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
226 }
227 
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
232 public:
233   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
234 
235   LogicalResult
236   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237                   typename LoadOrStoreOp::Adaptor adaptor,
238                   ConversionPatternRewriter &rewriter) const override {
239     // Only 1-D vectors can be lowered to LLVM.
240     VectorType vectorTy = loadOrStoreOp.getVectorType();
241     if (vectorTy.getRank() > 1)
242       return failure();
243 
244     auto loc = loadOrStoreOp->getLoc();
245     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
246 
247     // Resolve alignment.
248     unsigned align;
249     if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
250                                     align)))
251       return failure();
252 
253     // Resolve address.
254     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
255                      .template cast<VectorType>();
256     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
257                                                adaptor.indices(), rewriter);
258     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
259 
260     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
261     return success();
262   }
263 };
264 
265 /// Conversion pattern for a vector.gather.
266 class VectorGatherOpConversion
267     : public ConvertOpToLLVMPattern<vector::GatherOp> {
268 public:
269   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
270 
271   LogicalResult
272   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
273                   ConversionPatternRewriter &rewriter) const override {
274     auto loc = gather->getLoc();
275     MemRefType memRefType = gather.getMemRefType();
276 
277     // Resolve alignment.
278     unsigned align;
279     if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
280       return failure();
281 
282     // Resolve address.
283     Value ptrs;
284     VectorType vType = gather.getVectorType();
285     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
286                                      adaptor.indices(), rewriter);
287     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
288                               adaptor.index_vec(), memRefType, vType, ptrs)))
289       return failure();
290 
291     // Replace with the gather intrinsic.
292     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
294         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
295     return success();
296   }
297 };
298 
299 /// Conversion pattern for a vector.scatter.
300 class VectorScatterOpConversion
301     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
302 public:
303   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
304 
305   LogicalResult
306   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
307                   ConversionPatternRewriter &rewriter) const override {
308     auto loc = scatter->getLoc();
309     MemRefType memRefType = scatter.getMemRefType();
310 
311     // Resolve alignment.
312     unsigned align;
313     if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
314       return failure();
315 
316     // Resolve address.
317     Value ptrs;
318     VectorType vType = scatter.getVectorType();
319     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
320                                      adaptor.indices(), rewriter);
321     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
322                               adaptor.index_vec(), memRefType, vType, ptrs)))
323       return failure();
324 
325     // Replace with the scatter intrinsic.
326     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
327         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
328         rewriter.getI32IntegerAttr(align));
329     return success();
330   }
331 };
332 
333 /// Conversion pattern for a vector.expandload.
334 class VectorExpandLoadOpConversion
335     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
336 public:
337   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
338 
339   LogicalResult
340   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto loc = expand->getLoc();
343     MemRefType memRefType = expand.getMemRefType();
344 
345     // Resolve address.
346     auto vtype = typeConverter->convertType(expand.getVectorType());
347     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
348                                      adaptor.indices(), rewriter);
349 
350     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
351         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
352     return success();
353   }
354 };
355 
356 /// Conversion pattern for a vector.compressstore.
357 class VectorCompressStoreOpConversion
358     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
359 public:
360   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
361 
362   LogicalResult
363   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
364                   ConversionPatternRewriter &rewriter) const override {
365     auto loc = compress->getLoc();
366     MemRefType memRefType = compress.getMemRefType();
367 
368     // Resolve address.
369     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
370                                      adaptor.indices(), rewriter);
371 
372     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
373         compress, adaptor.valueToStore(), ptr, adaptor.mask());
374     return success();
375   }
376 };
377 
378 /// Conversion pattern for all vector reductions.
379 class VectorReductionOpConversion
380     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
381 public:
382   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
383                                        bool reassociateFPRed)
384       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
385         reassociateFPReductions(reassociateFPRed) {}
386 
387   LogicalResult
388   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
389                   ConversionPatternRewriter &rewriter) const override {
390     auto kind = reductionOp.kind();
391     Type eltType = reductionOp.dest().getType();
392     Type llvmType = typeConverter->convertType(eltType);
393     Value operand = adaptor.getOperands()[0];
394     if (eltType.isIntOrIndex()) {
395       // Integer reductions: add/mul/min/max/and/or/xor.
396       if (kind == "add")
397         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
398                                                              llvmType, operand);
399       else if (kind == "mul")
400         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
401                                                              llvmType, operand);
402       else if (kind == "minui")
403         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
404             reductionOp, llvmType, operand);
405       else if (kind == "minsi")
406         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
407             reductionOp, llvmType, operand);
408       else if (kind == "maxui")
409         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
410             reductionOp, llvmType, operand);
411       else if (kind == "maxsi")
412         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
413             reductionOp, llvmType, operand);
414       else if (kind == "and")
415         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
416                                                              llvmType, operand);
417       else if (kind == "or")
418         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
419                                                             llvmType, operand);
420       else if (kind == "xor")
421         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
422                                                              llvmType, operand);
423       else
424         return failure();
425       return success();
426     }
427 
428     if (!eltType.isa<FloatType>())
429       return failure();
430 
431     // Floating-point reductions: add/mul/min/max
432     if (kind == "add") {
433       // Optional accumulator (or zero).
434       Value acc = adaptor.getOperands().size() > 1
435                       ? adaptor.getOperands()[1]
436                       : rewriter.create<LLVM::ConstantOp>(
437                             reductionOp->getLoc(), llvmType,
438                             rewriter.getZeroAttr(eltType));
439       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
440           reductionOp, llvmType, acc, operand,
441           rewriter.getBoolAttr(reassociateFPReductions));
442     } else if (kind == "mul") {
443       // Optional accumulator (or one).
444       Value acc = adaptor.getOperands().size() > 1
445                       ? adaptor.getOperands()[1]
446                       : rewriter.create<LLVM::ConstantOp>(
447                             reductionOp->getLoc(), llvmType,
448                             rewriter.getFloatAttr(eltType, 1.0));
449       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
450           reductionOp, llvmType, acc, operand,
451           rewriter.getBoolAttr(reassociateFPReductions));
452     } else if (kind == "minf")
453       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
454       // NaNs/-0.0/+0.0 in the same way.
455       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
456                                                             llvmType, operand);
457     else if (kind == "maxf")
458       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
459       // NaNs/-0.0/+0.0 in the same way.
460       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
461                                                             llvmType, operand);
462     else
463       return failure();
464     return success();
465   }
466 
467 private:
468   const bool reassociateFPReductions;
469 };
470 
471 class VectorShuffleOpConversion
472     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
473 public:
474   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
475 
476   LogicalResult
477   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
478                   ConversionPatternRewriter &rewriter) const override {
479     auto loc = shuffleOp->getLoc();
480     auto v1Type = shuffleOp.getV1VectorType();
481     auto v2Type = shuffleOp.getV2VectorType();
482     auto vectorType = shuffleOp.getVectorType();
483     Type llvmType = typeConverter->convertType(vectorType);
484     auto maskArrayAttr = shuffleOp.mask();
485 
486     // Bail if result type cannot be lowered.
487     if (!llvmType)
488       return failure();
489 
490     // Get rank and dimension sizes.
491     int64_t rank = vectorType.getRank();
492     assert(v1Type.getRank() == rank);
493     assert(v2Type.getRank() == rank);
494     int64_t v1Dim = v1Type.getDimSize(0);
495 
496     // For rank 1, where both operands have *exactly* the same vector type,
497     // there is direct shuffle support in LLVM. Use it!
498     if (rank == 1 && v1Type == v2Type) {
499       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
500           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
501       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
502       return success();
503     }
504 
505     // For all other cases, insert the individual values individually.
506     Type eltType;
507     llvm::errs() << llvmType << "\n";
508     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
509       eltType = arrayType.getElementType();
510     else
511       eltType = llvmType.cast<VectorType>().getElementType();
512     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
513     int64_t insPos = 0;
514     for (auto en : llvm::enumerate(maskArrayAttr)) {
515       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
516       Value value = adaptor.v1();
517       if (extPos >= v1Dim) {
518         extPos -= v1Dim;
519         value = adaptor.v2();
520       }
521       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
522                                  eltType, rank, extPos);
523       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
524                          llvmType, rank, insPos++);
525     }
526     rewriter.replaceOp(shuffleOp, insert);
527     return success();
528   }
529 };
530 
531 class VectorExtractElementOpConversion
532     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
533 public:
534   using ConvertOpToLLVMPattern<
535       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
536 
537   LogicalResult
538   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     auto vectorType = extractEltOp.getVectorType();
541     auto llvmType = typeConverter->convertType(vectorType.getElementType());
542 
543     // Bail if result type cannot be lowered.
544     if (!llvmType)
545       return failure();
546 
547     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
548         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
549     return success();
550   }
551 };
552 
553 class VectorExtractOpConversion
554     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
555 public:
556   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
557 
558   LogicalResult
559   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
560                   ConversionPatternRewriter &rewriter) const override {
561     auto loc = extractOp->getLoc();
562     auto vectorType = extractOp.getVectorType();
563     auto resultType = extractOp.getResult().getType();
564     auto llvmResultType = typeConverter->convertType(resultType);
565     auto positionArrayAttr = extractOp.position();
566 
567     // Bail if result type cannot be lowered.
568     if (!llvmResultType)
569       return failure();
570 
571     // Extract entire vector. Should be handled by folder, but just to be safe.
572     if (positionArrayAttr.empty()) {
573       rewriter.replaceOp(extractOp, adaptor.vector());
574       return success();
575     }
576 
577     // One-shot extraction of vector from array (only requires extractvalue).
578     if (resultType.isa<VectorType>()) {
579       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
580           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
581       rewriter.replaceOp(extractOp, extracted);
582       return success();
583     }
584 
585     // Potential extraction of 1-D vector from array.
586     auto *context = extractOp->getContext();
587     Value extracted = adaptor.vector();
588     auto positionAttrs = positionArrayAttr.getValue();
589     if (positionAttrs.size() > 1) {
590       auto oneDVectorType = reducedVectorTypeBack(vectorType);
591       auto nMinusOnePositionAttrs =
592           ArrayAttr::get(context, positionAttrs.drop_back());
593       extracted = rewriter.create<LLVM::ExtractValueOp>(
594           loc, typeConverter->convertType(oneDVectorType), extracted,
595           nMinusOnePositionAttrs);
596     }
597 
598     // Remaining extraction of element from 1-D LLVM vector
599     auto position = positionAttrs.back().cast<IntegerAttr>();
600     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
601     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
602     extracted =
603         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
604     rewriter.replaceOp(extractOp, extracted);
605 
606     return success();
607   }
608 };
609 
610 /// Conversion pattern that turns a vector.fma on a 1-D vector
611 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
612 /// This does not match vectors of n >= 2 rank.
613 ///
614 /// Example:
615 /// ```
616 ///  vector.fma %a, %a, %a : vector<8xf32>
617 /// ```
618 /// is converted to:
619 /// ```
620 ///  llvm.intr.fmuladd %va, %va, %va:
621 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
622 ///    -> !llvm."<8 x f32>">
623 /// ```
624 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
625 public:
626   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
627 
628   LogicalResult
629   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
630                   ConversionPatternRewriter &rewriter) const override {
631     VectorType vType = fmaOp.getVectorType();
632     if (vType.getRank() != 1)
633       return failure();
634     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
635                                                  adaptor.rhs(), adaptor.acc());
636     return success();
637   }
638 };
639 
640 class VectorInsertElementOpConversion
641     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
642 public:
643   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
644 
645   LogicalResult
646   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
647                   ConversionPatternRewriter &rewriter) const override {
648     auto vectorType = insertEltOp.getDestVectorType();
649     auto llvmType = typeConverter->convertType(vectorType);
650 
651     // Bail if result type cannot be lowered.
652     if (!llvmType)
653       return failure();
654 
655     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
656         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
657         adaptor.position());
658     return success();
659   }
660 };
661 
662 class VectorInsertOpConversion
663     : public ConvertOpToLLVMPattern<vector::InsertOp> {
664 public:
665   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
666 
667   LogicalResult
668   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
669                   ConversionPatternRewriter &rewriter) const override {
670     auto loc = insertOp->getLoc();
671     auto sourceType = insertOp.getSourceType();
672     auto destVectorType = insertOp.getDestVectorType();
673     auto llvmResultType = typeConverter->convertType(destVectorType);
674     auto positionArrayAttr = insertOp.position();
675 
676     // Bail if result type cannot be lowered.
677     if (!llvmResultType)
678       return failure();
679 
680     // Overwrite entire vector with value. Should be handled by folder, but
681     // just to be safe.
682     if (positionArrayAttr.empty()) {
683       rewriter.replaceOp(insertOp, adaptor.source());
684       return success();
685     }
686 
687     // One-shot insertion of a vector into an array (only requires insertvalue).
688     if (sourceType.isa<VectorType>()) {
689       Value inserted = rewriter.create<LLVM::InsertValueOp>(
690           loc, llvmResultType, adaptor.dest(), adaptor.source(),
691           positionArrayAttr);
692       rewriter.replaceOp(insertOp, inserted);
693       return success();
694     }
695 
696     // Potential extraction of 1-D vector from array.
697     auto *context = insertOp->getContext();
698     Value extracted = adaptor.dest();
699     auto positionAttrs = positionArrayAttr.getValue();
700     auto position = positionAttrs.back().cast<IntegerAttr>();
701     auto oneDVectorType = destVectorType;
702     if (positionAttrs.size() > 1) {
703       oneDVectorType = reducedVectorTypeBack(destVectorType);
704       auto nMinusOnePositionAttrs =
705           ArrayAttr::get(context, positionAttrs.drop_back());
706       extracted = rewriter.create<LLVM::ExtractValueOp>(
707           loc, typeConverter->convertType(oneDVectorType), extracted,
708           nMinusOnePositionAttrs);
709     }
710 
711     // Insertion of an element into a 1-D LLVM vector.
712     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
713     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
714     Value inserted = rewriter.create<LLVM::InsertElementOp>(
715         loc, typeConverter->convertType(oneDVectorType), extracted,
716         adaptor.source(), constant);
717 
718     // Potential insertion of resulting 1-D vector into array.
719     if (positionAttrs.size() > 1) {
720       auto nMinusOnePositionAttrs =
721           ArrayAttr::get(context, positionAttrs.drop_back());
722       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
723                                                       adaptor.dest(), inserted,
724                                                       nMinusOnePositionAttrs);
725     }
726 
727     rewriter.replaceOp(insertOp, inserted);
728     return success();
729   }
730 };
731 
732 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
733 ///
734 /// Example:
735 /// ```
736 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
737 /// ```
738 /// is rewritten into:
739 /// ```
740 ///  %r = splat %f0: vector<2x4xf32>
741 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
742 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
743 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
744 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
745 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
746 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
747 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
748 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
749 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
750 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
751 ///  // %r3 holds the final value.
752 /// ```
753 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
754 public:
755   using OpRewritePattern<FMAOp>::OpRewritePattern;
756 
757   LogicalResult matchAndRewrite(FMAOp op,
758                                 PatternRewriter &rewriter) const override {
759     auto vType = op.getVectorType();
760     if (vType.getRank() < 2)
761       return failure();
762 
763     auto loc = op.getLoc();
764     auto elemType = vType.getElementType();
765     Value zero = rewriter.create<arith::ConstantOp>(
766         loc, elemType, rewriter.getZeroAttr(elemType));
767     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
768     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
769       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
770       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
771       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
772       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
773       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
774     }
775     rewriter.replaceOp(op, desc);
776     return success();
777   }
778 };
779 
780 /// Returns the strides if the memory underlying `memRefType` has a contiguous
781 /// static layout.
782 static llvm::Optional<SmallVector<int64_t, 4>>
783 computeContiguousStrides(MemRefType memRefType) {
784   int64_t offset;
785   SmallVector<int64_t, 4> strides;
786   if (failed(getStridesAndOffset(memRefType, strides, offset)))
787     return None;
788   if (!strides.empty() && strides.back() != 1)
789     return None;
790   // If no layout or identity layout, this is contiguous by definition.
791   if (memRefType.getLayout().isIdentity())
792     return strides;
793 
794   // Otherwise, we must determine contiguity form shapes. This can only ever
795   // work in static cases because MemRefType is underspecified to represent
796   // contiguous dynamic shapes in other ways than with just empty/identity
797   // layout.
798   auto sizes = memRefType.getShape();
799   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
800     if (ShapedType::isDynamic(sizes[index + 1]) ||
801         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
802         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
803       return None;
804     if (strides[index] != strides[index + 1] * sizes[index + 1])
805       return None;
806   }
807   return strides;
808 }
809 
810 class VectorTypeCastOpConversion
811     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
812 public:
813   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
814 
815   LogicalResult
816   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
817                   ConversionPatternRewriter &rewriter) const override {
818     auto loc = castOp->getLoc();
819     MemRefType sourceMemRefType =
820         castOp.getOperand().getType().cast<MemRefType>();
821     MemRefType targetMemRefType = castOp.getType();
822 
823     // Only static shape casts supported atm.
824     if (!sourceMemRefType.hasStaticShape() ||
825         !targetMemRefType.hasStaticShape())
826       return failure();
827 
828     auto llvmSourceDescriptorTy =
829         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
830     if (!llvmSourceDescriptorTy)
831       return failure();
832     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
833 
834     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
835                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
836     if (!llvmTargetDescriptorTy)
837       return failure();
838 
839     // Only contiguous source buffers supported atm.
840     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
841     if (!sourceStrides)
842       return failure();
843     auto targetStrides = computeContiguousStrides(targetMemRefType);
844     if (!targetStrides)
845       return failure();
846     // Only support static strides for now, regardless of contiguity.
847     if (llvm::any_of(*targetStrides, [](int64_t stride) {
848           return ShapedType::isDynamicStrideOrOffset(stride);
849         }))
850       return failure();
851 
852     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
853 
854     // Create descriptor.
855     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
856     Type llvmTargetElementTy = desc.getElementPtrType();
857     // Set allocated ptr.
858     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
859     allocated =
860         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
861     desc.setAllocatedPtr(rewriter, loc, allocated);
862     // Set aligned ptr.
863     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
864     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
865     desc.setAlignedPtr(rewriter, loc, ptr);
866     // Fill offset 0.
867     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
868     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
869     desc.setOffset(rewriter, loc, zero);
870 
871     // Fill size and stride descriptors in memref.
872     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
873       int64_t index = indexedSize.index();
874       auto sizeAttr =
875           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
876       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
877       desc.setSize(rewriter, loc, index, size);
878       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
879                                                 (*targetStrides)[index]);
880       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
881       desc.setStride(rewriter, loc, index, stride);
882     }
883 
884     rewriter.replaceOp(castOp, {desc});
885     return success();
886   }
887 };
888 
889 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
890 public:
891   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
892 
893   // Proof-of-concept lowering implementation that relies on a small
894   // runtime support library, which only needs to provide a few
895   // printing methods (single value for all data types, opening/closing
896   // bracket, comma, newline). The lowering fully unrolls a vector
897   // in terms of these elementary printing operations. The advantage
898   // of this approach is that the library can remain unaware of all
899   // low-level implementation details of vectors while still supporting
900   // output of any shaped and dimensioned vector. Due to full unrolling,
901   // this approach is less suited for very large vectors though.
902   //
903   // TODO: rely solely on libc in future? something else?
904   //
905   LogicalResult
906   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
907                   ConversionPatternRewriter &rewriter) const override {
908     Type printType = printOp.getPrintType();
909 
910     if (typeConverter->convertType(printType) == nullptr)
911       return failure();
912 
913     // Make sure element type has runtime support.
914     PrintConversion conversion = PrintConversion::None;
915     VectorType vectorType = printType.dyn_cast<VectorType>();
916     Type eltType = vectorType ? vectorType.getElementType() : printType;
917     Operation *printer;
918     if (eltType.isF32()) {
919       printer =
920           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
921     } else if (eltType.isF64()) {
922       printer =
923           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
924     } else if (eltType.isIndex()) {
925       printer =
926           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
927     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
928       // Integers need a zero or sign extension on the operand
929       // (depending on the source type) as well as a signed or
930       // unsigned print method. Up to 64-bit is supported.
931       unsigned width = intTy.getWidth();
932       if (intTy.isUnsigned()) {
933         if (width <= 64) {
934           if (width < 64)
935             conversion = PrintConversion::ZeroExt64;
936           printer = LLVM::lookupOrCreatePrintU64Fn(
937               printOp->getParentOfType<ModuleOp>());
938         } else {
939           return failure();
940         }
941       } else {
942         assert(intTy.isSignless() || intTy.isSigned());
943         if (width <= 64) {
944           // Note that we *always* zero extend booleans (1-bit integers),
945           // so that true/false is printed as 1/0 rather than -1/0.
946           if (width == 1)
947             conversion = PrintConversion::ZeroExt64;
948           else if (width < 64)
949             conversion = PrintConversion::SignExt64;
950           printer = LLVM::lookupOrCreatePrintI64Fn(
951               printOp->getParentOfType<ModuleOp>());
952         } else {
953           return failure();
954         }
955       }
956     } else {
957       return failure();
958     }
959 
960     // Unroll vector into elementary print calls.
961     int64_t rank = vectorType ? vectorType.getRank() : 0;
962     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
963               conversion);
964     emitCall(rewriter, printOp->getLoc(),
965              LLVM::lookupOrCreatePrintNewlineFn(
966                  printOp->getParentOfType<ModuleOp>()));
967     rewriter.eraseOp(printOp);
968     return success();
969   }
970 
971 private:
972   enum class PrintConversion {
973     // clang-format off
974     None,
975     ZeroExt64,
976     SignExt64
977     // clang-format on
978   };
979 
980   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
981                  Value value, VectorType vectorType, Operation *printer,
982                  int64_t rank, PrintConversion conversion) const {
983     Location loc = op->getLoc();
984     if (rank == 0) {
985       switch (conversion) {
986       case PrintConversion::ZeroExt64:
987         value = rewriter.create<arith::ExtUIOp>(
988             loc, value, IntegerType::get(rewriter.getContext(), 64));
989         break;
990       case PrintConversion::SignExt64:
991         value = rewriter.create<arith::ExtSIOp>(
992             loc, value, IntegerType::get(rewriter.getContext(), 64));
993         break;
994       case PrintConversion::None:
995         break;
996       }
997       emitCall(rewriter, loc, printer, value);
998       return;
999     }
1000 
1001     emitCall(rewriter, loc,
1002              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1003     Operation *printComma =
1004         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1005     int64_t dim = vectorType.getDimSize(0);
1006     for (int64_t d = 0; d < dim; ++d) {
1007       auto reducedType =
1008           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1009       auto llvmType = typeConverter->convertType(
1010           rank > 1 ? reducedType : vectorType.getElementType());
1011       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1012                                    llvmType, rank, d);
1013       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1014                 conversion);
1015       if (d != dim - 1)
1016         emitCall(rewriter, loc, printComma);
1017     }
1018     emitCall(rewriter, loc,
1019              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1020   }
1021 
1022   // Helper to emit a call.
1023   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1024                        Operation *ref, ValueRange params = ValueRange()) {
1025     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1026                                   params);
1027   }
1028 };
1029 
1030 } // namespace
1031 
1032 /// Populate the given list with patterns that convert from Vector to LLVM.
1033 void mlir::populateVectorToLLVMConversionPatterns(
1034     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1035     bool reassociateFPReductions) {
1036   MLIRContext *ctx = converter.getDialect()->getContext();
1037   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1038   populateVectorInsertExtractStridedSliceTransforms(patterns);
1039   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1040   patterns
1041       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1042            VectorExtractElementOpConversion, VectorExtractOpConversion,
1043            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1044            VectorInsertOpConversion, VectorPrintOpConversion,
1045            VectorTypeCastOpConversion,
1046            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1047            VectorLoadStoreConversion<vector::MaskedLoadOp,
1048                                      vector::MaskedLoadOpAdaptor>,
1049            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1050            VectorLoadStoreConversion<vector::MaskedStoreOp,
1051                                      vector::MaskedStoreOpAdaptor>,
1052            VectorGatherOpConversion, VectorScatterOpConversion,
1053            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1054           converter);
1055   // Transfer ops with rank > 1 are handled by VectorToSCF.
1056   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1057 }
1058 
1059 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1060     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1061   patterns.add<VectorMatmulOpConversion>(converter);
1062   patterns.add<VectorFlatTransposeOpConversion>(converter);
1063 }
1064