15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
16*09f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
14519dbb230Saartbik template <typename T>
14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
14719dbb230Saartbik                                  unsigned &align) {
1485f9e0466SNicolas Vasilache   Type elementTy =
14919dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1505f9e0466SNicolas Vasilache   if (!elementTy)
1515f9e0466SNicolas Vasilache     return failure();
1525f9e0466SNicolas Vasilache 
153b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
154b2ab375dSAlex Zinenko   // stop depending on translation.
15587a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15687a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1595f9e0466SNicolas Vasilache   return success();
1605f9e0466SNicolas Vasilache }
1615f9e0466SNicolas Vasilache 
162e8dcf5f8Saartbik // Helper that returns the base address of a memref.
163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16519dbb230Saartbik   // Inspect stride and offset structure.
16619dbb230Saartbik   //
16719dbb230Saartbik   // TODO: flat memory only for now, generalize
16819dbb230Saartbik   //
16919dbb230Saartbik   int64_t offset;
17019dbb230Saartbik   SmallVector<int64_t, 4> strides;
17119dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
17219dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17319dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17419dbb230Saartbik     return failure();
175e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176e8dcf5f8Saartbik   return success();
177e8dcf5f8Saartbik }
17819dbb230Saartbik 
179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
182b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
183e8dcf5f8Saartbik   Value base;
184e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185e8dcf5f8Saartbik     return failure();
1863a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
187e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188e8dcf5f8Saartbik   return success();
189e8dcf5f8Saartbik }
190e8dcf5f8Saartbik 
19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
194b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
19539379916Saartbik   Value base;
19639379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
19739379916Saartbik     return failure();
19839379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
19939379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
20039379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
20139379916Saartbik   return success();
20239379916Saartbik }
20339379916Saartbik 
204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
205e8dcf5f8Saartbik // vector.
206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
208b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
209b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
210e8dcf5f8Saartbik   Value base;
211e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212e8dcf5f8Saartbik     return failure();
2133a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
214e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2151485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
21619dbb230Saartbik   return success();
21719dbb230Saartbik }
21819dbb230Saartbik 
2195f9e0466SNicolas Vasilache static LogicalResult
2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2215f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2225f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2235f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
224affbc0cdSNicolas Vasilache   unsigned align;
22519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226affbc0cdSNicolas Vasilache     return failure();
227affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2285f9e0466SNicolas Vasilache   return success();
2295f9e0466SNicolas Vasilache }
2305f9e0466SNicolas Vasilache 
2315f9e0466SNicolas Vasilache static LogicalResult
2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2335f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2345f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2355f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2365f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2375f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2385f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2395f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2405f9e0466SNicolas Vasilache 
2415f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2425f9e0466SNicolas Vasilache   if (!vecTy)
2435f9e0466SNicolas Vasilache     return failure();
2445f9e0466SNicolas Vasilache 
2455f9e0466SNicolas Vasilache   unsigned align;
24619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2475f9e0466SNicolas Vasilache     return failure();
2485f9e0466SNicolas Vasilache 
2495f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2505f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2515f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2525f9e0466SNicolas Vasilache   return success();
2535f9e0466SNicolas Vasilache }
2545f9e0466SNicolas Vasilache 
2555f9e0466SNicolas Vasilache static LogicalResult
2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2575f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2585f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2595f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
260affbc0cdSNicolas Vasilache   unsigned align;
26119dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262affbc0cdSNicolas Vasilache     return failure();
2632d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
264affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265affbc0cdSNicolas Vasilache                                              align);
2665f9e0466SNicolas Vasilache   return success();
2675f9e0466SNicolas Vasilache }
2685f9e0466SNicolas Vasilache 
2695f9e0466SNicolas Vasilache static LogicalResult
2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2715f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2725f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2735f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2745f9e0466SNicolas Vasilache   unsigned align;
27519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2765f9e0466SNicolas Vasilache     return failure();
2775f9e0466SNicolas Vasilache 
2782d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2795f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2805f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2815f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2825f9e0466SNicolas Vasilache   return success();
2835f9e0466SNicolas Vasilache }
2845f9e0466SNicolas Vasilache 
2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2862d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2872d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2885f9e0466SNicolas Vasilache }
2895f9e0466SNicolas Vasilache 
2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2912d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2922d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2935f9e0466SNicolas Vasilache }
2945f9e0466SNicolas Vasilache 
29590c01357SBenjamin Kramer namespace {
296e83b7b99Saartbik 
29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
29963b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern {
30063b683a8SNicolas Vasilache public:
30163b683a8SNicolas Vasilache   explicit VectorMatmulOpConversion(MLIRContext *context,
30263b683a8SNicolas Vasilache                                     LLVMTypeConverter &typeConverter)
30363b683a8SNicolas Vasilache       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
30463b683a8SNicolas Vasilache                              typeConverter) {}
30563b683a8SNicolas Vasilache 
3063145427dSRiver Riddle   LogicalResult
30763b683a8SNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
30863b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
30963b683a8SNicolas Vasilache     auto matmulOp = cast<vector::MatmulOp>(op);
3102d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
31163b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
31263b683a8SNicolas Vasilache         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
31363b683a8SNicolas Vasilache         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
31463b683a8SNicolas Vasilache         matmulOp.rhs_columns());
3153145427dSRiver Riddle     return success();
31663b683a8SNicolas Vasilache   }
31763b683a8SNicolas Vasilache };
31863b683a8SNicolas Vasilache 
319c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
320c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
321c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
322c295a65dSaartbik public:
323c295a65dSaartbik   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
324c295a65dSaartbik                                            LLVMTypeConverter &typeConverter)
325c295a65dSaartbik       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
326c295a65dSaartbik                              context, typeConverter) {}
327c295a65dSaartbik 
328c295a65dSaartbik   LogicalResult
329c295a65dSaartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
331c295a65dSaartbik     auto transOp = cast<vector::FlatTransposeOp>(op);
3322d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334c295a65dSaartbik         transOp, typeConverter.convertType(transOp.res().getType()),
335c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
336c295a65dSaartbik     return success();
337c295a65dSaartbik   }
338c295a65dSaartbik };
339c295a65dSaartbik 
34039379916Saartbik /// Conversion pattern for a vector.maskedload.
34139379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
34239379916Saartbik public:
34339379916Saartbik   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
34439379916Saartbik                                         LLVMTypeConverter &typeConverter)
34539379916Saartbik       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
34639379916Saartbik                              typeConverter) {}
34739379916Saartbik 
34839379916Saartbik   LogicalResult
34939379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
35039379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
35139379916Saartbik     auto loc = op->getLoc();
35239379916Saartbik     auto load = cast<vector::MaskedLoadOp>(op);
35339379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
35439379916Saartbik 
35539379916Saartbik     // Resolve alignment.
35639379916Saartbik     unsigned align;
35739379916Saartbik     if (failed(getMemRefAlignment(typeConverter, load, align)))
35839379916Saartbik       return failure();
35939379916Saartbik 
36039379916Saartbik     auto vtype = typeConverter.convertType(load.getResultVectorType());
36139379916Saartbik     Value ptr;
36239379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
36339379916Saartbik                           vtype, ptr)))
36439379916Saartbik       return failure();
36539379916Saartbik 
36639379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
36739379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
36839379916Saartbik         rewriter.getI32IntegerAttr(align));
36939379916Saartbik     return success();
37039379916Saartbik   }
37139379916Saartbik };
37239379916Saartbik 
37339379916Saartbik /// Conversion pattern for a vector.maskedstore.
37439379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
37539379916Saartbik public:
37639379916Saartbik   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
37739379916Saartbik                                          LLVMTypeConverter &typeConverter)
37839379916Saartbik       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
37939379916Saartbik                              typeConverter) {}
38039379916Saartbik 
38139379916Saartbik   LogicalResult
38239379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
38339379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
38439379916Saartbik     auto loc = op->getLoc();
38539379916Saartbik     auto store = cast<vector::MaskedStoreOp>(op);
38639379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
38739379916Saartbik 
38839379916Saartbik     // Resolve alignment.
38939379916Saartbik     unsigned align;
39039379916Saartbik     if (failed(getMemRefAlignment(typeConverter, store, align)))
39139379916Saartbik       return failure();
39239379916Saartbik 
39339379916Saartbik     auto vtype = typeConverter.convertType(store.getValueVectorType());
39439379916Saartbik     Value ptr;
39539379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
39639379916Saartbik                           vtype, ptr)))
39739379916Saartbik       return failure();
39839379916Saartbik 
39939379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
40039379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
40139379916Saartbik         rewriter.getI32IntegerAttr(align));
40239379916Saartbik     return success();
40339379916Saartbik   }
40439379916Saartbik };
40539379916Saartbik 
40619dbb230Saartbik /// Conversion pattern for a vector.gather.
40719dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern {
40819dbb230Saartbik public:
40919dbb230Saartbik   explicit VectorGatherOpConversion(MLIRContext *context,
41019dbb230Saartbik                                     LLVMTypeConverter &typeConverter)
41119dbb230Saartbik       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
41219dbb230Saartbik                              typeConverter) {}
41319dbb230Saartbik 
41419dbb230Saartbik   LogicalResult
41519dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
41619dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
41719dbb230Saartbik     auto loc = op->getLoc();
41819dbb230Saartbik     auto gather = cast<vector::GatherOp>(op);
41919dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
42019dbb230Saartbik 
42119dbb230Saartbik     // Resolve alignment.
42219dbb230Saartbik     unsigned align;
42319dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, gather, align)))
42419dbb230Saartbik       return failure();
42519dbb230Saartbik 
42619dbb230Saartbik     // Get index ptrs.
42719dbb230Saartbik     VectorType vType = gather.getResultVectorType();
42819dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
42919dbb230Saartbik     Value ptrs;
430e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
431e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
43219dbb230Saartbik       return failure();
43319dbb230Saartbik 
43419dbb230Saartbik     // Replace with the gather intrinsic.
43519dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
4360c2a4d3cSBenjamin Kramer         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
4370c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
43819dbb230Saartbik     return success();
43919dbb230Saartbik   }
44019dbb230Saartbik };
44119dbb230Saartbik 
44219dbb230Saartbik /// Conversion pattern for a vector.scatter.
44319dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern {
44419dbb230Saartbik public:
44519dbb230Saartbik   explicit VectorScatterOpConversion(MLIRContext *context,
44619dbb230Saartbik                                      LLVMTypeConverter &typeConverter)
44719dbb230Saartbik       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
44819dbb230Saartbik                              typeConverter) {}
44919dbb230Saartbik 
45019dbb230Saartbik   LogicalResult
45119dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
45219dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
45319dbb230Saartbik     auto loc = op->getLoc();
45419dbb230Saartbik     auto scatter = cast<vector::ScatterOp>(op);
45519dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
45619dbb230Saartbik 
45719dbb230Saartbik     // Resolve alignment.
45819dbb230Saartbik     unsigned align;
45919dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
46019dbb230Saartbik       return failure();
46119dbb230Saartbik 
46219dbb230Saartbik     // Get index ptrs.
46319dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
46419dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
46519dbb230Saartbik     Value ptrs;
466e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
467e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
46819dbb230Saartbik       return failure();
46919dbb230Saartbik 
47019dbb230Saartbik     // Replace with the scatter intrinsic.
47119dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
47219dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
47319dbb230Saartbik         rewriter.getI32IntegerAttr(align));
47419dbb230Saartbik     return success();
47519dbb230Saartbik   }
47619dbb230Saartbik };
47719dbb230Saartbik 
478e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
479e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
480e8dcf5f8Saartbik public:
481e8dcf5f8Saartbik   explicit VectorExpandLoadOpConversion(MLIRContext *context,
482e8dcf5f8Saartbik                                         LLVMTypeConverter &typeConverter)
483e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
484e8dcf5f8Saartbik                              typeConverter) {}
485e8dcf5f8Saartbik 
486e8dcf5f8Saartbik   LogicalResult
487e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
489e8dcf5f8Saartbik     auto loc = op->getLoc();
490e8dcf5f8Saartbik     auto expand = cast<vector::ExpandLoadOp>(op);
491e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
492e8dcf5f8Saartbik 
493e8dcf5f8Saartbik     Value ptr;
494e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
495e8dcf5f8Saartbik                           ptr)))
496e8dcf5f8Saartbik       return failure();
497e8dcf5f8Saartbik 
498e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
499e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
500e8dcf5f8Saartbik         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
501e8dcf5f8Saartbik         adaptor.pass_thru());
502e8dcf5f8Saartbik     return success();
503e8dcf5f8Saartbik   }
504e8dcf5f8Saartbik };
505e8dcf5f8Saartbik 
506e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
507e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
508e8dcf5f8Saartbik public:
509e8dcf5f8Saartbik   explicit VectorCompressStoreOpConversion(MLIRContext *context,
510e8dcf5f8Saartbik                                            LLVMTypeConverter &typeConverter)
511e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
512e8dcf5f8Saartbik                              context, typeConverter) {}
513e8dcf5f8Saartbik 
514e8dcf5f8Saartbik   LogicalResult
515e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
516e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
517e8dcf5f8Saartbik     auto loc = op->getLoc();
518e8dcf5f8Saartbik     auto compress = cast<vector::CompressStoreOp>(op);
519e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
520e8dcf5f8Saartbik 
521e8dcf5f8Saartbik     Value ptr;
522e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
523e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
524e8dcf5f8Saartbik       return failure();
525e8dcf5f8Saartbik 
526e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
527e8dcf5f8Saartbik         op, adaptor.value(), ptr, adaptor.mask());
528e8dcf5f8Saartbik     return success();
529e8dcf5f8Saartbik   }
530e8dcf5f8Saartbik };
531e8dcf5f8Saartbik 
53219dbb230Saartbik /// Conversion pattern for all vector reductions.
533870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern {
534e83b7b99Saartbik public:
535e83b7b99Saartbik   explicit VectorReductionOpConversion(MLIRContext *context,
536ceb1b327Saartbik                                        LLVMTypeConverter &typeConverter,
537060c9dd1Saartbik                                        bool reassociateFPRed)
538870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
539ceb1b327Saartbik                              typeConverter),
540060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
541e83b7b99Saartbik 
5423145427dSRiver Riddle   LogicalResult
543e83b7b99Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
545e83b7b99Saartbik     auto reductionOp = cast<vector::ReductionOp>(op);
546e83b7b99Saartbik     auto kind = reductionOp.kind();
547e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
5480f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(eltType);
549e9628955SAart Bik     if (eltType.isIntOrIndex()) {
550e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
551e83b7b99Saartbik       if (kind == "add")
552322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
553e83b7b99Saartbik             op, llvmType, operands[0]);
554e83b7b99Saartbik       else if (kind == "mul")
555322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
556e83b7b99Saartbik             op, llvmType, operands[0]);
557e9628955SAart Bik       else if (kind == "min" &&
558e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
559322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
560e9628955SAart Bik             op, llvmType, operands[0]);
561e83b7b99Saartbik       else if (kind == "min")
562322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
563e83b7b99Saartbik             op, llvmType, operands[0]);
564e9628955SAart Bik       else if (kind == "max" &&
565e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
566322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
567e9628955SAart Bik             op, llvmType, operands[0]);
568e83b7b99Saartbik       else if (kind == "max")
569322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
570e83b7b99Saartbik             op, llvmType, operands[0]);
571e83b7b99Saartbik       else if (kind == "and")
572322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
573e83b7b99Saartbik             op, llvmType, operands[0]);
574e83b7b99Saartbik       else if (kind == "or")
575322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
576e83b7b99Saartbik             op, llvmType, operands[0]);
577e83b7b99Saartbik       else if (kind == "xor")
578322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
579e83b7b99Saartbik             op, llvmType, operands[0]);
580e83b7b99Saartbik       else
5813145427dSRiver Riddle         return failure();
5823145427dSRiver Riddle       return success();
583e83b7b99Saartbik 
5842d76274bSBenjamin Kramer     } else if (eltType.isa<FloatType>()) {
585e83b7b99Saartbik       // Floating-point reductions: add/mul/min/max
586e83b7b99Saartbik       if (kind == "add") {
5870d924700Saartbik         // Optional accumulator (or zero).
5880d924700Saartbik         Value acc = operands.size() > 1 ? operands[1]
5890d924700Saartbik                                         : rewriter.create<LLVM::ConstantOp>(
5900d924700Saartbik                                               op->getLoc(), llvmType,
5910d924700Saartbik                                               rewriter.getZeroAttr(eltType));
592322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
593ceb1b327Saartbik             op, llvmType, acc, operands[0],
594ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
595e83b7b99Saartbik       } else if (kind == "mul") {
5960d924700Saartbik         // Optional accumulator (or one).
5970d924700Saartbik         Value acc = operands.size() > 1
5980d924700Saartbik                         ? operands[1]
5990d924700Saartbik                         : rewriter.create<LLVM::ConstantOp>(
6000d924700Saartbik                               op->getLoc(), llvmType,
6010d924700Saartbik                               rewriter.getFloatAttr(eltType, 1.0));
602322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
603ceb1b327Saartbik             op, llvmType, acc, operands[0],
604ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
605e83b7b99Saartbik       } else if (kind == "min")
606322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
607e83b7b99Saartbik             op, llvmType, operands[0]);
608e83b7b99Saartbik       else if (kind == "max")
609322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
610e83b7b99Saartbik             op, llvmType, operands[0]);
611e83b7b99Saartbik       else
6123145427dSRiver Riddle         return failure();
6133145427dSRiver Riddle       return success();
614e83b7b99Saartbik     }
6153145427dSRiver Riddle     return failure();
616e83b7b99Saartbik   }
617ceb1b327Saartbik 
618ceb1b327Saartbik private:
619ceb1b327Saartbik   const bool reassociateFPReductions;
620e83b7b99Saartbik };
621e83b7b99Saartbik 
622060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
623060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
624060c9dd1Saartbik public:
625060c9dd1Saartbik   explicit VectorCreateMaskOpConversion(MLIRContext *context,
626060c9dd1Saartbik                                         LLVMTypeConverter &typeConverter,
627060c9dd1Saartbik                                         bool enableIndexOpt)
628060c9dd1Saartbik       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
629060c9dd1Saartbik                              typeConverter),
630060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
631060c9dd1Saartbik 
632060c9dd1Saartbik   LogicalResult
633060c9dd1Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
634060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
635060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
636060c9dd1Saartbik     int64_t rank = dstType.getRank();
637060c9dd1Saartbik     if (rank == 1) {
638060c9dd1Saartbik       rewriter.replaceOp(
639060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
640060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
641060c9dd1Saartbik       return success();
642060c9dd1Saartbik     }
643060c9dd1Saartbik     return failure();
644060c9dd1Saartbik   }
645060c9dd1Saartbik 
646060c9dd1Saartbik private:
647060c9dd1Saartbik   const bool enableIndexOptimizations;
648060c9dd1Saartbik };
649060c9dd1Saartbik 
650870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern {
6511c81adf3SAart Bik public:
6521c81adf3SAart Bik   explicit VectorShuffleOpConversion(MLIRContext *context,
6531c81adf3SAart Bik                                      LLVMTypeConverter &typeConverter)
654870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
6551c81adf3SAart Bik                              typeConverter) {}
6561c81adf3SAart Bik 
6573145427dSRiver Riddle   LogicalResult
658e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6591c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6601c81adf3SAart Bik     auto loc = op->getLoc();
6612d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6621c81adf3SAart Bik     auto shuffleOp = cast<vector::ShuffleOp>(op);
6631c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6641c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6651c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
6660f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(vectorType);
6671c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6681c81adf3SAart Bik 
6691c81adf3SAart Bik     // Bail if result type cannot be lowered.
6701c81adf3SAart Bik     if (!llvmType)
6713145427dSRiver Riddle       return failure();
6721c81adf3SAart Bik 
6731c81adf3SAart Bik     // Get rank and dimension sizes.
6741c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6751c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6761c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6771c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6781c81adf3SAart Bik 
6791c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6801c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6811c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
682e62a6956SRiver Riddle       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
6831c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
6841c81adf3SAart Bik       rewriter.replaceOp(op, shuffle);
6853145427dSRiver Riddle       return success();
686b36aaeafSAart Bik     }
687b36aaeafSAart Bik 
6881c81adf3SAart Bik     // For all other cases, insert the individual values individually.
689e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6901c81adf3SAart Bik     int64_t insPos = 0;
6911c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6921c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
693e62a6956SRiver Riddle       Value value = adaptor.v1();
6941c81adf3SAart Bik       if (extPos >= v1Dim) {
6951c81adf3SAart Bik         extPos -= v1Dim;
6961c81adf3SAart Bik         value = adaptor.v2();
697b36aaeafSAart Bik       }
6980f04384dSAlex Zinenko       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
6990f04384dSAlex Zinenko                                  rank, extPos);
7000f04384dSAlex Zinenko       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
7010f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7021c81adf3SAart Bik     }
7031c81adf3SAart Bik     rewriter.replaceOp(op, insert);
7043145427dSRiver Riddle     return success();
705b36aaeafSAart Bik   }
706b36aaeafSAart Bik };
707b36aaeafSAart Bik 
708870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
709cd5dab8aSAart Bik public:
710cd5dab8aSAart Bik   explicit VectorExtractElementOpConversion(MLIRContext *context,
711cd5dab8aSAart Bik                                             LLVMTypeConverter &typeConverter)
712870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
713870c1fd4SAlex Zinenko                              context, typeConverter) {}
714cd5dab8aSAart Bik 
7153145427dSRiver Riddle   LogicalResult
716e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
717cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7182d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
719cd5dab8aSAart Bik     auto extractEltOp = cast<vector::ExtractElementOp>(op);
720cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
7210f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType.getElementType());
722cd5dab8aSAart Bik 
723cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
724cd5dab8aSAart Bik     if (!llvmType)
7253145427dSRiver Riddle       return failure();
726cd5dab8aSAart Bik 
727cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
728cd5dab8aSAart Bik         op, llvmType, adaptor.vector(), adaptor.position());
7293145427dSRiver Riddle     return success();
730cd5dab8aSAart Bik   }
731cd5dab8aSAart Bik };
732cd5dab8aSAart Bik 
733870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern {
7345c0c51a9SNicolas Vasilache public:
7359826fe5cSAart Bik   explicit VectorExtractOpConversion(MLIRContext *context,
7365c0c51a9SNicolas Vasilache                                      LLVMTypeConverter &typeConverter)
737870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
7385c0c51a9SNicolas Vasilache                              typeConverter) {}
7395c0c51a9SNicolas Vasilache 
7403145427dSRiver Riddle   LogicalResult
741e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7425c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7435c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
7442d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
745d37f2725SAart Bik     auto extractOp = cast<vector::ExtractOp>(op);
7469826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7472bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
7480f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(resultType);
7495c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7509826fe5cSAart Bik 
7519826fe5cSAart Bik     // Bail if result type cannot be lowered.
7529826fe5cSAart Bik     if (!llvmResultType)
7533145427dSRiver Riddle       return failure();
7549826fe5cSAart Bik 
7555c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7565c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
757e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7585c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
7595c0c51a9SNicolas Vasilache       rewriter.replaceOp(op, extracted);
7603145427dSRiver Riddle       return success();
7615c0c51a9SNicolas Vasilache     }
7625c0c51a9SNicolas Vasilache 
7639826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
7645c0c51a9SNicolas Vasilache     auto *context = op->getContext();
765e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7665c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7675c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7689826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7695c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7705c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7715c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
7720f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
7735c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7745c0c51a9SNicolas Vasilache     }
7755c0c51a9SNicolas Vasilache 
7765c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7775c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7785446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7791d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7805c0c51a9SNicolas Vasilache     extracted =
7815c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
7825c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, extracted);
7835c0c51a9SNicolas Vasilache 
7843145427dSRiver Riddle     return success();
7855c0c51a9SNicolas Vasilache   }
7865c0c51a9SNicolas Vasilache };
7875c0c51a9SNicolas Vasilache 
788681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
789681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
790681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
791681f929fSNicolas Vasilache ///
792681f929fSNicolas Vasilache /// Example:
793681f929fSNicolas Vasilache /// ```
794681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
795681f929fSNicolas Vasilache /// ```
796681f929fSNicolas Vasilache /// is converted to:
797681f929fSNicolas Vasilache /// ```
7983bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
799681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
800681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
801681f929fSNicolas Vasilache /// ```
802870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
803681f929fSNicolas Vasilache public:
804681f929fSNicolas Vasilache   explicit VectorFMAOp1DConversion(MLIRContext *context,
805681f929fSNicolas Vasilache                                    LLVMTypeConverter &typeConverter)
806870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
807681f929fSNicolas Vasilache                              typeConverter) {}
808681f929fSNicolas Vasilache 
8093145427dSRiver Riddle   LogicalResult
810681f929fSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8122d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
813681f929fSNicolas Vasilache     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
814681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
815681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8163145427dSRiver Riddle       return failure();
8173bffe602SBenjamin Kramer     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
8183bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8193145427dSRiver Riddle     return success();
820681f929fSNicolas Vasilache   }
821681f929fSNicolas Vasilache };
822681f929fSNicolas Vasilache 
823870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
824cd5dab8aSAart Bik public:
825cd5dab8aSAart Bik   explicit VectorInsertElementOpConversion(MLIRContext *context,
826cd5dab8aSAart Bik                                            LLVMTypeConverter &typeConverter)
827870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
828870c1fd4SAlex Zinenko                              context, typeConverter) {}
829cd5dab8aSAart Bik 
8303145427dSRiver Riddle   LogicalResult
831e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
832cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8332d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
834cd5dab8aSAart Bik     auto insertEltOp = cast<vector::InsertElementOp>(op);
835cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
8360f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType);
837cd5dab8aSAart Bik 
838cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
839cd5dab8aSAart Bik     if (!llvmType)
8403145427dSRiver Riddle       return failure();
841cd5dab8aSAart Bik 
842cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
843cd5dab8aSAart Bik         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
8443145427dSRiver Riddle     return success();
845cd5dab8aSAart Bik   }
846cd5dab8aSAart Bik };
847cd5dab8aSAart Bik 
848870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern {
8499826fe5cSAart Bik public:
8509826fe5cSAart Bik   explicit VectorInsertOpConversion(MLIRContext *context,
8519826fe5cSAart Bik                                     LLVMTypeConverter &typeConverter)
852870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
8539826fe5cSAart Bik                              typeConverter) {}
8549826fe5cSAart Bik 
8553145427dSRiver Riddle   LogicalResult
856e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
8579826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8589826fe5cSAart Bik     auto loc = op->getLoc();
8592d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8609826fe5cSAart Bik     auto insertOp = cast<vector::InsertOp>(op);
8619826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8629826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
8630f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(destVectorType);
8649826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8659826fe5cSAart Bik 
8669826fe5cSAart Bik     // Bail if result type cannot be lowered.
8679826fe5cSAart Bik     if (!llvmResultType)
8683145427dSRiver Riddle       return failure();
8699826fe5cSAart Bik 
8709826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8719826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
872e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8739826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8749826fe5cSAart Bik           positionArrayAttr);
8759826fe5cSAart Bik       rewriter.replaceOp(op, inserted);
8763145427dSRiver Riddle       return success();
8779826fe5cSAart Bik     }
8789826fe5cSAart Bik 
8799826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
8809826fe5cSAart Bik     auto *context = op->getContext();
881e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8829826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8839826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8849826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8859826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8869826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8879826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8889826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8899826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
8900f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
8919826fe5cSAart Bik           nMinusOnePositionAttrs);
8929826fe5cSAart Bik     }
8939826fe5cSAart Bik 
8949826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8955446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
8961d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
897e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
8980f04384dSAlex Zinenko         loc, typeConverter.convertType(oneDVectorType), extracted,
8990f04384dSAlex Zinenko         adaptor.source(), constant);
9009826fe5cSAart Bik 
9019826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
9029826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9039826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9049826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9059826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
9069826fe5cSAart Bik                                                       adaptor.dest(), inserted,
9079826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
9089826fe5cSAart Bik     }
9099826fe5cSAart Bik 
9109826fe5cSAart Bik     rewriter.replaceOp(op, inserted);
9113145427dSRiver Riddle     return success();
9129826fe5cSAart Bik   }
9139826fe5cSAart Bik };
9149826fe5cSAart Bik 
915681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
916681f929fSNicolas Vasilache ///
917681f929fSNicolas Vasilache /// Example:
918681f929fSNicolas Vasilache /// ```
919681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
920681f929fSNicolas Vasilache /// ```
921681f929fSNicolas Vasilache /// is rewritten into:
922681f929fSNicolas Vasilache /// ```
923681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
924681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
925681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
926681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
927681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
928681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
929681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
930681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
931681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
932681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
933681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
934681f929fSNicolas Vasilache ///  // %r3 holds the final value.
935681f929fSNicolas Vasilache /// ```
936681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
937681f929fSNicolas Vasilache public:
938681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
939681f929fSNicolas Vasilache 
9403145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
941681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
942681f929fSNicolas Vasilache     auto vType = op.getVectorType();
943681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9443145427dSRiver Riddle       return failure();
945681f929fSNicolas Vasilache 
946681f929fSNicolas Vasilache     auto loc = op.getLoc();
947681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
948681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
949681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
950681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
951681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
952681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
953681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
954681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
955681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
956681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
957681f929fSNicolas Vasilache     }
958681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9593145427dSRiver Riddle     return success();
960681f929fSNicolas Vasilache   }
961681f929fSNicolas Vasilache };
962681f929fSNicolas Vasilache 
9632d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9642d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9652d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9662d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9672d515e49SNicolas Vasilache // rank.
9682d515e49SNicolas Vasilache //
9692d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9702d515e49SNicolas Vasilache // have different ranks. In this case:
9712d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9722d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9732d515e49SNicolas Vasilache //   destination subvector
9742d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9752d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9762d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9772d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9782d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9792d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9802d515e49SNicolas Vasilache public:
9812d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9822d515e49SNicolas Vasilache 
9833145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9842d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9852d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9862d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9872d515e49SNicolas Vasilache 
9882d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9893145427dSRiver Riddle       return failure();
9902d515e49SNicolas Vasilache 
9912d515e49SNicolas Vasilache     auto loc = op.getLoc();
9922d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9932d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9942d515e49SNicolas Vasilache     if (rankDiff == 0)
9953145427dSRiver Riddle       return failure();
9962d515e49SNicolas Vasilache 
9972d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9982d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9992d515e49SNicolas Vasilache     // on it.
10002d515e49SNicolas Vasilache     Value extracted =
10012d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
10022d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
10032d515e49SNicolas Vasilache                                                   /*dropFront=*/rankRest));
10042d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
10052d515e49SNicolas Vasilache     // ranks.
10062d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
10072d515e49SNicolas Vasilache         loc, op.source(), extracted,
10082d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1009c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
10102d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
10112d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
10122d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
10132d515e49SNicolas Vasilache                        /*dropFront=*/rankRest));
10143145427dSRiver Riddle     return success();
10152d515e49SNicolas Vasilache   }
10162d515e49SNicolas Vasilache };
10172d515e49SNicolas Vasilache 
10182d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10192d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10202d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10212d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10222d515e49SNicolas Vasilache //   destination subvector
10232d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10242d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10252d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10262d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10272d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10282d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10292d515e49SNicolas Vasilache public:
1030b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1031b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1032b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1033b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1034b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1035b99bd771SRiver Riddle   }
10362d515e49SNicolas Vasilache 
10373145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10382d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10392d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10402d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10412d515e49SNicolas Vasilache 
10422d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10433145427dSRiver Riddle       return failure();
10442d515e49SNicolas Vasilache 
10452d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10462d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10472d515e49SNicolas Vasilache     if (rankDiff != 0)
10483145427dSRiver Riddle       return failure();
10492d515e49SNicolas Vasilache 
10502d515e49SNicolas Vasilache     if (srcType == dstType) {
10512d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10523145427dSRiver Riddle       return success();
10532d515e49SNicolas Vasilache     }
10542d515e49SNicolas Vasilache 
10552d515e49SNicolas Vasilache     int64_t offset =
10562d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10572d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10582d515e49SNicolas Vasilache     int64_t stride =
10592d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10602d515e49SNicolas Vasilache 
10612d515e49SNicolas Vasilache     auto loc = op.getLoc();
10622d515e49SNicolas Vasilache     Value res = op.dest();
10632d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10642d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10652d515e49SNicolas Vasilache          off += stride, ++idx) {
10662d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10672d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10682d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10692d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10702d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10712d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10722d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10732d515e49SNicolas Vasilache         // smaller rank.
1074bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10752d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10762d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10772d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10782d515e49SNicolas Vasilache       }
10792d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10802d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10812d515e49SNicolas Vasilache     }
10822d515e49SNicolas Vasilache 
10832d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10843145427dSRiver Riddle     return success();
10852d515e49SNicolas Vasilache   }
10862d515e49SNicolas Vasilache };
10872d515e49SNicolas Vasilache 
108830e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
108930e6033bSNicolas Vasilache /// static layout.
109030e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
109130e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10922bf491c7SBenjamin Kramer   int64_t offset;
109330e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
109430e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
109530e6033bSNicolas Vasilache     return None;
109630e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
109730e6033bSNicolas Vasilache     return None;
109830e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
109930e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
110030e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
110130e6033bSNicolas Vasilache     return strides;
110230e6033bSNicolas Vasilache 
110330e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
110430e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
110530e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
110630e6033bSNicolas Vasilache   // layout.
11072bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
11082bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
110930e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
111030e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
111130e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
111230e6033bSNicolas Vasilache       return None;
111330e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
111430e6033bSNicolas Vasilache       return None;
11152bf491c7SBenjamin Kramer   }
111630e6033bSNicolas Vasilache   return strides;
11172bf491c7SBenjamin Kramer }
11182bf491c7SBenjamin Kramer 
1119870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
11205c0c51a9SNicolas Vasilache public:
11215c0c51a9SNicolas Vasilache   explicit VectorTypeCastOpConversion(MLIRContext *context,
11225c0c51a9SNicolas Vasilache                                       LLVMTypeConverter &typeConverter)
1123870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
11245c0c51a9SNicolas Vasilache                              typeConverter) {}
11255c0c51a9SNicolas Vasilache 
11263145427dSRiver Riddle   LogicalResult
1127e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11285c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11295c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
11305c0c51a9SNicolas Vasilache     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
11315c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11322bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11335c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
11342bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
11355c0c51a9SNicolas Vasilache 
11365c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11375c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11385c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11393145427dSRiver Riddle       return failure();
11405c0c51a9SNicolas Vasilache 
11415c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11422bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
11435c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
11443145427dSRiver Riddle       return failure();
11455c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11465c0c51a9SNicolas Vasilache 
11470f04384dSAlex Zinenko     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
11485c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11495c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11503145427dSRiver Riddle       return failure();
11515c0c51a9SNicolas Vasilache 
115230e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
115330e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
115430e6033bSNicolas Vasilache     if (!sourceStrides)
115530e6033bSNicolas Vasilache       return failure();
115630e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
115730e6033bSNicolas Vasilache     if (!targetStrides)
115830e6033bSNicolas Vasilache       return failure();
115930e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
116030e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
116130e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
116230e6033bSNicolas Vasilache         }))
11633145427dSRiver Riddle       return failure();
11645c0c51a9SNicolas Vasilache 
11655446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11665c0c51a9SNicolas Vasilache 
11675c0c51a9SNicolas Vasilache     // Create descriptor.
11685c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11693a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11705c0c51a9SNicolas Vasilache     // Set allocated ptr.
1171e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11725c0c51a9SNicolas Vasilache     allocated =
11735c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11745c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11755c0c51a9SNicolas Vasilache     // Set aligned ptr.
1176e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11775c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11785c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11795c0c51a9SNicolas Vasilache     // Fill offset 0.
11805c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11815c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11825c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11835c0c51a9SNicolas Vasilache 
11845c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11855c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11865c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11875c0c51a9SNicolas Vasilache       auto sizeAttr =
11885c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11895c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11905c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
119130e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
119230e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11935c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11945c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11955c0c51a9SNicolas Vasilache     }
11965c0c51a9SNicolas Vasilache 
11975c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, {desc});
11983145427dSRiver Riddle     return success();
11995c0c51a9SNicolas Vasilache   }
12005c0c51a9SNicolas Vasilache };
12015c0c51a9SNicolas Vasilache 
12028345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
12038345b86dSNicolas Vasilache /// sequence of:
1204060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1205060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1206060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1207060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1208060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
12098345b86dSNicolas Vasilache template <typename ConcreteOp>
12108345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern {
12118345b86dSNicolas Vasilache public:
12128345b86dSNicolas Vasilache   explicit VectorTransferConversion(MLIRContext *context,
1213060c9dd1Saartbik                                     LLVMTypeConverter &typeConv,
1214060c9dd1Saartbik                                     bool enableIndexOpt)
1215060c9dd1Saartbik       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1216060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
12178345b86dSNicolas Vasilache 
12188345b86dSNicolas Vasilache   LogicalResult
12198345b86dSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
12208345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12218345b86dSNicolas Vasilache     auto xferOp = cast<ConcreteOp>(op);
12228345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1223b2c79c50SNicolas Vasilache 
1224b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1225b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12268345b86dSNicolas Vasilache       return failure();
12275f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12285f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12295f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
12305f9e0466SNicolas Vasilache                                        op->getContext()))
12318345b86dSNicolas Vasilache       return failure();
12322bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
123330e6033bSNicolas Vasilache     auto strides = computeContiguousStrides(xferOp.getMemRefType());
123430e6033bSNicolas Vasilache     if (!strides)
12352bf491c7SBenjamin Kramer       return failure();
12368345b86dSNicolas Vasilache 
12378345b86dSNicolas Vasilache     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
12388345b86dSNicolas Vasilache 
12398345b86dSNicolas Vasilache     Location loc = op->getLoc();
12408345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
12418345b86dSNicolas Vasilache 
124268330ee0SThomas Raoux     if (auto memrefVectorElementType =
124368330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
124468330ee0SThomas Raoux       // Memref has vector element type.
124568330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
124668330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
124768330ee0SThomas Raoux         return failure();
12480de60b55SThomas Raoux #ifndef NDEBUG
124968330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
125068330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
125168330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
125268330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
125368330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
125468330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
125568330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
125668330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
125768330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
125868330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
125968330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
126068330ee0SThomas Raoux                "result shape.");
12610de60b55SThomas Raoux #endif // ifndef NDEBUG
126268330ee0SThomas Raoux     }
126368330ee0SThomas Raoux 
12648345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1265be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1266be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1267be16075bSWen-Heng (Jack) Chung     //    address space 0.
12688345b86dSNicolas Vasilache     // TODO: support alignment when possible.
12698b97e17dSChristian Sigg     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1270d3a98076SAlex Zinenko                                          adaptor.indices(), rewriter);
12718345b86dSNicolas Vasilache     auto vecTy =
12728345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1273be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1274be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1275be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12768345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1277be16075bSWen-Heng (Jack) Chung     else
1278be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1279be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12808345b86dSNicolas Vasilache 
12811870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
12821870e787SNicolas Vasilache       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
12831870e787SNicolas Vasilache                                               xferOp, operands, vectorDataPtr);
12841870e787SNicolas Vasilache 
12858345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12868345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12878345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12888345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1289060c9dd1Saartbik     //
1290060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1291060c9dd1Saartbik     //       dimensions here.
1292060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1293060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12940c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
1295b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1296060c9dd1Saartbik     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1297060c9dd1Saartbik                                        vecWidth, dim, &off);
12988345b86dSNicolas Vasilache 
12998345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
13001870e787SNicolas Vasilache     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1301a99f62c4SAlex Zinenko                                        operands, vectorDataPtr, mask);
13028345b86dSNicolas Vasilache   }
1303060c9dd1Saartbik 
1304060c9dd1Saartbik private:
1305060c9dd1Saartbik   const bool enableIndexOptimizations;
13068345b86dSNicolas Vasilache };
13078345b86dSNicolas Vasilache 
1308870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern {
1309d9b500d3SAart Bik public:
1310d9b500d3SAart Bik   explicit VectorPrintOpConversion(MLIRContext *context,
1311d9b500d3SAart Bik                                    LLVMTypeConverter &typeConverter)
1312870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1313d9b500d3SAart Bik                              typeConverter) {}
1314d9b500d3SAart Bik 
1315d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1316d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1317d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1318d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1319d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1320d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1321d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1322d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1323d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1324d9b500d3SAart Bik   //
13259db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1326d9b500d3SAart Bik   //
13273145427dSRiver Riddle   LogicalResult
1328e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1329d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1330d9b500d3SAart Bik     auto printOp = cast<vector::PrintOp>(op);
13312d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1332d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1333d9b500d3SAart Bik 
13340f04384dSAlex Zinenko     if (typeConverter.convertType(printType) == nullptr)
13353145427dSRiver Riddle       return failure();
1336d9b500d3SAart Bik 
1337b8880f5fSAart Bik     // Make sure element type has runtime support.
1338b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1339d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1340d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1341d9b500d3SAart Bik     Operation *printer;
1342b8880f5fSAart Bik     if (eltType.isF32()) {
1343d9b500d3SAart Bik       printer = getPrintFloat(op);
1344b8880f5fSAart Bik     } else if (eltType.isF64()) {
1345d9b500d3SAart Bik       printer = getPrintDouble(op);
134654759cefSAart Bik     } else if (eltType.isIndex()) {
134754759cefSAart Bik       printer = getPrintU64(op);
1348b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1349b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1350b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1351b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1352b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1353b8880f5fSAart Bik       if (intTy.isUnsigned()) {
135454759cefSAart Bik         if (width <= 64) {
1355b8880f5fSAart Bik           if (width < 64)
1356b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1357b8880f5fSAart Bik           printer = getPrintU64(op);
1358b8880f5fSAart Bik         } else {
13593145427dSRiver Riddle           return failure();
1360b8880f5fSAart Bik         }
1361b8880f5fSAart Bik       } else {
1362b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
136354759cefSAart Bik         if (width <= 64) {
1364b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1365b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1366b8880f5fSAart Bik           if (width == 1)
136754759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
136854759cefSAart Bik           else if (width < 64)
1369b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1370b8880f5fSAart Bik           printer = getPrintI64(op);
1371b8880f5fSAart Bik         } else {
1372b8880f5fSAart Bik           return failure();
1373b8880f5fSAart Bik         }
1374b8880f5fSAart Bik       }
1375b8880f5fSAart Bik     } else {
1376b8880f5fSAart Bik       return failure();
1377b8880f5fSAart Bik     }
1378d9b500d3SAart Bik 
1379d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1380b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1381b8880f5fSAart Bik     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1382b8880f5fSAart Bik               conversion);
1383d9b500d3SAart Bik     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1384d9b500d3SAart Bik     rewriter.eraseOp(op);
13853145427dSRiver Riddle     return success();
1386d9b500d3SAart Bik   }
1387d9b500d3SAart Bik 
1388d9b500d3SAart Bik private:
1389b8880f5fSAart Bik   enum class PrintConversion {
139030e6033bSNicolas Vasilache     // clang-format off
1391b8880f5fSAart Bik     None,
1392b8880f5fSAart Bik     ZeroExt64,
1393b8880f5fSAart Bik     SignExt64
139430e6033bSNicolas Vasilache     // clang-format on
1395b8880f5fSAart Bik   };
1396b8880f5fSAart Bik 
1397d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1398e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1399b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1400d9b500d3SAart Bik     Location loc = op->getLoc();
1401d9b500d3SAart Bik     if (rank == 0) {
1402b8880f5fSAart Bik       switch (conversion) {
1403b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1404b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1405b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1406b8880f5fSAart Bik         break;
1407b8880f5fSAart Bik       case PrintConversion::SignExt64:
1408b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1409b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1410b8880f5fSAart Bik         break;
1411b8880f5fSAart Bik       case PrintConversion::None:
1412b8880f5fSAart Bik         break;
1413c9eeeb38Saartbik       }
1414d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1415d9b500d3SAart Bik       return;
1416d9b500d3SAart Bik     }
1417d9b500d3SAart Bik 
1418d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1419d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1420d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1421d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1422d9b500d3SAart Bik       auto reducedType =
1423d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
14240f04384dSAlex Zinenko       auto llvmType = typeConverter.convertType(
1425d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1426e62a6956SRiver Riddle       Value nestedVal =
14270f04384dSAlex Zinenko           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1428b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1429b8880f5fSAart Bik                 conversion);
1430d9b500d3SAart Bik       if (d != dim - 1)
1431d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1432d9b500d3SAart Bik     }
1433d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1434d9b500d3SAart Bik   }
1435d9b500d3SAart Bik 
1436d9b500d3SAart Bik   // Helper to emit a call.
1437d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1438d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
143908e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1440d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1441d9b500d3SAart Bik   }
1442d9b500d3SAart Bik 
1443d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14445446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
14455446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1446d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1447d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1448d9b500d3SAart Bik     if (func)
1449d9b500d3SAart Bik       return func;
1450d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1451d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1452d9b500d3SAart Bik         op->getLoc(), name,
14535446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14545446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14555446ec85SAlex Zinenko             /*isVarArg=*/false));
1456d9b500d3SAart Bik   }
1457d9b500d3SAart Bik 
1458d9b500d3SAart Bik   // Helpers for method names.
1459e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
146054759cefSAart Bik     return getPrint(op, "printI64",
14615446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1462e52414b1Saartbik   }
1463b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1464b8880f5fSAart Bik     return getPrint(op, "printU64",
1465b8880f5fSAart Bik                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1466b8880f5fSAart Bik   }
1467d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
146854759cefSAart Bik     return getPrint(op, "printF32",
14695446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1470d9b500d3SAart Bik   }
1471d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
147254759cefSAart Bik     return getPrint(op, "printF64",
14735446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1474d9b500d3SAart Bik   }
1475d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
147654759cefSAart Bik     return getPrint(op, "printOpen", {});
1477d9b500d3SAart Bik   }
1478d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
147954759cefSAart Bik     return getPrint(op, "printClose", {});
1480d9b500d3SAart Bik   }
1481d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
148254759cefSAart Bik     return getPrint(op, "printComma", {});
1483d9b500d3SAart Bik   }
1484d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
148554759cefSAart Bik     return getPrint(op, "printNewline", {});
1486d9b500d3SAart Bik   }
1487d9b500d3SAart Bik };
1488d9b500d3SAart Bik 
1489334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1490c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1491c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1492c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1493334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
149465678d93SNicolas Vasilache public:
1495b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1496b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1497b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1498b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1499b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1500b99bd771SRiver Riddle   }
150165678d93SNicolas Vasilache 
1502334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
150365678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
150465678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
150565678d93SNicolas Vasilache 
150665678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
150765678d93SNicolas Vasilache 
150865678d93SNicolas Vasilache     int64_t offset =
150965678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
151065678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
151165678d93SNicolas Vasilache     int64_t stride =
151265678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
151365678d93SNicolas Vasilache 
151465678d93SNicolas Vasilache     auto loc = op.getLoc();
151565678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
151635b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1517c3c95b9cSaartbik 
1518c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1519c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1520c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1521c3c95b9cSaartbik       offsets.reserve(size);
1522c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1523c3c95b9cSaartbik            off += stride)
1524c3c95b9cSaartbik         offsets.push_back(off);
1525c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1526c3c95b9cSaartbik                                              op.vector(),
1527c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1528c3c95b9cSaartbik       return success();
1529c3c95b9cSaartbik     }
1530c3c95b9cSaartbik 
1531c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
153265678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
153365678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
153465678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
153565678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
153665678d93SNicolas Vasilache          off += stride, ++idx) {
1537c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1538c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1539c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
154065678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
154165678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
154265678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
154365678d93SNicolas Vasilache     }
1544c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15453145427dSRiver Riddle     return success();
154665678d93SNicolas Vasilache   }
154765678d93SNicolas Vasilache };
154865678d93SNicolas Vasilache 
1549df186507SBenjamin Kramer } // namespace
1550df186507SBenjamin Kramer 
15515c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15525c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1553ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1554060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
155565678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15568345b86dSNicolas Vasilache   // clang-format off
1557681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1558681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15592d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1560c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1561ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1562ceb1b327Saartbik       ctx, converter, reassociateFPReductions);
1563060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1564060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1565060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1566060c9dd1Saartbik       ctx, converter, enableIndexOptimizations);
15678345b86dSNicolas Vasilache   patterns
1568ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15698345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15708345b86dSNicolas Vasilache               VectorExtractOpConversion,
15718345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15728345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15738345b86dSNicolas Vasilache               VectorInsertOpConversion,
15748345b86dSNicolas Vasilache               VectorPrintOpConversion,
157519dbb230Saartbik               VectorTypeCastOpConversion,
157639379916Saartbik               VectorMaskedLoadOpConversion,
157739379916Saartbik               VectorMaskedStoreOpConversion,
157819dbb230Saartbik               VectorGatherOpConversion,
1579e8dcf5f8Saartbik               VectorScatterOpConversion,
1580e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1581e8dcf5f8Saartbik               VectorCompressStoreOpConversion>(ctx, converter);
15828345b86dSNicolas Vasilache   // clang-format on
15835c0c51a9SNicolas Vasilache }
15845c0c51a9SNicolas Vasilache 
158563b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
158663b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
158763b683a8SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
158863b683a8SNicolas Vasilache   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1589c295a65dSaartbik   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
159063b683a8SNicolas Vasilache }
1591