15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
14626c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
14726c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1485f9e0466SNicolas Vasilache   if (!elementTy)
1495f9e0466SNicolas Vasilache     return failure();
1505f9e0466SNicolas Vasilache 
151b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
152b2ab375dSAlex Zinenko   // stop depending on translation.
15387a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15487a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
155b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
156168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1575f9e0466SNicolas Vasilache   return success();
1585f9e0466SNicolas Vasilache }
1595f9e0466SNicolas Vasilache 
160e8dcf5f8Saartbik // Helper that returns the base address of a memref.
161b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
162e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16319dbb230Saartbik   // Inspect stride and offset structure.
16419dbb230Saartbik   //
16519dbb230Saartbik   // TODO: flat memory only for now, generalize
16619dbb230Saartbik   //
16719dbb230Saartbik   int64_t offset;
16819dbb230Saartbik   SmallVector<int64_t, 4> strides;
16919dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
17019dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17119dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17219dbb230Saartbik     return failure();
173e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
174e8dcf5f8Saartbik   return success();
175e8dcf5f8Saartbik }
17619dbb230Saartbik 
177e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
178b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
179b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
180b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
181e8dcf5f8Saartbik   Value base;
182e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
183e8dcf5f8Saartbik     return failure();
1843a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
185e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
186e8dcf5f8Saartbik   return success();
187e8dcf5f8Saartbik }
188e8dcf5f8Saartbik 
18939379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
190b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
191b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
192b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
19339379916Saartbik   Value base;
19439379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
19539379916Saartbik     return failure();
1968de43b92SAlex Zinenko   auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
19739379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
19839379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
19939379916Saartbik   return success();
20039379916Saartbik }
20139379916Saartbik 
202e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
203e8dcf5f8Saartbik // vector.
204b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
205b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
206b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
207b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
208e8dcf5f8Saartbik   Value base;
209e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
210e8dcf5f8Saartbik     return failure();
2113a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
212*7ed9cfc7SAlex Zinenko   auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0));
2131485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
21419dbb230Saartbik   return success();
21519dbb230Saartbik }
21619dbb230Saartbik 
2175f9e0466SNicolas Vasilache static LogicalResult
2185f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2195f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2205f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2215f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
222affbc0cdSNicolas Vasilache   unsigned align;
22326c8f908SThomas Raoux   if (failed(getMemRefAlignment(
22426c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
225affbc0cdSNicolas Vasilache     return failure();
226affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2275f9e0466SNicolas Vasilache   return success();
2285f9e0466SNicolas Vasilache }
2295f9e0466SNicolas Vasilache 
2305f9e0466SNicolas Vasilache static LogicalResult
2315f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2325f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2335f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2345f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2355f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2365f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2375f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2385f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2395f9e0466SNicolas Vasilache 
2405f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2415f9e0466SNicolas Vasilache   if (!vecTy)
2425f9e0466SNicolas Vasilache     return failure();
2435f9e0466SNicolas Vasilache 
2445f9e0466SNicolas Vasilache   unsigned align;
24526c8f908SThomas Raoux   if (failed(getMemRefAlignment(
24626c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2475f9e0466SNicolas Vasilache     return failure();
2485f9e0466SNicolas Vasilache 
2495f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2505f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2515f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2525f9e0466SNicolas Vasilache   return success();
2535f9e0466SNicolas Vasilache }
2545f9e0466SNicolas Vasilache 
2555f9e0466SNicolas Vasilache static LogicalResult
2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2575f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2585f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2595f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
260affbc0cdSNicolas Vasilache   unsigned align;
26126c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26226c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
263affbc0cdSNicolas Vasilache     return failure();
2642d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
265affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
266affbc0cdSNicolas Vasilache                                              align);
2675f9e0466SNicolas Vasilache   return success();
2685f9e0466SNicolas Vasilache }
2695f9e0466SNicolas Vasilache 
2705f9e0466SNicolas Vasilache static LogicalResult
2715f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2725f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2735f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2745f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2755f9e0466SNicolas Vasilache   unsigned align;
27626c8f908SThomas Raoux   if (failed(getMemRefAlignment(
27726c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2785f9e0466SNicolas Vasilache     return failure();
2795f9e0466SNicolas Vasilache 
2802d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2815f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2825f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2835f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2845f9e0466SNicolas Vasilache   return success();
2855f9e0466SNicolas Vasilache }
2865f9e0466SNicolas Vasilache 
2872d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2882d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2892d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2905f9e0466SNicolas Vasilache }
2915f9e0466SNicolas Vasilache 
2922d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2932d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2942d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2955f9e0466SNicolas Vasilache }
2965f9e0466SNicolas Vasilache 
29790c01357SBenjamin Kramer namespace {
298e83b7b99Saartbik 
29963b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
30063b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
301563879b6SRahul Joshi class VectorMatmulOpConversion
302563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
30363b683a8SNicolas Vasilache public:
304563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
30563b683a8SNicolas Vasilache 
3063145427dSRiver Riddle   LogicalResult
307563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
30863b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3092d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
31063b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
311563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
312563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
313563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3143145427dSRiver Riddle     return success();
31563b683a8SNicolas Vasilache   }
31663b683a8SNicolas Vasilache };
31763b683a8SNicolas Vasilache 
318c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
319c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
320563879b6SRahul Joshi class VectorFlatTransposeOpConversion
321563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
322c295a65dSaartbik public:
323563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
324c295a65dSaartbik 
325c295a65dSaartbik   LogicalResult
326563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
327c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3282d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
329c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
330dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
331c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
332c295a65dSaartbik     return success();
333c295a65dSaartbik   }
334c295a65dSaartbik };
335c295a65dSaartbik 
33639379916Saartbik /// Conversion pattern for a vector.maskedload.
337563879b6SRahul Joshi class VectorMaskedLoadOpConversion
338563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
33939379916Saartbik public:
340563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
34139379916Saartbik 
34239379916Saartbik   LogicalResult
343563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
34439379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
345563879b6SRahul Joshi     auto loc = load->getLoc();
34639379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
34739379916Saartbik 
34839379916Saartbik     // Resolve alignment.
34939379916Saartbik     unsigned align;
35026c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
35126c8f908SThomas Raoux                                   align)))
35239379916Saartbik       return failure();
35339379916Saartbik 
354dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
35539379916Saartbik     Value ptr;
35639379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
35739379916Saartbik                           vtype, ptr)))
35839379916Saartbik       return failure();
35939379916Saartbik 
36039379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
36139379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
36239379916Saartbik         rewriter.getI32IntegerAttr(align));
36339379916Saartbik     return success();
36439379916Saartbik   }
36539379916Saartbik };
36639379916Saartbik 
36739379916Saartbik /// Conversion pattern for a vector.maskedstore.
368563879b6SRahul Joshi class VectorMaskedStoreOpConversion
369563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
37039379916Saartbik public:
371563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
37239379916Saartbik 
37339379916Saartbik   LogicalResult
374563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
37539379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
376563879b6SRahul Joshi     auto loc = store->getLoc();
37739379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
37839379916Saartbik 
37939379916Saartbik     // Resolve alignment.
38039379916Saartbik     unsigned align;
38126c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
38226c8f908SThomas Raoux                                   align)))
38339379916Saartbik       return failure();
38439379916Saartbik 
385dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
38639379916Saartbik     Value ptr;
38739379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
38839379916Saartbik                           vtype, ptr)))
38939379916Saartbik       return failure();
39039379916Saartbik 
39139379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
39239379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
39339379916Saartbik         rewriter.getI32IntegerAttr(align));
39439379916Saartbik     return success();
39539379916Saartbik   }
39639379916Saartbik };
39739379916Saartbik 
39819dbb230Saartbik /// Conversion pattern for a vector.gather.
399563879b6SRahul Joshi class VectorGatherOpConversion
400563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
40119dbb230Saartbik public:
402563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
40319dbb230Saartbik 
40419dbb230Saartbik   LogicalResult
405563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
40619dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
407563879b6SRahul Joshi     auto loc = gather->getLoc();
40819dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
40919dbb230Saartbik 
41019dbb230Saartbik     // Resolve alignment.
41119dbb230Saartbik     unsigned align;
41226c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
41326c8f908SThomas Raoux                                   align)))
41419dbb230Saartbik       return failure();
41519dbb230Saartbik 
41619dbb230Saartbik     // Get index ptrs.
41719dbb230Saartbik     VectorType vType = gather.getResultVectorType();
41819dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
41919dbb230Saartbik     Value ptrs;
420e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
421e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
42219dbb230Saartbik       return failure();
42319dbb230Saartbik 
42419dbb230Saartbik     // Replace with the gather intrinsic.
42519dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
426dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4270c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
42819dbb230Saartbik     return success();
42919dbb230Saartbik   }
43019dbb230Saartbik };
43119dbb230Saartbik 
43219dbb230Saartbik /// Conversion pattern for a vector.scatter.
433563879b6SRahul Joshi class VectorScatterOpConversion
434563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
43519dbb230Saartbik public:
436563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
43719dbb230Saartbik 
43819dbb230Saartbik   LogicalResult
439563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
44019dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
441563879b6SRahul Joshi     auto loc = scatter->getLoc();
44219dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
44319dbb230Saartbik 
44419dbb230Saartbik     // Resolve alignment.
44519dbb230Saartbik     unsigned align;
44626c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
44726c8f908SThomas Raoux                                   align)))
44819dbb230Saartbik       return failure();
44919dbb230Saartbik 
45019dbb230Saartbik     // Get index ptrs.
45119dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
45219dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
45319dbb230Saartbik     Value ptrs;
454e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
455e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
45619dbb230Saartbik       return failure();
45719dbb230Saartbik 
45819dbb230Saartbik     // Replace with the scatter intrinsic.
45919dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
46019dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
46119dbb230Saartbik         rewriter.getI32IntegerAttr(align));
46219dbb230Saartbik     return success();
46319dbb230Saartbik   }
46419dbb230Saartbik };
46519dbb230Saartbik 
466e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
467563879b6SRahul Joshi class VectorExpandLoadOpConversion
468563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
469e8dcf5f8Saartbik public:
470563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
471e8dcf5f8Saartbik 
472e8dcf5f8Saartbik   LogicalResult
473563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
474e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
475563879b6SRahul Joshi     auto loc = expand->getLoc();
476e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
477e8dcf5f8Saartbik 
478e8dcf5f8Saartbik     Value ptr;
479e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
480e8dcf5f8Saartbik                           ptr)))
481e8dcf5f8Saartbik       return failure();
482e8dcf5f8Saartbik 
483e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
484e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
485563879b6SRahul Joshi         expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
486e8dcf5f8Saartbik         adaptor.pass_thru());
487e8dcf5f8Saartbik     return success();
488e8dcf5f8Saartbik   }
489e8dcf5f8Saartbik };
490e8dcf5f8Saartbik 
491e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
492563879b6SRahul Joshi class VectorCompressStoreOpConversion
493563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
494e8dcf5f8Saartbik public:
495563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
496e8dcf5f8Saartbik 
497e8dcf5f8Saartbik   LogicalResult
498563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
499e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
500563879b6SRahul Joshi     auto loc = compress->getLoc();
501e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
502e8dcf5f8Saartbik 
503e8dcf5f8Saartbik     Value ptr;
504e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
505e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
506e8dcf5f8Saartbik       return failure();
507e8dcf5f8Saartbik 
508e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
509563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
510e8dcf5f8Saartbik     return success();
511e8dcf5f8Saartbik   }
512e8dcf5f8Saartbik };
513e8dcf5f8Saartbik 
51419dbb230Saartbik /// Conversion pattern for all vector reductions.
515563879b6SRahul Joshi class VectorReductionOpConversion
516563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
517e83b7b99Saartbik public:
518563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
519060c9dd1Saartbik                                        bool reassociateFPRed)
520563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
521060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
522e83b7b99Saartbik 
5233145427dSRiver Riddle   LogicalResult
524563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
525e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
526e83b7b99Saartbik     auto kind = reductionOp.kind();
527e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
528dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
529e9628955SAart Bik     if (eltType.isIntOrIndex()) {
530e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
531e83b7b99Saartbik       if (kind == "add")
532322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
533563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
534e83b7b99Saartbik       else if (kind == "mul")
535322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
536563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
537e9628955SAart Bik       else if (kind == "min" &&
538e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
539322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
540563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
541e83b7b99Saartbik       else if (kind == "min")
542322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
543563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
544e9628955SAart Bik       else if (kind == "max" &&
545e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
546322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
547563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
548e83b7b99Saartbik       else if (kind == "max")
549322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
550563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
551e83b7b99Saartbik       else if (kind == "and")
552322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
553563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
554e83b7b99Saartbik       else if (kind == "or")
555322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
556563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
557e83b7b99Saartbik       else if (kind == "xor")
558322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
559563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
560e83b7b99Saartbik       else
5613145427dSRiver Riddle         return failure();
5623145427dSRiver Riddle       return success();
563dcec2ca5SChristian Sigg     }
564e83b7b99Saartbik 
565dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
566dcec2ca5SChristian Sigg       return failure();
567dcec2ca5SChristian Sigg 
568e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
569e83b7b99Saartbik     if (kind == "add") {
5700d924700Saartbik       // Optional accumulator (or zero).
5710d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5720d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
573563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5740d924700Saartbik                                             rewriter.getZeroAttr(eltType));
575322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
576563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
577ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
578e83b7b99Saartbik     } else if (kind == "mul") {
5790d924700Saartbik       // Optional accumulator (or one).
5800d924700Saartbik       Value acc = operands.size() > 1
5810d924700Saartbik                       ? operands[1]
5820d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
583563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
5840d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
585322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
586563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
587ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
588e83b7b99Saartbik     } else if (kind == "min")
589563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
590563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
591e83b7b99Saartbik     else if (kind == "max")
592563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
593563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
594e83b7b99Saartbik     else
5953145427dSRiver Riddle       return failure();
5963145427dSRiver Riddle     return success();
597e83b7b99Saartbik   }
598ceb1b327Saartbik 
599ceb1b327Saartbik private:
600ceb1b327Saartbik   const bool reassociateFPReductions;
601e83b7b99Saartbik };
602e83b7b99Saartbik 
603060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
604563879b6SRahul Joshi class VectorCreateMaskOpConversion
605563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
606060c9dd1Saartbik public:
607563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
608060c9dd1Saartbik                                         bool enableIndexOpt)
609563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
610060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
611060c9dd1Saartbik 
612060c9dd1Saartbik   LogicalResult
613563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
614060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
615060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
616060c9dd1Saartbik     int64_t rank = dstType.getRank();
617060c9dd1Saartbik     if (rank == 1) {
618060c9dd1Saartbik       rewriter.replaceOp(
619060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
620060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
621060c9dd1Saartbik       return success();
622060c9dd1Saartbik     }
623060c9dd1Saartbik     return failure();
624060c9dd1Saartbik   }
625060c9dd1Saartbik 
626060c9dd1Saartbik private:
627060c9dd1Saartbik   const bool enableIndexOptimizations;
628060c9dd1Saartbik };
629060c9dd1Saartbik 
630563879b6SRahul Joshi class VectorShuffleOpConversion
631563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6321c81adf3SAart Bik public:
633563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6341c81adf3SAart Bik 
6353145427dSRiver Riddle   LogicalResult
636563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6371c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
638563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6392d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6401c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6411c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6421c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
643dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6441c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6451c81adf3SAart Bik 
6461c81adf3SAart Bik     // Bail if result type cannot be lowered.
6471c81adf3SAart Bik     if (!llvmType)
6483145427dSRiver Riddle       return failure();
6491c81adf3SAart Bik 
6501c81adf3SAart Bik     // Get rank and dimension sizes.
6511c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6521c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6531c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6541c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6551c81adf3SAart Bik 
6561c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6571c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6581c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
659563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6601c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
661563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6623145427dSRiver Riddle       return success();
663b36aaeafSAart Bik     }
664b36aaeafSAart Bik 
6651c81adf3SAart Bik     // For all other cases, insert the individual values individually.
666e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6671c81adf3SAart Bik     int64_t insPos = 0;
6681c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6691c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
670e62a6956SRiver Riddle       Value value = adaptor.v1();
6711c81adf3SAart Bik       if (extPos >= v1Dim) {
6721c81adf3SAart Bik         extPos -= v1Dim;
6731c81adf3SAart Bik         value = adaptor.v2();
674b36aaeafSAart Bik       }
675dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
676dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
677dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
6780f04384dSAlex Zinenko                          llvmType, rank, insPos++);
6791c81adf3SAart Bik     }
680563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
6813145427dSRiver Riddle     return success();
682b36aaeafSAart Bik   }
683b36aaeafSAart Bik };
684b36aaeafSAart Bik 
685563879b6SRahul Joshi class VectorExtractElementOpConversion
686563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
687cd5dab8aSAart Bik public:
688563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
689563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
690cd5dab8aSAart Bik 
6913145427dSRiver Riddle   LogicalResult
692563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
693563879b6SRahul Joshi                   ArrayRef<Value> operands,
694cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
6952d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
696cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
697dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
698cd5dab8aSAart Bik 
699cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
700cd5dab8aSAart Bik     if (!llvmType)
7013145427dSRiver Riddle       return failure();
702cd5dab8aSAart Bik 
703cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
704563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
7053145427dSRiver Riddle     return success();
706cd5dab8aSAart Bik   }
707cd5dab8aSAart Bik };
708cd5dab8aSAart Bik 
709563879b6SRahul Joshi class VectorExtractOpConversion
710563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7115c0c51a9SNicolas Vasilache public:
712563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7135c0c51a9SNicolas Vasilache 
7143145427dSRiver Riddle   LogicalResult
715563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7165c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
717563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7182d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7199826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7202bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
721dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7225c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7239826fe5cSAart Bik 
7249826fe5cSAart Bik     // Bail if result type cannot be lowered.
7259826fe5cSAart Bik     if (!llvmResultType)
7263145427dSRiver Riddle       return failure();
7279826fe5cSAart Bik 
7285c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7295c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
730e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7315c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
732563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7333145427dSRiver Riddle       return success();
7345c0c51a9SNicolas Vasilache     }
7355c0c51a9SNicolas Vasilache 
7369826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
737563879b6SRahul Joshi     auto *context = extractOp->getContext();
738e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7395c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7405c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7419826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7425c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7435c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7445c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
745dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7465c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7475c0c51a9SNicolas Vasilache     }
7485c0c51a9SNicolas Vasilache 
7495c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7505c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
751*7ed9cfc7SAlex Zinenko     auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
7521d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7535c0c51a9SNicolas Vasilache     extracted =
7545c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
755563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7565c0c51a9SNicolas Vasilache 
7573145427dSRiver Riddle     return success();
7585c0c51a9SNicolas Vasilache   }
7595c0c51a9SNicolas Vasilache };
7605c0c51a9SNicolas Vasilache 
761681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
762681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
763681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
764681f929fSNicolas Vasilache ///
765681f929fSNicolas Vasilache /// Example:
766681f929fSNicolas Vasilache /// ```
767681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
768681f929fSNicolas Vasilache /// ```
769681f929fSNicolas Vasilache /// is converted to:
770681f929fSNicolas Vasilache /// ```
7713bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
772681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
773681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
774681f929fSNicolas Vasilache /// ```
775563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
776681f929fSNicolas Vasilache public:
777563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
778681f929fSNicolas Vasilache 
7793145427dSRiver Riddle   LogicalResult
780563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
781681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7822d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
783681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
784681f929fSNicolas Vasilache     if (vType.getRank() != 1)
7853145427dSRiver Riddle       return failure();
786563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
7873bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
7883145427dSRiver Riddle     return success();
789681f929fSNicolas Vasilache   }
790681f929fSNicolas Vasilache };
791681f929fSNicolas Vasilache 
792563879b6SRahul Joshi class VectorInsertElementOpConversion
793563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
794cd5dab8aSAart Bik public:
795563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
796cd5dab8aSAart Bik 
7973145427dSRiver Riddle   LogicalResult
798563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
799cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8002d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
801cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
802dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
803cd5dab8aSAart Bik 
804cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
805cd5dab8aSAart Bik     if (!llvmType)
8063145427dSRiver Riddle       return failure();
807cd5dab8aSAart Bik 
808cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
809563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
810563879b6SRahul Joshi         adaptor.position());
8113145427dSRiver Riddle     return success();
812cd5dab8aSAart Bik   }
813cd5dab8aSAart Bik };
814cd5dab8aSAart Bik 
815563879b6SRahul Joshi class VectorInsertOpConversion
816563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8179826fe5cSAart Bik public:
818563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8199826fe5cSAart Bik 
8203145427dSRiver Riddle   LogicalResult
821563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8229826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
823563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8242d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8259826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8269826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
827dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8289826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8299826fe5cSAart Bik 
8309826fe5cSAart Bik     // Bail if result type cannot be lowered.
8319826fe5cSAart Bik     if (!llvmResultType)
8323145427dSRiver Riddle       return failure();
8339826fe5cSAart Bik 
8349826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8359826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
836e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8379826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8389826fe5cSAart Bik           positionArrayAttr);
839563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8403145427dSRiver Riddle       return success();
8419826fe5cSAart Bik     }
8429826fe5cSAart Bik 
8439826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
844563879b6SRahul Joshi     auto *context = insertOp->getContext();
845e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8469826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8479826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8489826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8499826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8509826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8519826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8529826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8539826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
854dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8559826fe5cSAart Bik           nMinusOnePositionAttrs);
8569826fe5cSAart Bik     }
8579826fe5cSAart Bik 
8589826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
859*7ed9cfc7SAlex Zinenko     auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
8601d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
861e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
862dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8630f04384dSAlex Zinenko         adaptor.source(), constant);
8649826fe5cSAart Bik 
8659826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8669826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8679826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8689826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8699826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8709826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8719826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8729826fe5cSAart Bik     }
8739826fe5cSAart Bik 
874563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8753145427dSRiver Riddle     return success();
8769826fe5cSAart Bik   }
8779826fe5cSAart Bik };
8789826fe5cSAart Bik 
879681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
880681f929fSNicolas Vasilache ///
881681f929fSNicolas Vasilache /// Example:
882681f929fSNicolas Vasilache /// ```
883681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
884681f929fSNicolas Vasilache /// ```
885681f929fSNicolas Vasilache /// is rewritten into:
886681f929fSNicolas Vasilache /// ```
887681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
888681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
889681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
890681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
891681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
892681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
893681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
894681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
895681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
896681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
897681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
898681f929fSNicolas Vasilache ///  // %r3 holds the final value.
899681f929fSNicolas Vasilache /// ```
900681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
901681f929fSNicolas Vasilache public:
902681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
903681f929fSNicolas Vasilache 
9043145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
905681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
906681f929fSNicolas Vasilache     auto vType = op.getVectorType();
907681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9083145427dSRiver Riddle       return failure();
909681f929fSNicolas Vasilache 
910681f929fSNicolas Vasilache     auto loc = op.getLoc();
911681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
912681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
913681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
914681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
915681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
916681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
917681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
918681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
919681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
920681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
921681f929fSNicolas Vasilache     }
922681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9233145427dSRiver Riddle     return success();
924681f929fSNicolas Vasilache   }
925681f929fSNicolas Vasilache };
926681f929fSNicolas Vasilache 
9272d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9282d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9292d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9302d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9312d515e49SNicolas Vasilache // rank.
9322d515e49SNicolas Vasilache //
9332d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9342d515e49SNicolas Vasilache // have different ranks. In this case:
9352d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9362d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9372d515e49SNicolas Vasilache //   destination subvector
9382d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9392d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9402d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9412d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9422d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9432d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9442d515e49SNicolas Vasilache public:
9452d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9462d515e49SNicolas Vasilache 
9473145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9482d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9492d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9502d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9512d515e49SNicolas Vasilache 
9522d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9533145427dSRiver Riddle       return failure();
9542d515e49SNicolas Vasilache 
9552d515e49SNicolas Vasilache     auto loc = op.getLoc();
9562d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9572d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9582d515e49SNicolas Vasilache     if (rankDiff == 0)
9593145427dSRiver Riddle       return failure();
9602d515e49SNicolas Vasilache 
9612d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9622d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9632d515e49SNicolas Vasilache     // on it.
9642d515e49SNicolas Vasilache     Value extracted =
9652d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9662d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
967dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9682d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9692d515e49SNicolas Vasilache     // ranks.
9702d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9712d515e49SNicolas Vasilache         loc, op.source(), extracted,
9722d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
973c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9742d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9752d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9762d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
977dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
9783145427dSRiver Riddle     return success();
9792d515e49SNicolas Vasilache   }
9802d515e49SNicolas Vasilache };
9812d515e49SNicolas Vasilache 
9822d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9832d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
9842d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9852d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9862d515e49SNicolas Vasilache //   destination subvector
9872d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9882d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9892d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9902d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9912d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
9922d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9932d515e49SNicolas Vasilache public:
994b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
995b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
996b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
997b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
998b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
999b99bd771SRiver Riddle   }
10002d515e49SNicolas Vasilache 
10013145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10022d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10032d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10042d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10052d515e49SNicolas Vasilache 
10062d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10073145427dSRiver Riddle       return failure();
10082d515e49SNicolas Vasilache 
10092d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10102d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10112d515e49SNicolas Vasilache     if (rankDiff != 0)
10123145427dSRiver Riddle       return failure();
10132d515e49SNicolas Vasilache 
10142d515e49SNicolas Vasilache     if (srcType == dstType) {
10152d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10163145427dSRiver Riddle       return success();
10172d515e49SNicolas Vasilache     }
10182d515e49SNicolas Vasilache 
10192d515e49SNicolas Vasilache     int64_t offset =
10202d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10212d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10222d515e49SNicolas Vasilache     int64_t stride =
10232d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10242d515e49SNicolas Vasilache 
10252d515e49SNicolas Vasilache     auto loc = op.getLoc();
10262d515e49SNicolas Vasilache     Value res = op.dest();
10272d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10282d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10292d515e49SNicolas Vasilache          off += stride, ++idx) {
10302d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10312d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10322d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10332d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10342d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10352d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10362d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10372d515e49SNicolas Vasilache         // smaller rank.
1038bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10392d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10402d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10412d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10422d515e49SNicolas Vasilache       }
10432d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10442d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10452d515e49SNicolas Vasilache     }
10462d515e49SNicolas Vasilache 
10472d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10483145427dSRiver Riddle     return success();
10492d515e49SNicolas Vasilache   }
10502d515e49SNicolas Vasilache };
10512d515e49SNicolas Vasilache 
105230e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
105330e6033bSNicolas Vasilache /// static layout.
105430e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
105530e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10562bf491c7SBenjamin Kramer   int64_t offset;
105730e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
105830e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
105930e6033bSNicolas Vasilache     return None;
106030e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
106130e6033bSNicolas Vasilache     return None;
106230e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
106330e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
106430e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
106530e6033bSNicolas Vasilache     return strides;
106630e6033bSNicolas Vasilache 
106730e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
106830e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
106930e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
107030e6033bSNicolas Vasilache   // layout.
10712bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10722bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
107330e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
107430e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
107530e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
107630e6033bSNicolas Vasilache       return None;
107730e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
107830e6033bSNicolas Vasilache       return None;
10792bf491c7SBenjamin Kramer   }
108030e6033bSNicolas Vasilache   return strides;
10812bf491c7SBenjamin Kramer }
10822bf491c7SBenjamin Kramer 
1083563879b6SRahul Joshi class VectorTypeCastOpConversion
1084563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
10855c0c51a9SNicolas Vasilache public:
1086563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
10875c0c51a9SNicolas Vasilache 
10883145427dSRiver Riddle   LogicalResult
1089563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
10905c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1091563879b6SRahul Joshi     auto loc = castOp->getLoc();
10925c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
10932bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
10945c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
10952bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
10965c0c51a9SNicolas Vasilache 
10975c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
10985c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
10995c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11003145427dSRiver Riddle       return failure();
11015c0c51a9SNicolas Vasilache 
11025c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11038de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
11048de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
11053145427dSRiver Riddle       return failure();
11065c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11075c0c51a9SNicolas Vasilache 
1108dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11098de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
11108de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
11113145427dSRiver Riddle       return failure();
11125c0c51a9SNicolas Vasilache 
111330e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
111430e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
111530e6033bSNicolas Vasilache     if (!sourceStrides)
111630e6033bSNicolas Vasilache       return failure();
111730e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
111830e6033bSNicolas Vasilache     if (!targetStrides)
111930e6033bSNicolas Vasilache       return failure();
112030e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
112130e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
112230e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
112330e6033bSNicolas Vasilache         }))
11243145427dSRiver Riddle       return failure();
11255c0c51a9SNicolas Vasilache 
1126*7ed9cfc7SAlex Zinenko     auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64);
11275c0c51a9SNicolas Vasilache 
11285c0c51a9SNicolas Vasilache     // Create descriptor.
11295c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11303a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11315c0c51a9SNicolas Vasilache     // Set allocated ptr.
1132e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11335c0c51a9SNicolas Vasilache     allocated =
11345c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11355c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11365c0c51a9SNicolas Vasilache     // Set aligned ptr.
1137e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11385c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11395c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11405c0c51a9SNicolas Vasilache     // Fill offset 0.
11415c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11425c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11435c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11445c0c51a9SNicolas Vasilache 
11455c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11465c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11475c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11485c0c51a9SNicolas Vasilache       auto sizeAttr =
11495c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11505c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11515c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
115230e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
115330e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11545c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11555c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11565c0c51a9SNicolas Vasilache     }
11575c0c51a9SNicolas Vasilache 
1158563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11593145427dSRiver Riddle     return success();
11605c0c51a9SNicolas Vasilache   }
11615c0c51a9SNicolas Vasilache };
11625c0c51a9SNicolas Vasilache 
11638345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11648345b86dSNicolas Vasilache /// sequence of:
1165060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1166060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1167060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1168060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1169060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11708345b86dSNicolas Vasilache template <typename ConcreteOp>
1171563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11728345b86dSNicolas Vasilache public:
1173563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1174060c9dd1Saartbik                                     bool enableIndexOpt)
1175563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1176060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11778345b86dSNicolas Vasilache 
11788345b86dSNicolas Vasilache   LogicalResult
1179563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
11808345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11818345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1182b2c79c50SNicolas Vasilache 
1183b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1184b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
11858345b86dSNicolas Vasilache       return failure();
11865f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
11875f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
11885f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1189563879b6SRahul Joshi                                        xferOp->getContext()))
11908345b86dSNicolas Vasilache       return failure();
119126c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
119226c8f908SThomas Raoux     if (!memRefType)
119326c8f908SThomas Raoux       return failure();
11942bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
119526c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
119630e6033bSNicolas Vasilache     if (!strides)
11972bf491c7SBenjamin Kramer       return failure();
11988345b86dSNicolas Vasilache 
1199563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1200563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1201563879b6SRahul Joshi     };
12028345b86dSNicolas Vasilache 
1203563879b6SRahul Joshi     Location loc = xferOp->getLoc();
12048345b86dSNicolas Vasilache 
120568330ee0SThomas Raoux     if (auto memrefVectorElementType =
120626c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
120768330ee0SThomas Raoux       // Memref has vector element type.
120868330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
120968330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
121068330ee0SThomas Raoux         return failure();
12110de60b55SThomas Raoux #ifndef NDEBUG
121268330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
121368330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
121468330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
121568330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
121668330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
121768330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
121868330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
121968330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
122068330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
122168330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
122268330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
122368330ee0SThomas Raoux                "result shape.");
12240de60b55SThomas Raoux #endif // ifndef NDEBUG
122568330ee0SThomas Raoux     }
122668330ee0SThomas Raoux 
12278345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1228be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1229be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1230be16075bSWen-Heng (Jack) Chung     //    address space 0.
12318345b86dSNicolas Vasilache     // TODO: support alignment when possible.
1232563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
123326c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
12348de43b92SAlex Zinenko     auto vecTy = toLLVMTy(xferOp.getVectorType())
12358de43b92SAlex Zinenko                      .template cast<LLVM::LLVMFixedVectorType>();
1236be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1237be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
12388de43b92SAlex Zinenko       vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
12398de43b92SAlex Zinenko           loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
1240be16075bSWen-Heng (Jack) Chung     else
1241be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
12428de43b92SAlex Zinenko           loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
12438345b86dSNicolas Vasilache 
12441870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1245563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1246563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1247563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12481870e787SNicolas Vasilache 
12498345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12508345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12518345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12528345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1253060c9dd1Saartbik     //
1254060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1255060c9dd1Saartbik     //       dimensions here.
12568de43b92SAlex Zinenko     unsigned vecWidth = vecTy.getNumElements();
1257060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12580c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
125926c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1260563879b6SRahul Joshi     Value mask = buildVectorComparison(
1261563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12628345b86dSNicolas Vasilache 
12638345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1264563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1265dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12668345b86dSNicolas Vasilache   }
1267060c9dd1Saartbik 
1268060c9dd1Saartbik private:
1269060c9dd1Saartbik   const bool enableIndexOptimizations;
12708345b86dSNicolas Vasilache };
12718345b86dSNicolas Vasilache 
1272563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1273d9b500d3SAart Bik public:
1274563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1275d9b500d3SAart Bik 
1276d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1277d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1278d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1279d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1280d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1281d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1282d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1283d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1284d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1285d9b500d3SAart Bik   //
12869db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1287d9b500d3SAart Bik   //
12883145427dSRiver Riddle   LogicalResult
1289563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1290d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
12912d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1292d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1293d9b500d3SAart Bik 
1294dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
12953145427dSRiver Riddle       return failure();
1296d9b500d3SAart Bik 
1297b8880f5fSAart Bik     // Make sure element type has runtime support.
1298b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1299d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1300d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1301d9b500d3SAart Bik     Operation *printer;
1302b8880f5fSAart Bik     if (eltType.isF32()) {
1303563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1304b8880f5fSAart Bik     } else if (eltType.isF64()) {
1305563879b6SRahul Joshi       printer = getPrintDouble(printOp);
130654759cefSAart Bik     } else if (eltType.isIndex()) {
1307563879b6SRahul Joshi       printer = getPrintU64(printOp);
1308b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1309b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1310b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1311b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1312b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1313b8880f5fSAart Bik       if (intTy.isUnsigned()) {
131454759cefSAart Bik         if (width <= 64) {
1315b8880f5fSAart Bik           if (width < 64)
1316b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1317563879b6SRahul Joshi           printer = getPrintU64(printOp);
1318b8880f5fSAart Bik         } else {
13193145427dSRiver Riddle           return failure();
1320b8880f5fSAart Bik         }
1321b8880f5fSAart Bik       } else {
1322b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
132354759cefSAart Bik         if (width <= 64) {
1324b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1325b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1326b8880f5fSAart Bik           if (width == 1)
132754759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
132854759cefSAart Bik           else if (width < 64)
1329b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1330563879b6SRahul Joshi           printer = getPrintI64(printOp);
1331b8880f5fSAart Bik         } else {
1332b8880f5fSAart Bik           return failure();
1333b8880f5fSAart Bik         }
1334b8880f5fSAart Bik       }
1335b8880f5fSAart Bik     } else {
1336b8880f5fSAart Bik       return failure();
1337b8880f5fSAart Bik     }
1338d9b500d3SAart Bik 
1339d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1340b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1341563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1342b8880f5fSAart Bik               conversion);
1343563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1344563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13453145427dSRiver Riddle     return success();
1346d9b500d3SAart Bik   }
1347d9b500d3SAart Bik 
1348d9b500d3SAart Bik private:
1349b8880f5fSAart Bik   enum class PrintConversion {
135030e6033bSNicolas Vasilache     // clang-format off
1351b8880f5fSAart Bik     None,
1352b8880f5fSAart Bik     ZeroExt64,
1353b8880f5fSAart Bik     SignExt64
135430e6033bSNicolas Vasilache     // clang-format on
1355b8880f5fSAart Bik   };
1356b8880f5fSAart Bik 
1357d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1358e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1359b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1360d9b500d3SAart Bik     Location loc = op->getLoc();
1361d9b500d3SAart Bik     if (rank == 0) {
1362b8880f5fSAart Bik       switch (conversion) {
1363b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1364b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1365*7ed9cfc7SAlex Zinenko             loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64));
1366b8880f5fSAart Bik         break;
1367b8880f5fSAart Bik       case PrintConversion::SignExt64:
1368b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1369*7ed9cfc7SAlex Zinenko             loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64));
1370b8880f5fSAart Bik         break;
1371b8880f5fSAart Bik       case PrintConversion::None:
1372b8880f5fSAart Bik         break;
1373c9eeeb38Saartbik       }
1374d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1375d9b500d3SAart Bik       return;
1376d9b500d3SAart Bik     }
1377d9b500d3SAart Bik 
1378d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1379d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1380d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1381d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1382d9b500d3SAart Bik       auto reducedType =
1383d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1384dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1385d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1386dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1387dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1388b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1389b8880f5fSAart Bik                 conversion);
1390d9b500d3SAart Bik       if (d != dim - 1)
1391d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1392d9b500d3SAart Bik     }
1393d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1394d9b500d3SAart Bik   }
1395d9b500d3SAart Bik 
1396d9b500d3SAart Bik   // Helper to emit a call.
1397d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1398d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
139908e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1400d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1401d9b500d3SAart Bik   }
1402d9b500d3SAart Bik 
1403d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14045446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
14055446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1406d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1407d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1408d9b500d3SAart Bik     if (func)
1409d9b500d3SAart Bik       return func;
1410d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1411d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1412d9b500d3SAart Bik         op->getLoc(), name,
1413*7ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
1414*7ed9cfc7SAlex Zinenko                                     params));
1415d9b500d3SAart Bik   }
1416d9b500d3SAart Bik 
1417d9b500d3SAart Bik   // Helpers for method names.
1418e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
141954759cefSAart Bik     return getPrint(op, "printI64",
1420*7ed9cfc7SAlex Zinenko                     LLVM::LLVMIntegerType::get(op->getContext(), 64));
1421e52414b1Saartbik   }
1422b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1423b8880f5fSAart Bik     return getPrint(op, "printU64",
1424*7ed9cfc7SAlex Zinenko                     LLVM::LLVMIntegerType::get(op->getContext(), 64));
1425b8880f5fSAart Bik   }
1426d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
1427*7ed9cfc7SAlex Zinenko     return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext()));
1428d9b500d3SAart Bik   }
1429d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
143054759cefSAart Bik     return getPrint(op, "printF64",
1431*7ed9cfc7SAlex Zinenko                     LLVM::LLVMDoubleType::get(op->getContext()));
1432d9b500d3SAart Bik   }
1433d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
143454759cefSAart Bik     return getPrint(op, "printOpen", {});
1435d9b500d3SAart Bik   }
1436d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
143754759cefSAart Bik     return getPrint(op, "printClose", {});
1438d9b500d3SAart Bik   }
1439d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
144054759cefSAart Bik     return getPrint(op, "printComma", {});
1441d9b500d3SAart Bik   }
1442d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
144354759cefSAart Bik     return getPrint(op, "printNewline", {});
1444d9b500d3SAart Bik   }
1445d9b500d3SAart Bik };
1446d9b500d3SAart Bik 
1447334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1448c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1449c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1450c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1451334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
145265678d93SNicolas Vasilache public:
1453b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1454b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1455b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1456b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1457b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1458b99bd771SRiver Riddle   }
145965678d93SNicolas Vasilache 
1460334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
146165678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
146265678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
146365678d93SNicolas Vasilache 
146465678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
146565678d93SNicolas Vasilache 
146665678d93SNicolas Vasilache     int64_t offset =
146765678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
146865678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
146965678d93SNicolas Vasilache     int64_t stride =
147065678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
147165678d93SNicolas Vasilache 
147265678d93SNicolas Vasilache     auto loc = op.getLoc();
147365678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
147435b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1475c3c95b9cSaartbik 
1476c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1477c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1478c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1479c3c95b9cSaartbik       offsets.reserve(size);
1480c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1481c3c95b9cSaartbik            off += stride)
1482c3c95b9cSaartbik         offsets.push_back(off);
1483c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1484c3c95b9cSaartbik                                              op.vector(),
1485c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1486c3c95b9cSaartbik       return success();
1487c3c95b9cSaartbik     }
1488c3c95b9cSaartbik 
1489c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
149065678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
149165678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
149265678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
149365678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
149465678d93SNicolas Vasilache          off += stride, ++idx) {
1495c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1496c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1497c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
149865678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
149965678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
150065678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
150165678d93SNicolas Vasilache     }
1502c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15033145427dSRiver Riddle     return success();
150465678d93SNicolas Vasilache   }
150565678d93SNicolas Vasilache };
150665678d93SNicolas Vasilache 
1507df186507SBenjamin Kramer } // namespace
1508df186507SBenjamin Kramer 
15095c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15105c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1511ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1512060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
151365678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15148345b86dSNicolas Vasilache   // clang-format off
1515681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1516681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15172d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1518c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1519ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1520563879b6SRahul Joshi       converter, reassociateFPReductions);
1521060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1522060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1523060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1524563879b6SRahul Joshi       converter, enableIndexOptimizations);
15258345b86dSNicolas Vasilache   patterns
1526ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15278345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15288345b86dSNicolas Vasilache               VectorExtractOpConversion,
15298345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15308345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15318345b86dSNicolas Vasilache               VectorInsertOpConversion,
15328345b86dSNicolas Vasilache               VectorPrintOpConversion,
153319dbb230Saartbik               VectorTypeCastOpConversion,
153439379916Saartbik               VectorMaskedLoadOpConversion,
153539379916Saartbik               VectorMaskedStoreOpConversion,
153619dbb230Saartbik               VectorGatherOpConversion,
1537e8dcf5f8Saartbik               VectorScatterOpConversion,
1538e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1539563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15408345b86dSNicolas Vasilache   // clang-format on
15415c0c51a9SNicolas Vasilache }
15425c0c51a9SNicolas Vasilache 
154363b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
154463b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1545563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1546563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
154763b683a8SNicolas Vasilache }
1548