15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
111834ad4aSRiver Riddle #include "../PassDetail.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
135c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
178345b86dSNicolas Vasilache #include "mlir/IR/AffineMap.h"
185c0c51a9SNicolas Vasilache #include "mlir/IR/Builders.h"
1973ca690dSRiver Riddle #include "mlir/IR/BuiltinDialect.h"
205c0c51a9SNicolas Vasilache #include "mlir/IR/StandardTypes.h"
21ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
225c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
23b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
245c0c51a9SNicolas Vasilache #include "mlir/Transforms/Passes.h"
255c0c51a9SNicolas Vasilache #include "llvm/IR/DerivedTypes.h"
265c0c51a9SNicolas Vasilache #include "llvm/IR/Module.h"
275c0c51a9SNicolas Vasilache #include "llvm/IR/Type.h"
285c0c51a9SNicolas Vasilache #include "llvm/Support/Allocator.h"
295c0c51a9SNicolas Vasilache #include "llvm/Support/ErrorHandling.h"
305c0c51a9SNicolas Vasilache 
315c0c51a9SNicolas Vasilache using namespace mlir;
3265678d93SNicolas Vasilache using namespace mlir::vector;
335c0c51a9SNicolas Vasilache 
349826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
359826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
369826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
379826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
389826fe5cSAart Bik }
399826fe5cSAart Bik 
409826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
419826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
429826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
439826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
449826fe5cSAart Bik }
459826fe5cSAart Bik 
461c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
47e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
480f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
490f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
500f04384dSAlex Zinenko                        int64_t pos) {
511c81adf3SAart Bik   if (rank == 1) {
521c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
531c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
540f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
551c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
561c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
571c81adf3SAart Bik                                                   constant);
581c81adf3SAart Bik   }
591c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
601c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
611c81adf3SAart Bik }
621c81adf3SAart Bik 
632d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
642d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
652d515e49SNicolas Vasilache                        Value into, int64_t offset) {
662d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
672d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
682d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
692d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
702d515e49SNicolas Vasilache       loc, vectorType, from, into,
712d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
722d515e49SNicolas Vasilache }
732d515e49SNicolas Vasilache 
741c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
75e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
760f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
770f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
781c81adf3SAart Bik   if (rank == 1) {
791c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
801c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
810f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
821c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
831c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
841c81adf3SAart Bik                                                    constant);
851c81adf3SAart Bik   }
861c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
871c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
881c81adf3SAart Bik }
891c81adf3SAart Bik 
902d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
912d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
922d515e49SNicolas Vasilache                         int64_t offset) {
932d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
942d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
952d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
962d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
972d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
982d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
992d515e49SNicolas Vasilache }
1002d515e49SNicolas Vasilache 
1012d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
1029db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
1032d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
1042d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
1052d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
1062d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
1072d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
1082d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
1092d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
1102d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1112d515e49SNicolas Vasilache        it != eit; ++it)
1122d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1132d515e49SNicolas Vasilache   return res;
1142d515e49SNicolas Vasilache }
1152d515e49SNicolas Vasilache 
116060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
117060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
118060c9dd1Saartbik //
119060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
120060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
121060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
122060c9dd1Saartbik //       is very conservative on the boundary conditions.
123060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
124060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
125060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
126060c9dd1Saartbik   auto loc = op->getLoc();
127060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
128060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
129060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
130060c9dd1Saartbik   Value indices;
131060c9dd1Saartbik   Type idxType;
132060c9dd1Saartbik   if (enableIndexOptimizations) {
1330c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1340c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1350c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
136060c9dd1Saartbik     idxType = rewriter.getI32Type();
137060c9dd1Saartbik   } else {
1380c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1390c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1400c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
141060c9dd1Saartbik     idxType = rewriter.getI64Type();
142060c9dd1Saartbik   }
143060c9dd1Saartbik   // Add in an offset if requested.
144060c9dd1Saartbik   if (off) {
145060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
146060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
147060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
148060c9dd1Saartbik   }
149060c9dd1Saartbik   // Construct the vector comparison.
150060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
151060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
152060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
153060c9dd1Saartbik }
154060c9dd1Saartbik 
15519dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
15619dbb230Saartbik template <typename T>
15719dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
15819dbb230Saartbik                                  unsigned &align) {
1595f9e0466SNicolas Vasilache   Type elementTy =
16019dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1615f9e0466SNicolas Vasilache   if (!elementTy)
1625f9e0466SNicolas Vasilache     return failure();
1635f9e0466SNicolas Vasilache 
164b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
165b2ab375dSAlex Zinenko   // stop depending on translation.
16687a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
16787a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
168b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
169168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1705f9e0466SNicolas Vasilache   return success();
1715f9e0466SNicolas Vasilache }
1725f9e0466SNicolas Vasilache 
173e8dcf5f8Saartbik // Helper that returns the base address of a memref.
174b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
175e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
17619dbb230Saartbik   // Inspect stride and offset structure.
17719dbb230Saartbik   //
17819dbb230Saartbik   // TODO: flat memory only for now, generalize
17919dbb230Saartbik   //
18019dbb230Saartbik   int64_t offset;
18119dbb230Saartbik   SmallVector<int64_t, 4> strides;
18219dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
18319dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
18419dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
18519dbb230Saartbik     return failure();
186e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
187e8dcf5f8Saartbik   return success();
188e8dcf5f8Saartbik }
18919dbb230Saartbik 
190e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
191b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
192b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
193b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
194e8dcf5f8Saartbik   Value base;
195e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
196e8dcf5f8Saartbik     return failure();
1973a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
198e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
199e8dcf5f8Saartbik   return success();
200e8dcf5f8Saartbik }
201e8dcf5f8Saartbik 
20239379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
203b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
204b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
205b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
20639379916Saartbik   Value base;
20739379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
20839379916Saartbik     return failure();
20939379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
21039379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
21139379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
21239379916Saartbik   return success();
21339379916Saartbik }
21439379916Saartbik 
215e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
216e8dcf5f8Saartbik // vector.
217b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
218b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
219b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
220b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
221e8dcf5f8Saartbik   Value base;
222e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
223e8dcf5f8Saartbik     return failure();
2243a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
225e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2261485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
22719dbb230Saartbik   return success();
22819dbb230Saartbik }
22919dbb230Saartbik 
2305f9e0466SNicolas Vasilache static LogicalResult
2315f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2325f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2335f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2345f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
235affbc0cdSNicolas Vasilache   unsigned align;
23619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
237affbc0cdSNicolas Vasilache     return failure();
238affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2395f9e0466SNicolas Vasilache   return success();
2405f9e0466SNicolas Vasilache }
2415f9e0466SNicolas Vasilache 
2425f9e0466SNicolas Vasilache static LogicalResult
2435f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2445f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2455f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2465f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2475f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2485f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2495f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2505f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2515f9e0466SNicolas Vasilache 
2525f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2535f9e0466SNicolas Vasilache   if (!vecTy)
2545f9e0466SNicolas Vasilache     return failure();
2555f9e0466SNicolas Vasilache 
2565f9e0466SNicolas Vasilache   unsigned align;
25719dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2585f9e0466SNicolas Vasilache     return failure();
2595f9e0466SNicolas Vasilache 
2605f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2615f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2625f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2635f9e0466SNicolas Vasilache   return success();
2645f9e0466SNicolas Vasilache }
2655f9e0466SNicolas Vasilache 
2665f9e0466SNicolas Vasilache static LogicalResult
2675f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2685f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2695f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2705f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
271affbc0cdSNicolas Vasilache   unsigned align;
27219dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
273affbc0cdSNicolas Vasilache     return failure();
2742d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
275affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
276affbc0cdSNicolas Vasilache                                              align);
2775f9e0466SNicolas Vasilache   return success();
2785f9e0466SNicolas Vasilache }
2795f9e0466SNicolas Vasilache 
2805f9e0466SNicolas Vasilache static LogicalResult
2815f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2825f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2835f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2845f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2855f9e0466SNicolas Vasilache   unsigned align;
28619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2875f9e0466SNicolas Vasilache     return failure();
2885f9e0466SNicolas Vasilache 
2892d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2905f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2915f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2925f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2935f9e0466SNicolas Vasilache   return success();
2945f9e0466SNicolas Vasilache }
2955f9e0466SNicolas Vasilache 
2962d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2972d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2982d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2995f9e0466SNicolas Vasilache }
3005f9e0466SNicolas Vasilache 
3012d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
3022d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
3032d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
3045f9e0466SNicolas Vasilache }
3055f9e0466SNicolas Vasilache 
30690c01357SBenjamin Kramer namespace {
307e83b7b99Saartbik 
30863b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
30963b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
31063b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern {
31163b683a8SNicolas Vasilache public:
31263b683a8SNicolas Vasilache   explicit VectorMatmulOpConversion(MLIRContext *context,
31363b683a8SNicolas Vasilache                                     LLVMTypeConverter &typeConverter)
31463b683a8SNicolas Vasilache       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
31563b683a8SNicolas Vasilache                              typeConverter) {}
31663b683a8SNicolas Vasilache 
3173145427dSRiver Riddle   LogicalResult
31863b683a8SNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
31963b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
32063b683a8SNicolas Vasilache     auto matmulOp = cast<vector::MatmulOp>(op);
3212d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
32263b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
32363b683a8SNicolas Vasilache         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
32463b683a8SNicolas Vasilache         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
32563b683a8SNicolas Vasilache         matmulOp.rhs_columns());
3263145427dSRiver Riddle     return success();
32763b683a8SNicolas Vasilache   }
32863b683a8SNicolas Vasilache };
32963b683a8SNicolas Vasilache 
330c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
331c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
332c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
333c295a65dSaartbik public:
334c295a65dSaartbik   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
335c295a65dSaartbik                                            LLVMTypeConverter &typeConverter)
336c295a65dSaartbik       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
337c295a65dSaartbik                              context, typeConverter) {}
338c295a65dSaartbik 
339c295a65dSaartbik   LogicalResult
340c295a65dSaartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
341c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
342c295a65dSaartbik     auto transOp = cast<vector::FlatTransposeOp>(op);
3432d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
344c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
345c295a65dSaartbik         transOp, typeConverter.convertType(transOp.res().getType()),
346c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
347c295a65dSaartbik     return success();
348c295a65dSaartbik   }
349c295a65dSaartbik };
350c295a65dSaartbik 
35139379916Saartbik /// Conversion pattern for a vector.maskedload.
35239379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
35339379916Saartbik public:
35439379916Saartbik   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
35539379916Saartbik                                         LLVMTypeConverter &typeConverter)
35639379916Saartbik       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
35739379916Saartbik                              typeConverter) {}
35839379916Saartbik 
35939379916Saartbik   LogicalResult
36039379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
36139379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
36239379916Saartbik     auto loc = op->getLoc();
36339379916Saartbik     auto load = cast<vector::MaskedLoadOp>(op);
36439379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
36539379916Saartbik 
36639379916Saartbik     // Resolve alignment.
36739379916Saartbik     unsigned align;
36839379916Saartbik     if (failed(getMemRefAlignment(typeConverter, load, align)))
36939379916Saartbik       return failure();
37039379916Saartbik 
37139379916Saartbik     auto vtype = typeConverter.convertType(load.getResultVectorType());
37239379916Saartbik     Value ptr;
37339379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
37439379916Saartbik                           vtype, ptr)))
37539379916Saartbik       return failure();
37639379916Saartbik 
37739379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
37839379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
37939379916Saartbik         rewriter.getI32IntegerAttr(align));
38039379916Saartbik     return success();
38139379916Saartbik   }
38239379916Saartbik };
38339379916Saartbik 
38439379916Saartbik /// Conversion pattern for a vector.maskedstore.
38539379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
38639379916Saartbik public:
38739379916Saartbik   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
38839379916Saartbik                                          LLVMTypeConverter &typeConverter)
38939379916Saartbik       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
39039379916Saartbik                              typeConverter) {}
39139379916Saartbik 
39239379916Saartbik   LogicalResult
39339379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
39439379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
39539379916Saartbik     auto loc = op->getLoc();
39639379916Saartbik     auto store = cast<vector::MaskedStoreOp>(op);
39739379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
39839379916Saartbik 
39939379916Saartbik     // Resolve alignment.
40039379916Saartbik     unsigned align;
40139379916Saartbik     if (failed(getMemRefAlignment(typeConverter, store, align)))
40239379916Saartbik       return failure();
40339379916Saartbik 
40439379916Saartbik     auto vtype = typeConverter.convertType(store.getValueVectorType());
40539379916Saartbik     Value ptr;
40639379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
40739379916Saartbik                           vtype, ptr)))
40839379916Saartbik       return failure();
40939379916Saartbik 
41039379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
41139379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
41239379916Saartbik         rewriter.getI32IntegerAttr(align));
41339379916Saartbik     return success();
41439379916Saartbik   }
41539379916Saartbik };
41639379916Saartbik 
41719dbb230Saartbik /// Conversion pattern for a vector.gather.
41819dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern {
41919dbb230Saartbik public:
42019dbb230Saartbik   explicit VectorGatherOpConversion(MLIRContext *context,
42119dbb230Saartbik                                     LLVMTypeConverter &typeConverter)
42219dbb230Saartbik       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
42319dbb230Saartbik                              typeConverter) {}
42419dbb230Saartbik 
42519dbb230Saartbik   LogicalResult
42619dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
42719dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
42819dbb230Saartbik     auto loc = op->getLoc();
42919dbb230Saartbik     auto gather = cast<vector::GatherOp>(op);
43019dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
43119dbb230Saartbik 
43219dbb230Saartbik     // Resolve alignment.
43319dbb230Saartbik     unsigned align;
43419dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, gather, align)))
43519dbb230Saartbik       return failure();
43619dbb230Saartbik 
43719dbb230Saartbik     // Get index ptrs.
43819dbb230Saartbik     VectorType vType = gather.getResultVectorType();
43919dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
44019dbb230Saartbik     Value ptrs;
441e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
442e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
44319dbb230Saartbik       return failure();
44419dbb230Saartbik 
44519dbb230Saartbik     // Replace with the gather intrinsic.
44619dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
4470c2a4d3cSBenjamin Kramer         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
4480c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
44919dbb230Saartbik     return success();
45019dbb230Saartbik   }
45119dbb230Saartbik };
45219dbb230Saartbik 
45319dbb230Saartbik /// Conversion pattern for a vector.scatter.
45419dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern {
45519dbb230Saartbik public:
45619dbb230Saartbik   explicit VectorScatterOpConversion(MLIRContext *context,
45719dbb230Saartbik                                      LLVMTypeConverter &typeConverter)
45819dbb230Saartbik       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
45919dbb230Saartbik                              typeConverter) {}
46019dbb230Saartbik 
46119dbb230Saartbik   LogicalResult
46219dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
46319dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
46419dbb230Saartbik     auto loc = op->getLoc();
46519dbb230Saartbik     auto scatter = cast<vector::ScatterOp>(op);
46619dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
46719dbb230Saartbik 
46819dbb230Saartbik     // Resolve alignment.
46919dbb230Saartbik     unsigned align;
47019dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
47119dbb230Saartbik       return failure();
47219dbb230Saartbik 
47319dbb230Saartbik     // Get index ptrs.
47419dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
47519dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
47619dbb230Saartbik     Value ptrs;
477e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
478e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
47919dbb230Saartbik       return failure();
48019dbb230Saartbik 
48119dbb230Saartbik     // Replace with the scatter intrinsic.
48219dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
48319dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
48419dbb230Saartbik         rewriter.getI32IntegerAttr(align));
48519dbb230Saartbik     return success();
48619dbb230Saartbik   }
48719dbb230Saartbik };
48819dbb230Saartbik 
489e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
490e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
491e8dcf5f8Saartbik public:
492e8dcf5f8Saartbik   explicit VectorExpandLoadOpConversion(MLIRContext *context,
493e8dcf5f8Saartbik                                         LLVMTypeConverter &typeConverter)
494e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
495e8dcf5f8Saartbik                              typeConverter) {}
496e8dcf5f8Saartbik 
497e8dcf5f8Saartbik   LogicalResult
498e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
499e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
500e8dcf5f8Saartbik     auto loc = op->getLoc();
501e8dcf5f8Saartbik     auto expand = cast<vector::ExpandLoadOp>(op);
502e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
503e8dcf5f8Saartbik 
504e8dcf5f8Saartbik     Value ptr;
505e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
506e8dcf5f8Saartbik                           ptr)))
507e8dcf5f8Saartbik       return failure();
508e8dcf5f8Saartbik 
509e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
510e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
511e8dcf5f8Saartbik         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
512e8dcf5f8Saartbik         adaptor.pass_thru());
513e8dcf5f8Saartbik     return success();
514e8dcf5f8Saartbik   }
515e8dcf5f8Saartbik };
516e8dcf5f8Saartbik 
517e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
518e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
519e8dcf5f8Saartbik public:
520e8dcf5f8Saartbik   explicit VectorCompressStoreOpConversion(MLIRContext *context,
521e8dcf5f8Saartbik                                            LLVMTypeConverter &typeConverter)
522e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
523e8dcf5f8Saartbik                              context, typeConverter) {}
524e8dcf5f8Saartbik 
525e8dcf5f8Saartbik   LogicalResult
526e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
527e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
528e8dcf5f8Saartbik     auto loc = op->getLoc();
529e8dcf5f8Saartbik     auto compress = cast<vector::CompressStoreOp>(op);
530e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
531e8dcf5f8Saartbik 
532e8dcf5f8Saartbik     Value ptr;
533e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
534e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
535e8dcf5f8Saartbik       return failure();
536e8dcf5f8Saartbik 
537e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
538e8dcf5f8Saartbik         op, adaptor.value(), ptr, adaptor.mask());
539e8dcf5f8Saartbik     return success();
540e8dcf5f8Saartbik   }
541e8dcf5f8Saartbik };
542e8dcf5f8Saartbik 
54319dbb230Saartbik /// Conversion pattern for all vector reductions.
544870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern {
545e83b7b99Saartbik public:
546e83b7b99Saartbik   explicit VectorReductionOpConversion(MLIRContext *context,
547ceb1b327Saartbik                                        LLVMTypeConverter &typeConverter,
548060c9dd1Saartbik                                        bool reassociateFPRed)
549870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
550ceb1b327Saartbik                              typeConverter),
551060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
552e83b7b99Saartbik 
5533145427dSRiver Riddle   LogicalResult
554e83b7b99Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
555e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
556e83b7b99Saartbik     auto reductionOp = cast<vector::ReductionOp>(op);
557e83b7b99Saartbik     auto kind = reductionOp.kind();
558e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
5590f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(eltType);
560e9628955SAart Bik     if (eltType.isIntOrIndex()) {
561e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
562e83b7b99Saartbik       if (kind == "add")
563322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
564e83b7b99Saartbik             op, llvmType, operands[0]);
565e83b7b99Saartbik       else if (kind == "mul")
566322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
567e83b7b99Saartbik             op, llvmType, operands[0]);
568e9628955SAart Bik       else if (kind == "min" &&
569e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
570322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
571e9628955SAart Bik             op, llvmType, operands[0]);
572e83b7b99Saartbik       else if (kind == "min")
573322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
574e83b7b99Saartbik             op, llvmType, operands[0]);
575e9628955SAart Bik       else if (kind == "max" &&
576e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
577322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
578e9628955SAart Bik             op, llvmType, operands[0]);
579e83b7b99Saartbik       else if (kind == "max")
580322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
581e83b7b99Saartbik             op, llvmType, operands[0]);
582e83b7b99Saartbik       else if (kind == "and")
583322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
584e83b7b99Saartbik             op, llvmType, operands[0]);
585e83b7b99Saartbik       else if (kind == "or")
586322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
587e83b7b99Saartbik             op, llvmType, operands[0]);
588e83b7b99Saartbik       else if (kind == "xor")
589322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
590e83b7b99Saartbik             op, llvmType, operands[0]);
591e83b7b99Saartbik       else
5923145427dSRiver Riddle         return failure();
5933145427dSRiver Riddle       return success();
594e83b7b99Saartbik 
5952d76274bSBenjamin Kramer     } else if (eltType.isa<FloatType>()) {
596e83b7b99Saartbik       // Floating-point reductions: add/mul/min/max
597e83b7b99Saartbik       if (kind == "add") {
5980d924700Saartbik         // Optional accumulator (or zero).
5990d924700Saartbik         Value acc = operands.size() > 1 ? operands[1]
6000d924700Saartbik                                         : rewriter.create<LLVM::ConstantOp>(
6010d924700Saartbik                                               op->getLoc(), llvmType,
6020d924700Saartbik                                               rewriter.getZeroAttr(eltType));
603322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
604ceb1b327Saartbik             op, llvmType, acc, operands[0],
605ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
606e83b7b99Saartbik       } else if (kind == "mul") {
6070d924700Saartbik         // Optional accumulator (or one).
6080d924700Saartbik         Value acc = operands.size() > 1
6090d924700Saartbik                         ? operands[1]
6100d924700Saartbik                         : rewriter.create<LLVM::ConstantOp>(
6110d924700Saartbik                               op->getLoc(), llvmType,
6120d924700Saartbik                               rewriter.getFloatAttr(eltType, 1.0));
613322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
614ceb1b327Saartbik             op, llvmType, acc, operands[0],
615ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
616e83b7b99Saartbik       } else if (kind == "min")
617322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
618e83b7b99Saartbik             op, llvmType, operands[0]);
619e83b7b99Saartbik       else if (kind == "max")
620322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
621e83b7b99Saartbik             op, llvmType, operands[0]);
622e83b7b99Saartbik       else
6233145427dSRiver Riddle         return failure();
6243145427dSRiver Riddle       return success();
625e83b7b99Saartbik     }
6263145427dSRiver Riddle     return failure();
627e83b7b99Saartbik   }
628ceb1b327Saartbik 
629ceb1b327Saartbik private:
630ceb1b327Saartbik   const bool reassociateFPReductions;
631e83b7b99Saartbik };
632e83b7b99Saartbik 
633060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
634060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
635060c9dd1Saartbik public:
636060c9dd1Saartbik   explicit VectorCreateMaskOpConversion(MLIRContext *context,
637060c9dd1Saartbik                                         LLVMTypeConverter &typeConverter,
638060c9dd1Saartbik                                         bool enableIndexOpt)
639060c9dd1Saartbik       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
640060c9dd1Saartbik                              typeConverter),
641060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
642060c9dd1Saartbik 
643060c9dd1Saartbik   LogicalResult
644060c9dd1Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
645060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
646060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
647060c9dd1Saartbik     int64_t rank = dstType.getRank();
648060c9dd1Saartbik     if (rank == 1) {
649060c9dd1Saartbik       rewriter.replaceOp(
650060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
651060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
652060c9dd1Saartbik       return success();
653060c9dd1Saartbik     }
654060c9dd1Saartbik     return failure();
655060c9dd1Saartbik   }
656060c9dd1Saartbik 
657060c9dd1Saartbik private:
658060c9dd1Saartbik   const bool enableIndexOptimizations;
659060c9dd1Saartbik };
660060c9dd1Saartbik 
661870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern {
6621c81adf3SAart Bik public:
6631c81adf3SAart Bik   explicit VectorShuffleOpConversion(MLIRContext *context,
6641c81adf3SAart Bik                                      LLVMTypeConverter &typeConverter)
665870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
6661c81adf3SAart Bik                              typeConverter) {}
6671c81adf3SAart Bik 
6683145427dSRiver Riddle   LogicalResult
669e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6701c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6711c81adf3SAart Bik     auto loc = op->getLoc();
6722d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6731c81adf3SAart Bik     auto shuffleOp = cast<vector::ShuffleOp>(op);
6741c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6751c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6761c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
6770f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(vectorType);
6781c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6791c81adf3SAart Bik 
6801c81adf3SAart Bik     // Bail if result type cannot be lowered.
6811c81adf3SAart Bik     if (!llvmType)
6823145427dSRiver Riddle       return failure();
6831c81adf3SAart Bik 
6841c81adf3SAart Bik     // Get rank and dimension sizes.
6851c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6861c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6871c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6881c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6891c81adf3SAart Bik 
6901c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6911c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6921c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
693e62a6956SRiver Riddle       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
6941c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
6951c81adf3SAart Bik       rewriter.replaceOp(op, shuffle);
6963145427dSRiver Riddle       return success();
697b36aaeafSAart Bik     }
698b36aaeafSAart Bik 
6991c81adf3SAart Bik     // For all other cases, insert the individual values individually.
700e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
7011c81adf3SAart Bik     int64_t insPos = 0;
7021c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
7031c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
704e62a6956SRiver Riddle       Value value = adaptor.v1();
7051c81adf3SAart Bik       if (extPos >= v1Dim) {
7061c81adf3SAart Bik         extPos -= v1Dim;
7071c81adf3SAart Bik         value = adaptor.v2();
708b36aaeafSAart Bik       }
7090f04384dSAlex Zinenko       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
7100f04384dSAlex Zinenko                                  rank, extPos);
7110f04384dSAlex Zinenko       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
7120f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7131c81adf3SAart Bik     }
7141c81adf3SAart Bik     rewriter.replaceOp(op, insert);
7153145427dSRiver Riddle     return success();
716b36aaeafSAart Bik   }
717b36aaeafSAart Bik };
718b36aaeafSAart Bik 
719870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
720cd5dab8aSAart Bik public:
721cd5dab8aSAart Bik   explicit VectorExtractElementOpConversion(MLIRContext *context,
722cd5dab8aSAart Bik                                             LLVMTypeConverter &typeConverter)
723870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
724870c1fd4SAlex Zinenko                              context, typeConverter) {}
725cd5dab8aSAart Bik 
7263145427dSRiver Riddle   LogicalResult
727e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
728cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7292d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
730cd5dab8aSAart Bik     auto extractEltOp = cast<vector::ExtractElementOp>(op);
731cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
7320f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType.getElementType());
733cd5dab8aSAart Bik 
734cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
735cd5dab8aSAart Bik     if (!llvmType)
7363145427dSRiver Riddle       return failure();
737cd5dab8aSAart Bik 
738cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
739cd5dab8aSAart Bik         op, llvmType, adaptor.vector(), adaptor.position());
7403145427dSRiver Riddle     return success();
741cd5dab8aSAart Bik   }
742cd5dab8aSAart Bik };
743cd5dab8aSAart Bik 
744870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern {
7455c0c51a9SNicolas Vasilache public:
7469826fe5cSAart Bik   explicit VectorExtractOpConversion(MLIRContext *context,
7475c0c51a9SNicolas Vasilache                                      LLVMTypeConverter &typeConverter)
748870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
7495c0c51a9SNicolas Vasilache                              typeConverter) {}
7505c0c51a9SNicolas Vasilache 
7513145427dSRiver Riddle   LogicalResult
752e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7535c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7545c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
7552d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
756d37f2725SAart Bik     auto extractOp = cast<vector::ExtractOp>(op);
7579826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7582bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
7590f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(resultType);
7605c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7619826fe5cSAart Bik 
7629826fe5cSAart Bik     // Bail if result type cannot be lowered.
7639826fe5cSAart Bik     if (!llvmResultType)
7643145427dSRiver Riddle       return failure();
7659826fe5cSAart Bik 
7665c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7675c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
768e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7695c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
7705c0c51a9SNicolas Vasilache       rewriter.replaceOp(op, extracted);
7713145427dSRiver Riddle       return success();
7725c0c51a9SNicolas Vasilache     }
7735c0c51a9SNicolas Vasilache 
7749826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
7755c0c51a9SNicolas Vasilache     auto *context = op->getContext();
776e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7775c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7785c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7799826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7805c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7815c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7825c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
7830f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
7845c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7855c0c51a9SNicolas Vasilache     }
7865c0c51a9SNicolas Vasilache 
7875c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7885c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7895446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7901d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7915c0c51a9SNicolas Vasilache     extracted =
7925c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
7935c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, extracted);
7945c0c51a9SNicolas Vasilache 
7953145427dSRiver Riddle     return success();
7965c0c51a9SNicolas Vasilache   }
7975c0c51a9SNicolas Vasilache };
7985c0c51a9SNicolas Vasilache 
799681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
800681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
801681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
802681f929fSNicolas Vasilache ///
803681f929fSNicolas Vasilache /// Example:
804681f929fSNicolas Vasilache /// ```
805681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
806681f929fSNicolas Vasilache /// ```
807681f929fSNicolas Vasilache /// is converted to:
808681f929fSNicolas Vasilache /// ```
8093bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
810681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
811681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
812681f929fSNicolas Vasilache /// ```
813870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
814681f929fSNicolas Vasilache public:
815681f929fSNicolas Vasilache   explicit VectorFMAOp1DConversion(MLIRContext *context,
816681f929fSNicolas Vasilache                                    LLVMTypeConverter &typeConverter)
817870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
818681f929fSNicolas Vasilache                              typeConverter) {}
819681f929fSNicolas Vasilache 
8203145427dSRiver Riddle   LogicalResult
821681f929fSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
822681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8232d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
824681f929fSNicolas Vasilache     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
825681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
826681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8273145427dSRiver Riddle       return failure();
8283bffe602SBenjamin Kramer     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
8293bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8303145427dSRiver Riddle     return success();
831681f929fSNicolas Vasilache   }
832681f929fSNicolas Vasilache };
833681f929fSNicolas Vasilache 
834870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
835cd5dab8aSAart Bik public:
836cd5dab8aSAart Bik   explicit VectorInsertElementOpConversion(MLIRContext *context,
837cd5dab8aSAart Bik                                            LLVMTypeConverter &typeConverter)
838870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
839870c1fd4SAlex Zinenko                              context, typeConverter) {}
840cd5dab8aSAart Bik 
8413145427dSRiver Riddle   LogicalResult
842e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
843cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8442d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
845cd5dab8aSAart Bik     auto insertEltOp = cast<vector::InsertElementOp>(op);
846cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
8470f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType);
848cd5dab8aSAart Bik 
849cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
850cd5dab8aSAart Bik     if (!llvmType)
8513145427dSRiver Riddle       return failure();
852cd5dab8aSAart Bik 
853cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
854cd5dab8aSAart Bik         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
8553145427dSRiver Riddle     return success();
856cd5dab8aSAart Bik   }
857cd5dab8aSAart Bik };
858cd5dab8aSAart Bik 
859870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern {
8609826fe5cSAart Bik public:
8619826fe5cSAart Bik   explicit VectorInsertOpConversion(MLIRContext *context,
8629826fe5cSAart Bik                                     LLVMTypeConverter &typeConverter)
863870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
8649826fe5cSAart Bik                              typeConverter) {}
8659826fe5cSAart Bik 
8663145427dSRiver Riddle   LogicalResult
867e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
8689826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8699826fe5cSAart Bik     auto loc = op->getLoc();
8702d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8719826fe5cSAart Bik     auto insertOp = cast<vector::InsertOp>(op);
8729826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8739826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
8740f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(destVectorType);
8759826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8769826fe5cSAart Bik 
8779826fe5cSAart Bik     // Bail if result type cannot be lowered.
8789826fe5cSAart Bik     if (!llvmResultType)
8793145427dSRiver Riddle       return failure();
8809826fe5cSAart Bik 
8819826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8829826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
883e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8849826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8859826fe5cSAart Bik           positionArrayAttr);
8869826fe5cSAart Bik       rewriter.replaceOp(op, inserted);
8873145427dSRiver Riddle       return success();
8889826fe5cSAart Bik     }
8899826fe5cSAart Bik 
8909826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
8919826fe5cSAart Bik     auto *context = op->getContext();
892e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8939826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8949826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8959826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8969826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8979826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8989826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8999826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9009826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
9010f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
9029826fe5cSAart Bik           nMinusOnePositionAttrs);
9039826fe5cSAart Bik     }
9049826fe5cSAart Bik 
9059826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
9065446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
9071d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
908e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
9090f04384dSAlex Zinenko         loc, typeConverter.convertType(oneDVectorType), extracted,
9100f04384dSAlex Zinenko         adaptor.source(), constant);
9119826fe5cSAart Bik 
9129826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
9139826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9149826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9159826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9169826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
9179826fe5cSAart Bik                                                       adaptor.dest(), inserted,
9189826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
9199826fe5cSAart Bik     }
9209826fe5cSAart Bik 
9219826fe5cSAart Bik     rewriter.replaceOp(op, inserted);
9223145427dSRiver Riddle     return success();
9239826fe5cSAart Bik   }
9249826fe5cSAart Bik };
9259826fe5cSAart Bik 
926681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
927681f929fSNicolas Vasilache ///
928681f929fSNicolas Vasilache /// Example:
929681f929fSNicolas Vasilache /// ```
930681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
931681f929fSNicolas Vasilache /// ```
932681f929fSNicolas Vasilache /// is rewritten into:
933681f929fSNicolas Vasilache /// ```
934681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
935681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
936681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
937681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
938681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
939681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
940681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
941681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
942681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
943681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
944681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
945681f929fSNicolas Vasilache ///  // %r3 holds the final value.
946681f929fSNicolas Vasilache /// ```
947681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
948681f929fSNicolas Vasilache public:
949681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
950681f929fSNicolas Vasilache 
9513145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
952681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
953681f929fSNicolas Vasilache     auto vType = op.getVectorType();
954681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9553145427dSRiver Riddle       return failure();
956681f929fSNicolas Vasilache 
957681f929fSNicolas Vasilache     auto loc = op.getLoc();
958681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
959681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
960681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
961681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
962681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
963681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
964681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
965681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
966681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
967681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
968681f929fSNicolas Vasilache     }
969681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9703145427dSRiver Riddle     return success();
971681f929fSNicolas Vasilache   }
972681f929fSNicolas Vasilache };
973681f929fSNicolas Vasilache 
9742d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9752d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9762d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9772d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9782d515e49SNicolas Vasilache // rank.
9792d515e49SNicolas Vasilache //
9802d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9812d515e49SNicolas Vasilache // have different ranks. In this case:
9822d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9832d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9842d515e49SNicolas Vasilache //   destination subvector
9852d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9862d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9872d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9882d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9892d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9902d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9912d515e49SNicolas Vasilache public:
9922d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9932d515e49SNicolas Vasilache 
9943145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9952d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9962d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9972d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9982d515e49SNicolas Vasilache 
9992d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10003145427dSRiver Riddle       return failure();
10012d515e49SNicolas Vasilache 
10022d515e49SNicolas Vasilache     auto loc = op.getLoc();
10032d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10042d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10052d515e49SNicolas Vasilache     if (rankDiff == 0)
10063145427dSRiver Riddle       return failure();
10072d515e49SNicolas Vasilache 
10082d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
10092d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
10102d515e49SNicolas Vasilache     // on it.
10112d515e49SNicolas Vasilache     Value extracted =
10122d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
10132d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
10142d515e49SNicolas Vasilache                                                   /*dropFront=*/rankRest));
10152d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
10162d515e49SNicolas Vasilache     // ranks.
10172d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
10182d515e49SNicolas Vasilache         loc, op.source(), extracted,
10192d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1020c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
10212d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
10222d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
10232d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
10242d515e49SNicolas Vasilache                        /*dropFront=*/rankRest));
10253145427dSRiver Riddle     return success();
10262d515e49SNicolas Vasilache   }
10272d515e49SNicolas Vasilache };
10282d515e49SNicolas Vasilache 
10292d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10302d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10312d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10322d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10332d515e49SNicolas Vasilache //   destination subvector
10342d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10352d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10362d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10372d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10382d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10392d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10402d515e49SNicolas Vasilache public:
1041b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1042b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1043b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1044b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1045b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1046b99bd771SRiver Riddle   }
10472d515e49SNicolas Vasilache 
10483145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10492d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10502d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10512d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10522d515e49SNicolas Vasilache 
10532d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10543145427dSRiver Riddle       return failure();
10552d515e49SNicolas Vasilache 
10562d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10572d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10582d515e49SNicolas Vasilache     if (rankDiff != 0)
10593145427dSRiver Riddle       return failure();
10602d515e49SNicolas Vasilache 
10612d515e49SNicolas Vasilache     if (srcType == dstType) {
10622d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10633145427dSRiver Riddle       return success();
10642d515e49SNicolas Vasilache     }
10652d515e49SNicolas Vasilache 
10662d515e49SNicolas Vasilache     int64_t offset =
10672d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10682d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10692d515e49SNicolas Vasilache     int64_t stride =
10702d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10712d515e49SNicolas Vasilache 
10722d515e49SNicolas Vasilache     auto loc = op.getLoc();
10732d515e49SNicolas Vasilache     Value res = op.dest();
10742d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10752d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10762d515e49SNicolas Vasilache          off += stride, ++idx) {
10772d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10782d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10792d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10802d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10812d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10822d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10832d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10842d515e49SNicolas Vasilache         // smaller rank.
1085bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10862d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10872d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10882d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10892d515e49SNicolas Vasilache       }
10902d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10912d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10922d515e49SNicolas Vasilache     }
10932d515e49SNicolas Vasilache 
10942d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10953145427dSRiver Riddle     return success();
10962d515e49SNicolas Vasilache   }
10972d515e49SNicolas Vasilache };
10982d515e49SNicolas Vasilache 
109930e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
110030e6033bSNicolas Vasilache /// static layout.
110130e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
110230e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
11032bf491c7SBenjamin Kramer   int64_t offset;
110430e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
110530e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
110630e6033bSNicolas Vasilache     return None;
110730e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
110830e6033bSNicolas Vasilache     return None;
110930e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
111030e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
111130e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
111230e6033bSNicolas Vasilache     return strides;
111330e6033bSNicolas Vasilache 
111430e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
111530e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
111630e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
111730e6033bSNicolas Vasilache   // layout.
11182bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
11192bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
112030e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
112130e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
112230e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
112330e6033bSNicolas Vasilache       return None;
112430e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
112530e6033bSNicolas Vasilache       return None;
11262bf491c7SBenjamin Kramer   }
112730e6033bSNicolas Vasilache   return strides;
11282bf491c7SBenjamin Kramer }
11292bf491c7SBenjamin Kramer 
1130870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
11315c0c51a9SNicolas Vasilache public:
11325c0c51a9SNicolas Vasilache   explicit VectorTypeCastOpConversion(MLIRContext *context,
11335c0c51a9SNicolas Vasilache                                       LLVMTypeConverter &typeConverter)
1134870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
11355c0c51a9SNicolas Vasilache                              typeConverter) {}
11365c0c51a9SNicolas Vasilache 
11373145427dSRiver Riddle   LogicalResult
1138e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11395c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11405c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
11415c0c51a9SNicolas Vasilache     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
11425c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11432bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11445c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
11452bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
11465c0c51a9SNicolas Vasilache 
11475c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11485c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11495c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11503145427dSRiver Riddle       return failure();
11515c0c51a9SNicolas Vasilache 
11525c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11532bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
11545c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
11553145427dSRiver Riddle       return failure();
11565c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11575c0c51a9SNicolas Vasilache 
11580f04384dSAlex Zinenko     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
11595c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11605c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11613145427dSRiver Riddle       return failure();
11625c0c51a9SNicolas Vasilache 
116330e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
116430e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
116530e6033bSNicolas Vasilache     if (!sourceStrides)
116630e6033bSNicolas Vasilache       return failure();
116730e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
116830e6033bSNicolas Vasilache     if (!targetStrides)
116930e6033bSNicolas Vasilache       return failure();
117030e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
117130e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
117230e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
117330e6033bSNicolas Vasilache         }))
11743145427dSRiver Riddle       return failure();
11755c0c51a9SNicolas Vasilache 
11765446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11775c0c51a9SNicolas Vasilache 
11785c0c51a9SNicolas Vasilache     // Create descriptor.
11795c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11803a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11815c0c51a9SNicolas Vasilache     // Set allocated ptr.
1182e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11835c0c51a9SNicolas Vasilache     allocated =
11845c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11855c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11865c0c51a9SNicolas Vasilache     // Set aligned ptr.
1187e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11885c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11895c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11905c0c51a9SNicolas Vasilache     // Fill offset 0.
11915c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11925c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11935c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11945c0c51a9SNicolas Vasilache 
11955c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11965c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11975c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11985c0c51a9SNicolas Vasilache       auto sizeAttr =
11995c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
12005c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
12015c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
120230e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
120330e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
12045c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
12055c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
12065c0c51a9SNicolas Vasilache     }
12075c0c51a9SNicolas Vasilache 
12085c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, {desc});
12093145427dSRiver Riddle     return success();
12105c0c51a9SNicolas Vasilache   }
12115c0c51a9SNicolas Vasilache };
12125c0c51a9SNicolas Vasilache 
12138345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
12148345b86dSNicolas Vasilache /// sequence of:
1215060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1216060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1217060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1218060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1219060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
12208345b86dSNicolas Vasilache template <typename ConcreteOp>
12218345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern {
12228345b86dSNicolas Vasilache public:
12238345b86dSNicolas Vasilache   explicit VectorTransferConversion(MLIRContext *context,
1224060c9dd1Saartbik                                     LLVMTypeConverter &typeConv,
1225060c9dd1Saartbik                                     bool enableIndexOpt)
1226060c9dd1Saartbik       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1227060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
12288345b86dSNicolas Vasilache 
12298345b86dSNicolas Vasilache   LogicalResult
12308345b86dSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
12318345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12328345b86dSNicolas Vasilache     auto xferOp = cast<ConcreteOp>(op);
12338345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1234b2c79c50SNicolas Vasilache 
1235b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1236b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12378345b86dSNicolas Vasilache       return failure();
12385f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12395f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12405f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
12415f9e0466SNicolas Vasilache                                        op->getContext()))
12428345b86dSNicolas Vasilache       return failure();
12432bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
124430e6033bSNicolas Vasilache     auto strides = computeContiguousStrides(xferOp.getMemRefType());
124530e6033bSNicolas Vasilache     if (!strides)
12462bf491c7SBenjamin Kramer       return failure();
12478345b86dSNicolas Vasilache 
12488345b86dSNicolas Vasilache     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
12498345b86dSNicolas Vasilache 
12508345b86dSNicolas Vasilache     Location loc = op->getLoc();
12518345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
12528345b86dSNicolas Vasilache 
125368330ee0SThomas Raoux     if (auto memrefVectorElementType =
125468330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
125568330ee0SThomas Raoux       // Memref has vector element type.
125668330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
125768330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
125868330ee0SThomas Raoux         return failure();
12590de60b55SThomas Raoux #ifndef NDEBUG
126068330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
126168330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
126268330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
126368330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
126468330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
126568330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
126668330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
126768330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
126868330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
126968330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
127068330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
127168330ee0SThomas Raoux                "result shape.");
12720de60b55SThomas Raoux #endif // ifndef NDEBUG
127368330ee0SThomas Raoux     }
127468330ee0SThomas Raoux 
12758345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1276be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1277be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1278be16075bSWen-Heng (Jack) Chung     //    address space 0.
12798345b86dSNicolas Vasilache     // TODO: support alignment when possible.
1280*8b97e17dSChristian Sigg     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1281d3a98076SAlex Zinenko                                          adaptor.indices(), rewriter);
12828345b86dSNicolas Vasilache     auto vecTy =
12838345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1284be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1285be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1286be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12878345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1288be16075bSWen-Heng (Jack) Chung     else
1289be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1290be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12918345b86dSNicolas Vasilache 
12921870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
12931870e787SNicolas Vasilache       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
12941870e787SNicolas Vasilache                                               xferOp, operands, vectorDataPtr);
12951870e787SNicolas Vasilache 
12968345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12978345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12988345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12998345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1300060c9dd1Saartbik     //
1301060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1302060c9dd1Saartbik     //       dimensions here.
1303060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1304060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
13050c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
1306b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1307060c9dd1Saartbik     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1308060c9dd1Saartbik                                        vecWidth, dim, &off);
13098345b86dSNicolas Vasilache 
13108345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
13111870e787SNicolas Vasilache     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1312a99f62c4SAlex Zinenko                                        operands, vectorDataPtr, mask);
13138345b86dSNicolas Vasilache   }
1314060c9dd1Saartbik 
1315060c9dd1Saartbik private:
1316060c9dd1Saartbik   const bool enableIndexOptimizations;
13178345b86dSNicolas Vasilache };
13188345b86dSNicolas Vasilache 
1319870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern {
1320d9b500d3SAart Bik public:
1321d9b500d3SAart Bik   explicit VectorPrintOpConversion(MLIRContext *context,
1322d9b500d3SAart Bik                                    LLVMTypeConverter &typeConverter)
1323870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1324d9b500d3SAart Bik                              typeConverter) {}
1325d9b500d3SAart Bik 
1326d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1327d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1328d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1329d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1330d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1331d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1332d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1333d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1334d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1335d9b500d3SAart Bik   //
13369db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1337d9b500d3SAart Bik   //
13383145427dSRiver Riddle   LogicalResult
1339e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1340d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1341d9b500d3SAart Bik     auto printOp = cast<vector::PrintOp>(op);
13422d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1343d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1344d9b500d3SAart Bik 
13450f04384dSAlex Zinenko     if (typeConverter.convertType(printType) == nullptr)
13463145427dSRiver Riddle       return failure();
1347d9b500d3SAart Bik 
1348b8880f5fSAart Bik     // Make sure element type has runtime support.
1349b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1350d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1351d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1352d9b500d3SAart Bik     Operation *printer;
1353b8880f5fSAart Bik     if (eltType.isF32()) {
1354d9b500d3SAart Bik       printer = getPrintFloat(op);
1355b8880f5fSAart Bik     } else if (eltType.isF64()) {
1356d9b500d3SAart Bik       printer = getPrintDouble(op);
135754759cefSAart Bik     } else if (eltType.isIndex()) {
135854759cefSAart Bik       printer = getPrintU64(op);
1359b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1360b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1361b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1362b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1363b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1364b8880f5fSAart Bik       if (intTy.isUnsigned()) {
136554759cefSAart Bik         if (width <= 64) {
1366b8880f5fSAart Bik           if (width < 64)
1367b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1368b8880f5fSAart Bik           printer = getPrintU64(op);
1369b8880f5fSAart Bik         } else {
13703145427dSRiver Riddle           return failure();
1371b8880f5fSAart Bik         }
1372b8880f5fSAart Bik       } else {
1373b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
137454759cefSAart Bik         if (width <= 64) {
1375b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1376b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1377b8880f5fSAart Bik           if (width == 1)
137854759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
137954759cefSAart Bik           else if (width < 64)
1380b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1381b8880f5fSAart Bik           printer = getPrintI64(op);
1382b8880f5fSAart Bik         } else {
1383b8880f5fSAart Bik           return failure();
1384b8880f5fSAart Bik         }
1385b8880f5fSAart Bik       }
1386b8880f5fSAart Bik     } else {
1387b8880f5fSAart Bik       return failure();
1388b8880f5fSAart Bik     }
1389d9b500d3SAart Bik 
1390d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1391b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1392b8880f5fSAart Bik     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1393b8880f5fSAart Bik               conversion);
1394d9b500d3SAart Bik     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1395d9b500d3SAart Bik     rewriter.eraseOp(op);
13963145427dSRiver Riddle     return success();
1397d9b500d3SAart Bik   }
1398d9b500d3SAart Bik 
1399d9b500d3SAart Bik private:
1400b8880f5fSAart Bik   enum class PrintConversion {
140130e6033bSNicolas Vasilache     // clang-format off
1402b8880f5fSAart Bik     None,
1403b8880f5fSAart Bik     ZeroExt64,
1404b8880f5fSAart Bik     SignExt64
140530e6033bSNicolas Vasilache     // clang-format on
1406b8880f5fSAart Bik   };
1407b8880f5fSAart Bik 
1408d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1409e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1410b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1411d9b500d3SAart Bik     Location loc = op->getLoc();
1412d9b500d3SAart Bik     if (rank == 0) {
1413b8880f5fSAart Bik       switch (conversion) {
1414b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1415b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1416b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1417b8880f5fSAart Bik         break;
1418b8880f5fSAart Bik       case PrintConversion::SignExt64:
1419b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1420b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1421b8880f5fSAart Bik         break;
1422b8880f5fSAart Bik       case PrintConversion::None:
1423b8880f5fSAart Bik         break;
1424c9eeeb38Saartbik       }
1425d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1426d9b500d3SAart Bik       return;
1427d9b500d3SAart Bik     }
1428d9b500d3SAart Bik 
1429d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1430d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1431d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1432d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1433d9b500d3SAart Bik       auto reducedType =
1434d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
14350f04384dSAlex Zinenko       auto llvmType = typeConverter.convertType(
1436d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1437e62a6956SRiver Riddle       Value nestedVal =
14380f04384dSAlex Zinenko           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1439b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1440b8880f5fSAart Bik                 conversion);
1441d9b500d3SAart Bik       if (d != dim - 1)
1442d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1443d9b500d3SAart Bik     }
1444d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1445d9b500d3SAart Bik   }
1446d9b500d3SAart Bik 
1447d9b500d3SAart Bik   // Helper to emit a call.
1448d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1449d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
145008e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1451d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1452d9b500d3SAart Bik   }
1453d9b500d3SAart Bik 
1454d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14555446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
14565446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1457d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1458d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1459d9b500d3SAart Bik     if (func)
1460d9b500d3SAart Bik       return func;
1461d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1462d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1463d9b500d3SAart Bik         op->getLoc(), name,
14645446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14655446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14665446ec85SAlex Zinenko             /*isVarArg=*/false));
1467d9b500d3SAart Bik   }
1468d9b500d3SAart Bik 
1469d9b500d3SAart Bik   // Helpers for method names.
1470e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
147154759cefSAart Bik     return getPrint(op, "printI64",
14725446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1473e52414b1Saartbik   }
1474b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1475b8880f5fSAart Bik     return getPrint(op, "printU64",
1476b8880f5fSAart Bik                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1477b8880f5fSAart Bik   }
1478d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
147954759cefSAart Bik     return getPrint(op, "printF32",
14805446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1481d9b500d3SAart Bik   }
1482d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
148354759cefSAart Bik     return getPrint(op, "printF64",
14845446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1485d9b500d3SAart Bik   }
1486d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
148754759cefSAart Bik     return getPrint(op, "printOpen", {});
1488d9b500d3SAart Bik   }
1489d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
149054759cefSAart Bik     return getPrint(op, "printClose", {});
1491d9b500d3SAart Bik   }
1492d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
149354759cefSAart Bik     return getPrint(op, "printComma", {});
1494d9b500d3SAart Bik   }
1495d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
149654759cefSAart Bik     return getPrint(op, "printNewline", {});
1497d9b500d3SAart Bik   }
1498d9b500d3SAart Bik };
1499d9b500d3SAart Bik 
1500334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1501c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1502c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1503c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1504334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
150565678d93SNicolas Vasilache public:
1506b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1507b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1508b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1509b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1510b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1511b99bd771SRiver Riddle   }
151265678d93SNicolas Vasilache 
1513334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
151465678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
151565678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
151665678d93SNicolas Vasilache 
151765678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
151865678d93SNicolas Vasilache 
151965678d93SNicolas Vasilache     int64_t offset =
152065678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
152165678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
152265678d93SNicolas Vasilache     int64_t stride =
152365678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
152465678d93SNicolas Vasilache 
152565678d93SNicolas Vasilache     auto loc = op.getLoc();
152665678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
152735b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1528c3c95b9cSaartbik 
1529c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1530c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1531c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1532c3c95b9cSaartbik       offsets.reserve(size);
1533c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1534c3c95b9cSaartbik            off += stride)
1535c3c95b9cSaartbik         offsets.push_back(off);
1536c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1537c3c95b9cSaartbik                                              op.vector(),
1538c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1539c3c95b9cSaartbik       return success();
1540c3c95b9cSaartbik     }
1541c3c95b9cSaartbik 
1542c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
154365678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
154465678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
154565678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
154665678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
154765678d93SNicolas Vasilache          off += stride, ++idx) {
1548c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1549c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1550c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
155165678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
155265678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
155365678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
155465678d93SNicolas Vasilache     }
1555c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15563145427dSRiver Riddle     return success();
155765678d93SNicolas Vasilache   }
155865678d93SNicolas Vasilache };
155965678d93SNicolas Vasilache 
1560df186507SBenjamin Kramer } // namespace
1561df186507SBenjamin Kramer 
15625c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15635c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1564ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1565060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
156665678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15678345b86dSNicolas Vasilache   // clang-format off
1568681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1569681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15702d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1571c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1572ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1573ceb1b327Saartbik       ctx, converter, reassociateFPReductions);
1574060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1575060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1576060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1577060c9dd1Saartbik       ctx, converter, enableIndexOptimizations);
15788345b86dSNicolas Vasilache   patterns
1579ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15808345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15818345b86dSNicolas Vasilache               VectorExtractOpConversion,
15828345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15838345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15848345b86dSNicolas Vasilache               VectorInsertOpConversion,
15858345b86dSNicolas Vasilache               VectorPrintOpConversion,
158619dbb230Saartbik               VectorTypeCastOpConversion,
158739379916Saartbik               VectorMaskedLoadOpConversion,
158839379916Saartbik               VectorMaskedStoreOpConversion,
158919dbb230Saartbik               VectorGatherOpConversion,
1590e8dcf5f8Saartbik               VectorScatterOpConversion,
1591e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1592e8dcf5f8Saartbik               VectorCompressStoreOpConversion>(ctx, converter);
15938345b86dSNicolas Vasilache   // clang-format on
15945c0c51a9SNicolas Vasilache }
15955c0c51a9SNicolas Vasilache 
159663b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
159763b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
159863b683a8SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
159963b683a8SNicolas Vasilache   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1600c295a65dSaartbik   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
160163b683a8SNicolas Vasilache }
160263b683a8SNicolas Vasilache 
16035c0c51a9SNicolas Vasilache namespace {
1604722f909fSRiver Riddle struct LowerVectorToLLVMPass
16051834ad4aSRiver Riddle     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
16061bfdf7c7Saartbik   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
16071bfdf7c7Saartbik     this->reassociateFPReductions = options.reassociateFPReductions;
1608060c9dd1Saartbik     this->enableIndexOptimizations = options.enableIndexOptimizations;
16091bfdf7c7Saartbik   }
1610722f909fSRiver Riddle   void runOnOperation() override;
16115c0c51a9SNicolas Vasilache };
16125c0c51a9SNicolas Vasilache } // namespace
16135c0c51a9SNicolas Vasilache 
1614722f909fSRiver Riddle void LowerVectorToLLVMPass::runOnOperation() {
1615078776a6Saartbik   // Perform progressive lowering of operations on slices and
1616b21c7999Saartbik   // all contraction operations. Also applies folding and DCE.
1617459cf6e5Saartbik   {
16185c0c51a9SNicolas Vasilache     OwningRewritePatternList patterns;
1619b1c688dbSaartbik     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1620459cf6e5Saartbik     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1621b21c7999Saartbik     populateVectorContractLoweringPatterns(patterns, &getContext());
16223fffffa8SRiver Riddle     applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1623459cf6e5Saartbik   }
1624459cf6e5Saartbik 
1625459cf6e5Saartbik   // Convert to the LLVM IR dialect.
16265c0c51a9SNicolas Vasilache   LLVMTypeConverter converter(&getContext());
1627459cf6e5Saartbik   OwningRewritePatternList patterns;
162863b683a8SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1629060c9dd1Saartbik   populateVectorToLLVMConversionPatterns(
1630060c9dd1Saartbik       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
1631bbf3ef85SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
16325c0c51a9SNicolas Vasilache   populateStdToLLVMConversionPatterns(converter, patterns);
16335c0c51a9SNicolas Vasilache 
16342a00ae39STim Shen   LLVMConversionTarget target(getContext());
16353fffffa8SRiver Riddle   if (failed(
16363fffffa8SRiver Riddle           applyPartialConversion(getOperation(), target, std::move(patterns))))
16375c0c51a9SNicolas Vasilache     signalPassFailure();
16385c0c51a9SNicolas Vasilache }
16395c0c51a9SNicolas Vasilache 
16401bfdf7c7Saartbik std::unique_ptr<OperationPass<ModuleOp>>
16411bfdf7c7Saartbik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
16421bfdf7c7Saartbik   return std::make_unique<LowerVectorToLLVMPass>(options);
16435c0c51a9SNicolas Vasilache }
1644