15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13*e332c22cSNicolas Vasilache #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1709f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
18ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
195c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
205c0c51a9SNicolas Vasilache 
215c0c51a9SNicolas Vasilache using namespace mlir;
2265678d93SNicolas Vasilache using namespace mlir::vector;
235c0c51a9SNicolas Vasilache 
249826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
259826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
269826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
279826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
289826fe5cSAart Bik }
299826fe5cSAart Bik 
309826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
319826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
329826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
339826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
349826fe5cSAart Bik }
359826fe5cSAart Bik 
361c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
37e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
380f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
390f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
400f04384dSAlex Zinenko                        int64_t pos) {
411c81adf3SAart Bik   if (rank == 1) {
421c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
431c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
440f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
451c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
461c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
471c81adf3SAart Bik                                                   constant);
481c81adf3SAart Bik   }
491c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
501c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
511c81adf3SAart Bik }
521c81adf3SAart Bik 
532d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
542d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
552d515e49SNicolas Vasilache                        Value into, int64_t offset) {
562d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
572d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
582d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
592d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
602d515e49SNicolas Vasilache       loc, vectorType, from, into,
612d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
622d515e49SNicolas Vasilache }
632d515e49SNicolas Vasilache 
641c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
65e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
660f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
670f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
681c81adf3SAart Bik   if (rank == 1) {
691c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
701c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
710f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
721c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
731c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
741c81adf3SAart Bik                                                    constant);
751c81adf3SAart Bik   }
761c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
771c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
781c81adf3SAart Bik }
791c81adf3SAart Bik 
802d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
812d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
822d515e49SNicolas Vasilache                         int64_t offset) {
832d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
842d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
852d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
862d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
872d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
882d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
892d515e49SNicolas Vasilache }
902d515e49SNicolas Vasilache 
912d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
929db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
932d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
942d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
952d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
962d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
972d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
982d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
992d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
1002d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1012d515e49SNicolas Vasilache        it != eit; ++it)
1022d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1032d515e49SNicolas Vasilache   return res;
1042d515e49SNicolas Vasilache }
1052d515e49SNicolas Vasilache 
106ba87f991SAlex Zinenko static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
107ba87f991SAlex Zinenko                                    Location loc, Type targetType, Value value) {
108ba87f991SAlex Zinenko   if (targetType == value.getType())
109ba87f991SAlex Zinenko     return value;
110ba87f991SAlex Zinenko 
111ba87f991SAlex Zinenko   bool targetIsIndex = targetType.isIndex();
112ba87f991SAlex Zinenko   bool valueIsIndex = value.getType().isIndex();
113ba87f991SAlex Zinenko   if (targetIsIndex ^ valueIsIndex)
114ba87f991SAlex Zinenko     return rewriter.create<IndexCastOp>(loc, targetType, value);
115ba87f991SAlex Zinenko 
116ba87f991SAlex Zinenko   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
117ba87f991SAlex Zinenko   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
118ba87f991SAlex Zinenko   assert(targetIntegerType && valueIntegerType &&
119ba87f991SAlex Zinenko          "unexpected cast between types other than integers and index");
120ba87f991SAlex Zinenko   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
121ba87f991SAlex Zinenko 
122ba87f991SAlex Zinenko   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
123ba87f991SAlex Zinenko     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
124ba87f991SAlex Zinenko   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
125ba87f991SAlex Zinenko }
126ba87f991SAlex Zinenko 
127060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
128060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
129060c9dd1Saartbik //
130060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
131060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
132060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
133060c9dd1Saartbik //       is very conservative on the boundary conditions.
134060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
135060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
136060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
137060c9dd1Saartbik   auto loc = op->getLoc();
138060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
139060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
140060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
141060c9dd1Saartbik   Value indices;
142060c9dd1Saartbik   Type idxType;
143060c9dd1Saartbik   if (enableIndexOptimizations) {
1440c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1450c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1460c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
147060c9dd1Saartbik     idxType = rewriter.getI32Type();
148060c9dd1Saartbik   } else {
1490c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1500c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1510c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
152060c9dd1Saartbik     idxType = rewriter.getI64Type();
153060c9dd1Saartbik   }
154060c9dd1Saartbik   // Add in an offset if requested.
155060c9dd1Saartbik   if (off) {
156ba87f991SAlex Zinenko     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
157060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
158060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
159060c9dd1Saartbik   }
160060c9dd1Saartbik   // Construct the vector comparison.
161ba87f991SAlex Zinenko   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
162060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
163060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
164060c9dd1Saartbik }
165060c9dd1Saartbik 
16626c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
16726c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
16826c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
16926c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1705f9e0466SNicolas Vasilache   if (!elementTy)
1715f9e0466SNicolas Vasilache     return failure();
1725f9e0466SNicolas Vasilache 
173b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
174b2ab375dSAlex Zinenko   // stop depending on translation.
17587a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
17687a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
177c69c9e0fSAlex Zinenko               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
1785f9e0466SNicolas Vasilache   return success();
1795f9e0466SNicolas Vasilache }
1805f9e0466SNicolas Vasilache 
181e8dcf5f8Saartbik // Helper that returns the base address of a memref.
182b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
183e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
18419dbb230Saartbik   // Inspect stride and offset structure.
18519dbb230Saartbik   //
18619dbb230Saartbik   // TODO: flat memory only for now, generalize
18719dbb230Saartbik   //
18819dbb230Saartbik   int64_t offset;
18919dbb230Saartbik   SmallVector<int64_t, 4> strides;
19019dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
19119dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
19219dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
19319dbb230Saartbik     return failure();
194e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
195e8dcf5f8Saartbik   return success();
196e8dcf5f8Saartbik }
19719dbb230Saartbik 
198a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector.
199b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
200b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
201b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
202b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
203e8dcf5f8Saartbik   Value base;
204e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
205e8dcf5f8Saartbik     return failure();
2063a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
207bd30a796SAlex Zinenko   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
2081485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
20919dbb230Saartbik   return success();
21019dbb230Saartbik }
21119dbb230Saartbik 
212a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer
213a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be
214a57def30SAart Bik // used when source/dst memrefs are not on address space 0.
215a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
216a57def30SAart Bik                          Value ptr, MemRefType memRefType, Type vt) {
217bd30a796SAlex Zinenko   auto pType = LLVM::LLVMPointerType::get(vt);
218a57def30SAart Bik   if (memRefType.getMemorySpace() == 0)
219a57def30SAart Bik     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
220a57def30SAart Bik   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
221a57def30SAart Bik }
222a57def30SAart Bik 
2235f9e0466SNicolas Vasilache static LogicalResult
2245f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2255f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2265f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2275f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
228affbc0cdSNicolas Vasilache   unsigned align;
22926c8f908SThomas Raoux   if (failed(getMemRefAlignment(
23026c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
231affbc0cdSNicolas Vasilache     return failure();
232affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2335f9e0466SNicolas Vasilache   return success();
2345f9e0466SNicolas Vasilache }
2355f9e0466SNicolas Vasilache 
2365f9e0466SNicolas Vasilache static LogicalResult
2375f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2385f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2395f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2405f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2415f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2425f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2435f9e0466SNicolas Vasilache 
2445f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2455f9e0466SNicolas Vasilache   if (!vecTy)
2465f9e0466SNicolas Vasilache     return failure();
2475f9e0466SNicolas Vasilache 
2485f9e0466SNicolas Vasilache   unsigned align;
24926c8f908SThomas Raoux   if (failed(getMemRefAlignment(
25026c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2515f9e0466SNicolas Vasilache     return failure();
2525f9e0466SNicolas Vasilache 
2535f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2545f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2555f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2565f9e0466SNicolas Vasilache   return success();
2575f9e0466SNicolas Vasilache }
2585f9e0466SNicolas Vasilache 
2595f9e0466SNicolas Vasilache static LogicalResult
2605f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2615f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2625f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2635f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
264affbc0cdSNicolas Vasilache   unsigned align;
26526c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26626c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
267affbc0cdSNicolas Vasilache     return failure();
2682d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
269affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
270affbc0cdSNicolas Vasilache                                              align);
2715f9e0466SNicolas Vasilache   return success();
2725f9e0466SNicolas Vasilache }
2735f9e0466SNicolas Vasilache 
2745f9e0466SNicolas Vasilache static LogicalResult
2755f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2765f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2775f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2785f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2795f9e0466SNicolas Vasilache   unsigned align;
28026c8f908SThomas Raoux   if (failed(getMemRefAlignment(
28126c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2825f9e0466SNicolas Vasilache     return failure();
2835f9e0466SNicolas Vasilache 
2842d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2855f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2865f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2875f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2885f9e0466SNicolas Vasilache   return success();
2895f9e0466SNicolas Vasilache }
2905f9e0466SNicolas Vasilache 
2912d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2922d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2932d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2945f9e0466SNicolas Vasilache }
2955f9e0466SNicolas Vasilache 
2962d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2972d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2982d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2995f9e0466SNicolas Vasilache }
3005f9e0466SNicolas Vasilache 
30190c01357SBenjamin Kramer namespace {
302e83b7b99Saartbik 
303cf5c517cSDiego Caballero /// Conversion pattern for a vector.bitcast.
304cf5c517cSDiego Caballero class VectorBitCastOpConversion
305cf5c517cSDiego Caballero     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
306cf5c517cSDiego Caballero public:
307cf5c517cSDiego Caballero   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
308cf5c517cSDiego Caballero 
309cf5c517cSDiego Caballero   LogicalResult
310cf5c517cSDiego Caballero   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
311cf5c517cSDiego Caballero                   ConversionPatternRewriter &rewriter) const override {
312cf5c517cSDiego Caballero     // Only 1-D vectors can be lowered to LLVM.
313cf5c517cSDiego Caballero     VectorType resultTy = bitCastOp.getType();
314cf5c517cSDiego Caballero     if (resultTy.getRank() != 1)
315cf5c517cSDiego Caballero       return failure();
316cf5c517cSDiego Caballero     Type newResultTy = typeConverter->convertType(resultTy);
317cf5c517cSDiego Caballero     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
318cf5c517cSDiego Caballero                                                  operands[0]);
319cf5c517cSDiego Caballero     return success();
320cf5c517cSDiego Caballero   }
321cf5c517cSDiego Caballero };
322cf5c517cSDiego Caballero 
32363b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
32463b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
325563879b6SRahul Joshi class VectorMatmulOpConversion
326563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
32763b683a8SNicolas Vasilache public:
328563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
32963b683a8SNicolas Vasilache 
3303145427dSRiver Riddle   LogicalResult
331563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
33263b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3332d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
33463b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
335563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
336563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
337563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3383145427dSRiver Riddle     return success();
33963b683a8SNicolas Vasilache   }
34063b683a8SNicolas Vasilache };
34163b683a8SNicolas Vasilache 
342c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
343c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
344563879b6SRahul Joshi class VectorFlatTransposeOpConversion
345563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
346c295a65dSaartbik public:
347563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
348c295a65dSaartbik 
349c295a65dSaartbik   LogicalResult
350563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
351c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3522d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
353c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
354dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
355c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
356c295a65dSaartbik     return success();
357c295a65dSaartbik   }
358c295a65dSaartbik };
359c295a65dSaartbik 
36039379916Saartbik /// Conversion pattern for a vector.maskedload.
361563879b6SRahul Joshi class VectorMaskedLoadOpConversion
362563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
36339379916Saartbik public:
364563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
36539379916Saartbik 
36639379916Saartbik   LogicalResult
367563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
36839379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
369563879b6SRahul Joshi     auto loc = load->getLoc();
37039379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
371a57def30SAart Bik     MemRefType memRefType = load.getMemRefType();
37239379916Saartbik 
37339379916Saartbik     // Resolve alignment.
37439379916Saartbik     unsigned align;
375a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
37639379916Saartbik       return failure();
37739379916Saartbik 
378a57def30SAart Bik     // Resolve address.
379dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
380a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
381a57def30SAart Bik                                                adaptor.indices(), rewriter);
382a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
38339379916Saartbik 
38439379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
38539379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
38639379916Saartbik         rewriter.getI32IntegerAttr(align));
38739379916Saartbik     return success();
38839379916Saartbik   }
38939379916Saartbik };
39039379916Saartbik 
39139379916Saartbik /// Conversion pattern for a vector.maskedstore.
392563879b6SRahul Joshi class VectorMaskedStoreOpConversion
393563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
39439379916Saartbik public:
395563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
39639379916Saartbik 
39739379916Saartbik   LogicalResult
398563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
39939379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
400563879b6SRahul Joshi     auto loc = store->getLoc();
40139379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
402a57def30SAart Bik     MemRefType memRefType = store.getMemRefType();
40339379916Saartbik 
40439379916Saartbik     // Resolve alignment.
40539379916Saartbik     unsigned align;
406a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
40739379916Saartbik       return failure();
40839379916Saartbik 
409a57def30SAart Bik     // Resolve address.
410dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
411a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
412a57def30SAart Bik                                                adaptor.indices(), rewriter);
413a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
41439379916Saartbik 
41539379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
41639379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
41739379916Saartbik         rewriter.getI32IntegerAttr(align));
41839379916Saartbik     return success();
41939379916Saartbik   }
42039379916Saartbik };
42139379916Saartbik 
42219dbb230Saartbik /// Conversion pattern for a vector.gather.
423563879b6SRahul Joshi class VectorGatherOpConversion
424563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
42519dbb230Saartbik public:
426563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
42719dbb230Saartbik 
42819dbb230Saartbik   LogicalResult
429563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
43019dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
431563879b6SRahul Joshi     auto loc = gather->getLoc();
43219dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
43319dbb230Saartbik 
43419dbb230Saartbik     // Resolve alignment.
43519dbb230Saartbik     unsigned align;
43626c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
43726c8f908SThomas Raoux                                   align)))
43819dbb230Saartbik       return failure();
43919dbb230Saartbik 
44019dbb230Saartbik     // Get index ptrs.
44119dbb230Saartbik     VectorType vType = gather.getResultVectorType();
44219dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
44319dbb230Saartbik     Value ptrs;
444e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
445e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
44619dbb230Saartbik       return failure();
44719dbb230Saartbik 
44819dbb230Saartbik     // Replace with the gather intrinsic.
44919dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
450dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4510c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
45219dbb230Saartbik     return success();
45319dbb230Saartbik   }
45419dbb230Saartbik };
45519dbb230Saartbik 
45619dbb230Saartbik /// Conversion pattern for a vector.scatter.
457563879b6SRahul Joshi class VectorScatterOpConversion
458563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
45919dbb230Saartbik public:
460563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
46119dbb230Saartbik 
46219dbb230Saartbik   LogicalResult
463563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
46419dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
465563879b6SRahul Joshi     auto loc = scatter->getLoc();
46619dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
46719dbb230Saartbik 
46819dbb230Saartbik     // Resolve alignment.
46919dbb230Saartbik     unsigned align;
47026c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
47126c8f908SThomas Raoux                                   align)))
47219dbb230Saartbik       return failure();
47319dbb230Saartbik 
47419dbb230Saartbik     // Get index ptrs.
47519dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
47619dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
47719dbb230Saartbik     Value ptrs;
478e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
479e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
48019dbb230Saartbik       return failure();
48119dbb230Saartbik 
48219dbb230Saartbik     // Replace with the scatter intrinsic.
48319dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
48419dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
48519dbb230Saartbik         rewriter.getI32IntegerAttr(align));
48619dbb230Saartbik     return success();
48719dbb230Saartbik   }
48819dbb230Saartbik };
48919dbb230Saartbik 
490e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
491563879b6SRahul Joshi class VectorExpandLoadOpConversion
492563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
493e8dcf5f8Saartbik public:
494563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
495e8dcf5f8Saartbik 
496e8dcf5f8Saartbik   LogicalResult
497563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
498e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
499563879b6SRahul Joshi     auto loc = expand->getLoc();
500e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
501a57def30SAart Bik     MemRefType memRefType = expand.getMemRefType();
502e8dcf5f8Saartbik 
503a57def30SAart Bik     // Resolve address.
504a57def30SAart Bik     auto vtype = typeConverter->convertType(expand.getResultVectorType());
505a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
506a57def30SAart Bik                                            adaptor.indices(), rewriter);
507e8dcf5f8Saartbik 
508e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
509a57def30SAart Bik         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
510e8dcf5f8Saartbik     return success();
511e8dcf5f8Saartbik   }
512e8dcf5f8Saartbik };
513e8dcf5f8Saartbik 
514e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
515563879b6SRahul Joshi class VectorCompressStoreOpConversion
516563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
517e8dcf5f8Saartbik public:
518563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
519e8dcf5f8Saartbik 
520e8dcf5f8Saartbik   LogicalResult
521563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
522e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
523563879b6SRahul Joshi     auto loc = compress->getLoc();
524e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
525a57def30SAart Bik     MemRefType memRefType = compress.getMemRefType();
526e8dcf5f8Saartbik 
527a57def30SAart Bik     // Resolve address.
528a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
529a57def30SAart Bik                                            adaptor.indices(), rewriter);
530e8dcf5f8Saartbik 
531e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
532563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
533e8dcf5f8Saartbik     return success();
534e8dcf5f8Saartbik   }
535e8dcf5f8Saartbik };
536e8dcf5f8Saartbik 
53719dbb230Saartbik /// Conversion pattern for all vector reductions.
538563879b6SRahul Joshi class VectorReductionOpConversion
539563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
540e83b7b99Saartbik public:
541563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
542060c9dd1Saartbik                                        bool reassociateFPRed)
543563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
544060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
545e83b7b99Saartbik 
5463145427dSRiver Riddle   LogicalResult
547563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
548e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
549e83b7b99Saartbik     auto kind = reductionOp.kind();
550e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
551dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
552e9628955SAart Bik     if (eltType.isIntOrIndex()) {
553e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
554e83b7b99Saartbik       if (kind == "add")
555322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
556563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
557e83b7b99Saartbik       else if (kind == "mul")
558322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
559563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
560e9628955SAart Bik       else if (kind == "min" &&
561e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
562322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
563563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
564e83b7b99Saartbik       else if (kind == "min")
565322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
566563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
567e9628955SAart Bik       else if (kind == "max" &&
568e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
569322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
570563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
571e83b7b99Saartbik       else if (kind == "max")
572322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
573563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
574e83b7b99Saartbik       else if (kind == "and")
575322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
576563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
577e83b7b99Saartbik       else if (kind == "or")
578322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
579563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
580e83b7b99Saartbik       else if (kind == "xor")
581322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
582563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
583e83b7b99Saartbik       else
5843145427dSRiver Riddle         return failure();
5853145427dSRiver Riddle       return success();
586dcec2ca5SChristian Sigg     }
587e83b7b99Saartbik 
588dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
589dcec2ca5SChristian Sigg       return failure();
590dcec2ca5SChristian Sigg 
591e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
592e83b7b99Saartbik     if (kind == "add") {
5930d924700Saartbik       // Optional accumulator (or zero).
5940d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5950d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
596563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5970d924700Saartbik                                             rewriter.getZeroAttr(eltType));
598322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
599563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
600ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
601e83b7b99Saartbik     } else if (kind == "mul") {
6020d924700Saartbik       // Optional accumulator (or one).
6030d924700Saartbik       Value acc = operands.size() > 1
6040d924700Saartbik                       ? operands[1]
6050d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
606563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
6070d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
608322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
609563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
610ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
611e83b7b99Saartbik     } else if (kind == "min")
612563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
613563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
614e83b7b99Saartbik     else if (kind == "max")
615563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
616563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
617e83b7b99Saartbik     else
6183145427dSRiver Riddle       return failure();
6193145427dSRiver Riddle     return success();
620e83b7b99Saartbik   }
621ceb1b327Saartbik 
622ceb1b327Saartbik private:
623ceb1b327Saartbik   const bool reassociateFPReductions;
624e83b7b99Saartbik };
625e83b7b99Saartbik 
626060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
627563879b6SRahul Joshi class VectorCreateMaskOpConversion
628563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
629060c9dd1Saartbik public:
630563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
631060c9dd1Saartbik                                         bool enableIndexOpt)
632563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
633060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
634060c9dd1Saartbik 
635060c9dd1Saartbik   LogicalResult
636563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
637060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
6389eb3e564SChris Lattner     auto dstType = op.getType();
639060c9dd1Saartbik     int64_t rank = dstType.getRank();
640060c9dd1Saartbik     if (rank == 1) {
641060c9dd1Saartbik       rewriter.replaceOp(
642060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
643060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
644060c9dd1Saartbik       return success();
645060c9dd1Saartbik     }
646060c9dd1Saartbik     return failure();
647060c9dd1Saartbik   }
648060c9dd1Saartbik 
649060c9dd1Saartbik private:
650060c9dd1Saartbik   const bool enableIndexOptimizations;
651060c9dd1Saartbik };
652060c9dd1Saartbik 
653563879b6SRahul Joshi class VectorShuffleOpConversion
654563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6551c81adf3SAart Bik public:
656563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6571c81adf3SAart Bik 
6583145427dSRiver Riddle   LogicalResult
659563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6601c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
661563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6622d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6631c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6641c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6651c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
666dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6671c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6681c81adf3SAart Bik 
6691c81adf3SAart Bik     // Bail if result type cannot be lowered.
6701c81adf3SAart Bik     if (!llvmType)
6713145427dSRiver Riddle       return failure();
6721c81adf3SAart Bik 
6731c81adf3SAart Bik     // Get rank and dimension sizes.
6741c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6751c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6761c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6771c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6781c81adf3SAart Bik 
6791c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6801c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6811c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
682563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6831c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
684563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6853145427dSRiver Riddle       return success();
686b36aaeafSAart Bik     }
687b36aaeafSAart Bik 
6881c81adf3SAart Bik     // For all other cases, insert the individual values individually.
689e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6901c81adf3SAart Bik     int64_t insPos = 0;
6911c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6921c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
693e62a6956SRiver Riddle       Value value = adaptor.v1();
6941c81adf3SAart Bik       if (extPos >= v1Dim) {
6951c81adf3SAart Bik         extPos -= v1Dim;
6961c81adf3SAart Bik         value = adaptor.v2();
697b36aaeafSAart Bik       }
698dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
699dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
700dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
7010f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7021c81adf3SAart Bik     }
703563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
7043145427dSRiver Riddle     return success();
705b36aaeafSAart Bik   }
706b36aaeafSAart Bik };
707b36aaeafSAart Bik 
708563879b6SRahul Joshi class VectorExtractElementOpConversion
709563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
710cd5dab8aSAart Bik public:
711563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
712563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
713cd5dab8aSAart Bik 
7143145427dSRiver Riddle   LogicalResult
715563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
716563879b6SRahul Joshi                   ArrayRef<Value> operands,
717cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7182d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
719cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
720dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
721cd5dab8aSAart Bik 
722cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
723cd5dab8aSAart Bik     if (!llvmType)
7243145427dSRiver Riddle       return failure();
725cd5dab8aSAart Bik 
726cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
727563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
7283145427dSRiver Riddle     return success();
729cd5dab8aSAart Bik   }
730cd5dab8aSAart Bik };
731cd5dab8aSAart Bik 
732563879b6SRahul Joshi class VectorExtractOpConversion
733563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7345c0c51a9SNicolas Vasilache public:
735563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7365c0c51a9SNicolas Vasilache 
7373145427dSRiver Riddle   LogicalResult
738563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7395c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
740563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7412d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7429826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7432bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
744dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7455c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7469826fe5cSAart Bik 
7479826fe5cSAart Bik     // Bail if result type cannot be lowered.
7489826fe5cSAart Bik     if (!llvmResultType)
7493145427dSRiver Riddle       return failure();
7509826fe5cSAart Bik 
7515c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7525c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
753e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7545c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
755563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7563145427dSRiver Riddle       return success();
7575c0c51a9SNicolas Vasilache     }
7585c0c51a9SNicolas Vasilache 
7599826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
760563879b6SRahul Joshi     auto *context = extractOp->getContext();
761e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7625c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7635c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7649826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7655c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
766c2c83e97STres Popp           ArrayAttr::get(context, positionAttrs.drop_back());
7675c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
768dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7695c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7705c0c51a9SNicolas Vasilache     }
7715c0c51a9SNicolas Vasilache 
7725c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7735c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7742230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
7751d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7765c0c51a9SNicolas Vasilache     extracted =
7775c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
778563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7795c0c51a9SNicolas Vasilache 
7803145427dSRiver Riddle     return success();
7815c0c51a9SNicolas Vasilache   }
7825c0c51a9SNicolas Vasilache };
7835c0c51a9SNicolas Vasilache 
784681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
785681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
786681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
787681f929fSNicolas Vasilache ///
788681f929fSNicolas Vasilache /// Example:
789681f929fSNicolas Vasilache /// ```
790681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
791681f929fSNicolas Vasilache /// ```
792681f929fSNicolas Vasilache /// is converted to:
793681f929fSNicolas Vasilache /// ```
7943bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
795dd5165a9SAlex Zinenko ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
796dd5165a9SAlex Zinenko ///    -> !llvm."<8 x f32>">
797681f929fSNicolas Vasilache /// ```
798563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
799681f929fSNicolas Vasilache public:
800563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
801681f929fSNicolas Vasilache 
8023145427dSRiver Riddle   LogicalResult
803563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
804681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8052d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
806681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
807681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8083145427dSRiver Riddle       return failure();
809563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
8103bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8113145427dSRiver Riddle     return success();
812681f929fSNicolas Vasilache   }
813681f929fSNicolas Vasilache };
814681f929fSNicolas Vasilache 
815563879b6SRahul Joshi class VectorInsertElementOpConversion
816563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
817cd5dab8aSAart Bik public:
818563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
819cd5dab8aSAart Bik 
8203145427dSRiver Riddle   LogicalResult
821563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
822cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8232d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
824cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
825dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
826cd5dab8aSAart Bik 
827cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
828cd5dab8aSAart Bik     if (!llvmType)
8293145427dSRiver Riddle       return failure();
830cd5dab8aSAart Bik 
831cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
832563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
833563879b6SRahul Joshi         adaptor.position());
8343145427dSRiver Riddle     return success();
835cd5dab8aSAart Bik   }
836cd5dab8aSAart Bik };
837cd5dab8aSAart Bik 
838563879b6SRahul Joshi class VectorInsertOpConversion
839563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8409826fe5cSAart Bik public:
841563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8429826fe5cSAart Bik 
8433145427dSRiver Riddle   LogicalResult
844563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8459826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
846563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8472d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8489826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8499826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
850dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8519826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8529826fe5cSAart Bik 
8539826fe5cSAart Bik     // Bail if result type cannot be lowered.
8549826fe5cSAart Bik     if (!llvmResultType)
8553145427dSRiver Riddle       return failure();
8569826fe5cSAart Bik 
8579826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8589826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
859e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8609826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8619826fe5cSAart Bik           positionArrayAttr);
862563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8633145427dSRiver Riddle       return success();
8649826fe5cSAart Bik     }
8659826fe5cSAart Bik 
8669826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
867563879b6SRahul Joshi     auto *context = insertOp->getContext();
868e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8699826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8709826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8719826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8729826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8739826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8749826fe5cSAart Bik       auto nMinusOnePositionAttrs =
875c2c83e97STres Popp           ArrayAttr::get(context, positionAttrs.drop_back());
8769826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
877dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8789826fe5cSAart Bik           nMinusOnePositionAttrs);
8799826fe5cSAart Bik     }
8809826fe5cSAart Bik 
8819826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8822230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
8831d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
884e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
885dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8860f04384dSAlex Zinenko         adaptor.source(), constant);
8879826fe5cSAart Bik 
8889826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8899826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8909826fe5cSAart Bik       auto nMinusOnePositionAttrs =
891c2c83e97STres Popp           ArrayAttr::get(context, positionAttrs.drop_back());
8929826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8939826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8949826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8959826fe5cSAart Bik     }
8969826fe5cSAart Bik 
897563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8983145427dSRiver Riddle     return success();
8999826fe5cSAart Bik   }
9009826fe5cSAart Bik };
9019826fe5cSAart Bik 
902681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
903681f929fSNicolas Vasilache ///
904681f929fSNicolas Vasilache /// Example:
905681f929fSNicolas Vasilache /// ```
906681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
907681f929fSNicolas Vasilache /// ```
908681f929fSNicolas Vasilache /// is rewritten into:
909681f929fSNicolas Vasilache /// ```
910681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
911681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
912681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
913681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
914681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
915681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
916681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
917681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
918681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
919681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
920681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
921681f929fSNicolas Vasilache ///  // %r3 holds the final value.
922681f929fSNicolas Vasilache /// ```
923681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
924681f929fSNicolas Vasilache public:
925681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
926681f929fSNicolas Vasilache 
9273145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
928681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
929681f929fSNicolas Vasilache     auto vType = op.getVectorType();
930681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9313145427dSRiver Riddle       return failure();
932681f929fSNicolas Vasilache 
933681f929fSNicolas Vasilache     auto loc = op.getLoc();
934681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
935681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
936681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
937681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
938681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
939681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
940681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
941681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
942681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
943681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
944681f929fSNicolas Vasilache     }
945681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9463145427dSRiver Riddle     return success();
947681f929fSNicolas Vasilache   }
948681f929fSNicolas Vasilache };
949681f929fSNicolas Vasilache 
9502d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9512d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9522d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9532d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9542d515e49SNicolas Vasilache // rank.
9552d515e49SNicolas Vasilache //
9562d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9572d515e49SNicolas Vasilache // have different ranks. In this case:
9582d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9592d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9602d515e49SNicolas Vasilache //   destination subvector
9612d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9622d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9632d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9642d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9652d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9662d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9672d515e49SNicolas Vasilache public:
9682d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9692d515e49SNicolas Vasilache 
9703145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9712d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9722d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9732d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9742d515e49SNicolas Vasilache 
9752d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9763145427dSRiver Riddle       return failure();
9772d515e49SNicolas Vasilache 
9782d515e49SNicolas Vasilache     auto loc = op.getLoc();
9792d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9802d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9812d515e49SNicolas Vasilache     if (rankDiff == 0)
9823145427dSRiver Riddle       return failure();
9832d515e49SNicolas Vasilache 
9842d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9852d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9862d515e49SNicolas Vasilache     // on it.
9872d515e49SNicolas Vasilache     Value extracted =
9882d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9892d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
990dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9912d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9922d515e49SNicolas Vasilache     // ranks.
9932d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9942d515e49SNicolas Vasilache         loc, op.source(), extracted,
9952d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
996c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9972d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9982d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9992d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
1000dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
10013145427dSRiver Riddle     return success();
10022d515e49SNicolas Vasilache   }
10032d515e49SNicolas Vasilache };
10042d515e49SNicolas Vasilache 
10052d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10062d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10072d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10082d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10092d515e49SNicolas Vasilache //   destination subvector
10102d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10112d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10122d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10132d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10142d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10152d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10162d515e49SNicolas Vasilache public:
1017b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1018b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1019b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1020b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1021b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1022b99bd771SRiver Riddle   }
10232d515e49SNicolas Vasilache 
10243145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10252d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10262d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10272d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10282d515e49SNicolas Vasilache 
10292d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10303145427dSRiver Riddle       return failure();
10312d515e49SNicolas Vasilache 
10322d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10332d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10342d515e49SNicolas Vasilache     if (rankDiff != 0)
10353145427dSRiver Riddle       return failure();
10362d515e49SNicolas Vasilache 
10372d515e49SNicolas Vasilache     if (srcType == dstType) {
10382d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10393145427dSRiver Riddle       return success();
10402d515e49SNicolas Vasilache     }
10412d515e49SNicolas Vasilache 
10422d515e49SNicolas Vasilache     int64_t offset =
10432d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10442d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10452d515e49SNicolas Vasilache     int64_t stride =
10462d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10472d515e49SNicolas Vasilache 
10482d515e49SNicolas Vasilache     auto loc = op.getLoc();
10492d515e49SNicolas Vasilache     Value res = op.dest();
10502d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10512d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10522d515e49SNicolas Vasilache          off += stride, ++idx) {
10532d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10542d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10552d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10562d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10572d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10582d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10592d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10602d515e49SNicolas Vasilache         // smaller rank.
1061bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10622d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10632d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10642d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10652d515e49SNicolas Vasilache       }
10662d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10672d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10682d515e49SNicolas Vasilache     }
10692d515e49SNicolas Vasilache 
10702d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10713145427dSRiver Riddle     return success();
10722d515e49SNicolas Vasilache   }
10732d515e49SNicolas Vasilache };
10742d515e49SNicolas Vasilache 
107530e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
107630e6033bSNicolas Vasilache /// static layout.
107730e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
107830e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10792bf491c7SBenjamin Kramer   int64_t offset;
108030e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
108130e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
108230e6033bSNicolas Vasilache     return None;
108330e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
108430e6033bSNicolas Vasilache     return None;
108530e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
108630e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
108730e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
108830e6033bSNicolas Vasilache     return strides;
108930e6033bSNicolas Vasilache 
109030e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
109130e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
109230e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
109330e6033bSNicolas Vasilache   // layout.
10942bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10952bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
109630e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
109730e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
109830e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
109930e6033bSNicolas Vasilache       return None;
110030e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
110130e6033bSNicolas Vasilache       return None;
11022bf491c7SBenjamin Kramer   }
110330e6033bSNicolas Vasilache   return strides;
11042bf491c7SBenjamin Kramer }
11052bf491c7SBenjamin Kramer 
1106563879b6SRahul Joshi class VectorTypeCastOpConversion
1107563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
11085c0c51a9SNicolas Vasilache public:
1109563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
11105c0c51a9SNicolas Vasilache 
11113145427dSRiver Riddle   LogicalResult
1112563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
11135c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1114563879b6SRahul Joshi     auto loc = castOp->getLoc();
11155c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11162bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11179eb3e564SChris Lattner     MemRefType targetMemRefType = castOp.getType();
11185c0c51a9SNicolas Vasilache 
11195c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11205c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11215c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11223145427dSRiver Riddle       return failure();
11235c0c51a9SNicolas Vasilache 
11245c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11258de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
11268de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
11273145427dSRiver Riddle       return failure();
11285c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11295c0c51a9SNicolas Vasilache 
1130dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11318de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
11328de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
11333145427dSRiver Riddle       return failure();
11345c0c51a9SNicolas Vasilache 
113530e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
113630e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
113730e6033bSNicolas Vasilache     if (!sourceStrides)
113830e6033bSNicolas Vasilache       return failure();
113930e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
114030e6033bSNicolas Vasilache     if (!targetStrides)
114130e6033bSNicolas Vasilache       return failure();
114230e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
114330e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
114430e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
114530e6033bSNicolas Vasilache         }))
11463145427dSRiver Riddle       return failure();
11475c0c51a9SNicolas Vasilache 
11482230bf99SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
11495c0c51a9SNicolas Vasilache 
11505c0c51a9SNicolas Vasilache     // Create descriptor.
11515c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11523a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11535c0c51a9SNicolas Vasilache     // Set allocated ptr.
1154e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11555c0c51a9SNicolas Vasilache     allocated =
11565c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11575c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11585c0c51a9SNicolas Vasilache     // Set aligned ptr.
1159e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11605c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11615c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11625c0c51a9SNicolas Vasilache     // Fill offset 0.
11635c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11645c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11655c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11665c0c51a9SNicolas Vasilache 
11675c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11685c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11695c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11705c0c51a9SNicolas Vasilache       auto sizeAttr =
11715c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11725c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11735c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
117430e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
117530e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11765c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11775c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11785c0c51a9SNicolas Vasilache     }
11795c0c51a9SNicolas Vasilache 
1180563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11813145427dSRiver Riddle     return success();
11825c0c51a9SNicolas Vasilache   }
11835c0c51a9SNicolas Vasilache };
11845c0c51a9SNicolas Vasilache 
11858345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11868345b86dSNicolas Vasilache /// sequence of:
1187060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1188060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1189060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1190060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1191060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11928345b86dSNicolas Vasilache template <typename ConcreteOp>
1193563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11948345b86dSNicolas Vasilache public:
1195563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1196060c9dd1Saartbik                                     bool enableIndexOpt)
1197563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1198060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11998345b86dSNicolas Vasilache 
12008345b86dSNicolas Vasilache   LogicalResult
1201563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
12028345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12038345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1204b2c79c50SNicolas Vasilache 
1205b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1206b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12078345b86dSNicolas Vasilache       return failure();
12085f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12095f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12105f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1211563879b6SRahul Joshi                                        xferOp->getContext()))
12128345b86dSNicolas Vasilache       return failure();
121326c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
121426c8f908SThomas Raoux     if (!memRefType)
121526c8f908SThomas Raoux       return failure();
12162bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
121726c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
121830e6033bSNicolas Vasilache     if (!strides)
12192bf491c7SBenjamin Kramer       return failure();
12208345b86dSNicolas Vasilache 
1221563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1222563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1223563879b6SRahul Joshi     };
12248345b86dSNicolas Vasilache 
1225563879b6SRahul Joshi     Location loc = xferOp->getLoc();
12268345b86dSNicolas Vasilache 
122768330ee0SThomas Raoux     if (auto memrefVectorElementType =
122826c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
122968330ee0SThomas Raoux       // Memref has vector element type.
123068330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
123168330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
123268330ee0SThomas Raoux         return failure();
12330de60b55SThomas Raoux #ifndef NDEBUG
123468330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
123568330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
123668330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
123768330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
123868330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
123968330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
124068330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
124168330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
124268330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
124368330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
124468330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
124568330ee0SThomas Raoux                "result shape.");
12460de60b55SThomas Raoux #endif // ifndef NDEBUG
124768330ee0SThomas Raoux     }
124868330ee0SThomas Raoux 
12498345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1250a57def30SAart Bik     VectorType vtp = xferOp.getVectorType();
1251563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
125226c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1253a57def30SAart Bik     Value vectorDataPtr =
1254a57def30SAart Bik         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
12558345b86dSNicolas Vasilache 
12561870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1257563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1258563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1259563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12601870e787SNicolas Vasilache 
12618345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12628345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12638345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12648345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1265060c9dd1Saartbik     //
1266060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1267060c9dd1Saartbik     //       dimensions here.
1268bd30a796SAlex Zinenko     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1269060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12700c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
127126c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1272563879b6SRahul Joshi     Value mask = buildVectorComparison(
1273563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12748345b86dSNicolas Vasilache 
12758345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1276563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1277dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12788345b86dSNicolas Vasilache   }
1279060c9dd1Saartbik 
1280060c9dd1Saartbik private:
1281060c9dd1Saartbik   const bool enableIndexOptimizations;
12828345b86dSNicolas Vasilache };
12838345b86dSNicolas Vasilache 
1284563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1285d9b500d3SAart Bik public:
1286563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1287d9b500d3SAart Bik 
1288d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1289d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1290d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1291d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1292d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1293d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1294d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1295d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1296d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1297d9b500d3SAart Bik   //
12989db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1299d9b500d3SAart Bik   //
13003145427dSRiver Riddle   LogicalResult
1301563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1302d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
13032d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1304d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1305d9b500d3SAart Bik 
1306dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
13073145427dSRiver Riddle       return failure();
1308d9b500d3SAart Bik 
1309b8880f5fSAart Bik     // Make sure element type has runtime support.
1310b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1311d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1312d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1313d9b500d3SAart Bik     Operation *printer;
1314b8880f5fSAart Bik     if (eltType.isF32()) {
1315*e332c22cSNicolas Vasilache       printer =
1316*e332c22cSNicolas Vasilache           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1317b8880f5fSAart Bik     } else if (eltType.isF64()) {
1318*e332c22cSNicolas Vasilache       printer =
1319*e332c22cSNicolas Vasilache           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
132054759cefSAart Bik     } else if (eltType.isIndex()) {
1321*e332c22cSNicolas Vasilache       printer =
1322*e332c22cSNicolas Vasilache           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1323b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1324b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1325b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1326b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1327b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1328b8880f5fSAart Bik       if (intTy.isUnsigned()) {
132954759cefSAart Bik         if (width <= 64) {
1330b8880f5fSAart Bik           if (width < 64)
1331b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1332*e332c22cSNicolas Vasilache           printer = LLVM::lookupOrCreatePrintU64Fn(
1333*e332c22cSNicolas Vasilache               printOp->getParentOfType<ModuleOp>());
1334b8880f5fSAart Bik         } else {
13353145427dSRiver Riddle           return failure();
1336b8880f5fSAart Bik         }
1337b8880f5fSAart Bik       } else {
1338b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
133954759cefSAart Bik         if (width <= 64) {
1340b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1341b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1342b8880f5fSAart Bik           if (width == 1)
134354759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
134454759cefSAart Bik           else if (width < 64)
1345b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1346*e332c22cSNicolas Vasilache           printer = LLVM::lookupOrCreatePrintI64Fn(
1347*e332c22cSNicolas Vasilache               printOp->getParentOfType<ModuleOp>());
1348b8880f5fSAart Bik         } else {
1349b8880f5fSAart Bik           return failure();
1350b8880f5fSAart Bik         }
1351b8880f5fSAart Bik       }
1352b8880f5fSAart Bik     } else {
1353b8880f5fSAart Bik       return failure();
1354b8880f5fSAart Bik     }
1355d9b500d3SAart Bik 
1356d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1357b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1358563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1359b8880f5fSAart Bik               conversion);
1360*e332c22cSNicolas Vasilache     emitCall(rewriter, printOp->getLoc(),
1361*e332c22cSNicolas Vasilache              LLVM::lookupOrCreatePrintNewlineFn(
1362*e332c22cSNicolas Vasilache                  printOp->getParentOfType<ModuleOp>()));
1363563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13643145427dSRiver Riddle     return success();
1365d9b500d3SAart Bik   }
1366d9b500d3SAart Bik 
1367d9b500d3SAart Bik private:
1368b8880f5fSAart Bik   enum class PrintConversion {
136930e6033bSNicolas Vasilache     // clang-format off
1370b8880f5fSAart Bik     None,
1371b8880f5fSAart Bik     ZeroExt64,
1372b8880f5fSAart Bik     SignExt64
137330e6033bSNicolas Vasilache     // clang-format on
1374b8880f5fSAart Bik   };
1375b8880f5fSAart Bik 
1376d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1377e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1378b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1379d9b500d3SAart Bik     Location loc = op->getLoc();
1380d9b500d3SAart Bik     if (rank == 0) {
1381b8880f5fSAart Bik       switch (conversion) {
1382b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1383b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
13842230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1385b8880f5fSAart Bik         break;
1386b8880f5fSAart Bik       case PrintConversion::SignExt64:
1387b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
13882230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1389b8880f5fSAart Bik         break;
1390b8880f5fSAart Bik       case PrintConversion::None:
1391b8880f5fSAart Bik         break;
1392c9eeeb38Saartbik       }
1393d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1394d9b500d3SAart Bik       return;
1395d9b500d3SAart Bik     }
1396d9b500d3SAart Bik 
1397*e332c22cSNicolas Vasilache     emitCall(rewriter, loc,
1398*e332c22cSNicolas Vasilache              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1399*e332c22cSNicolas Vasilache     Operation *printComma =
1400*e332c22cSNicolas Vasilache         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1401d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1402d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1403d9b500d3SAart Bik       auto reducedType =
1404d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1405dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1406d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1407dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1408dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1409b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1410b8880f5fSAart Bik                 conversion);
1411d9b500d3SAart Bik       if (d != dim - 1)
1412d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1413d9b500d3SAart Bik     }
1414*e332c22cSNicolas Vasilache     emitCall(rewriter, loc,
1415*e332c22cSNicolas Vasilache              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1416d9b500d3SAart Bik   }
1417d9b500d3SAart Bik 
1418d9b500d3SAart Bik   // Helper to emit a call.
1419d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1420d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
142108e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1422d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1423d9b500d3SAart Bik   }
1424d9b500d3SAart Bik };
1425d9b500d3SAart Bik 
1426334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1427c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1428c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1429c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1430334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
143165678d93SNicolas Vasilache public:
1432b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1433b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1434b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1435b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1436b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1437b99bd771SRiver Riddle   }
143865678d93SNicolas Vasilache 
1439334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
144065678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
14419eb3e564SChris Lattner     auto dstType = op.getType();
144265678d93SNicolas Vasilache 
144365678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
144465678d93SNicolas Vasilache 
144565678d93SNicolas Vasilache     int64_t offset =
144665678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
144765678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
144865678d93SNicolas Vasilache     int64_t stride =
144965678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
145065678d93SNicolas Vasilache 
145165678d93SNicolas Vasilache     auto loc = op.getLoc();
145265678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
145335b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1454c3c95b9cSaartbik 
1455c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1456c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1457c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1458c3c95b9cSaartbik       offsets.reserve(size);
1459c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1460c3c95b9cSaartbik            off += stride)
1461c3c95b9cSaartbik         offsets.push_back(off);
1462c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1463c3c95b9cSaartbik                                              op.vector(),
1464c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1465c3c95b9cSaartbik       return success();
1466c3c95b9cSaartbik     }
1467c3c95b9cSaartbik 
1468c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
146965678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
147065678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
147165678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
147265678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
147365678d93SNicolas Vasilache          off += stride, ++idx) {
1474c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1475c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1476c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
147765678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
147865678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
147965678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
148065678d93SNicolas Vasilache     }
1481c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14823145427dSRiver Riddle     return success();
148365678d93SNicolas Vasilache   }
148465678d93SNicolas Vasilache };
148565678d93SNicolas Vasilache 
1486df186507SBenjamin Kramer } // namespace
1487df186507SBenjamin Kramer 
14885c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
14895c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1490ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1491060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
149265678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
14938345b86dSNicolas Vasilache   // clang-format off
1494681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1495681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
14962d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1497c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1498ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1499563879b6SRahul Joshi       converter, reassociateFPReductions);
1500060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1501060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1502060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1503563879b6SRahul Joshi       converter, enableIndexOptimizations);
15048345b86dSNicolas Vasilache   patterns
1505cf5c517cSDiego Caballero       .insert<VectorBitCastOpConversion,
1506cf5c517cSDiego Caballero               VectorShuffleOpConversion,
15078345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15088345b86dSNicolas Vasilache               VectorExtractOpConversion,
15098345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15108345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15118345b86dSNicolas Vasilache               VectorInsertOpConversion,
15128345b86dSNicolas Vasilache               VectorPrintOpConversion,
151319dbb230Saartbik               VectorTypeCastOpConversion,
151439379916Saartbik               VectorMaskedLoadOpConversion,
151539379916Saartbik               VectorMaskedStoreOpConversion,
151619dbb230Saartbik               VectorGatherOpConversion,
1517e8dcf5f8Saartbik               VectorScatterOpConversion,
1518e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1519563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15208345b86dSNicolas Vasilache   // clang-format on
15215c0c51a9SNicolas Vasilache }
15225c0c51a9SNicolas Vasilache 
152363b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
152463b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1525563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1526563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
152763b683a8SNicolas Vasilache }
1528