15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
14519dbb230Saartbik template <typename T>
14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
14719dbb230Saartbik                                  unsigned &align) {
1485f9e0466SNicolas Vasilache   Type elementTy =
14919dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1505f9e0466SNicolas Vasilache   if (!elementTy)
1515f9e0466SNicolas Vasilache     return failure();
1525f9e0466SNicolas Vasilache 
153b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
154b2ab375dSAlex Zinenko   // stop depending on translation.
15587a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15687a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1595f9e0466SNicolas Vasilache   return success();
1605f9e0466SNicolas Vasilache }
1615f9e0466SNicolas Vasilache 
162e8dcf5f8Saartbik // Helper that returns the base address of a memref.
163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16519dbb230Saartbik   // Inspect stride and offset structure.
16619dbb230Saartbik   //
16719dbb230Saartbik   // TODO: flat memory only for now, generalize
16819dbb230Saartbik   //
16919dbb230Saartbik   int64_t offset;
17019dbb230Saartbik   SmallVector<int64_t, 4> strides;
17119dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
17219dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17319dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17419dbb230Saartbik     return failure();
175e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176e8dcf5f8Saartbik   return success();
177e8dcf5f8Saartbik }
17819dbb230Saartbik 
179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
182b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
183e8dcf5f8Saartbik   Value base;
184e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185e8dcf5f8Saartbik     return failure();
1863a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
187e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188e8dcf5f8Saartbik   return success();
189e8dcf5f8Saartbik }
190e8dcf5f8Saartbik 
19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
194b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
19539379916Saartbik   Value base;
19639379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
19739379916Saartbik     return failure();
19839379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
19939379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
20039379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
20139379916Saartbik   return success();
20239379916Saartbik }
20339379916Saartbik 
204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
205e8dcf5f8Saartbik // vector.
206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
208b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
209b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
210e8dcf5f8Saartbik   Value base;
211e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212e8dcf5f8Saartbik     return failure();
2133a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
214e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2151485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
21619dbb230Saartbik   return success();
21719dbb230Saartbik }
21819dbb230Saartbik 
2195f9e0466SNicolas Vasilache static LogicalResult
2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2215f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2225f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2235f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
224affbc0cdSNicolas Vasilache   unsigned align;
22519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226affbc0cdSNicolas Vasilache     return failure();
227affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2285f9e0466SNicolas Vasilache   return success();
2295f9e0466SNicolas Vasilache }
2305f9e0466SNicolas Vasilache 
2315f9e0466SNicolas Vasilache static LogicalResult
2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2335f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2345f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2355f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2365f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2375f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2385f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2395f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2405f9e0466SNicolas Vasilache 
2415f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2425f9e0466SNicolas Vasilache   if (!vecTy)
2435f9e0466SNicolas Vasilache     return failure();
2445f9e0466SNicolas Vasilache 
2455f9e0466SNicolas Vasilache   unsigned align;
24619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2475f9e0466SNicolas Vasilache     return failure();
2485f9e0466SNicolas Vasilache 
2495f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2505f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2515f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2525f9e0466SNicolas Vasilache   return success();
2535f9e0466SNicolas Vasilache }
2545f9e0466SNicolas Vasilache 
2555f9e0466SNicolas Vasilache static LogicalResult
2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2575f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2585f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2595f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
260affbc0cdSNicolas Vasilache   unsigned align;
26119dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262affbc0cdSNicolas Vasilache     return failure();
2632d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
264affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265affbc0cdSNicolas Vasilache                                              align);
2665f9e0466SNicolas Vasilache   return success();
2675f9e0466SNicolas Vasilache }
2685f9e0466SNicolas Vasilache 
2695f9e0466SNicolas Vasilache static LogicalResult
2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2715f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2725f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2735f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2745f9e0466SNicolas Vasilache   unsigned align;
27519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2765f9e0466SNicolas Vasilache     return failure();
2775f9e0466SNicolas Vasilache 
2782d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2795f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2805f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2815f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2825f9e0466SNicolas Vasilache   return success();
2835f9e0466SNicolas Vasilache }
2845f9e0466SNicolas Vasilache 
2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2862d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2872d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2885f9e0466SNicolas Vasilache }
2895f9e0466SNicolas Vasilache 
2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2912d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2922d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2935f9e0466SNicolas Vasilache }
2945f9e0466SNicolas Vasilache 
29590c01357SBenjamin Kramer namespace {
296e83b7b99Saartbik 
29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
29963b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern {
30063b683a8SNicolas Vasilache public:
30163b683a8SNicolas Vasilache   explicit VectorMatmulOpConversion(MLIRContext *context,
30263b683a8SNicolas Vasilache                                     LLVMTypeConverter &typeConverter)
30363b683a8SNicolas Vasilache       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
30463b683a8SNicolas Vasilache                              typeConverter) {}
30563b683a8SNicolas Vasilache 
3063145427dSRiver Riddle   LogicalResult
30763b683a8SNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
30863b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
30963b683a8SNicolas Vasilache     auto matmulOp = cast<vector::MatmulOp>(op);
3102d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
31163b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
312*dcec2ca5SChristian Sigg         op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(),
31363b683a8SNicolas Vasilache         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
31463b683a8SNicolas Vasilache         matmulOp.rhs_columns());
3153145427dSRiver Riddle     return success();
31663b683a8SNicolas Vasilache   }
31763b683a8SNicolas Vasilache };
31863b683a8SNicolas Vasilache 
319c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
320c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
321c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
322c295a65dSaartbik public:
323c295a65dSaartbik   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
324c295a65dSaartbik                                            LLVMTypeConverter &typeConverter)
325c295a65dSaartbik       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
326c295a65dSaartbik                              context, typeConverter) {}
327c295a65dSaartbik 
328c295a65dSaartbik   LogicalResult
329c295a65dSaartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
331c295a65dSaartbik     auto transOp = cast<vector::FlatTransposeOp>(op);
3322d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334*dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
335c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
336c295a65dSaartbik     return success();
337c295a65dSaartbik   }
338c295a65dSaartbik };
339c295a65dSaartbik 
34039379916Saartbik /// Conversion pattern for a vector.maskedload.
34139379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
34239379916Saartbik public:
34339379916Saartbik   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
34439379916Saartbik                                         LLVMTypeConverter &typeConverter)
34539379916Saartbik       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
34639379916Saartbik                              typeConverter) {}
34739379916Saartbik 
34839379916Saartbik   LogicalResult
34939379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
35039379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
35139379916Saartbik     auto loc = op->getLoc();
35239379916Saartbik     auto load = cast<vector::MaskedLoadOp>(op);
35339379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
35439379916Saartbik 
35539379916Saartbik     // Resolve alignment.
35639379916Saartbik     unsigned align;
357*dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
35839379916Saartbik       return failure();
35939379916Saartbik 
360*dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
36139379916Saartbik     Value ptr;
36239379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
36339379916Saartbik                           vtype, ptr)))
36439379916Saartbik       return failure();
36539379916Saartbik 
36639379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
36739379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
36839379916Saartbik         rewriter.getI32IntegerAttr(align));
36939379916Saartbik     return success();
37039379916Saartbik   }
37139379916Saartbik };
37239379916Saartbik 
37339379916Saartbik /// Conversion pattern for a vector.maskedstore.
37439379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
37539379916Saartbik public:
37639379916Saartbik   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
37739379916Saartbik                                          LLVMTypeConverter &typeConverter)
37839379916Saartbik       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
37939379916Saartbik                              typeConverter) {}
38039379916Saartbik 
38139379916Saartbik   LogicalResult
38239379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
38339379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
38439379916Saartbik     auto loc = op->getLoc();
38539379916Saartbik     auto store = cast<vector::MaskedStoreOp>(op);
38639379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
38739379916Saartbik 
38839379916Saartbik     // Resolve alignment.
38939379916Saartbik     unsigned align;
390*dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
39139379916Saartbik       return failure();
39239379916Saartbik 
393*dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
39439379916Saartbik     Value ptr;
39539379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
39639379916Saartbik                           vtype, ptr)))
39739379916Saartbik       return failure();
39839379916Saartbik 
39939379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
40039379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
40139379916Saartbik         rewriter.getI32IntegerAttr(align));
40239379916Saartbik     return success();
40339379916Saartbik   }
40439379916Saartbik };
40539379916Saartbik 
40619dbb230Saartbik /// Conversion pattern for a vector.gather.
40719dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern {
40819dbb230Saartbik public:
40919dbb230Saartbik   explicit VectorGatherOpConversion(MLIRContext *context,
41019dbb230Saartbik                                     LLVMTypeConverter &typeConverter)
41119dbb230Saartbik       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
41219dbb230Saartbik                              typeConverter) {}
41319dbb230Saartbik 
41419dbb230Saartbik   LogicalResult
41519dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
41619dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
41719dbb230Saartbik     auto loc = op->getLoc();
41819dbb230Saartbik     auto gather = cast<vector::GatherOp>(op);
41919dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
42019dbb230Saartbik 
42119dbb230Saartbik     // Resolve alignment.
42219dbb230Saartbik     unsigned align;
423*dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
42419dbb230Saartbik       return failure();
42519dbb230Saartbik 
42619dbb230Saartbik     // Get index ptrs.
42719dbb230Saartbik     VectorType vType = gather.getResultVectorType();
42819dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
42919dbb230Saartbik     Value ptrs;
430e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
431e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
43219dbb230Saartbik       return failure();
43319dbb230Saartbik 
43419dbb230Saartbik     // Replace with the gather intrinsic.
43519dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
436*dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4370c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
43819dbb230Saartbik     return success();
43919dbb230Saartbik   }
44019dbb230Saartbik };
44119dbb230Saartbik 
44219dbb230Saartbik /// Conversion pattern for a vector.scatter.
44319dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern {
44419dbb230Saartbik public:
44519dbb230Saartbik   explicit VectorScatterOpConversion(MLIRContext *context,
44619dbb230Saartbik                                      LLVMTypeConverter &typeConverter)
44719dbb230Saartbik       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
44819dbb230Saartbik                              typeConverter) {}
44919dbb230Saartbik 
45019dbb230Saartbik   LogicalResult
45119dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
45219dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
45319dbb230Saartbik     auto loc = op->getLoc();
45419dbb230Saartbik     auto scatter = cast<vector::ScatterOp>(op);
45519dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
45619dbb230Saartbik 
45719dbb230Saartbik     // Resolve alignment.
45819dbb230Saartbik     unsigned align;
459*dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
46019dbb230Saartbik       return failure();
46119dbb230Saartbik 
46219dbb230Saartbik     // Get index ptrs.
46319dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
46419dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
46519dbb230Saartbik     Value ptrs;
466e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
467e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
46819dbb230Saartbik       return failure();
46919dbb230Saartbik 
47019dbb230Saartbik     // Replace with the scatter intrinsic.
47119dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
47219dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
47319dbb230Saartbik         rewriter.getI32IntegerAttr(align));
47419dbb230Saartbik     return success();
47519dbb230Saartbik   }
47619dbb230Saartbik };
47719dbb230Saartbik 
478e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
479e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
480e8dcf5f8Saartbik public:
481e8dcf5f8Saartbik   explicit VectorExpandLoadOpConversion(MLIRContext *context,
482e8dcf5f8Saartbik                                         LLVMTypeConverter &typeConverter)
483e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
484e8dcf5f8Saartbik                              typeConverter) {}
485e8dcf5f8Saartbik 
486e8dcf5f8Saartbik   LogicalResult
487e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
489e8dcf5f8Saartbik     auto loc = op->getLoc();
490e8dcf5f8Saartbik     auto expand = cast<vector::ExpandLoadOp>(op);
491e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
492e8dcf5f8Saartbik 
493e8dcf5f8Saartbik     Value ptr;
494e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
495e8dcf5f8Saartbik                           ptr)))
496e8dcf5f8Saartbik       return failure();
497e8dcf5f8Saartbik 
498e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
499e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
500*dcec2ca5SChristian Sigg         op, typeConverter->convertType(vType), ptr, adaptor.mask(),
501e8dcf5f8Saartbik         adaptor.pass_thru());
502e8dcf5f8Saartbik     return success();
503e8dcf5f8Saartbik   }
504e8dcf5f8Saartbik };
505e8dcf5f8Saartbik 
506e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
507e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
508e8dcf5f8Saartbik public:
509e8dcf5f8Saartbik   explicit VectorCompressStoreOpConversion(MLIRContext *context,
510e8dcf5f8Saartbik                                            LLVMTypeConverter &typeConverter)
511e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
512e8dcf5f8Saartbik                              context, typeConverter) {}
513e8dcf5f8Saartbik 
514e8dcf5f8Saartbik   LogicalResult
515e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
516e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
517e8dcf5f8Saartbik     auto loc = op->getLoc();
518e8dcf5f8Saartbik     auto compress = cast<vector::CompressStoreOp>(op);
519e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
520e8dcf5f8Saartbik 
521e8dcf5f8Saartbik     Value ptr;
522e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
523e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
524e8dcf5f8Saartbik       return failure();
525e8dcf5f8Saartbik 
526e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
527e8dcf5f8Saartbik         op, adaptor.value(), ptr, adaptor.mask());
528e8dcf5f8Saartbik     return success();
529e8dcf5f8Saartbik   }
530e8dcf5f8Saartbik };
531e8dcf5f8Saartbik 
53219dbb230Saartbik /// Conversion pattern for all vector reductions.
533870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern {
534e83b7b99Saartbik public:
535e83b7b99Saartbik   explicit VectorReductionOpConversion(MLIRContext *context,
536ceb1b327Saartbik                                        LLVMTypeConverter &typeConverter,
537060c9dd1Saartbik                                        bool reassociateFPRed)
538870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
539ceb1b327Saartbik                              typeConverter),
540060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
541e83b7b99Saartbik 
5423145427dSRiver Riddle   LogicalResult
543e83b7b99Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
545e83b7b99Saartbik     auto reductionOp = cast<vector::ReductionOp>(op);
546e83b7b99Saartbik     auto kind = reductionOp.kind();
547e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
548*dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
549e9628955SAart Bik     if (eltType.isIntOrIndex()) {
550e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
551e83b7b99Saartbik       if (kind == "add")
552322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
553e83b7b99Saartbik             op, llvmType, operands[0]);
554e83b7b99Saartbik       else if (kind == "mul")
555322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
556e83b7b99Saartbik             op, llvmType, operands[0]);
557e9628955SAart Bik       else if (kind == "min" &&
558e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
559322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
560e9628955SAart Bik             op, llvmType, operands[0]);
561e83b7b99Saartbik       else if (kind == "min")
562322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
563e83b7b99Saartbik             op, llvmType, operands[0]);
564e9628955SAart Bik       else if (kind == "max" &&
565e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
566322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
567e9628955SAart Bik             op, llvmType, operands[0]);
568e83b7b99Saartbik       else if (kind == "max")
569322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
570e83b7b99Saartbik             op, llvmType, operands[0]);
571e83b7b99Saartbik       else if (kind == "and")
572322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
573e83b7b99Saartbik             op, llvmType, operands[0]);
574e83b7b99Saartbik       else if (kind == "or")
575322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
576e83b7b99Saartbik             op, llvmType, operands[0]);
577e83b7b99Saartbik       else if (kind == "xor")
578322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
579e83b7b99Saartbik             op, llvmType, operands[0]);
580e83b7b99Saartbik       else
5813145427dSRiver Riddle         return failure();
5823145427dSRiver Riddle       return success();
583*dcec2ca5SChristian Sigg     }
584e83b7b99Saartbik 
585*dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
586*dcec2ca5SChristian Sigg       return failure();
587*dcec2ca5SChristian Sigg 
588e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
589e83b7b99Saartbik     if (kind == "add") {
5900d924700Saartbik       // Optional accumulator (or zero).
5910d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5920d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
5930d924700Saartbik                                             op->getLoc(), llvmType,
5940d924700Saartbik                                             rewriter.getZeroAttr(eltType));
595322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
596ceb1b327Saartbik           op, llvmType, acc, operands[0],
597ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
598e83b7b99Saartbik     } else if (kind == "mul") {
5990d924700Saartbik       // Optional accumulator (or one).
6000d924700Saartbik       Value acc = operands.size() > 1
6010d924700Saartbik                       ? operands[1]
6020d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
6030d924700Saartbik                             op->getLoc(), llvmType,
6040d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
605322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
606ceb1b327Saartbik           op, llvmType, acc, operands[0],
607ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
608e83b7b99Saartbik     } else if (kind == "min")
609*dcec2ca5SChristian Sigg       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType,
610*dcec2ca5SChristian Sigg                                                             operands[0]);
611e83b7b99Saartbik     else if (kind == "max")
612*dcec2ca5SChristian Sigg       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType,
613*dcec2ca5SChristian Sigg                                                             operands[0]);
614e83b7b99Saartbik     else
6153145427dSRiver Riddle       return failure();
6163145427dSRiver Riddle     return success();
617e83b7b99Saartbik   }
618ceb1b327Saartbik 
619ceb1b327Saartbik private:
620ceb1b327Saartbik   const bool reassociateFPReductions;
621e83b7b99Saartbik };
622e83b7b99Saartbik 
623060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
624060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
625060c9dd1Saartbik public:
626060c9dd1Saartbik   explicit VectorCreateMaskOpConversion(MLIRContext *context,
627060c9dd1Saartbik                                         LLVMTypeConverter &typeConverter,
628060c9dd1Saartbik                                         bool enableIndexOpt)
629060c9dd1Saartbik       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
630060c9dd1Saartbik                              typeConverter),
631060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
632060c9dd1Saartbik 
633060c9dd1Saartbik   LogicalResult
634060c9dd1Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
635060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
636060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
637060c9dd1Saartbik     int64_t rank = dstType.getRank();
638060c9dd1Saartbik     if (rank == 1) {
639060c9dd1Saartbik       rewriter.replaceOp(
640060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
641060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
642060c9dd1Saartbik       return success();
643060c9dd1Saartbik     }
644060c9dd1Saartbik     return failure();
645060c9dd1Saartbik   }
646060c9dd1Saartbik 
647060c9dd1Saartbik private:
648060c9dd1Saartbik   const bool enableIndexOptimizations;
649060c9dd1Saartbik };
650060c9dd1Saartbik 
651870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern {
6521c81adf3SAart Bik public:
6531c81adf3SAart Bik   explicit VectorShuffleOpConversion(MLIRContext *context,
6541c81adf3SAart Bik                                      LLVMTypeConverter &typeConverter)
655870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
6561c81adf3SAart Bik                              typeConverter) {}
6571c81adf3SAart Bik 
6583145427dSRiver Riddle   LogicalResult
659e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6601c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6611c81adf3SAart Bik     auto loc = op->getLoc();
6622d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6631c81adf3SAart Bik     auto shuffleOp = cast<vector::ShuffleOp>(op);
6641c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6651c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6661c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
667*dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6681c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6691c81adf3SAart Bik 
6701c81adf3SAart Bik     // Bail if result type cannot be lowered.
6711c81adf3SAart Bik     if (!llvmType)
6723145427dSRiver Riddle       return failure();
6731c81adf3SAart Bik 
6741c81adf3SAart Bik     // Get rank and dimension sizes.
6751c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6761c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6771c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6781c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6791c81adf3SAart Bik 
6801c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6811c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6821c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
683e62a6956SRiver Riddle       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
6841c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
6851c81adf3SAart Bik       rewriter.replaceOp(op, shuffle);
6863145427dSRiver Riddle       return success();
687b36aaeafSAart Bik     }
688b36aaeafSAart Bik 
6891c81adf3SAart Bik     // For all other cases, insert the individual values individually.
690e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6911c81adf3SAart Bik     int64_t insPos = 0;
6921c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6931c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
694e62a6956SRiver Riddle       Value value = adaptor.v1();
6951c81adf3SAart Bik       if (extPos >= v1Dim) {
6961c81adf3SAart Bik         extPos -= v1Dim;
6971c81adf3SAart Bik         value = adaptor.v2();
698b36aaeafSAart Bik       }
699*dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
700*dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
701*dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
7020f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7031c81adf3SAart Bik     }
7041c81adf3SAart Bik     rewriter.replaceOp(op, insert);
7053145427dSRiver Riddle     return success();
706b36aaeafSAart Bik   }
707b36aaeafSAart Bik };
708b36aaeafSAart Bik 
709870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
710cd5dab8aSAart Bik public:
711cd5dab8aSAart Bik   explicit VectorExtractElementOpConversion(MLIRContext *context,
712cd5dab8aSAart Bik                                             LLVMTypeConverter &typeConverter)
713870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
714870c1fd4SAlex Zinenko                              context, typeConverter) {}
715cd5dab8aSAart Bik 
7163145427dSRiver Riddle   LogicalResult
717e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
718cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7192d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
720cd5dab8aSAart Bik     auto extractEltOp = cast<vector::ExtractElementOp>(op);
721cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
722*dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
723cd5dab8aSAart Bik 
724cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
725cd5dab8aSAart Bik     if (!llvmType)
7263145427dSRiver Riddle       return failure();
727cd5dab8aSAart Bik 
728cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
729cd5dab8aSAart Bik         op, llvmType, adaptor.vector(), adaptor.position());
7303145427dSRiver Riddle     return success();
731cd5dab8aSAart Bik   }
732cd5dab8aSAart Bik };
733cd5dab8aSAart Bik 
734870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern {
7355c0c51a9SNicolas Vasilache public:
7369826fe5cSAart Bik   explicit VectorExtractOpConversion(MLIRContext *context,
7375c0c51a9SNicolas Vasilache                                      LLVMTypeConverter &typeConverter)
738870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
7395c0c51a9SNicolas Vasilache                              typeConverter) {}
7405c0c51a9SNicolas Vasilache 
7413145427dSRiver Riddle   LogicalResult
742e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7435c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7445c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
7452d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
746d37f2725SAart Bik     auto extractOp = cast<vector::ExtractOp>(op);
7479826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7482bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
749*dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7505c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7519826fe5cSAart Bik 
7529826fe5cSAart Bik     // Bail if result type cannot be lowered.
7539826fe5cSAart Bik     if (!llvmResultType)
7543145427dSRiver Riddle       return failure();
7559826fe5cSAart Bik 
7565c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7575c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
758e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7595c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
7605c0c51a9SNicolas Vasilache       rewriter.replaceOp(op, extracted);
7613145427dSRiver Riddle       return success();
7625c0c51a9SNicolas Vasilache     }
7635c0c51a9SNicolas Vasilache 
7649826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
7655c0c51a9SNicolas Vasilache     auto *context = op->getContext();
766e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7675c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7685c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7699826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7705c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7715c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7725c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
773*dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7745c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7755c0c51a9SNicolas Vasilache     }
7765c0c51a9SNicolas Vasilache 
7775c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7785c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7795446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7801d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7815c0c51a9SNicolas Vasilache     extracted =
7825c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
7835c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, extracted);
7845c0c51a9SNicolas Vasilache 
7853145427dSRiver Riddle     return success();
7865c0c51a9SNicolas Vasilache   }
7875c0c51a9SNicolas Vasilache };
7885c0c51a9SNicolas Vasilache 
789681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
790681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
791681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
792681f929fSNicolas Vasilache ///
793681f929fSNicolas Vasilache /// Example:
794681f929fSNicolas Vasilache /// ```
795681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
796681f929fSNicolas Vasilache /// ```
797681f929fSNicolas Vasilache /// is converted to:
798681f929fSNicolas Vasilache /// ```
7993bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
800681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
801681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
802681f929fSNicolas Vasilache /// ```
803870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
804681f929fSNicolas Vasilache public:
805681f929fSNicolas Vasilache   explicit VectorFMAOp1DConversion(MLIRContext *context,
806681f929fSNicolas Vasilache                                    LLVMTypeConverter &typeConverter)
807870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
808681f929fSNicolas Vasilache                              typeConverter) {}
809681f929fSNicolas Vasilache 
8103145427dSRiver Riddle   LogicalResult
811681f929fSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
812681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8132d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
814681f929fSNicolas Vasilache     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
815681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
816681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8173145427dSRiver Riddle       return failure();
8183bffe602SBenjamin Kramer     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
8193bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8203145427dSRiver Riddle     return success();
821681f929fSNicolas Vasilache   }
822681f929fSNicolas Vasilache };
823681f929fSNicolas Vasilache 
824870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
825cd5dab8aSAart Bik public:
826cd5dab8aSAart Bik   explicit VectorInsertElementOpConversion(MLIRContext *context,
827cd5dab8aSAart Bik                                            LLVMTypeConverter &typeConverter)
828870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
829870c1fd4SAlex Zinenko                              context, typeConverter) {}
830cd5dab8aSAart Bik 
8313145427dSRiver Riddle   LogicalResult
832e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
833cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8342d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
835cd5dab8aSAart Bik     auto insertEltOp = cast<vector::InsertElementOp>(op);
836cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
837*dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
838cd5dab8aSAart Bik 
839cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
840cd5dab8aSAart Bik     if (!llvmType)
8413145427dSRiver Riddle       return failure();
842cd5dab8aSAart Bik 
843cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
844cd5dab8aSAart Bik         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
8453145427dSRiver Riddle     return success();
846cd5dab8aSAart Bik   }
847cd5dab8aSAart Bik };
848cd5dab8aSAart Bik 
849870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern {
8509826fe5cSAart Bik public:
8519826fe5cSAart Bik   explicit VectorInsertOpConversion(MLIRContext *context,
8529826fe5cSAart Bik                                     LLVMTypeConverter &typeConverter)
853870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
8549826fe5cSAart Bik                              typeConverter) {}
8559826fe5cSAart Bik 
8563145427dSRiver Riddle   LogicalResult
857e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
8589826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8599826fe5cSAart Bik     auto loc = op->getLoc();
8602d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8619826fe5cSAart Bik     auto insertOp = cast<vector::InsertOp>(op);
8629826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8639826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
864*dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8659826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8669826fe5cSAart Bik 
8679826fe5cSAart Bik     // Bail if result type cannot be lowered.
8689826fe5cSAart Bik     if (!llvmResultType)
8693145427dSRiver Riddle       return failure();
8709826fe5cSAart Bik 
8719826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8729826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
873e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8749826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8759826fe5cSAart Bik           positionArrayAttr);
8769826fe5cSAart Bik       rewriter.replaceOp(op, inserted);
8773145427dSRiver Riddle       return success();
8789826fe5cSAart Bik     }
8799826fe5cSAart Bik 
8809826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
8819826fe5cSAart Bik     auto *context = op->getContext();
882e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8839826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8849826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8859826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8869826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8879826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8889826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8899826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8909826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
891*dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8929826fe5cSAart Bik           nMinusOnePositionAttrs);
8939826fe5cSAart Bik     }
8949826fe5cSAart Bik 
8959826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8965446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
8971d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
898e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
899*dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
9000f04384dSAlex Zinenko         adaptor.source(), constant);
9019826fe5cSAart Bik 
9029826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
9039826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9049826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9059826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9069826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
9079826fe5cSAart Bik                                                       adaptor.dest(), inserted,
9089826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
9099826fe5cSAart Bik     }
9109826fe5cSAart Bik 
9119826fe5cSAart Bik     rewriter.replaceOp(op, inserted);
9123145427dSRiver Riddle     return success();
9139826fe5cSAart Bik   }
9149826fe5cSAart Bik };
9159826fe5cSAart Bik 
916681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
917681f929fSNicolas Vasilache ///
918681f929fSNicolas Vasilache /// Example:
919681f929fSNicolas Vasilache /// ```
920681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
921681f929fSNicolas Vasilache /// ```
922681f929fSNicolas Vasilache /// is rewritten into:
923681f929fSNicolas Vasilache /// ```
924681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
925681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
926681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
927681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
928681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
929681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
930681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
931681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
932681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
933681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
934681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
935681f929fSNicolas Vasilache ///  // %r3 holds the final value.
936681f929fSNicolas Vasilache /// ```
937681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
938681f929fSNicolas Vasilache public:
939681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
940681f929fSNicolas Vasilache 
9413145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
942681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
943681f929fSNicolas Vasilache     auto vType = op.getVectorType();
944681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9453145427dSRiver Riddle       return failure();
946681f929fSNicolas Vasilache 
947681f929fSNicolas Vasilache     auto loc = op.getLoc();
948681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
949681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
950681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
951681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
952681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
953681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
954681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
955681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
956681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
957681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
958681f929fSNicolas Vasilache     }
959681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9603145427dSRiver Riddle     return success();
961681f929fSNicolas Vasilache   }
962681f929fSNicolas Vasilache };
963681f929fSNicolas Vasilache 
9642d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9652d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9662d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9672d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9682d515e49SNicolas Vasilache // rank.
9692d515e49SNicolas Vasilache //
9702d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9712d515e49SNicolas Vasilache // have different ranks. In this case:
9722d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9732d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9742d515e49SNicolas Vasilache //   destination subvector
9752d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9762d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9772d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9782d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9792d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9802d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9812d515e49SNicolas Vasilache public:
9822d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9832d515e49SNicolas Vasilache 
9843145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9852d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9862d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9872d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9882d515e49SNicolas Vasilache 
9892d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9903145427dSRiver Riddle       return failure();
9912d515e49SNicolas Vasilache 
9922d515e49SNicolas Vasilache     auto loc = op.getLoc();
9932d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9942d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9952d515e49SNicolas Vasilache     if (rankDiff == 0)
9963145427dSRiver Riddle       return failure();
9972d515e49SNicolas Vasilache 
9982d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9992d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
10002d515e49SNicolas Vasilache     // on it.
10012d515e49SNicolas Vasilache     Value extracted =
10022d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
10032d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
1004*dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
10052d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
10062d515e49SNicolas Vasilache     // ranks.
10072d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
10082d515e49SNicolas Vasilache         loc, op.source(), extracted,
10092d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1010c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
10112d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
10122d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
10132d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
1014*dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
10153145427dSRiver Riddle     return success();
10162d515e49SNicolas Vasilache   }
10172d515e49SNicolas Vasilache };
10182d515e49SNicolas Vasilache 
10192d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10202d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10212d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10222d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10232d515e49SNicolas Vasilache //   destination subvector
10242d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10252d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10262d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10272d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10282d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10292d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10302d515e49SNicolas Vasilache public:
1031b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1032b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1033b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1034b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1035b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1036b99bd771SRiver Riddle   }
10372d515e49SNicolas Vasilache 
10383145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10392d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10402d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10412d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10422d515e49SNicolas Vasilache 
10432d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10443145427dSRiver Riddle       return failure();
10452d515e49SNicolas Vasilache 
10462d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10472d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10482d515e49SNicolas Vasilache     if (rankDiff != 0)
10493145427dSRiver Riddle       return failure();
10502d515e49SNicolas Vasilache 
10512d515e49SNicolas Vasilache     if (srcType == dstType) {
10522d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10533145427dSRiver Riddle       return success();
10542d515e49SNicolas Vasilache     }
10552d515e49SNicolas Vasilache 
10562d515e49SNicolas Vasilache     int64_t offset =
10572d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10582d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10592d515e49SNicolas Vasilache     int64_t stride =
10602d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10612d515e49SNicolas Vasilache 
10622d515e49SNicolas Vasilache     auto loc = op.getLoc();
10632d515e49SNicolas Vasilache     Value res = op.dest();
10642d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10652d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10662d515e49SNicolas Vasilache          off += stride, ++idx) {
10672d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10682d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10692d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10702d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10712d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10722d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10732d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10742d515e49SNicolas Vasilache         // smaller rank.
1075bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10762d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10772d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10782d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10792d515e49SNicolas Vasilache       }
10802d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10812d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10822d515e49SNicolas Vasilache     }
10832d515e49SNicolas Vasilache 
10842d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10853145427dSRiver Riddle     return success();
10862d515e49SNicolas Vasilache   }
10872d515e49SNicolas Vasilache };
10882d515e49SNicolas Vasilache 
108930e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
109030e6033bSNicolas Vasilache /// static layout.
109130e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
109230e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10932bf491c7SBenjamin Kramer   int64_t offset;
109430e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
109530e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
109630e6033bSNicolas Vasilache     return None;
109730e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
109830e6033bSNicolas Vasilache     return None;
109930e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
110030e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
110130e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
110230e6033bSNicolas Vasilache     return strides;
110330e6033bSNicolas Vasilache 
110430e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
110530e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
110630e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
110730e6033bSNicolas Vasilache   // layout.
11082bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
11092bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
111030e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
111130e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
111230e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
111330e6033bSNicolas Vasilache       return None;
111430e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
111530e6033bSNicolas Vasilache       return None;
11162bf491c7SBenjamin Kramer   }
111730e6033bSNicolas Vasilache   return strides;
11182bf491c7SBenjamin Kramer }
11192bf491c7SBenjamin Kramer 
1120870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
11215c0c51a9SNicolas Vasilache public:
11225c0c51a9SNicolas Vasilache   explicit VectorTypeCastOpConversion(MLIRContext *context,
11235c0c51a9SNicolas Vasilache                                       LLVMTypeConverter &typeConverter)
1124870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
11255c0c51a9SNicolas Vasilache                              typeConverter) {}
11265c0c51a9SNicolas Vasilache 
11273145427dSRiver Riddle   LogicalResult
1128e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11295c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11305c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
11315c0c51a9SNicolas Vasilache     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
11325c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11332bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11345c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
11352bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
11365c0c51a9SNicolas Vasilache 
11375c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11385c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11395c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11403145427dSRiver Riddle       return failure();
11415c0c51a9SNicolas Vasilache 
11425c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11432bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
11445c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
11453145427dSRiver Riddle       return failure();
11465c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11475c0c51a9SNicolas Vasilache 
1148*dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11495c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11505c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11513145427dSRiver Riddle       return failure();
11525c0c51a9SNicolas Vasilache 
115330e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
115430e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
115530e6033bSNicolas Vasilache     if (!sourceStrides)
115630e6033bSNicolas Vasilache       return failure();
115730e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
115830e6033bSNicolas Vasilache     if (!targetStrides)
115930e6033bSNicolas Vasilache       return failure();
116030e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
116130e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
116230e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
116330e6033bSNicolas Vasilache         }))
11643145427dSRiver Riddle       return failure();
11655c0c51a9SNicolas Vasilache 
11665446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11675c0c51a9SNicolas Vasilache 
11685c0c51a9SNicolas Vasilache     // Create descriptor.
11695c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11703a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11715c0c51a9SNicolas Vasilache     // Set allocated ptr.
1172e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11735c0c51a9SNicolas Vasilache     allocated =
11745c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11755c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11765c0c51a9SNicolas Vasilache     // Set aligned ptr.
1177e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11785c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11795c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11805c0c51a9SNicolas Vasilache     // Fill offset 0.
11815c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11825c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11835c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11845c0c51a9SNicolas Vasilache 
11855c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11865c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11875c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11885c0c51a9SNicolas Vasilache       auto sizeAttr =
11895c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11905c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11915c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
119230e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
119330e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11945c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11955c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11965c0c51a9SNicolas Vasilache     }
11975c0c51a9SNicolas Vasilache 
11985c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, {desc});
11993145427dSRiver Riddle     return success();
12005c0c51a9SNicolas Vasilache   }
12015c0c51a9SNicolas Vasilache };
12025c0c51a9SNicolas Vasilache 
12038345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
12048345b86dSNicolas Vasilache /// sequence of:
1205060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1206060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1207060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1208060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1209060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
12108345b86dSNicolas Vasilache template <typename ConcreteOp>
12118345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern {
12128345b86dSNicolas Vasilache public:
12138345b86dSNicolas Vasilache   explicit VectorTransferConversion(MLIRContext *context,
1214060c9dd1Saartbik                                     LLVMTypeConverter &typeConv,
1215060c9dd1Saartbik                                     bool enableIndexOpt)
1216060c9dd1Saartbik       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1217060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
12188345b86dSNicolas Vasilache 
12198345b86dSNicolas Vasilache   LogicalResult
12208345b86dSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
12218345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12228345b86dSNicolas Vasilache     auto xferOp = cast<ConcreteOp>(op);
12238345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1224b2c79c50SNicolas Vasilache 
1225b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1226b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12278345b86dSNicolas Vasilache       return failure();
12285f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12295f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12305f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
12315f9e0466SNicolas Vasilache                                        op->getContext()))
12328345b86dSNicolas Vasilache       return failure();
12332bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
123430e6033bSNicolas Vasilache     auto strides = computeContiguousStrides(xferOp.getMemRefType());
123530e6033bSNicolas Vasilache     if (!strides)
12362bf491c7SBenjamin Kramer       return failure();
12378345b86dSNicolas Vasilache 
1238*dcec2ca5SChristian Sigg     auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); };
12398345b86dSNicolas Vasilache 
12408345b86dSNicolas Vasilache     Location loc = op->getLoc();
12418345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
12428345b86dSNicolas Vasilache 
124368330ee0SThomas Raoux     if (auto memrefVectorElementType =
124468330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
124568330ee0SThomas Raoux       // Memref has vector element type.
124668330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
124768330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
124868330ee0SThomas Raoux         return failure();
12490de60b55SThomas Raoux #ifndef NDEBUG
125068330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
125168330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
125268330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
125368330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
125468330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
125568330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
125668330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
125768330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
125868330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
125968330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
126068330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
126168330ee0SThomas Raoux                "result shape.");
12620de60b55SThomas Raoux #endif // ifndef NDEBUG
126368330ee0SThomas Raoux     }
126468330ee0SThomas Raoux 
12658345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1266be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1267be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1268be16075bSWen-Heng (Jack) Chung     //    address space 0.
12698345b86dSNicolas Vasilache     // TODO: support alignment when possible.
12708b97e17dSChristian Sigg     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1271d3a98076SAlex Zinenko                                          adaptor.indices(), rewriter);
12728345b86dSNicolas Vasilache     auto vecTy =
12738345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1274be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1275be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1276be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12778345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1278be16075bSWen-Heng (Jack) Chung     else
1279be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1280be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12818345b86dSNicolas Vasilache 
12821870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1283*dcec2ca5SChristian Sigg       return replaceTransferOpWithLoadOrStore(
1284*dcec2ca5SChristian Sigg           rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr);
12851870e787SNicolas Vasilache 
12868345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12878345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12888345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12898345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1290060c9dd1Saartbik     //
1291060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1292060c9dd1Saartbik     //       dimensions here.
1293060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1294060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12950c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
1296b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1297060c9dd1Saartbik     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1298060c9dd1Saartbik                                        vecWidth, dim, &off);
12998345b86dSNicolas Vasilache 
13008345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1301*dcec2ca5SChristian Sigg     return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc,
1302*dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
13038345b86dSNicolas Vasilache   }
1304060c9dd1Saartbik 
1305060c9dd1Saartbik private:
1306060c9dd1Saartbik   const bool enableIndexOptimizations;
13078345b86dSNicolas Vasilache };
13088345b86dSNicolas Vasilache 
1309870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern {
1310d9b500d3SAart Bik public:
1311d9b500d3SAart Bik   explicit VectorPrintOpConversion(MLIRContext *context,
1312d9b500d3SAart Bik                                    LLVMTypeConverter &typeConverter)
1313870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1314d9b500d3SAart Bik                              typeConverter) {}
1315d9b500d3SAart Bik 
1316d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1317d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1318d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1319d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1320d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1321d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1322d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1323d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1324d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1325d9b500d3SAart Bik   //
13269db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1327d9b500d3SAart Bik   //
13283145427dSRiver Riddle   LogicalResult
1329e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1330d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1331d9b500d3SAart Bik     auto printOp = cast<vector::PrintOp>(op);
13322d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1333d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1334d9b500d3SAart Bik 
1335*dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
13363145427dSRiver Riddle       return failure();
1337d9b500d3SAart Bik 
1338b8880f5fSAart Bik     // Make sure element type has runtime support.
1339b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1340d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1341d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1342d9b500d3SAart Bik     Operation *printer;
1343b8880f5fSAart Bik     if (eltType.isF32()) {
1344d9b500d3SAart Bik       printer = getPrintFloat(op);
1345b8880f5fSAart Bik     } else if (eltType.isF64()) {
1346d9b500d3SAart Bik       printer = getPrintDouble(op);
134754759cefSAart Bik     } else if (eltType.isIndex()) {
134854759cefSAart Bik       printer = getPrintU64(op);
1349b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1350b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1351b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1352b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1353b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1354b8880f5fSAart Bik       if (intTy.isUnsigned()) {
135554759cefSAart Bik         if (width <= 64) {
1356b8880f5fSAart Bik           if (width < 64)
1357b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1358b8880f5fSAart Bik           printer = getPrintU64(op);
1359b8880f5fSAart Bik         } else {
13603145427dSRiver Riddle           return failure();
1361b8880f5fSAart Bik         }
1362b8880f5fSAart Bik       } else {
1363b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
136454759cefSAart Bik         if (width <= 64) {
1365b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1366b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1367b8880f5fSAart Bik           if (width == 1)
136854759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
136954759cefSAart Bik           else if (width < 64)
1370b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1371b8880f5fSAart Bik           printer = getPrintI64(op);
1372b8880f5fSAart Bik         } else {
1373b8880f5fSAart Bik           return failure();
1374b8880f5fSAart Bik         }
1375b8880f5fSAart Bik       }
1376b8880f5fSAart Bik     } else {
1377b8880f5fSAart Bik       return failure();
1378b8880f5fSAart Bik     }
1379d9b500d3SAart Bik 
1380d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1381b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1382b8880f5fSAart Bik     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1383b8880f5fSAart Bik               conversion);
1384d9b500d3SAart Bik     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1385d9b500d3SAart Bik     rewriter.eraseOp(op);
13863145427dSRiver Riddle     return success();
1387d9b500d3SAart Bik   }
1388d9b500d3SAart Bik 
1389d9b500d3SAart Bik private:
1390b8880f5fSAart Bik   enum class PrintConversion {
139130e6033bSNicolas Vasilache     // clang-format off
1392b8880f5fSAart Bik     None,
1393b8880f5fSAart Bik     ZeroExt64,
1394b8880f5fSAart Bik     SignExt64
139530e6033bSNicolas Vasilache     // clang-format on
1396b8880f5fSAart Bik   };
1397b8880f5fSAart Bik 
1398d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1399e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1400b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1401d9b500d3SAart Bik     Location loc = op->getLoc();
1402d9b500d3SAart Bik     if (rank == 0) {
1403b8880f5fSAart Bik       switch (conversion) {
1404b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1405b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1406b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1407b8880f5fSAart Bik         break;
1408b8880f5fSAart Bik       case PrintConversion::SignExt64:
1409b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1410b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1411b8880f5fSAart Bik         break;
1412b8880f5fSAart Bik       case PrintConversion::None:
1413b8880f5fSAart Bik         break;
1414c9eeeb38Saartbik       }
1415d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1416d9b500d3SAart Bik       return;
1417d9b500d3SAart Bik     }
1418d9b500d3SAart Bik 
1419d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1420d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1421d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1422d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1423d9b500d3SAart Bik       auto reducedType =
1424d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1425*dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1426d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1427*dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1428*dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1429b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1430b8880f5fSAart Bik                 conversion);
1431d9b500d3SAart Bik       if (d != dim - 1)
1432d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1433d9b500d3SAart Bik     }
1434d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1435d9b500d3SAart Bik   }
1436d9b500d3SAart Bik 
1437d9b500d3SAart Bik   // Helper to emit a call.
1438d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1439d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
144008e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1441d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1442d9b500d3SAart Bik   }
1443d9b500d3SAart Bik 
1444d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14455446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
14465446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1447d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1448d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1449d9b500d3SAart Bik     if (func)
1450d9b500d3SAart Bik       return func;
1451d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1452d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1453d9b500d3SAart Bik         op->getLoc(), name,
14545446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14555446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14565446ec85SAlex Zinenko             /*isVarArg=*/false));
1457d9b500d3SAart Bik   }
1458d9b500d3SAart Bik 
1459d9b500d3SAart Bik   // Helpers for method names.
1460e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
146154759cefSAart Bik     return getPrint(op, "printI64",
14625446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1463e52414b1Saartbik   }
1464b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1465b8880f5fSAart Bik     return getPrint(op, "printU64",
1466b8880f5fSAart Bik                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1467b8880f5fSAart Bik   }
1468d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
146954759cefSAart Bik     return getPrint(op, "printF32",
14705446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1471d9b500d3SAart Bik   }
1472d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
147354759cefSAart Bik     return getPrint(op, "printF64",
14745446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1475d9b500d3SAart Bik   }
1476d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
147754759cefSAart Bik     return getPrint(op, "printOpen", {});
1478d9b500d3SAart Bik   }
1479d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
148054759cefSAart Bik     return getPrint(op, "printClose", {});
1481d9b500d3SAart Bik   }
1482d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
148354759cefSAart Bik     return getPrint(op, "printComma", {});
1484d9b500d3SAart Bik   }
1485d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
148654759cefSAart Bik     return getPrint(op, "printNewline", {});
1487d9b500d3SAart Bik   }
1488d9b500d3SAart Bik };
1489d9b500d3SAart Bik 
1490334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1491c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1492c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1493c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1494334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
149565678d93SNicolas Vasilache public:
1496b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1497b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1498b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1499b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1500b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1501b99bd771SRiver Riddle   }
150265678d93SNicolas Vasilache 
1503334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
150465678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
150565678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
150665678d93SNicolas Vasilache 
150765678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
150865678d93SNicolas Vasilache 
150965678d93SNicolas Vasilache     int64_t offset =
151065678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
151165678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
151265678d93SNicolas Vasilache     int64_t stride =
151365678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
151465678d93SNicolas Vasilache 
151565678d93SNicolas Vasilache     auto loc = op.getLoc();
151665678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
151735b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1518c3c95b9cSaartbik 
1519c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1520c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1521c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1522c3c95b9cSaartbik       offsets.reserve(size);
1523c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1524c3c95b9cSaartbik            off += stride)
1525c3c95b9cSaartbik         offsets.push_back(off);
1526c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1527c3c95b9cSaartbik                                              op.vector(),
1528c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1529c3c95b9cSaartbik       return success();
1530c3c95b9cSaartbik     }
1531c3c95b9cSaartbik 
1532c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
153365678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
153465678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
153565678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
153665678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
153765678d93SNicolas Vasilache          off += stride, ++idx) {
1538c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1539c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1540c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
154165678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
154265678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
154365678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
154465678d93SNicolas Vasilache     }
1545c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15463145427dSRiver Riddle     return success();
154765678d93SNicolas Vasilache   }
154865678d93SNicolas Vasilache };
154965678d93SNicolas Vasilache 
1550df186507SBenjamin Kramer } // namespace
1551df186507SBenjamin Kramer 
15525c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15535c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1554ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1555060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
155665678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15578345b86dSNicolas Vasilache   // clang-format off
1558681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1559681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15602d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1561c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1562ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1563ceb1b327Saartbik       ctx, converter, reassociateFPReductions);
1564060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1565060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1566060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1567060c9dd1Saartbik       ctx, converter, enableIndexOptimizations);
15688345b86dSNicolas Vasilache   patterns
1569ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15708345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15718345b86dSNicolas Vasilache               VectorExtractOpConversion,
15728345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15738345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15748345b86dSNicolas Vasilache               VectorInsertOpConversion,
15758345b86dSNicolas Vasilache               VectorPrintOpConversion,
157619dbb230Saartbik               VectorTypeCastOpConversion,
157739379916Saartbik               VectorMaskedLoadOpConversion,
157839379916Saartbik               VectorMaskedStoreOpConversion,
157919dbb230Saartbik               VectorGatherOpConversion,
1580e8dcf5f8Saartbik               VectorScatterOpConversion,
1581e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1582e8dcf5f8Saartbik               VectorCompressStoreOpConversion>(ctx, converter);
15838345b86dSNicolas Vasilache   // clang-format on
15845c0c51a9SNicolas Vasilache }
15855c0c51a9SNicolas Vasilache 
158663b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
158763b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
158863b683a8SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
158963b683a8SNicolas Vasilache   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1590c295a65dSaartbik   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
159163b683a8SNicolas Vasilache }
1592