15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105*ba87f991SAlex Zinenko static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
106*ba87f991SAlex Zinenko                                    Location loc, Type targetType, Value value) {
107*ba87f991SAlex Zinenko   if (targetType == value.getType())
108*ba87f991SAlex Zinenko     return value;
109*ba87f991SAlex Zinenko 
110*ba87f991SAlex Zinenko   bool targetIsIndex = targetType.isIndex();
111*ba87f991SAlex Zinenko   bool valueIsIndex = value.getType().isIndex();
112*ba87f991SAlex Zinenko   if (targetIsIndex ^ valueIsIndex)
113*ba87f991SAlex Zinenko     return rewriter.create<IndexCastOp>(loc, targetType, value);
114*ba87f991SAlex Zinenko 
115*ba87f991SAlex Zinenko   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
116*ba87f991SAlex Zinenko   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
117*ba87f991SAlex Zinenko   assert(targetIntegerType && valueIntegerType &&
118*ba87f991SAlex Zinenko          "unexpected cast between types other than integers and index");
119*ba87f991SAlex Zinenko   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
120*ba87f991SAlex Zinenko 
121*ba87f991SAlex Zinenko   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
122*ba87f991SAlex Zinenko     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
123*ba87f991SAlex Zinenko   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
124*ba87f991SAlex Zinenko }
125*ba87f991SAlex Zinenko 
126060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
127060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
128060c9dd1Saartbik //
129060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
130060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
131060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
132060c9dd1Saartbik //       is very conservative on the boundary conditions.
133060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
134060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
135060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
136060c9dd1Saartbik   auto loc = op->getLoc();
137060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
138060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
139060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
140060c9dd1Saartbik   Value indices;
141060c9dd1Saartbik   Type idxType;
142060c9dd1Saartbik   if (enableIndexOptimizations) {
1430c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1440c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1450c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
146060c9dd1Saartbik     idxType = rewriter.getI32Type();
147060c9dd1Saartbik   } else {
1480c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1490c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1500c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
151060c9dd1Saartbik     idxType = rewriter.getI64Type();
152060c9dd1Saartbik   }
153060c9dd1Saartbik   // Add in an offset if requested.
154060c9dd1Saartbik   if (off) {
155*ba87f991SAlex Zinenko     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
156060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
157060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
158060c9dd1Saartbik   }
159060c9dd1Saartbik   // Construct the vector comparison.
160*ba87f991SAlex Zinenko   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
161060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
162060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
163060c9dd1Saartbik }
164060c9dd1Saartbik 
16526c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
16626c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
16726c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
16826c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1695f9e0466SNicolas Vasilache   if (!elementTy)
1705f9e0466SNicolas Vasilache     return failure();
1715f9e0466SNicolas Vasilache 
172b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
173b2ab375dSAlex Zinenko   // stop depending on translation.
17487a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
17587a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
176c69c9e0fSAlex Zinenko               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
1775f9e0466SNicolas Vasilache   return success();
1785f9e0466SNicolas Vasilache }
1795f9e0466SNicolas Vasilache 
180e8dcf5f8Saartbik // Helper that returns the base address of a memref.
181b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
182e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
18319dbb230Saartbik   // Inspect stride and offset structure.
18419dbb230Saartbik   //
18519dbb230Saartbik   // TODO: flat memory only for now, generalize
18619dbb230Saartbik   //
18719dbb230Saartbik   int64_t offset;
18819dbb230Saartbik   SmallVector<int64_t, 4> strides;
18919dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
19019dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
19119dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
19219dbb230Saartbik     return failure();
193e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
194e8dcf5f8Saartbik   return success();
195e8dcf5f8Saartbik }
19619dbb230Saartbik 
197a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector.
198b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
199b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
200b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
201b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
202e8dcf5f8Saartbik   Value base;
203e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
204e8dcf5f8Saartbik     return failure();
2053a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
206bd30a796SAlex Zinenko   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
2071485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
20819dbb230Saartbik   return success();
20919dbb230Saartbik }
21019dbb230Saartbik 
211a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer
212a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be
213a57def30SAart Bik // used when source/dst memrefs are not on address space 0.
214a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
215a57def30SAart Bik                          Value ptr, MemRefType memRefType, Type vt) {
216bd30a796SAlex Zinenko   auto pType = LLVM::LLVMPointerType::get(vt);
217a57def30SAart Bik   if (memRefType.getMemorySpace() == 0)
218a57def30SAart Bik     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
219a57def30SAart Bik   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
220a57def30SAart Bik }
221a57def30SAart Bik 
2225f9e0466SNicolas Vasilache static LogicalResult
2235f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2245f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2255f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2265f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
227affbc0cdSNicolas Vasilache   unsigned align;
22826c8f908SThomas Raoux   if (failed(getMemRefAlignment(
22926c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
230affbc0cdSNicolas Vasilache     return failure();
231affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2325f9e0466SNicolas Vasilache   return success();
2335f9e0466SNicolas Vasilache }
2345f9e0466SNicolas Vasilache 
2355f9e0466SNicolas Vasilache static LogicalResult
2365f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2375f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2385f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2395f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2405f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2415f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2425f9e0466SNicolas Vasilache 
2435f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2445f9e0466SNicolas Vasilache   if (!vecTy)
2455f9e0466SNicolas Vasilache     return failure();
2465f9e0466SNicolas Vasilache 
2475f9e0466SNicolas Vasilache   unsigned align;
24826c8f908SThomas Raoux   if (failed(getMemRefAlignment(
24926c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2505f9e0466SNicolas Vasilache     return failure();
2515f9e0466SNicolas Vasilache 
2525f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2535f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2545f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2555f9e0466SNicolas Vasilache   return success();
2565f9e0466SNicolas Vasilache }
2575f9e0466SNicolas Vasilache 
2585f9e0466SNicolas Vasilache static LogicalResult
2595f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2605f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2615f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2625f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
263affbc0cdSNicolas Vasilache   unsigned align;
26426c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26526c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
266affbc0cdSNicolas Vasilache     return failure();
2672d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
268affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
269affbc0cdSNicolas Vasilache                                              align);
2705f9e0466SNicolas Vasilache   return success();
2715f9e0466SNicolas Vasilache }
2725f9e0466SNicolas Vasilache 
2735f9e0466SNicolas Vasilache static LogicalResult
2745f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2755f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2765f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2775f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2785f9e0466SNicolas Vasilache   unsigned align;
27926c8f908SThomas Raoux   if (failed(getMemRefAlignment(
28026c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2815f9e0466SNicolas Vasilache     return failure();
2825f9e0466SNicolas Vasilache 
2832d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2845f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2855f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2865f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2875f9e0466SNicolas Vasilache   return success();
2885f9e0466SNicolas Vasilache }
2895f9e0466SNicolas Vasilache 
2902d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2912d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2922d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2935f9e0466SNicolas Vasilache }
2945f9e0466SNicolas Vasilache 
2952d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2962d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2972d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2985f9e0466SNicolas Vasilache }
2995f9e0466SNicolas Vasilache 
30090c01357SBenjamin Kramer namespace {
301e83b7b99Saartbik 
302cf5c517cSDiego Caballero /// Conversion pattern for a vector.bitcast.
303cf5c517cSDiego Caballero class VectorBitCastOpConversion
304cf5c517cSDiego Caballero     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
305cf5c517cSDiego Caballero public:
306cf5c517cSDiego Caballero   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
307cf5c517cSDiego Caballero 
308cf5c517cSDiego Caballero   LogicalResult
309cf5c517cSDiego Caballero   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
310cf5c517cSDiego Caballero                   ConversionPatternRewriter &rewriter) const override {
311cf5c517cSDiego Caballero     // Only 1-D vectors can be lowered to LLVM.
312cf5c517cSDiego Caballero     VectorType resultTy = bitCastOp.getType();
313cf5c517cSDiego Caballero     if (resultTy.getRank() != 1)
314cf5c517cSDiego Caballero       return failure();
315cf5c517cSDiego Caballero     Type newResultTy = typeConverter->convertType(resultTy);
316cf5c517cSDiego Caballero     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
317cf5c517cSDiego Caballero                                                  operands[0]);
318cf5c517cSDiego Caballero     return success();
319cf5c517cSDiego Caballero   }
320cf5c517cSDiego Caballero };
321cf5c517cSDiego Caballero 
32263b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
32363b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
324563879b6SRahul Joshi class VectorMatmulOpConversion
325563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
32663b683a8SNicolas Vasilache public:
327563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
32863b683a8SNicolas Vasilache 
3293145427dSRiver Riddle   LogicalResult
330563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
33163b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3322d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
33363b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
334563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
335563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
336563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3373145427dSRiver Riddle     return success();
33863b683a8SNicolas Vasilache   }
33963b683a8SNicolas Vasilache };
34063b683a8SNicolas Vasilache 
341c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
342c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
343563879b6SRahul Joshi class VectorFlatTransposeOpConversion
344563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
345c295a65dSaartbik public:
346563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
347c295a65dSaartbik 
348c295a65dSaartbik   LogicalResult
349563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
350c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3512d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
352c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
353dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
354c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
355c295a65dSaartbik     return success();
356c295a65dSaartbik   }
357c295a65dSaartbik };
358c295a65dSaartbik 
35939379916Saartbik /// Conversion pattern for a vector.maskedload.
360563879b6SRahul Joshi class VectorMaskedLoadOpConversion
361563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
36239379916Saartbik public:
363563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
36439379916Saartbik 
36539379916Saartbik   LogicalResult
366563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
36739379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
368563879b6SRahul Joshi     auto loc = load->getLoc();
36939379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
370a57def30SAart Bik     MemRefType memRefType = load.getMemRefType();
37139379916Saartbik 
37239379916Saartbik     // Resolve alignment.
37339379916Saartbik     unsigned align;
374a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
37539379916Saartbik       return failure();
37639379916Saartbik 
377a57def30SAart Bik     // Resolve address.
378dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
379a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
380a57def30SAart Bik                                                adaptor.indices(), rewriter);
381a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
38239379916Saartbik 
38339379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
38439379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
38539379916Saartbik         rewriter.getI32IntegerAttr(align));
38639379916Saartbik     return success();
38739379916Saartbik   }
38839379916Saartbik };
38939379916Saartbik 
39039379916Saartbik /// Conversion pattern for a vector.maskedstore.
391563879b6SRahul Joshi class VectorMaskedStoreOpConversion
392563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
39339379916Saartbik public:
394563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
39539379916Saartbik 
39639379916Saartbik   LogicalResult
397563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
39839379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
399563879b6SRahul Joshi     auto loc = store->getLoc();
40039379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
401a57def30SAart Bik     MemRefType memRefType = store.getMemRefType();
40239379916Saartbik 
40339379916Saartbik     // Resolve alignment.
40439379916Saartbik     unsigned align;
405a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
40639379916Saartbik       return failure();
40739379916Saartbik 
408a57def30SAart Bik     // Resolve address.
409dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
410a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
411a57def30SAart Bik                                                adaptor.indices(), rewriter);
412a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
41339379916Saartbik 
41439379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
41539379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
41639379916Saartbik         rewriter.getI32IntegerAttr(align));
41739379916Saartbik     return success();
41839379916Saartbik   }
41939379916Saartbik };
42039379916Saartbik 
42119dbb230Saartbik /// Conversion pattern for a vector.gather.
422563879b6SRahul Joshi class VectorGatherOpConversion
423563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
42419dbb230Saartbik public:
425563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
42619dbb230Saartbik 
42719dbb230Saartbik   LogicalResult
428563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
42919dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
430563879b6SRahul Joshi     auto loc = gather->getLoc();
43119dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
43219dbb230Saartbik 
43319dbb230Saartbik     // Resolve alignment.
43419dbb230Saartbik     unsigned align;
43526c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
43626c8f908SThomas Raoux                                   align)))
43719dbb230Saartbik       return failure();
43819dbb230Saartbik 
43919dbb230Saartbik     // Get index ptrs.
44019dbb230Saartbik     VectorType vType = gather.getResultVectorType();
44119dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
44219dbb230Saartbik     Value ptrs;
443e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
444e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
44519dbb230Saartbik       return failure();
44619dbb230Saartbik 
44719dbb230Saartbik     // Replace with the gather intrinsic.
44819dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
449dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4500c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
45119dbb230Saartbik     return success();
45219dbb230Saartbik   }
45319dbb230Saartbik };
45419dbb230Saartbik 
45519dbb230Saartbik /// Conversion pattern for a vector.scatter.
456563879b6SRahul Joshi class VectorScatterOpConversion
457563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
45819dbb230Saartbik public:
459563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
46019dbb230Saartbik 
46119dbb230Saartbik   LogicalResult
462563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
46319dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
464563879b6SRahul Joshi     auto loc = scatter->getLoc();
46519dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
46619dbb230Saartbik 
46719dbb230Saartbik     // Resolve alignment.
46819dbb230Saartbik     unsigned align;
46926c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
47026c8f908SThomas Raoux                                   align)))
47119dbb230Saartbik       return failure();
47219dbb230Saartbik 
47319dbb230Saartbik     // Get index ptrs.
47419dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
47519dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
47619dbb230Saartbik     Value ptrs;
477e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
478e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
47919dbb230Saartbik       return failure();
48019dbb230Saartbik 
48119dbb230Saartbik     // Replace with the scatter intrinsic.
48219dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
48319dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
48419dbb230Saartbik         rewriter.getI32IntegerAttr(align));
48519dbb230Saartbik     return success();
48619dbb230Saartbik   }
48719dbb230Saartbik };
48819dbb230Saartbik 
489e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
490563879b6SRahul Joshi class VectorExpandLoadOpConversion
491563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
492e8dcf5f8Saartbik public:
493563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
494e8dcf5f8Saartbik 
495e8dcf5f8Saartbik   LogicalResult
496563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
497e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
498563879b6SRahul Joshi     auto loc = expand->getLoc();
499e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
500a57def30SAart Bik     MemRefType memRefType = expand.getMemRefType();
501e8dcf5f8Saartbik 
502a57def30SAart Bik     // Resolve address.
503a57def30SAart Bik     auto vtype = typeConverter->convertType(expand.getResultVectorType());
504a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
505a57def30SAart Bik                                            adaptor.indices(), rewriter);
506e8dcf5f8Saartbik 
507e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
508a57def30SAart Bik         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
509e8dcf5f8Saartbik     return success();
510e8dcf5f8Saartbik   }
511e8dcf5f8Saartbik };
512e8dcf5f8Saartbik 
513e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
514563879b6SRahul Joshi class VectorCompressStoreOpConversion
515563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
516e8dcf5f8Saartbik public:
517563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
518e8dcf5f8Saartbik 
519e8dcf5f8Saartbik   LogicalResult
520563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
521e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
522563879b6SRahul Joshi     auto loc = compress->getLoc();
523e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
524a57def30SAart Bik     MemRefType memRefType = compress.getMemRefType();
525e8dcf5f8Saartbik 
526a57def30SAart Bik     // Resolve address.
527a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
528a57def30SAart Bik                                            adaptor.indices(), rewriter);
529e8dcf5f8Saartbik 
530e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
531563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
532e8dcf5f8Saartbik     return success();
533e8dcf5f8Saartbik   }
534e8dcf5f8Saartbik };
535e8dcf5f8Saartbik 
53619dbb230Saartbik /// Conversion pattern for all vector reductions.
537563879b6SRahul Joshi class VectorReductionOpConversion
538563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
539e83b7b99Saartbik public:
540563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
541060c9dd1Saartbik                                        bool reassociateFPRed)
542563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
543060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
544e83b7b99Saartbik 
5453145427dSRiver Riddle   LogicalResult
546563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
547e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
548e83b7b99Saartbik     auto kind = reductionOp.kind();
549e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
550dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
551e9628955SAart Bik     if (eltType.isIntOrIndex()) {
552e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
553e83b7b99Saartbik       if (kind == "add")
554322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
555563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
556e83b7b99Saartbik       else if (kind == "mul")
557322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
558563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
559e9628955SAart Bik       else if (kind == "min" &&
560e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
561322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
562563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
563e83b7b99Saartbik       else if (kind == "min")
564322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
565563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
566e9628955SAart Bik       else if (kind == "max" &&
567e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
568322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
569563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
570e83b7b99Saartbik       else if (kind == "max")
571322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
572563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
573e83b7b99Saartbik       else if (kind == "and")
574322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
575563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
576e83b7b99Saartbik       else if (kind == "or")
577322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
578563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
579e83b7b99Saartbik       else if (kind == "xor")
580322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
581563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
582e83b7b99Saartbik       else
5833145427dSRiver Riddle         return failure();
5843145427dSRiver Riddle       return success();
585dcec2ca5SChristian Sigg     }
586e83b7b99Saartbik 
587dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
588dcec2ca5SChristian Sigg       return failure();
589dcec2ca5SChristian Sigg 
590e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
591e83b7b99Saartbik     if (kind == "add") {
5920d924700Saartbik       // Optional accumulator (or zero).
5930d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5940d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
595563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5960d924700Saartbik                                             rewriter.getZeroAttr(eltType));
597322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
598563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
599ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
600e83b7b99Saartbik     } else if (kind == "mul") {
6010d924700Saartbik       // Optional accumulator (or one).
6020d924700Saartbik       Value acc = operands.size() > 1
6030d924700Saartbik                       ? operands[1]
6040d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
605563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
6060d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
607322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
608563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
609ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
610e83b7b99Saartbik     } else if (kind == "min")
611563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
612563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
613e83b7b99Saartbik     else if (kind == "max")
614563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
615563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
616e83b7b99Saartbik     else
6173145427dSRiver Riddle       return failure();
6183145427dSRiver Riddle     return success();
619e83b7b99Saartbik   }
620ceb1b327Saartbik 
621ceb1b327Saartbik private:
622ceb1b327Saartbik   const bool reassociateFPReductions;
623e83b7b99Saartbik };
624e83b7b99Saartbik 
625060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
626563879b6SRahul Joshi class VectorCreateMaskOpConversion
627563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
628060c9dd1Saartbik public:
629563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
630060c9dd1Saartbik                                         bool enableIndexOpt)
631563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
632060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
633060c9dd1Saartbik 
634060c9dd1Saartbik   LogicalResult
635563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
636060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
6379eb3e564SChris Lattner     auto dstType = op.getType();
638060c9dd1Saartbik     int64_t rank = dstType.getRank();
639060c9dd1Saartbik     if (rank == 1) {
640060c9dd1Saartbik       rewriter.replaceOp(
641060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
642060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
643060c9dd1Saartbik       return success();
644060c9dd1Saartbik     }
645060c9dd1Saartbik     return failure();
646060c9dd1Saartbik   }
647060c9dd1Saartbik 
648060c9dd1Saartbik private:
649060c9dd1Saartbik   const bool enableIndexOptimizations;
650060c9dd1Saartbik };
651060c9dd1Saartbik 
652563879b6SRahul Joshi class VectorShuffleOpConversion
653563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6541c81adf3SAart Bik public:
655563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6561c81adf3SAart Bik 
6573145427dSRiver Riddle   LogicalResult
658563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6591c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
660563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6612d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6621c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6631c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6641c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
665dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6661c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6671c81adf3SAart Bik 
6681c81adf3SAart Bik     // Bail if result type cannot be lowered.
6691c81adf3SAart Bik     if (!llvmType)
6703145427dSRiver Riddle       return failure();
6711c81adf3SAart Bik 
6721c81adf3SAart Bik     // Get rank and dimension sizes.
6731c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6741c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6751c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6761c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6771c81adf3SAart Bik 
6781c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6791c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6801c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
681563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6821c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
683563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6843145427dSRiver Riddle       return success();
685b36aaeafSAart Bik     }
686b36aaeafSAart Bik 
6871c81adf3SAart Bik     // For all other cases, insert the individual values individually.
688e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6891c81adf3SAart Bik     int64_t insPos = 0;
6901c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6911c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
692e62a6956SRiver Riddle       Value value = adaptor.v1();
6931c81adf3SAart Bik       if (extPos >= v1Dim) {
6941c81adf3SAart Bik         extPos -= v1Dim;
6951c81adf3SAart Bik         value = adaptor.v2();
696b36aaeafSAart Bik       }
697dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
698dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
699dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
7000f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7011c81adf3SAart Bik     }
702563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
7033145427dSRiver Riddle     return success();
704b36aaeafSAart Bik   }
705b36aaeafSAart Bik };
706b36aaeafSAart Bik 
707563879b6SRahul Joshi class VectorExtractElementOpConversion
708563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
709cd5dab8aSAart Bik public:
710563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
711563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
712cd5dab8aSAart Bik 
7133145427dSRiver Riddle   LogicalResult
714563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
715563879b6SRahul Joshi                   ArrayRef<Value> operands,
716cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7172d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
718cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
719dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
720cd5dab8aSAart Bik 
721cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
722cd5dab8aSAart Bik     if (!llvmType)
7233145427dSRiver Riddle       return failure();
724cd5dab8aSAart Bik 
725cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
726563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
7273145427dSRiver Riddle     return success();
728cd5dab8aSAart Bik   }
729cd5dab8aSAart Bik };
730cd5dab8aSAart Bik 
731563879b6SRahul Joshi class VectorExtractOpConversion
732563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7335c0c51a9SNicolas Vasilache public:
734563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7355c0c51a9SNicolas Vasilache 
7363145427dSRiver Riddle   LogicalResult
737563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7385c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
739563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7402d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7419826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7422bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
743dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7445c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7459826fe5cSAart Bik 
7469826fe5cSAart Bik     // Bail if result type cannot be lowered.
7479826fe5cSAart Bik     if (!llvmResultType)
7483145427dSRiver Riddle       return failure();
7499826fe5cSAart Bik 
7505c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7515c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
752e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7535c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
754563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7553145427dSRiver Riddle       return success();
7565c0c51a9SNicolas Vasilache     }
7575c0c51a9SNicolas Vasilache 
7589826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
759563879b6SRahul Joshi     auto *context = extractOp->getContext();
760e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7615c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7625c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7639826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7645c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7655c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7665c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
767dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7685c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7695c0c51a9SNicolas Vasilache     }
7705c0c51a9SNicolas Vasilache 
7715c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7725c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7732230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
7741d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7755c0c51a9SNicolas Vasilache     extracted =
7765c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
777563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7785c0c51a9SNicolas Vasilache 
7793145427dSRiver Riddle     return success();
7805c0c51a9SNicolas Vasilache   }
7815c0c51a9SNicolas Vasilache };
7825c0c51a9SNicolas Vasilache 
783681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
784681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
785681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
786681f929fSNicolas Vasilache ///
787681f929fSNicolas Vasilache /// Example:
788681f929fSNicolas Vasilache /// ```
789681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
790681f929fSNicolas Vasilache /// ```
791681f929fSNicolas Vasilache /// is converted to:
792681f929fSNicolas Vasilache /// ```
7933bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
794dd5165a9SAlex Zinenko ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
795dd5165a9SAlex Zinenko ///    -> !llvm."<8 x f32>">
796681f929fSNicolas Vasilache /// ```
797563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
798681f929fSNicolas Vasilache public:
799563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
800681f929fSNicolas Vasilache 
8013145427dSRiver Riddle   LogicalResult
802563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
803681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8042d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
805681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
806681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8073145427dSRiver Riddle       return failure();
808563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
8093bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8103145427dSRiver Riddle     return success();
811681f929fSNicolas Vasilache   }
812681f929fSNicolas Vasilache };
813681f929fSNicolas Vasilache 
814563879b6SRahul Joshi class VectorInsertElementOpConversion
815563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
816cd5dab8aSAart Bik public:
817563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
818cd5dab8aSAart Bik 
8193145427dSRiver Riddle   LogicalResult
820563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
821cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8222d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
823cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
824dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
825cd5dab8aSAart Bik 
826cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
827cd5dab8aSAart Bik     if (!llvmType)
8283145427dSRiver Riddle       return failure();
829cd5dab8aSAart Bik 
830cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
831563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
832563879b6SRahul Joshi         adaptor.position());
8333145427dSRiver Riddle     return success();
834cd5dab8aSAart Bik   }
835cd5dab8aSAart Bik };
836cd5dab8aSAart Bik 
837563879b6SRahul Joshi class VectorInsertOpConversion
838563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8399826fe5cSAart Bik public:
840563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8419826fe5cSAart Bik 
8423145427dSRiver Riddle   LogicalResult
843563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8449826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
845563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8462d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8479826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8489826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
849dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8509826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8519826fe5cSAart Bik 
8529826fe5cSAart Bik     // Bail if result type cannot be lowered.
8539826fe5cSAart Bik     if (!llvmResultType)
8543145427dSRiver Riddle       return failure();
8559826fe5cSAart Bik 
8569826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8579826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
858e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8599826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8609826fe5cSAart Bik           positionArrayAttr);
861563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8623145427dSRiver Riddle       return success();
8639826fe5cSAart Bik     }
8649826fe5cSAart Bik 
8659826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
866563879b6SRahul Joshi     auto *context = insertOp->getContext();
867e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8689826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8699826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8709826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8719826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8729826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8739826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8749826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8759826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
876dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8779826fe5cSAart Bik           nMinusOnePositionAttrs);
8789826fe5cSAart Bik     }
8799826fe5cSAart Bik 
8809826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8812230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
8821d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
883e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
884dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8850f04384dSAlex Zinenko         adaptor.source(), constant);
8869826fe5cSAart Bik 
8879826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8889826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8899826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8909826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8919826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8929826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8939826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8949826fe5cSAart Bik     }
8959826fe5cSAart Bik 
896563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8973145427dSRiver Riddle     return success();
8989826fe5cSAart Bik   }
8999826fe5cSAart Bik };
9009826fe5cSAart Bik 
901681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
902681f929fSNicolas Vasilache ///
903681f929fSNicolas Vasilache /// Example:
904681f929fSNicolas Vasilache /// ```
905681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
906681f929fSNicolas Vasilache /// ```
907681f929fSNicolas Vasilache /// is rewritten into:
908681f929fSNicolas Vasilache /// ```
909681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
910681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
911681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
912681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
913681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
914681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
915681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
916681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
917681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
918681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
919681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
920681f929fSNicolas Vasilache ///  // %r3 holds the final value.
921681f929fSNicolas Vasilache /// ```
922681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
923681f929fSNicolas Vasilache public:
924681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
925681f929fSNicolas Vasilache 
9263145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
927681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
928681f929fSNicolas Vasilache     auto vType = op.getVectorType();
929681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9303145427dSRiver Riddle       return failure();
931681f929fSNicolas Vasilache 
932681f929fSNicolas Vasilache     auto loc = op.getLoc();
933681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
934681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
935681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
936681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
937681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
938681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
939681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
940681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
941681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
942681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
943681f929fSNicolas Vasilache     }
944681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9453145427dSRiver Riddle     return success();
946681f929fSNicolas Vasilache   }
947681f929fSNicolas Vasilache };
948681f929fSNicolas Vasilache 
9492d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9502d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9512d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9522d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9532d515e49SNicolas Vasilache // rank.
9542d515e49SNicolas Vasilache //
9552d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9562d515e49SNicolas Vasilache // have different ranks. In this case:
9572d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9582d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9592d515e49SNicolas Vasilache //   destination subvector
9602d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9612d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9622d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9632d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9642d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9652d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9662d515e49SNicolas Vasilache public:
9672d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9682d515e49SNicolas Vasilache 
9693145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9702d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9712d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9722d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9732d515e49SNicolas Vasilache 
9742d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9753145427dSRiver Riddle       return failure();
9762d515e49SNicolas Vasilache 
9772d515e49SNicolas Vasilache     auto loc = op.getLoc();
9782d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9792d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9802d515e49SNicolas Vasilache     if (rankDiff == 0)
9813145427dSRiver Riddle       return failure();
9822d515e49SNicolas Vasilache 
9832d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9842d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9852d515e49SNicolas Vasilache     // on it.
9862d515e49SNicolas Vasilache     Value extracted =
9872d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9882d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
989dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9902d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9912d515e49SNicolas Vasilache     // ranks.
9922d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9932d515e49SNicolas Vasilache         loc, op.source(), extracted,
9942d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
995c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9962d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9972d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9982d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
999dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
10003145427dSRiver Riddle     return success();
10012d515e49SNicolas Vasilache   }
10022d515e49SNicolas Vasilache };
10032d515e49SNicolas Vasilache 
10042d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10052d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10062d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10072d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10082d515e49SNicolas Vasilache //   destination subvector
10092d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10102d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10112d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10122d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10132d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10142d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10152d515e49SNicolas Vasilache public:
1016b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1017b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1018b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1019b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1020b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1021b99bd771SRiver Riddle   }
10222d515e49SNicolas Vasilache 
10233145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10242d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10252d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10262d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10272d515e49SNicolas Vasilache 
10282d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10293145427dSRiver Riddle       return failure();
10302d515e49SNicolas Vasilache 
10312d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10322d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10332d515e49SNicolas Vasilache     if (rankDiff != 0)
10343145427dSRiver Riddle       return failure();
10352d515e49SNicolas Vasilache 
10362d515e49SNicolas Vasilache     if (srcType == dstType) {
10372d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10383145427dSRiver Riddle       return success();
10392d515e49SNicolas Vasilache     }
10402d515e49SNicolas Vasilache 
10412d515e49SNicolas Vasilache     int64_t offset =
10422d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10432d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10442d515e49SNicolas Vasilache     int64_t stride =
10452d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10462d515e49SNicolas Vasilache 
10472d515e49SNicolas Vasilache     auto loc = op.getLoc();
10482d515e49SNicolas Vasilache     Value res = op.dest();
10492d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10502d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10512d515e49SNicolas Vasilache          off += stride, ++idx) {
10522d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10532d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10542d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10552d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10562d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10572d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10582d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10592d515e49SNicolas Vasilache         // smaller rank.
1060bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10612d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10622d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10632d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10642d515e49SNicolas Vasilache       }
10652d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10662d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10672d515e49SNicolas Vasilache     }
10682d515e49SNicolas Vasilache 
10692d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10703145427dSRiver Riddle     return success();
10712d515e49SNicolas Vasilache   }
10722d515e49SNicolas Vasilache };
10732d515e49SNicolas Vasilache 
107430e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
107530e6033bSNicolas Vasilache /// static layout.
107630e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
107730e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10782bf491c7SBenjamin Kramer   int64_t offset;
107930e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
108030e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
108130e6033bSNicolas Vasilache     return None;
108230e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
108330e6033bSNicolas Vasilache     return None;
108430e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
108530e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
108630e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
108730e6033bSNicolas Vasilache     return strides;
108830e6033bSNicolas Vasilache 
108930e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
109030e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
109130e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
109230e6033bSNicolas Vasilache   // layout.
10932bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10942bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
109530e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
109630e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
109730e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
109830e6033bSNicolas Vasilache       return None;
109930e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
110030e6033bSNicolas Vasilache       return None;
11012bf491c7SBenjamin Kramer   }
110230e6033bSNicolas Vasilache   return strides;
11032bf491c7SBenjamin Kramer }
11042bf491c7SBenjamin Kramer 
1105563879b6SRahul Joshi class VectorTypeCastOpConversion
1106563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
11075c0c51a9SNicolas Vasilache public:
1108563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
11095c0c51a9SNicolas Vasilache 
11103145427dSRiver Riddle   LogicalResult
1111563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
11125c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1113563879b6SRahul Joshi     auto loc = castOp->getLoc();
11145c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11152bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11169eb3e564SChris Lattner     MemRefType targetMemRefType = castOp.getType();
11175c0c51a9SNicolas Vasilache 
11185c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11195c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11205c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11213145427dSRiver Riddle       return failure();
11225c0c51a9SNicolas Vasilache 
11235c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11248de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
11258de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
11263145427dSRiver Riddle       return failure();
11275c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11285c0c51a9SNicolas Vasilache 
1129dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11308de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
11318de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
11323145427dSRiver Riddle       return failure();
11335c0c51a9SNicolas Vasilache 
113430e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
113530e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
113630e6033bSNicolas Vasilache     if (!sourceStrides)
113730e6033bSNicolas Vasilache       return failure();
113830e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
113930e6033bSNicolas Vasilache     if (!targetStrides)
114030e6033bSNicolas Vasilache       return failure();
114130e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
114230e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
114330e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
114430e6033bSNicolas Vasilache         }))
11453145427dSRiver Riddle       return failure();
11465c0c51a9SNicolas Vasilache 
11472230bf99SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
11485c0c51a9SNicolas Vasilache 
11495c0c51a9SNicolas Vasilache     // Create descriptor.
11505c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11513a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11525c0c51a9SNicolas Vasilache     // Set allocated ptr.
1153e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11545c0c51a9SNicolas Vasilache     allocated =
11555c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11565c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11575c0c51a9SNicolas Vasilache     // Set aligned ptr.
1158e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11595c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11605c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11615c0c51a9SNicolas Vasilache     // Fill offset 0.
11625c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11635c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11645c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11655c0c51a9SNicolas Vasilache 
11665c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11675c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11685c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11695c0c51a9SNicolas Vasilache       auto sizeAttr =
11705c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11715c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11725c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
117330e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
117430e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11755c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11765c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11775c0c51a9SNicolas Vasilache     }
11785c0c51a9SNicolas Vasilache 
1179563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11803145427dSRiver Riddle     return success();
11815c0c51a9SNicolas Vasilache   }
11825c0c51a9SNicolas Vasilache };
11835c0c51a9SNicolas Vasilache 
11848345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11858345b86dSNicolas Vasilache /// sequence of:
1186060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1187060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1188060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1189060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1190060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11918345b86dSNicolas Vasilache template <typename ConcreteOp>
1192563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11938345b86dSNicolas Vasilache public:
1194563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1195060c9dd1Saartbik                                     bool enableIndexOpt)
1196563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1197060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11988345b86dSNicolas Vasilache 
11998345b86dSNicolas Vasilache   LogicalResult
1200563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
12018345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12028345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1203b2c79c50SNicolas Vasilache 
1204b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1205b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12068345b86dSNicolas Vasilache       return failure();
12075f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12085f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12095f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1210563879b6SRahul Joshi                                        xferOp->getContext()))
12118345b86dSNicolas Vasilache       return failure();
121226c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
121326c8f908SThomas Raoux     if (!memRefType)
121426c8f908SThomas Raoux       return failure();
12152bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
121626c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
121730e6033bSNicolas Vasilache     if (!strides)
12182bf491c7SBenjamin Kramer       return failure();
12198345b86dSNicolas Vasilache 
1220563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1221563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1222563879b6SRahul Joshi     };
12238345b86dSNicolas Vasilache 
1224563879b6SRahul Joshi     Location loc = xferOp->getLoc();
12258345b86dSNicolas Vasilache 
122668330ee0SThomas Raoux     if (auto memrefVectorElementType =
122726c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
122868330ee0SThomas Raoux       // Memref has vector element type.
122968330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
123068330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
123168330ee0SThomas Raoux         return failure();
12320de60b55SThomas Raoux #ifndef NDEBUG
123368330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
123468330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
123568330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
123668330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
123768330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
123868330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
123968330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
124068330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
124168330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
124268330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
124368330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
124468330ee0SThomas Raoux                "result shape.");
12450de60b55SThomas Raoux #endif // ifndef NDEBUG
124668330ee0SThomas Raoux     }
124768330ee0SThomas Raoux 
12488345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1249a57def30SAart Bik     VectorType vtp = xferOp.getVectorType();
1250563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
125126c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1252a57def30SAart Bik     Value vectorDataPtr =
1253a57def30SAart Bik         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
12548345b86dSNicolas Vasilache 
12551870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1256563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1257563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1258563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12591870e787SNicolas Vasilache 
12608345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12618345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12628345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12638345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1264060c9dd1Saartbik     //
1265060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1266060c9dd1Saartbik     //       dimensions here.
1267bd30a796SAlex Zinenko     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1268060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12690c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
127026c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1271563879b6SRahul Joshi     Value mask = buildVectorComparison(
1272563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12738345b86dSNicolas Vasilache 
12748345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1275563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1276dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12778345b86dSNicolas Vasilache   }
1278060c9dd1Saartbik 
1279060c9dd1Saartbik private:
1280060c9dd1Saartbik   const bool enableIndexOptimizations;
12818345b86dSNicolas Vasilache };
12828345b86dSNicolas Vasilache 
1283563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1284d9b500d3SAart Bik public:
1285563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1286d9b500d3SAart Bik 
1287d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1288d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1289d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1290d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1291d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1292d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1293d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1294d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1295d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1296d9b500d3SAart Bik   //
12979db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1298d9b500d3SAart Bik   //
12993145427dSRiver Riddle   LogicalResult
1300563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1301d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
13022d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1303d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1304d9b500d3SAart Bik 
1305dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
13063145427dSRiver Riddle       return failure();
1307d9b500d3SAart Bik 
1308b8880f5fSAart Bik     // Make sure element type has runtime support.
1309b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1310d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1311d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1312d9b500d3SAart Bik     Operation *printer;
1313b8880f5fSAart Bik     if (eltType.isF32()) {
1314563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1315b8880f5fSAart Bik     } else if (eltType.isF64()) {
1316563879b6SRahul Joshi       printer = getPrintDouble(printOp);
131754759cefSAart Bik     } else if (eltType.isIndex()) {
1318563879b6SRahul Joshi       printer = getPrintU64(printOp);
1319b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1320b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1321b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1322b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1323b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1324b8880f5fSAart Bik       if (intTy.isUnsigned()) {
132554759cefSAart Bik         if (width <= 64) {
1326b8880f5fSAart Bik           if (width < 64)
1327b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1328563879b6SRahul Joshi           printer = getPrintU64(printOp);
1329b8880f5fSAart Bik         } else {
13303145427dSRiver Riddle           return failure();
1331b8880f5fSAart Bik         }
1332b8880f5fSAart Bik       } else {
1333b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
133454759cefSAart Bik         if (width <= 64) {
1335b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1336b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1337b8880f5fSAart Bik           if (width == 1)
133854759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
133954759cefSAart Bik           else if (width < 64)
1340b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1341563879b6SRahul Joshi           printer = getPrintI64(printOp);
1342b8880f5fSAart Bik         } else {
1343b8880f5fSAart Bik           return failure();
1344b8880f5fSAart Bik         }
1345b8880f5fSAart Bik       }
1346b8880f5fSAart Bik     } else {
1347b8880f5fSAart Bik       return failure();
1348b8880f5fSAart Bik     }
1349d9b500d3SAart Bik 
1350d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1351b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1352563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1353b8880f5fSAart Bik               conversion);
1354563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1355563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13563145427dSRiver Riddle     return success();
1357d9b500d3SAart Bik   }
1358d9b500d3SAart Bik 
1359d9b500d3SAart Bik private:
1360b8880f5fSAart Bik   enum class PrintConversion {
136130e6033bSNicolas Vasilache     // clang-format off
1362b8880f5fSAart Bik     None,
1363b8880f5fSAart Bik     ZeroExt64,
1364b8880f5fSAart Bik     SignExt64
136530e6033bSNicolas Vasilache     // clang-format on
1366b8880f5fSAart Bik   };
1367b8880f5fSAart Bik 
1368d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1369e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1370b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1371d9b500d3SAart Bik     Location loc = op->getLoc();
1372d9b500d3SAart Bik     if (rank == 0) {
1373b8880f5fSAart Bik       switch (conversion) {
1374b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1375b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
13762230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1377b8880f5fSAart Bik         break;
1378b8880f5fSAart Bik       case PrintConversion::SignExt64:
1379b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
13802230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1381b8880f5fSAart Bik         break;
1382b8880f5fSAart Bik       case PrintConversion::None:
1383b8880f5fSAart Bik         break;
1384c9eeeb38Saartbik       }
1385d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1386d9b500d3SAart Bik       return;
1387d9b500d3SAart Bik     }
1388d9b500d3SAart Bik 
1389d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1390d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1391d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1392d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1393d9b500d3SAart Bik       auto reducedType =
1394d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1395dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1396d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1397dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1398dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1399b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1400b8880f5fSAart Bik                 conversion);
1401d9b500d3SAart Bik       if (d != dim - 1)
1402d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1403d9b500d3SAart Bik     }
1404d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1405d9b500d3SAart Bik   }
1406d9b500d3SAart Bik 
1407d9b500d3SAart Bik   // Helper to emit a call.
1408d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1409d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
141008e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1411d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1412d9b500d3SAart Bik   }
1413d9b500d3SAart Bik 
1414d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14155446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
1416c69c9e0fSAlex Zinenko                              ArrayRef<Type> params) {
1417d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1418d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1419d9b500d3SAart Bik     if (func)
1420d9b500d3SAart Bik       return func;
1421d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1422d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1423d9b500d3SAart Bik         op->getLoc(), name,
14247ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
14257ed9cfc7SAlex Zinenko                                     params));
1426d9b500d3SAart Bik   }
1427d9b500d3SAart Bik 
1428d9b500d3SAart Bik   // Helpers for method names.
1429e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
14302230bf99SAlex Zinenko     return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1431e52414b1Saartbik   }
1432b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
14332230bf99SAlex Zinenko     return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1434b8880f5fSAart Bik   }
1435d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
1436dd5165a9SAlex Zinenko     return getPrint(op, "printF32", Float32Type::get(op->getContext()));
1437d9b500d3SAart Bik   }
1438d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
1439dd5165a9SAlex Zinenko     return getPrint(op, "printF64", Float64Type::get(op->getContext()));
1440d9b500d3SAart Bik   }
1441d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
144254759cefSAart Bik     return getPrint(op, "printOpen", {});
1443d9b500d3SAart Bik   }
1444d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
144554759cefSAart Bik     return getPrint(op, "printClose", {});
1446d9b500d3SAart Bik   }
1447d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
144854759cefSAart Bik     return getPrint(op, "printComma", {});
1449d9b500d3SAart Bik   }
1450d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
145154759cefSAart Bik     return getPrint(op, "printNewline", {});
1452d9b500d3SAart Bik   }
1453d9b500d3SAart Bik };
1454d9b500d3SAart Bik 
1455334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1456c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1457c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1458c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1459334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
146065678d93SNicolas Vasilache public:
1461b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1462b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1463b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1464b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1465b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1466b99bd771SRiver Riddle   }
146765678d93SNicolas Vasilache 
1468334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
146965678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
14709eb3e564SChris Lattner     auto dstType = op.getType();
147165678d93SNicolas Vasilache 
147265678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
147365678d93SNicolas Vasilache 
147465678d93SNicolas Vasilache     int64_t offset =
147565678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
147665678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
147765678d93SNicolas Vasilache     int64_t stride =
147865678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
147965678d93SNicolas Vasilache 
148065678d93SNicolas Vasilache     auto loc = op.getLoc();
148165678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
148235b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1483c3c95b9cSaartbik 
1484c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1485c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1486c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1487c3c95b9cSaartbik       offsets.reserve(size);
1488c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1489c3c95b9cSaartbik            off += stride)
1490c3c95b9cSaartbik         offsets.push_back(off);
1491c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1492c3c95b9cSaartbik                                              op.vector(),
1493c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1494c3c95b9cSaartbik       return success();
1495c3c95b9cSaartbik     }
1496c3c95b9cSaartbik 
1497c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
149865678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
149965678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
150065678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
150165678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
150265678d93SNicolas Vasilache          off += stride, ++idx) {
1503c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1504c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1505c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
150665678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
150765678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
150865678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
150965678d93SNicolas Vasilache     }
1510c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15113145427dSRiver Riddle     return success();
151265678d93SNicolas Vasilache   }
151365678d93SNicolas Vasilache };
151465678d93SNicolas Vasilache 
1515df186507SBenjamin Kramer } // namespace
1516df186507SBenjamin Kramer 
15175c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15185c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1519ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1520060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
152165678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15228345b86dSNicolas Vasilache   // clang-format off
1523681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1524681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15252d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1526c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1527ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1528563879b6SRahul Joshi       converter, reassociateFPReductions);
1529060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1530060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1531060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1532563879b6SRahul Joshi       converter, enableIndexOptimizations);
15338345b86dSNicolas Vasilache   patterns
1534cf5c517cSDiego Caballero       .insert<VectorBitCastOpConversion,
1535cf5c517cSDiego Caballero               VectorShuffleOpConversion,
15368345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15378345b86dSNicolas Vasilache               VectorExtractOpConversion,
15388345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15398345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15408345b86dSNicolas Vasilache               VectorInsertOpConversion,
15418345b86dSNicolas Vasilache               VectorPrintOpConversion,
154219dbb230Saartbik               VectorTypeCastOpConversion,
154339379916Saartbik               VectorMaskedLoadOpConversion,
154439379916Saartbik               VectorMaskedStoreOpConversion,
154519dbb230Saartbik               VectorGatherOpConversion,
1546e8dcf5f8Saartbik               VectorScatterOpConversion,
1547e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1548563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15498345b86dSNicolas Vasilache   // clang-format on
15505c0c51a9SNicolas Vasilache }
15515c0c51a9SNicolas Vasilache 
155263b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
155363b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1554563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1555563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
155663b683a8SNicolas Vasilache }
1557