15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
14519dbb230Saartbik template <typename T>
14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
14719dbb230Saartbik                                  unsigned &align) {
1485f9e0466SNicolas Vasilache   Type elementTy =
14919dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1505f9e0466SNicolas Vasilache   if (!elementTy)
1515f9e0466SNicolas Vasilache     return failure();
1525f9e0466SNicolas Vasilache 
153b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
154b2ab375dSAlex Zinenko   // stop depending on translation.
15587a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15687a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1595f9e0466SNicolas Vasilache   return success();
1605f9e0466SNicolas Vasilache }
1615f9e0466SNicolas Vasilache 
162e8dcf5f8Saartbik // Helper that returns the base address of a memref.
163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16519dbb230Saartbik   // Inspect stride and offset structure.
16619dbb230Saartbik   //
16719dbb230Saartbik   // TODO: flat memory only for now, generalize
16819dbb230Saartbik   //
16919dbb230Saartbik   int64_t offset;
17019dbb230Saartbik   SmallVector<int64_t, 4> strides;
17119dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
17219dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17319dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17419dbb230Saartbik     return failure();
175e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176e8dcf5f8Saartbik   return success();
177e8dcf5f8Saartbik }
17819dbb230Saartbik 
179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
182b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
183e8dcf5f8Saartbik   Value base;
184e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185e8dcf5f8Saartbik     return failure();
1863a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
187e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188e8dcf5f8Saartbik   return success();
189e8dcf5f8Saartbik }
190e8dcf5f8Saartbik 
19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
194b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
19539379916Saartbik   Value base;
19639379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
19739379916Saartbik     return failure();
19839379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
19939379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
20039379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
20139379916Saartbik   return success();
20239379916Saartbik }
20339379916Saartbik 
204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
205e8dcf5f8Saartbik // vector.
206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
208b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
209b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
210e8dcf5f8Saartbik   Value base;
211e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212e8dcf5f8Saartbik     return failure();
2133a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
214e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2151485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
21619dbb230Saartbik   return success();
21719dbb230Saartbik }
21819dbb230Saartbik 
2195f9e0466SNicolas Vasilache static LogicalResult
2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2215f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2225f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2235f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
224affbc0cdSNicolas Vasilache   unsigned align;
22519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226affbc0cdSNicolas Vasilache     return failure();
227affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2285f9e0466SNicolas Vasilache   return success();
2295f9e0466SNicolas Vasilache }
2305f9e0466SNicolas Vasilache 
2315f9e0466SNicolas Vasilache static LogicalResult
2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2335f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2345f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2355f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2365f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2375f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2385f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2395f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2405f9e0466SNicolas Vasilache 
2415f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2425f9e0466SNicolas Vasilache   if (!vecTy)
2435f9e0466SNicolas Vasilache     return failure();
2445f9e0466SNicolas Vasilache 
2455f9e0466SNicolas Vasilache   unsigned align;
24619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2475f9e0466SNicolas Vasilache     return failure();
2485f9e0466SNicolas Vasilache 
2495f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2505f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2515f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2525f9e0466SNicolas Vasilache   return success();
2535f9e0466SNicolas Vasilache }
2545f9e0466SNicolas Vasilache 
2555f9e0466SNicolas Vasilache static LogicalResult
2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2575f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2585f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2595f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
260affbc0cdSNicolas Vasilache   unsigned align;
26119dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262affbc0cdSNicolas Vasilache     return failure();
2632d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
264affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265affbc0cdSNicolas Vasilache                                              align);
2665f9e0466SNicolas Vasilache   return success();
2675f9e0466SNicolas Vasilache }
2685f9e0466SNicolas Vasilache 
2695f9e0466SNicolas Vasilache static LogicalResult
2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2715f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2725f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2735f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2745f9e0466SNicolas Vasilache   unsigned align;
27519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2765f9e0466SNicolas Vasilache     return failure();
2775f9e0466SNicolas Vasilache 
2782d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2795f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2805f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2815f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2825f9e0466SNicolas Vasilache   return success();
2835f9e0466SNicolas Vasilache }
2845f9e0466SNicolas Vasilache 
2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2862d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2872d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2885f9e0466SNicolas Vasilache }
2895f9e0466SNicolas Vasilache 
2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2912d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2922d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2935f9e0466SNicolas Vasilache }
2945f9e0466SNicolas Vasilache 
29590c01357SBenjamin Kramer namespace {
296e83b7b99Saartbik 
29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
299*563879b6SRahul Joshi class VectorMatmulOpConversion
300*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
30163b683a8SNicolas Vasilache public:
302*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
30363b683a8SNicolas Vasilache 
3043145427dSRiver Riddle   LogicalResult
305*563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
30663b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3072d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
30863b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
309*563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
310*563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
311*563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3123145427dSRiver Riddle     return success();
31363b683a8SNicolas Vasilache   }
31463b683a8SNicolas Vasilache };
31563b683a8SNicolas Vasilache 
316c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
317c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
318*563879b6SRahul Joshi class VectorFlatTransposeOpConversion
319*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
320c295a65dSaartbik public:
321*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
322c295a65dSaartbik 
323c295a65dSaartbik   LogicalResult
324*563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
325c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3262d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
327c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
328dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
329c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
330c295a65dSaartbik     return success();
331c295a65dSaartbik   }
332c295a65dSaartbik };
333c295a65dSaartbik 
33439379916Saartbik /// Conversion pattern for a vector.maskedload.
335*563879b6SRahul Joshi class VectorMaskedLoadOpConversion
336*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
33739379916Saartbik public:
338*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
33939379916Saartbik 
34039379916Saartbik   LogicalResult
341*563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
34239379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
343*563879b6SRahul Joshi     auto loc = load->getLoc();
34439379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
34539379916Saartbik 
34639379916Saartbik     // Resolve alignment.
34739379916Saartbik     unsigned align;
348dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
34939379916Saartbik       return failure();
35039379916Saartbik 
351dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
35239379916Saartbik     Value ptr;
35339379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
35439379916Saartbik                           vtype, ptr)))
35539379916Saartbik       return failure();
35639379916Saartbik 
35739379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
35839379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
35939379916Saartbik         rewriter.getI32IntegerAttr(align));
36039379916Saartbik     return success();
36139379916Saartbik   }
36239379916Saartbik };
36339379916Saartbik 
36439379916Saartbik /// Conversion pattern for a vector.maskedstore.
365*563879b6SRahul Joshi class VectorMaskedStoreOpConversion
366*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
36739379916Saartbik public:
368*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
36939379916Saartbik 
37039379916Saartbik   LogicalResult
371*563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
37239379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
373*563879b6SRahul Joshi     auto loc = store->getLoc();
37439379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
37539379916Saartbik 
37639379916Saartbik     // Resolve alignment.
37739379916Saartbik     unsigned align;
378dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
37939379916Saartbik       return failure();
38039379916Saartbik 
381dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
38239379916Saartbik     Value ptr;
38339379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
38439379916Saartbik                           vtype, ptr)))
38539379916Saartbik       return failure();
38639379916Saartbik 
38739379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
38839379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
38939379916Saartbik         rewriter.getI32IntegerAttr(align));
39039379916Saartbik     return success();
39139379916Saartbik   }
39239379916Saartbik };
39339379916Saartbik 
39419dbb230Saartbik /// Conversion pattern for a vector.gather.
395*563879b6SRahul Joshi class VectorGatherOpConversion
396*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
39719dbb230Saartbik public:
398*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
39919dbb230Saartbik 
40019dbb230Saartbik   LogicalResult
401*563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
40219dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
403*563879b6SRahul Joshi     auto loc = gather->getLoc();
40419dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
40519dbb230Saartbik 
40619dbb230Saartbik     // Resolve alignment.
40719dbb230Saartbik     unsigned align;
408dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
40919dbb230Saartbik       return failure();
41019dbb230Saartbik 
41119dbb230Saartbik     // Get index ptrs.
41219dbb230Saartbik     VectorType vType = gather.getResultVectorType();
41319dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
41419dbb230Saartbik     Value ptrs;
415e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
416e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
41719dbb230Saartbik       return failure();
41819dbb230Saartbik 
41919dbb230Saartbik     // Replace with the gather intrinsic.
42019dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
421dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4220c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
42319dbb230Saartbik     return success();
42419dbb230Saartbik   }
42519dbb230Saartbik };
42619dbb230Saartbik 
42719dbb230Saartbik /// Conversion pattern for a vector.scatter.
428*563879b6SRahul Joshi class VectorScatterOpConversion
429*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
43019dbb230Saartbik public:
431*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
43219dbb230Saartbik 
43319dbb230Saartbik   LogicalResult
434*563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
43519dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
436*563879b6SRahul Joshi     auto loc = scatter->getLoc();
43719dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
43819dbb230Saartbik 
43919dbb230Saartbik     // Resolve alignment.
44019dbb230Saartbik     unsigned align;
441dcec2ca5SChristian Sigg     if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
44219dbb230Saartbik       return failure();
44319dbb230Saartbik 
44419dbb230Saartbik     // Get index ptrs.
44519dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
44619dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
44719dbb230Saartbik     Value ptrs;
448e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
449e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
45019dbb230Saartbik       return failure();
45119dbb230Saartbik 
45219dbb230Saartbik     // Replace with the scatter intrinsic.
45319dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
45419dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
45519dbb230Saartbik         rewriter.getI32IntegerAttr(align));
45619dbb230Saartbik     return success();
45719dbb230Saartbik   }
45819dbb230Saartbik };
45919dbb230Saartbik 
460e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
461*563879b6SRahul Joshi class VectorExpandLoadOpConversion
462*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
463e8dcf5f8Saartbik public:
464*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
465e8dcf5f8Saartbik 
466e8dcf5f8Saartbik   LogicalResult
467*563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
468e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
469*563879b6SRahul Joshi     auto loc = expand->getLoc();
470e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
471e8dcf5f8Saartbik 
472e8dcf5f8Saartbik     Value ptr;
473e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
474e8dcf5f8Saartbik                           ptr)))
475e8dcf5f8Saartbik       return failure();
476e8dcf5f8Saartbik 
477e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
478e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
479*563879b6SRahul Joshi         expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
480e8dcf5f8Saartbik         adaptor.pass_thru());
481e8dcf5f8Saartbik     return success();
482e8dcf5f8Saartbik   }
483e8dcf5f8Saartbik };
484e8dcf5f8Saartbik 
485e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
486*563879b6SRahul Joshi class VectorCompressStoreOpConversion
487*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
488e8dcf5f8Saartbik public:
489*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
490e8dcf5f8Saartbik 
491e8dcf5f8Saartbik   LogicalResult
492*563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
493e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
494*563879b6SRahul Joshi     auto loc = compress->getLoc();
495e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
496e8dcf5f8Saartbik 
497e8dcf5f8Saartbik     Value ptr;
498e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
499e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
500e8dcf5f8Saartbik       return failure();
501e8dcf5f8Saartbik 
502e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
503*563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
504e8dcf5f8Saartbik     return success();
505e8dcf5f8Saartbik   }
506e8dcf5f8Saartbik };
507e8dcf5f8Saartbik 
50819dbb230Saartbik /// Conversion pattern for all vector reductions.
509*563879b6SRahul Joshi class VectorReductionOpConversion
510*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
511e83b7b99Saartbik public:
512*563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
513060c9dd1Saartbik                                        bool reassociateFPRed)
514*563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
515060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
516e83b7b99Saartbik 
5173145427dSRiver Riddle   LogicalResult
518*563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
519e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
520e83b7b99Saartbik     auto kind = reductionOp.kind();
521e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
522dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
523e9628955SAart Bik     if (eltType.isIntOrIndex()) {
524e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
525e83b7b99Saartbik       if (kind == "add")
526322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
527*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
528e83b7b99Saartbik       else if (kind == "mul")
529322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
530*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
531e9628955SAart Bik       else if (kind == "min" &&
532e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
533322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
534*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
535e83b7b99Saartbik       else if (kind == "min")
536322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
537*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
538e9628955SAart Bik       else if (kind == "max" &&
539e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
540322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
541*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
542e83b7b99Saartbik       else if (kind == "max")
543322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
544*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
545e83b7b99Saartbik       else if (kind == "and")
546322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
547*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
548e83b7b99Saartbik       else if (kind == "or")
549322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
550*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
551e83b7b99Saartbik       else if (kind == "xor")
552322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
553*563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
554e83b7b99Saartbik       else
5553145427dSRiver Riddle         return failure();
5563145427dSRiver Riddle       return success();
557dcec2ca5SChristian Sigg     }
558e83b7b99Saartbik 
559dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
560dcec2ca5SChristian Sigg       return failure();
561dcec2ca5SChristian Sigg 
562e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
563e83b7b99Saartbik     if (kind == "add") {
5640d924700Saartbik       // Optional accumulator (or zero).
5650d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5660d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
567*563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5680d924700Saartbik                                             rewriter.getZeroAttr(eltType));
569322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
570*563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
571ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
572e83b7b99Saartbik     } else if (kind == "mul") {
5730d924700Saartbik       // Optional accumulator (or one).
5740d924700Saartbik       Value acc = operands.size() > 1
5750d924700Saartbik                       ? operands[1]
5760d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
577*563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
5780d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
579322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
580*563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
581ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
582e83b7b99Saartbik     } else if (kind == "min")
583*563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
584*563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
585e83b7b99Saartbik     else if (kind == "max")
586*563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
587*563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
588e83b7b99Saartbik     else
5893145427dSRiver Riddle       return failure();
5903145427dSRiver Riddle     return success();
591e83b7b99Saartbik   }
592ceb1b327Saartbik 
593ceb1b327Saartbik private:
594ceb1b327Saartbik   const bool reassociateFPReductions;
595e83b7b99Saartbik };
596e83b7b99Saartbik 
597060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
598*563879b6SRahul Joshi class VectorCreateMaskOpConversion
599*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
600060c9dd1Saartbik public:
601*563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
602060c9dd1Saartbik                                         bool enableIndexOpt)
603*563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
604060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
605060c9dd1Saartbik 
606060c9dd1Saartbik   LogicalResult
607*563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
608060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
609060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
610060c9dd1Saartbik     int64_t rank = dstType.getRank();
611060c9dd1Saartbik     if (rank == 1) {
612060c9dd1Saartbik       rewriter.replaceOp(
613060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
614060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
615060c9dd1Saartbik       return success();
616060c9dd1Saartbik     }
617060c9dd1Saartbik     return failure();
618060c9dd1Saartbik   }
619060c9dd1Saartbik 
620060c9dd1Saartbik private:
621060c9dd1Saartbik   const bool enableIndexOptimizations;
622060c9dd1Saartbik };
623060c9dd1Saartbik 
624*563879b6SRahul Joshi class VectorShuffleOpConversion
625*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6261c81adf3SAart Bik public:
627*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6281c81adf3SAart Bik 
6293145427dSRiver Riddle   LogicalResult
630*563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6311c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
632*563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6332d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6341c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6351c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6361c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
637dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6381c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6391c81adf3SAart Bik 
6401c81adf3SAart Bik     // Bail if result type cannot be lowered.
6411c81adf3SAart Bik     if (!llvmType)
6423145427dSRiver Riddle       return failure();
6431c81adf3SAart Bik 
6441c81adf3SAart Bik     // Get rank and dimension sizes.
6451c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6461c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6471c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6481c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6491c81adf3SAart Bik 
6501c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6511c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6521c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
653*563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6541c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
655*563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6563145427dSRiver Riddle       return success();
657b36aaeafSAart Bik     }
658b36aaeafSAart Bik 
6591c81adf3SAart Bik     // For all other cases, insert the individual values individually.
660e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6611c81adf3SAart Bik     int64_t insPos = 0;
6621c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6631c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
664e62a6956SRiver Riddle       Value value = adaptor.v1();
6651c81adf3SAart Bik       if (extPos >= v1Dim) {
6661c81adf3SAart Bik         extPos -= v1Dim;
6671c81adf3SAart Bik         value = adaptor.v2();
668b36aaeafSAart Bik       }
669dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
670dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
671dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
6720f04384dSAlex Zinenko                          llvmType, rank, insPos++);
6731c81adf3SAart Bik     }
674*563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
6753145427dSRiver Riddle     return success();
676b36aaeafSAart Bik   }
677b36aaeafSAart Bik };
678b36aaeafSAart Bik 
679*563879b6SRahul Joshi class VectorExtractElementOpConversion
680*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
681cd5dab8aSAart Bik public:
682*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
683*563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
684cd5dab8aSAart Bik 
6853145427dSRiver Riddle   LogicalResult
686*563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
687*563879b6SRahul Joshi                   ArrayRef<Value> operands,
688cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
6892d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
690cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
691dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
692cd5dab8aSAart Bik 
693cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
694cd5dab8aSAart Bik     if (!llvmType)
6953145427dSRiver Riddle       return failure();
696cd5dab8aSAart Bik 
697cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
698*563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
6993145427dSRiver Riddle     return success();
700cd5dab8aSAart Bik   }
701cd5dab8aSAart Bik };
702cd5dab8aSAart Bik 
703*563879b6SRahul Joshi class VectorExtractOpConversion
704*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7055c0c51a9SNicolas Vasilache public:
706*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7075c0c51a9SNicolas Vasilache 
7083145427dSRiver Riddle   LogicalResult
709*563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7105c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
711*563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7122d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7139826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7142bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
715dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7165c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7179826fe5cSAart Bik 
7189826fe5cSAart Bik     // Bail if result type cannot be lowered.
7199826fe5cSAart Bik     if (!llvmResultType)
7203145427dSRiver Riddle       return failure();
7219826fe5cSAart Bik 
7225c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7235c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
724e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7255c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
726*563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7273145427dSRiver Riddle       return success();
7285c0c51a9SNicolas Vasilache     }
7295c0c51a9SNicolas Vasilache 
7309826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
731*563879b6SRahul Joshi     auto *context = extractOp->getContext();
732e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7335c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7345c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7359826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7365c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7375c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7385c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
739dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7405c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7415c0c51a9SNicolas Vasilache     }
7425c0c51a9SNicolas Vasilache 
7435c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7445c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7455446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7461d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7475c0c51a9SNicolas Vasilache     extracted =
7485c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
749*563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7505c0c51a9SNicolas Vasilache 
7513145427dSRiver Riddle     return success();
7525c0c51a9SNicolas Vasilache   }
7535c0c51a9SNicolas Vasilache };
7545c0c51a9SNicolas Vasilache 
755681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
756681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
757681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
758681f929fSNicolas Vasilache ///
759681f929fSNicolas Vasilache /// Example:
760681f929fSNicolas Vasilache /// ```
761681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
762681f929fSNicolas Vasilache /// ```
763681f929fSNicolas Vasilache /// is converted to:
764681f929fSNicolas Vasilache /// ```
7653bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
766681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
767681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
768681f929fSNicolas Vasilache /// ```
769*563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
770681f929fSNicolas Vasilache public:
771*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
772681f929fSNicolas Vasilache 
7733145427dSRiver Riddle   LogicalResult
774*563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
775681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7762d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
777681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
778681f929fSNicolas Vasilache     if (vType.getRank() != 1)
7793145427dSRiver Riddle       return failure();
780*563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
7813bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
7823145427dSRiver Riddle     return success();
783681f929fSNicolas Vasilache   }
784681f929fSNicolas Vasilache };
785681f929fSNicolas Vasilache 
786*563879b6SRahul Joshi class VectorInsertElementOpConversion
787*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
788cd5dab8aSAart Bik public:
789*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
790cd5dab8aSAart Bik 
7913145427dSRiver Riddle   LogicalResult
792*563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
793cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7942d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
795cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
796dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
797cd5dab8aSAart Bik 
798cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
799cd5dab8aSAart Bik     if (!llvmType)
8003145427dSRiver Riddle       return failure();
801cd5dab8aSAart Bik 
802cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
803*563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
804*563879b6SRahul Joshi         adaptor.position());
8053145427dSRiver Riddle     return success();
806cd5dab8aSAart Bik   }
807cd5dab8aSAart Bik };
808cd5dab8aSAart Bik 
809*563879b6SRahul Joshi class VectorInsertOpConversion
810*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8119826fe5cSAart Bik public:
812*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8139826fe5cSAart Bik 
8143145427dSRiver Riddle   LogicalResult
815*563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8169826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
817*563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8182d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8199826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8209826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
821dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8229826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8239826fe5cSAart Bik 
8249826fe5cSAart Bik     // Bail if result type cannot be lowered.
8259826fe5cSAart Bik     if (!llvmResultType)
8263145427dSRiver Riddle       return failure();
8279826fe5cSAart Bik 
8289826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8299826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
830e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8319826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8329826fe5cSAart Bik           positionArrayAttr);
833*563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8343145427dSRiver Riddle       return success();
8359826fe5cSAart Bik     }
8369826fe5cSAart Bik 
8379826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
838*563879b6SRahul Joshi     auto *context = insertOp->getContext();
839e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8409826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8419826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8429826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8439826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8449826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8459826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8469826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8479826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
848dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8499826fe5cSAart Bik           nMinusOnePositionAttrs);
8509826fe5cSAart Bik     }
8519826fe5cSAart Bik 
8529826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8535446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
8541d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
855e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
856dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8570f04384dSAlex Zinenko         adaptor.source(), constant);
8589826fe5cSAart Bik 
8599826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8609826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8619826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8629826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8639826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8649826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8659826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8669826fe5cSAart Bik     }
8679826fe5cSAart Bik 
868*563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8693145427dSRiver Riddle     return success();
8709826fe5cSAart Bik   }
8719826fe5cSAart Bik };
8729826fe5cSAart Bik 
873681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
874681f929fSNicolas Vasilache ///
875681f929fSNicolas Vasilache /// Example:
876681f929fSNicolas Vasilache /// ```
877681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
878681f929fSNicolas Vasilache /// ```
879681f929fSNicolas Vasilache /// is rewritten into:
880681f929fSNicolas Vasilache /// ```
881681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
882681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
883681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
884681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
885681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
886681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
887681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
888681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
889681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
890681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
891681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
892681f929fSNicolas Vasilache ///  // %r3 holds the final value.
893681f929fSNicolas Vasilache /// ```
894681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
895681f929fSNicolas Vasilache public:
896681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
897681f929fSNicolas Vasilache 
8983145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
899681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
900681f929fSNicolas Vasilache     auto vType = op.getVectorType();
901681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9023145427dSRiver Riddle       return failure();
903681f929fSNicolas Vasilache 
904681f929fSNicolas Vasilache     auto loc = op.getLoc();
905681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
906681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
907681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
908681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
909681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
910681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
911681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
912681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
913681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
914681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
915681f929fSNicolas Vasilache     }
916681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9173145427dSRiver Riddle     return success();
918681f929fSNicolas Vasilache   }
919681f929fSNicolas Vasilache };
920681f929fSNicolas Vasilache 
9212d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9222d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9232d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9242d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9252d515e49SNicolas Vasilache // rank.
9262d515e49SNicolas Vasilache //
9272d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9282d515e49SNicolas Vasilache // have different ranks. In this case:
9292d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9302d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9312d515e49SNicolas Vasilache //   destination subvector
9322d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9332d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9342d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9352d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9362d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9372d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9382d515e49SNicolas Vasilache public:
9392d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9402d515e49SNicolas Vasilache 
9413145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9422d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9432d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9442d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9452d515e49SNicolas Vasilache 
9462d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9473145427dSRiver Riddle       return failure();
9482d515e49SNicolas Vasilache 
9492d515e49SNicolas Vasilache     auto loc = op.getLoc();
9502d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9512d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9522d515e49SNicolas Vasilache     if (rankDiff == 0)
9533145427dSRiver Riddle       return failure();
9542d515e49SNicolas Vasilache 
9552d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9562d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9572d515e49SNicolas Vasilache     // on it.
9582d515e49SNicolas Vasilache     Value extracted =
9592d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9602d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
961dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9622d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9632d515e49SNicolas Vasilache     // ranks.
9642d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9652d515e49SNicolas Vasilache         loc, op.source(), extracted,
9662d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
967c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9682d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9692d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9702d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
971dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
9723145427dSRiver Riddle     return success();
9732d515e49SNicolas Vasilache   }
9742d515e49SNicolas Vasilache };
9752d515e49SNicolas Vasilache 
9762d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9772d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
9782d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9792d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9802d515e49SNicolas Vasilache //   destination subvector
9812d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9822d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9832d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9842d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9852d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
9862d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9872d515e49SNicolas Vasilache public:
988b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
989b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
990b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
991b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
992b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
993b99bd771SRiver Riddle   }
9942d515e49SNicolas Vasilache 
9953145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9962d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9972d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9982d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9992d515e49SNicolas Vasilache 
10002d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10013145427dSRiver Riddle       return failure();
10022d515e49SNicolas Vasilache 
10032d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10042d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10052d515e49SNicolas Vasilache     if (rankDiff != 0)
10063145427dSRiver Riddle       return failure();
10072d515e49SNicolas Vasilache 
10082d515e49SNicolas Vasilache     if (srcType == dstType) {
10092d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10103145427dSRiver Riddle       return success();
10112d515e49SNicolas Vasilache     }
10122d515e49SNicolas Vasilache 
10132d515e49SNicolas Vasilache     int64_t offset =
10142d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10152d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10162d515e49SNicolas Vasilache     int64_t stride =
10172d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10182d515e49SNicolas Vasilache 
10192d515e49SNicolas Vasilache     auto loc = op.getLoc();
10202d515e49SNicolas Vasilache     Value res = op.dest();
10212d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10222d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10232d515e49SNicolas Vasilache          off += stride, ++idx) {
10242d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10252d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10262d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10272d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10282d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10292d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10302d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10312d515e49SNicolas Vasilache         // smaller rank.
1032bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10332d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10342d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10352d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10362d515e49SNicolas Vasilache       }
10372d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10382d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10392d515e49SNicolas Vasilache     }
10402d515e49SNicolas Vasilache 
10412d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10423145427dSRiver Riddle     return success();
10432d515e49SNicolas Vasilache   }
10442d515e49SNicolas Vasilache };
10452d515e49SNicolas Vasilache 
104630e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
104730e6033bSNicolas Vasilache /// static layout.
104830e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
104930e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10502bf491c7SBenjamin Kramer   int64_t offset;
105130e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
105230e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
105330e6033bSNicolas Vasilache     return None;
105430e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
105530e6033bSNicolas Vasilache     return None;
105630e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
105730e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
105830e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
105930e6033bSNicolas Vasilache     return strides;
106030e6033bSNicolas Vasilache 
106130e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
106230e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
106330e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
106430e6033bSNicolas Vasilache   // layout.
10652bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10662bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
106730e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
106830e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
106930e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
107030e6033bSNicolas Vasilache       return None;
107130e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
107230e6033bSNicolas Vasilache       return None;
10732bf491c7SBenjamin Kramer   }
107430e6033bSNicolas Vasilache   return strides;
10752bf491c7SBenjamin Kramer }
10762bf491c7SBenjamin Kramer 
1077*563879b6SRahul Joshi class VectorTypeCastOpConversion
1078*563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
10795c0c51a9SNicolas Vasilache public:
1080*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
10815c0c51a9SNicolas Vasilache 
10823145427dSRiver Riddle   LogicalResult
1083*563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
10845c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1085*563879b6SRahul Joshi     auto loc = castOp->getLoc();
10865c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
10872bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
10885c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
10892bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
10905c0c51a9SNicolas Vasilache 
10915c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
10925c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
10935c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
10943145427dSRiver Riddle       return failure();
10955c0c51a9SNicolas Vasilache 
10965c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
10972bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
10985c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
10993145427dSRiver Riddle       return failure();
11005c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11015c0c51a9SNicolas Vasilache 
1102dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11035c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11045c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11053145427dSRiver Riddle       return failure();
11065c0c51a9SNicolas Vasilache 
110730e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
110830e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
110930e6033bSNicolas Vasilache     if (!sourceStrides)
111030e6033bSNicolas Vasilache       return failure();
111130e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
111230e6033bSNicolas Vasilache     if (!targetStrides)
111330e6033bSNicolas Vasilache       return failure();
111430e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
111530e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
111630e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
111730e6033bSNicolas Vasilache         }))
11183145427dSRiver Riddle       return failure();
11195c0c51a9SNicolas Vasilache 
11205446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11215c0c51a9SNicolas Vasilache 
11225c0c51a9SNicolas Vasilache     // Create descriptor.
11235c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11243a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11255c0c51a9SNicolas Vasilache     // Set allocated ptr.
1126e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11275c0c51a9SNicolas Vasilache     allocated =
11285c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11295c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11305c0c51a9SNicolas Vasilache     // Set aligned ptr.
1131e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11325c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11335c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11345c0c51a9SNicolas Vasilache     // Fill offset 0.
11355c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11365c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11375c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11385c0c51a9SNicolas Vasilache 
11395c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11405c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11415c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11425c0c51a9SNicolas Vasilache       auto sizeAttr =
11435c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11445c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11455c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
114630e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
114730e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11485c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11495c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11505c0c51a9SNicolas Vasilache     }
11515c0c51a9SNicolas Vasilache 
1152*563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11533145427dSRiver Riddle     return success();
11545c0c51a9SNicolas Vasilache   }
11555c0c51a9SNicolas Vasilache };
11565c0c51a9SNicolas Vasilache 
11578345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11588345b86dSNicolas Vasilache /// sequence of:
1159060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1160060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1161060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1162060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1163060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11648345b86dSNicolas Vasilache template <typename ConcreteOp>
1165*563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11668345b86dSNicolas Vasilache public:
1167*563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1168060c9dd1Saartbik                                     bool enableIndexOpt)
1169*563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1170060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11718345b86dSNicolas Vasilache 
11728345b86dSNicolas Vasilache   LogicalResult
1173*563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
11748345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11758345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1176b2c79c50SNicolas Vasilache 
1177b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1178b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
11798345b86dSNicolas Vasilache       return failure();
11805f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
11815f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
11825f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1183*563879b6SRahul Joshi                                        xferOp->getContext()))
11848345b86dSNicolas Vasilache       return failure();
11852bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
118630e6033bSNicolas Vasilache     auto strides = computeContiguousStrides(xferOp.getMemRefType());
118730e6033bSNicolas Vasilache     if (!strides)
11882bf491c7SBenjamin Kramer       return failure();
11898345b86dSNicolas Vasilache 
1190*563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1191*563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1192*563879b6SRahul Joshi     };
11938345b86dSNicolas Vasilache 
1194*563879b6SRahul Joshi     Location loc = xferOp->getLoc();
11958345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
11968345b86dSNicolas Vasilache 
119768330ee0SThomas Raoux     if (auto memrefVectorElementType =
119868330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
119968330ee0SThomas Raoux       // Memref has vector element type.
120068330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
120168330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
120268330ee0SThomas Raoux         return failure();
12030de60b55SThomas Raoux #ifndef NDEBUG
120468330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
120568330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
120668330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
120768330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
120868330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
120968330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
121068330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
121168330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
121268330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
121368330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
121468330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
121568330ee0SThomas Raoux                "result shape.");
12160de60b55SThomas Raoux #endif // ifndef NDEBUG
121768330ee0SThomas Raoux     }
121868330ee0SThomas Raoux 
12198345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1220be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1221be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1222be16075bSWen-Heng (Jack) Chung     //    address space 0.
12238345b86dSNicolas Vasilache     // TODO: support alignment when possible.
1224*563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
1225*563879b6SRahul Joshi         loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
12268345b86dSNicolas Vasilache     auto vecTy =
12278345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1228be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1229be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1230be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12318345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1232be16075bSWen-Heng (Jack) Chung     else
1233be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1234be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12358345b86dSNicolas Vasilache 
12361870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1237*563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1238*563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1239*563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12401870e787SNicolas Vasilache 
12418345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12428345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12438345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12448345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1245060c9dd1Saartbik     //
1246060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1247060c9dd1Saartbik     //       dimensions here.
1248060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1249060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12500c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
1251b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1252*563879b6SRahul Joshi     Value mask = buildVectorComparison(
1253*563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12548345b86dSNicolas Vasilache 
12558345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1256*563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1257dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12588345b86dSNicolas Vasilache   }
1259060c9dd1Saartbik 
1260060c9dd1Saartbik private:
1261060c9dd1Saartbik   const bool enableIndexOptimizations;
12628345b86dSNicolas Vasilache };
12638345b86dSNicolas Vasilache 
1264*563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1265d9b500d3SAart Bik public:
1266*563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1267d9b500d3SAart Bik 
1268d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1269d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1270d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1271d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1272d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1273d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1274d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1275d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1276d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1277d9b500d3SAart Bik   //
12789db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1279d9b500d3SAart Bik   //
12803145427dSRiver Riddle   LogicalResult
1281*563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1282d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
12832d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1284d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1285d9b500d3SAart Bik 
1286dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
12873145427dSRiver Riddle       return failure();
1288d9b500d3SAart Bik 
1289b8880f5fSAart Bik     // Make sure element type has runtime support.
1290b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1291d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1292d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1293d9b500d3SAart Bik     Operation *printer;
1294b8880f5fSAart Bik     if (eltType.isF32()) {
1295*563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1296b8880f5fSAart Bik     } else if (eltType.isF64()) {
1297*563879b6SRahul Joshi       printer = getPrintDouble(printOp);
129854759cefSAart Bik     } else if (eltType.isIndex()) {
1299*563879b6SRahul Joshi       printer = getPrintU64(printOp);
1300b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1301b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1302b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1303b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1304b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1305b8880f5fSAart Bik       if (intTy.isUnsigned()) {
130654759cefSAart Bik         if (width <= 64) {
1307b8880f5fSAart Bik           if (width < 64)
1308b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1309*563879b6SRahul Joshi           printer = getPrintU64(printOp);
1310b8880f5fSAart Bik         } else {
13113145427dSRiver Riddle           return failure();
1312b8880f5fSAart Bik         }
1313b8880f5fSAart Bik       } else {
1314b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
131554759cefSAart Bik         if (width <= 64) {
1316b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1317b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1318b8880f5fSAart Bik           if (width == 1)
131954759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
132054759cefSAart Bik           else if (width < 64)
1321b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1322*563879b6SRahul Joshi           printer = getPrintI64(printOp);
1323b8880f5fSAart Bik         } else {
1324b8880f5fSAart Bik           return failure();
1325b8880f5fSAart Bik         }
1326b8880f5fSAart Bik       }
1327b8880f5fSAart Bik     } else {
1328b8880f5fSAart Bik       return failure();
1329b8880f5fSAart Bik     }
1330d9b500d3SAart Bik 
1331d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1332b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1333*563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1334b8880f5fSAart Bik               conversion);
1335*563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1336*563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13373145427dSRiver Riddle     return success();
1338d9b500d3SAart Bik   }
1339d9b500d3SAart Bik 
1340d9b500d3SAart Bik private:
1341b8880f5fSAart Bik   enum class PrintConversion {
134230e6033bSNicolas Vasilache     // clang-format off
1343b8880f5fSAart Bik     None,
1344b8880f5fSAart Bik     ZeroExt64,
1345b8880f5fSAart Bik     SignExt64
134630e6033bSNicolas Vasilache     // clang-format on
1347b8880f5fSAart Bik   };
1348b8880f5fSAart Bik 
1349d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1350e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1351b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1352d9b500d3SAart Bik     Location loc = op->getLoc();
1353d9b500d3SAart Bik     if (rank == 0) {
1354b8880f5fSAart Bik       switch (conversion) {
1355b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1356b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1357b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1358b8880f5fSAart Bik         break;
1359b8880f5fSAart Bik       case PrintConversion::SignExt64:
1360b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1361b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1362b8880f5fSAart Bik         break;
1363b8880f5fSAart Bik       case PrintConversion::None:
1364b8880f5fSAart Bik         break;
1365c9eeeb38Saartbik       }
1366d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1367d9b500d3SAart Bik       return;
1368d9b500d3SAart Bik     }
1369d9b500d3SAart Bik 
1370d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1371d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1372d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1373d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1374d9b500d3SAart Bik       auto reducedType =
1375d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1376dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1377d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1378dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1379dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1380b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1381b8880f5fSAart Bik                 conversion);
1382d9b500d3SAart Bik       if (d != dim - 1)
1383d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1384d9b500d3SAart Bik     }
1385d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1386d9b500d3SAart Bik   }
1387d9b500d3SAart Bik 
1388d9b500d3SAart Bik   // Helper to emit a call.
1389d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1390d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
139108e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1392d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1393d9b500d3SAart Bik   }
1394d9b500d3SAart Bik 
1395d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
13965446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
13975446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1398d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1399d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1400d9b500d3SAart Bik     if (func)
1401d9b500d3SAart Bik       return func;
1402d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1403d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1404d9b500d3SAart Bik         op->getLoc(), name,
14055446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14065446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14075446ec85SAlex Zinenko             /*isVarArg=*/false));
1408d9b500d3SAart Bik   }
1409d9b500d3SAart Bik 
1410d9b500d3SAart Bik   // Helpers for method names.
1411e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
141254759cefSAart Bik     return getPrint(op, "printI64",
14135446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1414e52414b1Saartbik   }
1415b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1416b8880f5fSAart Bik     return getPrint(op, "printU64",
1417b8880f5fSAart Bik                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1418b8880f5fSAart Bik   }
1419d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
142054759cefSAart Bik     return getPrint(op, "printF32",
14215446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1422d9b500d3SAart Bik   }
1423d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
142454759cefSAart Bik     return getPrint(op, "printF64",
14255446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1426d9b500d3SAart Bik   }
1427d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
142854759cefSAart Bik     return getPrint(op, "printOpen", {});
1429d9b500d3SAart Bik   }
1430d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
143154759cefSAart Bik     return getPrint(op, "printClose", {});
1432d9b500d3SAart Bik   }
1433d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
143454759cefSAart Bik     return getPrint(op, "printComma", {});
1435d9b500d3SAart Bik   }
1436d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
143754759cefSAart Bik     return getPrint(op, "printNewline", {});
1438d9b500d3SAart Bik   }
1439d9b500d3SAart Bik };
1440d9b500d3SAart Bik 
1441334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1442c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1443c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1444c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1445334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
144665678d93SNicolas Vasilache public:
1447b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1448b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1449b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1450b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1451b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1452b99bd771SRiver Riddle   }
145365678d93SNicolas Vasilache 
1454334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
145565678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
145665678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
145765678d93SNicolas Vasilache 
145865678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
145965678d93SNicolas Vasilache 
146065678d93SNicolas Vasilache     int64_t offset =
146165678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
146265678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
146365678d93SNicolas Vasilache     int64_t stride =
146465678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
146565678d93SNicolas Vasilache 
146665678d93SNicolas Vasilache     auto loc = op.getLoc();
146765678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
146835b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1469c3c95b9cSaartbik 
1470c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1471c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1472c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1473c3c95b9cSaartbik       offsets.reserve(size);
1474c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1475c3c95b9cSaartbik            off += stride)
1476c3c95b9cSaartbik         offsets.push_back(off);
1477c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1478c3c95b9cSaartbik                                              op.vector(),
1479c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1480c3c95b9cSaartbik       return success();
1481c3c95b9cSaartbik     }
1482c3c95b9cSaartbik 
1483c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
148465678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
148565678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
148665678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
148765678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
148865678d93SNicolas Vasilache          off += stride, ++idx) {
1489c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1490c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1491c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
149265678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
149365678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
149465678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
149565678d93SNicolas Vasilache     }
1496c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14973145427dSRiver Riddle     return success();
149865678d93SNicolas Vasilache   }
149965678d93SNicolas Vasilache };
150065678d93SNicolas Vasilache 
1501df186507SBenjamin Kramer } // namespace
1502df186507SBenjamin Kramer 
15035c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15045c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1505ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1506060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
150765678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15088345b86dSNicolas Vasilache   // clang-format off
1509681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1510681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15112d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1512c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1513ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1514*563879b6SRahul Joshi       converter, reassociateFPReductions);
1515060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1516060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1517060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1518*563879b6SRahul Joshi       converter, enableIndexOptimizations);
15198345b86dSNicolas Vasilache   patterns
1520ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15218345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15228345b86dSNicolas Vasilache               VectorExtractOpConversion,
15238345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15248345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15258345b86dSNicolas Vasilache               VectorInsertOpConversion,
15268345b86dSNicolas Vasilache               VectorPrintOpConversion,
152719dbb230Saartbik               VectorTypeCastOpConversion,
152839379916Saartbik               VectorMaskedLoadOpConversion,
152939379916Saartbik               VectorMaskedStoreOpConversion,
153019dbb230Saartbik               VectorGatherOpConversion,
1531e8dcf5f8Saartbik               VectorScatterOpConversion,
1532e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1533*563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15348345b86dSNicolas Vasilache   // clang-format on
15355c0c51a9SNicolas Vasilache }
15365c0c51a9SNicolas Vasilache 
153763b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
153863b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1539*563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1540*563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
154163b683a8SNicolas Vasilache }
1542