15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
14626c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
14726c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1485f9e0466SNicolas Vasilache   if (!elementTy)
1495f9e0466SNicolas Vasilache     return failure();
1505f9e0466SNicolas Vasilache 
151b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
152b2ab375dSAlex Zinenko   // stop depending on translation.
15387a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15487a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
155c69c9e0fSAlex Zinenko               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
1565f9e0466SNicolas Vasilache   return success();
1575f9e0466SNicolas Vasilache }
1585f9e0466SNicolas Vasilache 
159e8dcf5f8Saartbik // Helper that returns the base address of a memref.
160b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
161e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16219dbb230Saartbik   // Inspect stride and offset structure.
16319dbb230Saartbik   //
16419dbb230Saartbik   // TODO: flat memory only for now, generalize
16519dbb230Saartbik   //
16619dbb230Saartbik   int64_t offset;
16719dbb230Saartbik   SmallVector<int64_t, 4> strides;
16819dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
16919dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17019dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17119dbb230Saartbik     return failure();
172e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
173e8dcf5f8Saartbik   return success();
174e8dcf5f8Saartbik }
17519dbb230Saartbik 
176*a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector.
177b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
178b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
179b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
180b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
181e8dcf5f8Saartbik   Value base;
182e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
183e8dcf5f8Saartbik     return failure();
1843a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
1857ed9cfc7SAlex Zinenko   auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0));
1861485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
18719dbb230Saartbik   return success();
18819dbb230Saartbik }
18919dbb230Saartbik 
190*a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer
191*a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be
192*a57def30SAart Bik // used when source/dst memrefs are not on address space 0.
193*a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
194*a57def30SAart Bik                          Value ptr, MemRefType memRefType, Type vt) {
195*a57def30SAart Bik   auto pType =
196*a57def30SAart Bik       LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
197*a57def30SAart Bik   if (memRefType.getMemorySpace() == 0)
198*a57def30SAart Bik     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
199*a57def30SAart Bik   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
200*a57def30SAart Bik }
201*a57def30SAart Bik 
2025f9e0466SNicolas Vasilache static LogicalResult
2035f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2045f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2055f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2065f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
207affbc0cdSNicolas Vasilache   unsigned align;
20826c8f908SThomas Raoux   if (failed(getMemRefAlignment(
20926c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
210affbc0cdSNicolas Vasilache     return failure();
211affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2125f9e0466SNicolas Vasilache   return success();
2135f9e0466SNicolas Vasilache }
2145f9e0466SNicolas Vasilache 
2155f9e0466SNicolas Vasilache static LogicalResult
2165f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2175f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2185f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2195f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2205f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2215f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2225f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2235f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2245f9e0466SNicolas Vasilache 
2255f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2265f9e0466SNicolas Vasilache   if (!vecTy)
2275f9e0466SNicolas Vasilache     return failure();
2285f9e0466SNicolas Vasilache 
2295f9e0466SNicolas Vasilache   unsigned align;
23026c8f908SThomas Raoux   if (failed(getMemRefAlignment(
23126c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2325f9e0466SNicolas Vasilache     return failure();
2335f9e0466SNicolas Vasilache 
2345f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2355f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2365f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2375f9e0466SNicolas Vasilache   return success();
2385f9e0466SNicolas Vasilache }
2395f9e0466SNicolas Vasilache 
2405f9e0466SNicolas Vasilache static LogicalResult
2415f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2425f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2435f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2445f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
245affbc0cdSNicolas Vasilache   unsigned align;
24626c8f908SThomas Raoux   if (failed(getMemRefAlignment(
24726c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
248affbc0cdSNicolas Vasilache     return failure();
2492d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
250affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
251affbc0cdSNicolas Vasilache                                              align);
2525f9e0466SNicolas Vasilache   return success();
2535f9e0466SNicolas Vasilache }
2545f9e0466SNicolas Vasilache 
2555f9e0466SNicolas Vasilache static LogicalResult
2565f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2575f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2585f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2595f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2605f9e0466SNicolas Vasilache   unsigned align;
26126c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26226c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2635f9e0466SNicolas Vasilache     return failure();
2645f9e0466SNicolas Vasilache 
2652d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2665f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2675f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2685f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2695f9e0466SNicolas Vasilache   return success();
2705f9e0466SNicolas Vasilache }
2715f9e0466SNicolas Vasilache 
2722d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2732d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2742d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2755f9e0466SNicolas Vasilache }
2765f9e0466SNicolas Vasilache 
2772d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2782d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2792d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2805f9e0466SNicolas Vasilache }
2815f9e0466SNicolas Vasilache 
28290c01357SBenjamin Kramer namespace {
283e83b7b99Saartbik 
28463b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
28563b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
286563879b6SRahul Joshi class VectorMatmulOpConversion
287563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
28863b683a8SNicolas Vasilache public:
289563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
29063b683a8SNicolas Vasilache 
2913145427dSRiver Riddle   LogicalResult
292563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
29363b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
2942d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
29563b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
296563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
297563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
298563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
2993145427dSRiver Riddle     return success();
30063b683a8SNicolas Vasilache   }
30163b683a8SNicolas Vasilache };
30263b683a8SNicolas Vasilache 
303c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
304c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
305563879b6SRahul Joshi class VectorFlatTransposeOpConversion
306563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
307c295a65dSaartbik public:
308563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
309c295a65dSaartbik 
310c295a65dSaartbik   LogicalResult
311563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
312c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3132d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
314c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
315dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
316c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
317c295a65dSaartbik     return success();
318c295a65dSaartbik   }
319c295a65dSaartbik };
320c295a65dSaartbik 
32139379916Saartbik /// Conversion pattern for a vector.maskedload.
322563879b6SRahul Joshi class VectorMaskedLoadOpConversion
323563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
32439379916Saartbik public:
325563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
32639379916Saartbik 
32739379916Saartbik   LogicalResult
328563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
32939379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
330563879b6SRahul Joshi     auto loc = load->getLoc();
33139379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
332*a57def30SAart Bik     MemRefType memRefType = load.getMemRefType();
33339379916Saartbik 
33439379916Saartbik     // Resolve alignment.
33539379916Saartbik     unsigned align;
336*a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
33739379916Saartbik       return failure();
33839379916Saartbik 
339*a57def30SAart Bik     // Resolve address.
340dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
341*a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
342*a57def30SAart Bik                                                adaptor.indices(), rewriter);
343*a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
34439379916Saartbik 
34539379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
34639379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
34739379916Saartbik         rewriter.getI32IntegerAttr(align));
34839379916Saartbik     return success();
34939379916Saartbik   }
35039379916Saartbik };
35139379916Saartbik 
35239379916Saartbik /// Conversion pattern for a vector.maskedstore.
353563879b6SRahul Joshi class VectorMaskedStoreOpConversion
354563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
35539379916Saartbik public:
356563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
35739379916Saartbik 
35839379916Saartbik   LogicalResult
359563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
36039379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
361563879b6SRahul Joshi     auto loc = store->getLoc();
36239379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
363*a57def30SAart Bik     MemRefType memRefType = store.getMemRefType();
36439379916Saartbik 
36539379916Saartbik     // Resolve alignment.
36639379916Saartbik     unsigned align;
367*a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
36839379916Saartbik       return failure();
36939379916Saartbik 
370*a57def30SAart Bik     // Resolve address.
371dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
372*a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
373*a57def30SAart Bik                                                adaptor.indices(), rewriter);
374*a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
37539379916Saartbik 
37639379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
37739379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
37839379916Saartbik         rewriter.getI32IntegerAttr(align));
37939379916Saartbik     return success();
38039379916Saartbik   }
38139379916Saartbik };
38239379916Saartbik 
38319dbb230Saartbik /// Conversion pattern for a vector.gather.
384563879b6SRahul Joshi class VectorGatherOpConversion
385563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
38619dbb230Saartbik public:
387563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
38819dbb230Saartbik 
38919dbb230Saartbik   LogicalResult
390563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
39119dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
392563879b6SRahul Joshi     auto loc = gather->getLoc();
39319dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
39419dbb230Saartbik 
39519dbb230Saartbik     // Resolve alignment.
39619dbb230Saartbik     unsigned align;
39726c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
39826c8f908SThomas Raoux                                   align)))
39919dbb230Saartbik       return failure();
40019dbb230Saartbik 
40119dbb230Saartbik     // Get index ptrs.
40219dbb230Saartbik     VectorType vType = gather.getResultVectorType();
40319dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
40419dbb230Saartbik     Value ptrs;
405e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
406e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
40719dbb230Saartbik       return failure();
40819dbb230Saartbik 
40919dbb230Saartbik     // Replace with the gather intrinsic.
41019dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
411dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4120c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
41319dbb230Saartbik     return success();
41419dbb230Saartbik   }
41519dbb230Saartbik };
41619dbb230Saartbik 
41719dbb230Saartbik /// Conversion pattern for a vector.scatter.
418563879b6SRahul Joshi class VectorScatterOpConversion
419563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
42019dbb230Saartbik public:
421563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
42219dbb230Saartbik 
42319dbb230Saartbik   LogicalResult
424563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
42519dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
426563879b6SRahul Joshi     auto loc = scatter->getLoc();
42719dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
42819dbb230Saartbik 
42919dbb230Saartbik     // Resolve alignment.
43019dbb230Saartbik     unsigned align;
43126c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
43226c8f908SThomas Raoux                                   align)))
43319dbb230Saartbik       return failure();
43419dbb230Saartbik 
43519dbb230Saartbik     // Get index ptrs.
43619dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
43719dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
43819dbb230Saartbik     Value ptrs;
439e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
440e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
44119dbb230Saartbik       return failure();
44219dbb230Saartbik 
44319dbb230Saartbik     // Replace with the scatter intrinsic.
44419dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
44519dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
44619dbb230Saartbik         rewriter.getI32IntegerAttr(align));
44719dbb230Saartbik     return success();
44819dbb230Saartbik   }
44919dbb230Saartbik };
45019dbb230Saartbik 
451e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
452563879b6SRahul Joshi class VectorExpandLoadOpConversion
453563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
454e8dcf5f8Saartbik public:
455563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
456e8dcf5f8Saartbik 
457e8dcf5f8Saartbik   LogicalResult
458563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
459e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
460563879b6SRahul Joshi     auto loc = expand->getLoc();
461e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
462*a57def30SAart Bik     MemRefType memRefType = expand.getMemRefType();
463e8dcf5f8Saartbik 
464*a57def30SAart Bik     // Resolve address.
465*a57def30SAart Bik     auto vtype = typeConverter->convertType(expand.getResultVectorType());
466*a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
467*a57def30SAart Bik                                            adaptor.indices(), rewriter);
468e8dcf5f8Saartbik 
469e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
470*a57def30SAart Bik         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
471e8dcf5f8Saartbik     return success();
472e8dcf5f8Saartbik   }
473e8dcf5f8Saartbik };
474e8dcf5f8Saartbik 
475e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
476563879b6SRahul Joshi class VectorCompressStoreOpConversion
477563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
478e8dcf5f8Saartbik public:
479563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
480e8dcf5f8Saartbik 
481e8dcf5f8Saartbik   LogicalResult
482563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
483e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
484563879b6SRahul Joshi     auto loc = compress->getLoc();
485e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
486*a57def30SAart Bik     MemRefType memRefType = compress.getMemRefType();
487e8dcf5f8Saartbik 
488*a57def30SAart Bik     // Resolve address.
489*a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
490*a57def30SAart Bik                                            adaptor.indices(), rewriter);
491e8dcf5f8Saartbik 
492e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
493563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
494e8dcf5f8Saartbik     return success();
495e8dcf5f8Saartbik   }
496e8dcf5f8Saartbik };
497e8dcf5f8Saartbik 
49819dbb230Saartbik /// Conversion pattern for all vector reductions.
499563879b6SRahul Joshi class VectorReductionOpConversion
500563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
501e83b7b99Saartbik public:
502563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
503060c9dd1Saartbik                                        bool reassociateFPRed)
504563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
505060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
506e83b7b99Saartbik 
5073145427dSRiver Riddle   LogicalResult
508563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
509e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
510e83b7b99Saartbik     auto kind = reductionOp.kind();
511e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
512dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
513e9628955SAart Bik     if (eltType.isIntOrIndex()) {
514e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
515e83b7b99Saartbik       if (kind == "add")
516322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
517563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
518e83b7b99Saartbik       else if (kind == "mul")
519322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
520563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
521e9628955SAart Bik       else if (kind == "min" &&
522e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
523322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
524563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
525e83b7b99Saartbik       else if (kind == "min")
526322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
527563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
528e9628955SAart Bik       else if (kind == "max" &&
529e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
530322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
531563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
532e83b7b99Saartbik       else if (kind == "max")
533322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
534563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
535e83b7b99Saartbik       else if (kind == "and")
536322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
537563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
538e83b7b99Saartbik       else if (kind == "or")
539322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
540563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
541e83b7b99Saartbik       else if (kind == "xor")
542322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
543563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
544e83b7b99Saartbik       else
5453145427dSRiver Riddle         return failure();
5463145427dSRiver Riddle       return success();
547dcec2ca5SChristian Sigg     }
548e83b7b99Saartbik 
549dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
550dcec2ca5SChristian Sigg       return failure();
551dcec2ca5SChristian Sigg 
552e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
553e83b7b99Saartbik     if (kind == "add") {
5540d924700Saartbik       // Optional accumulator (or zero).
5550d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5560d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
557563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5580d924700Saartbik                                             rewriter.getZeroAttr(eltType));
559322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
560563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
561ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
562e83b7b99Saartbik     } else if (kind == "mul") {
5630d924700Saartbik       // Optional accumulator (or one).
5640d924700Saartbik       Value acc = operands.size() > 1
5650d924700Saartbik                       ? operands[1]
5660d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
567563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
5680d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
569322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
570563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
571ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
572e83b7b99Saartbik     } else if (kind == "min")
573563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
574563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
575e83b7b99Saartbik     else if (kind == "max")
576563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
577563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
578e83b7b99Saartbik     else
5793145427dSRiver Riddle       return failure();
5803145427dSRiver Riddle     return success();
581e83b7b99Saartbik   }
582ceb1b327Saartbik 
583ceb1b327Saartbik private:
584ceb1b327Saartbik   const bool reassociateFPReductions;
585e83b7b99Saartbik };
586e83b7b99Saartbik 
587060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
588563879b6SRahul Joshi class VectorCreateMaskOpConversion
589563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
590060c9dd1Saartbik public:
591563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
592060c9dd1Saartbik                                         bool enableIndexOpt)
593563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
594060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
595060c9dd1Saartbik 
596060c9dd1Saartbik   LogicalResult
597563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
598060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
5999eb3e564SChris Lattner     auto dstType = op.getType();
600060c9dd1Saartbik     int64_t rank = dstType.getRank();
601060c9dd1Saartbik     if (rank == 1) {
602060c9dd1Saartbik       rewriter.replaceOp(
603060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
604060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
605060c9dd1Saartbik       return success();
606060c9dd1Saartbik     }
607060c9dd1Saartbik     return failure();
608060c9dd1Saartbik   }
609060c9dd1Saartbik 
610060c9dd1Saartbik private:
611060c9dd1Saartbik   const bool enableIndexOptimizations;
612060c9dd1Saartbik };
613060c9dd1Saartbik 
614563879b6SRahul Joshi class VectorShuffleOpConversion
615563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6161c81adf3SAart Bik public:
617563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6181c81adf3SAart Bik 
6193145427dSRiver Riddle   LogicalResult
620563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6211c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
622563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6232d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6241c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6251c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6261c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
627dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6281c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6291c81adf3SAart Bik 
6301c81adf3SAart Bik     // Bail if result type cannot be lowered.
6311c81adf3SAart Bik     if (!llvmType)
6323145427dSRiver Riddle       return failure();
6331c81adf3SAart Bik 
6341c81adf3SAart Bik     // Get rank and dimension sizes.
6351c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6361c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6371c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6381c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6391c81adf3SAart Bik 
6401c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6411c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6421c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
643563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6441c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
645563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6463145427dSRiver Riddle       return success();
647b36aaeafSAart Bik     }
648b36aaeafSAart Bik 
6491c81adf3SAart Bik     // For all other cases, insert the individual values individually.
650e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6511c81adf3SAart Bik     int64_t insPos = 0;
6521c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6531c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
654e62a6956SRiver Riddle       Value value = adaptor.v1();
6551c81adf3SAart Bik       if (extPos >= v1Dim) {
6561c81adf3SAart Bik         extPos -= v1Dim;
6571c81adf3SAart Bik         value = adaptor.v2();
658b36aaeafSAart Bik       }
659dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
660dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
661dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
6620f04384dSAlex Zinenko                          llvmType, rank, insPos++);
6631c81adf3SAart Bik     }
664563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
6653145427dSRiver Riddle     return success();
666b36aaeafSAart Bik   }
667b36aaeafSAart Bik };
668b36aaeafSAart Bik 
669563879b6SRahul Joshi class VectorExtractElementOpConversion
670563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
671cd5dab8aSAart Bik public:
672563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
673563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
674cd5dab8aSAart Bik 
6753145427dSRiver Riddle   LogicalResult
676563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
677563879b6SRahul Joshi                   ArrayRef<Value> operands,
678cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
6792d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
680cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
681dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
682cd5dab8aSAart Bik 
683cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
684cd5dab8aSAart Bik     if (!llvmType)
6853145427dSRiver Riddle       return failure();
686cd5dab8aSAart Bik 
687cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
688563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
6893145427dSRiver Riddle     return success();
690cd5dab8aSAart Bik   }
691cd5dab8aSAart Bik };
692cd5dab8aSAart Bik 
693563879b6SRahul Joshi class VectorExtractOpConversion
694563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
6955c0c51a9SNicolas Vasilache public:
696563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
6975c0c51a9SNicolas Vasilache 
6983145427dSRiver Riddle   LogicalResult
699563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7005c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
701563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7022d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7039826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7042bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
705dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7065c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7079826fe5cSAart Bik 
7089826fe5cSAart Bik     // Bail if result type cannot be lowered.
7099826fe5cSAart Bik     if (!llvmResultType)
7103145427dSRiver Riddle       return failure();
7119826fe5cSAart Bik 
7125c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7135c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
714e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7155c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
716563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7173145427dSRiver Riddle       return success();
7185c0c51a9SNicolas Vasilache     }
7195c0c51a9SNicolas Vasilache 
7209826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
721563879b6SRahul Joshi     auto *context = extractOp->getContext();
722e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7235c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7245c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7259826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7265c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7275c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7285c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
729dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7305c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7315c0c51a9SNicolas Vasilache     }
7325c0c51a9SNicolas Vasilache 
7335c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7345c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7352230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
7361d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7375c0c51a9SNicolas Vasilache     extracted =
7385c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
739563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7405c0c51a9SNicolas Vasilache 
7413145427dSRiver Riddle     return success();
7425c0c51a9SNicolas Vasilache   }
7435c0c51a9SNicolas Vasilache };
7445c0c51a9SNicolas Vasilache 
745681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
746681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
747681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
748681f929fSNicolas Vasilache ///
749681f929fSNicolas Vasilache /// Example:
750681f929fSNicolas Vasilache /// ```
751681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
752681f929fSNicolas Vasilache /// ```
753681f929fSNicolas Vasilache /// is converted to:
754681f929fSNicolas Vasilache /// ```
7553bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
756dd5165a9SAlex Zinenko ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
757dd5165a9SAlex Zinenko ///    -> !llvm."<8 x f32>">
758681f929fSNicolas Vasilache /// ```
759563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
760681f929fSNicolas Vasilache public:
761563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
762681f929fSNicolas Vasilache 
7633145427dSRiver Riddle   LogicalResult
764563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
765681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7662d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
767681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
768681f929fSNicolas Vasilache     if (vType.getRank() != 1)
7693145427dSRiver Riddle       return failure();
770563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
7713bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
7723145427dSRiver Riddle     return success();
773681f929fSNicolas Vasilache   }
774681f929fSNicolas Vasilache };
775681f929fSNicolas Vasilache 
776563879b6SRahul Joshi class VectorInsertElementOpConversion
777563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
778cd5dab8aSAart Bik public:
779563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
780cd5dab8aSAart Bik 
7813145427dSRiver Riddle   LogicalResult
782563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
783cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7842d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
785cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
786dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
787cd5dab8aSAart Bik 
788cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
789cd5dab8aSAart Bik     if (!llvmType)
7903145427dSRiver Riddle       return failure();
791cd5dab8aSAart Bik 
792cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
793563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
794563879b6SRahul Joshi         adaptor.position());
7953145427dSRiver Riddle     return success();
796cd5dab8aSAart Bik   }
797cd5dab8aSAart Bik };
798cd5dab8aSAart Bik 
799563879b6SRahul Joshi class VectorInsertOpConversion
800563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8019826fe5cSAart Bik public:
802563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8039826fe5cSAart Bik 
8043145427dSRiver Riddle   LogicalResult
805563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8069826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
807563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8082d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8099826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8109826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
811dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8129826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8139826fe5cSAart Bik 
8149826fe5cSAart Bik     // Bail if result type cannot be lowered.
8159826fe5cSAart Bik     if (!llvmResultType)
8163145427dSRiver Riddle       return failure();
8179826fe5cSAart Bik 
8189826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8199826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
820e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8219826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8229826fe5cSAart Bik           positionArrayAttr);
823563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8243145427dSRiver Riddle       return success();
8259826fe5cSAart Bik     }
8269826fe5cSAart Bik 
8279826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
828563879b6SRahul Joshi     auto *context = insertOp->getContext();
829e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8309826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8319826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8329826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8339826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8349826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8359826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8369826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8379826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
838dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8399826fe5cSAart Bik           nMinusOnePositionAttrs);
8409826fe5cSAart Bik     }
8419826fe5cSAart Bik 
8429826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8432230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
8441d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
845e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
846dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8470f04384dSAlex Zinenko         adaptor.source(), constant);
8489826fe5cSAart Bik 
8499826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8509826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8519826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8529826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8539826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8549826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8559826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8569826fe5cSAart Bik     }
8579826fe5cSAart Bik 
858563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8593145427dSRiver Riddle     return success();
8609826fe5cSAart Bik   }
8619826fe5cSAart Bik };
8629826fe5cSAart Bik 
863681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
864681f929fSNicolas Vasilache ///
865681f929fSNicolas Vasilache /// Example:
866681f929fSNicolas Vasilache /// ```
867681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
868681f929fSNicolas Vasilache /// ```
869681f929fSNicolas Vasilache /// is rewritten into:
870681f929fSNicolas Vasilache /// ```
871681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
872681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
873681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
874681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
875681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
876681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
877681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
878681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
879681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
880681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
881681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
882681f929fSNicolas Vasilache ///  // %r3 holds the final value.
883681f929fSNicolas Vasilache /// ```
884681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
885681f929fSNicolas Vasilache public:
886681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
887681f929fSNicolas Vasilache 
8883145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
889681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
890681f929fSNicolas Vasilache     auto vType = op.getVectorType();
891681f929fSNicolas Vasilache     if (vType.getRank() < 2)
8923145427dSRiver Riddle       return failure();
893681f929fSNicolas Vasilache 
894681f929fSNicolas Vasilache     auto loc = op.getLoc();
895681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
896681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
897681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
898681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
899681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
900681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
901681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
902681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
903681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
904681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
905681f929fSNicolas Vasilache     }
906681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9073145427dSRiver Riddle     return success();
908681f929fSNicolas Vasilache   }
909681f929fSNicolas Vasilache };
910681f929fSNicolas Vasilache 
9112d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9122d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9132d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9142d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9152d515e49SNicolas Vasilache // rank.
9162d515e49SNicolas Vasilache //
9172d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9182d515e49SNicolas Vasilache // have different ranks. In this case:
9192d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9202d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9212d515e49SNicolas Vasilache //   destination subvector
9222d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9232d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9242d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9252d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9262d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9272d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9282d515e49SNicolas Vasilache public:
9292d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9302d515e49SNicolas Vasilache 
9313145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9322d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9332d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9342d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9352d515e49SNicolas Vasilache 
9362d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9373145427dSRiver Riddle       return failure();
9382d515e49SNicolas Vasilache 
9392d515e49SNicolas Vasilache     auto loc = op.getLoc();
9402d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9412d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9422d515e49SNicolas Vasilache     if (rankDiff == 0)
9433145427dSRiver Riddle       return failure();
9442d515e49SNicolas Vasilache 
9452d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9462d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9472d515e49SNicolas Vasilache     // on it.
9482d515e49SNicolas Vasilache     Value extracted =
9492d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9502d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
951dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9522d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9532d515e49SNicolas Vasilache     // ranks.
9542d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9552d515e49SNicolas Vasilache         loc, op.source(), extracted,
9562d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
957c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9582d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9592d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9602d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
961dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
9623145427dSRiver Riddle     return success();
9632d515e49SNicolas Vasilache   }
9642d515e49SNicolas Vasilache };
9652d515e49SNicolas Vasilache 
9662d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9672d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
9682d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9692d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9702d515e49SNicolas Vasilache //   destination subvector
9712d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9722d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9732d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9742d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9752d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
9762d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9772d515e49SNicolas Vasilache public:
978b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
979b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
980b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
981b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
982b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
983b99bd771SRiver Riddle   }
9842d515e49SNicolas Vasilache 
9853145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9862d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9872d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9882d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9892d515e49SNicolas Vasilache 
9902d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9913145427dSRiver Riddle       return failure();
9922d515e49SNicolas Vasilache 
9932d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9942d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9952d515e49SNicolas Vasilache     if (rankDiff != 0)
9963145427dSRiver Riddle       return failure();
9972d515e49SNicolas Vasilache 
9982d515e49SNicolas Vasilache     if (srcType == dstType) {
9992d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10003145427dSRiver Riddle       return success();
10012d515e49SNicolas Vasilache     }
10022d515e49SNicolas Vasilache 
10032d515e49SNicolas Vasilache     int64_t offset =
10042d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10052d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10062d515e49SNicolas Vasilache     int64_t stride =
10072d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10082d515e49SNicolas Vasilache 
10092d515e49SNicolas Vasilache     auto loc = op.getLoc();
10102d515e49SNicolas Vasilache     Value res = op.dest();
10112d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10122d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10132d515e49SNicolas Vasilache          off += stride, ++idx) {
10142d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10152d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10162d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10172d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10182d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10192d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10202d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10212d515e49SNicolas Vasilache         // smaller rank.
1022bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10232d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10242d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10252d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10262d515e49SNicolas Vasilache       }
10272d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10282d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10292d515e49SNicolas Vasilache     }
10302d515e49SNicolas Vasilache 
10312d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10323145427dSRiver Riddle     return success();
10332d515e49SNicolas Vasilache   }
10342d515e49SNicolas Vasilache };
10352d515e49SNicolas Vasilache 
103630e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
103730e6033bSNicolas Vasilache /// static layout.
103830e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
103930e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10402bf491c7SBenjamin Kramer   int64_t offset;
104130e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
104230e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
104330e6033bSNicolas Vasilache     return None;
104430e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
104530e6033bSNicolas Vasilache     return None;
104630e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
104730e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
104830e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
104930e6033bSNicolas Vasilache     return strides;
105030e6033bSNicolas Vasilache 
105130e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
105230e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
105330e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
105430e6033bSNicolas Vasilache   // layout.
10552bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10562bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
105730e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
105830e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
105930e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
106030e6033bSNicolas Vasilache       return None;
106130e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
106230e6033bSNicolas Vasilache       return None;
10632bf491c7SBenjamin Kramer   }
106430e6033bSNicolas Vasilache   return strides;
10652bf491c7SBenjamin Kramer }
10662bf491c7SBenjamin Kramer 
1067563879b6SRahul Joshi class VectorTypeCastOpConversion
1068563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
10695c0c51a9SNicolas Vasilache public:
1070563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
10715c0c51a9SNicolas Vasilache 
10723145427dSRiver Riddle   LogicalResult
1073563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
10745c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1075563879b6SRahul Joshi     auto loc = castOp->getLoc();
10765c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
10772bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
10789eb3e564SChris Lattner     MemRefType targetMemRefType = castOp.getType();
10795c0c51a9SNicolas Vasilache 
10805c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
10815c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
10825c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
10833145427dSRiver Riddle       return failure();
10845c0c51a9SNicolas Vasilache 
10855c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
10868de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
10878de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
10883145427dSRiver Riddle       return failure();
10895c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
10905c0c51a9SNicolas Vasilache 
1091dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
10928de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
10938de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
10943145427dSRiver Riddle       return failure();
10955c0c51a9SNicolas Vasilache 
109630e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
109730e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
109830e6033bSNicolas Vasilache     if (!sourceStrides)
109930e6033bSNicolas Vasilache       return failure();
110030e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
110130e6033bSNicolas Vasilache     if (!targetStrides)
110230e6033bSNicolas Vasilache       return failure();
110330e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
110430e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
110530e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
110630e6033bSNicolas Vasilache         }))
11073145427dSRiver Riddle       return failure();
11085c0c51a9SNicolas Vasilache 
11092230bf99SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
11105c0c51a9SNicolas Vasilache 
11115c0c51a9SNicolas Vasilache     // Create descriptor.
11125c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11133a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11145c0c51a9SNicolas Vasilache     // Set allocated ptr.
1115e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11165c0c51a9SNicolas Vasilache     allocated =
11175c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11185c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11195c0c51a9SNicolas Vasilache     // Set aligned ptr.
1120e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11215c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11225c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11235c0c51a9SNicolas Vasilache     // Fill offset 0.
11245c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11255c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11265c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11275c0c51a9SNicolas Vasilache 
11285c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11295c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11305c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11315c0c51a9SNicolas Vasilache       auto sizeAttr =
11325c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11335c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11345c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
113530e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
113630e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11375c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11385c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11395c0c51a9SNicolas Vasilache     }
11405c0c51a9SNicolas Vasilache 
1141563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11423145427dSRiver Riddle     return success();
11435c0c51a9SNicolas Vasilache   }
11445c0c51a9SNicolas Vasilache };
11455c0c51a9SNicolas Vasilache 
11468345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11478345b86dSNicolas Vasilache /// sequence of:
1148060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1149060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1150060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1151060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1152060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11538345b86dSNicolas Vasilache template <typename ConcreteOp>
1154563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11558345b86dSNicolas Vasilache public:
1156563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1157060c9dd1Saartbik                                     bool enableIndexOpt)
1158563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1159060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11608345b86dSNicolas Vasilache 
11618345b86dSNicolas Vasilache   LogicalResult
1162563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
11638345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11648345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1165b2c79c50SNicolas Vasilache 
1166b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1167b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
11688345b86dSNicolas Vasilache       return failure();
11695f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
11705f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
11715f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1172563879b6SRahul Joshi                                        xferOp->getContext()))
11738345b86dSNicolas Vasilache       return failure();
117426c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
117526c8f908SThomas Raoux     if (!memRefType)
117626c8f908SThomas Raoux       return failure();
11772bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
117826c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
117930e6033bSNicolas Vasilache     if (!strides)
11802bf491c7SBenjamin Kramer       return failure();
11818345b86dSNicolas Vasilache 
1182563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1183563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1184563879b6SRahul Joshi     };
11858345b86dSNicolas Vasilache 
1186563879b6SRahul Joshi     Location loc = xferOp->getLoc();
11878345b86dSNicolas Vasilache 
118868330ee0SThomas Raoux     if (auto memrefVectorElementType =
118926c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
119068330ee0SThomas Raoux       // Memref has vector element type.
119168330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
119268330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
119368330ee0SThomas Raoux         return failure();
11940de60b55SThomas Raoux #ifndef NDEBUG
119568330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
119668330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
119768330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
119868330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
119968330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
120068330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
120168330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
120268330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
120368330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
120468330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
120568330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
120668330ee0SThomas Raoux                "result shape.");
12070de60b55SThomas Raoux #endif // ifndef NDEBUG
120868330ee0SThomas Raoux     }
120968330ee0SThomas Raoux 
12108345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1211*a57def30SAart Bik     VectorType vtp = xferOp.getVectorType();
1212563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
121326c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1214*a57def30SAart Bik     Value vectorDataPtr =
1215*a57def30SAart Bik         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
12168345b86dSNicolas Vasilache 
12171870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1218563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1219563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1220563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12211870e787SNicolas Vasilache 
12228345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12238345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12248345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12258345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1226060c9dd1Saartbik     //
1227060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1228060c9dd1Saartbik     //       dimensions here.
1229*a57def30SAart Bik     unsigned vecWidth = vtp.getNumElements();
1230060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12310c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
123226c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1233563879b6SRahul Joshi     Value mask = buildVectorComparison(
1234563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12358345b86dSNicolas Vasilache 
12368345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1237563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1238dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12398345b86dSNicolas Vasilache   }
1240060c9dd1Saartbik 
1241060c9dd1Saartbik private:
1242060c9dd1Saartbik   const bool enableIndexOptimizations;
12438345b86dSNicolas Vasilache };
12448345b86dSNicolas Vasilache 
1245563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1246d9b500d3SAart Bik public:
1247563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1248d9b500d3SAart Bik 
1249d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1250d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1251d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1252d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1253d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1254d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1255d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1256d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1257d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1258d9b500d3SAart Bik   //
12599db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1260d9b500d3SAart Bik   //
12613145427dSRiver Riddle   LogicalResult
1262563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1263d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
12642d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1265d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1266d9b500d3SAart Bik 
1267dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
12683145427dSRiver Riddle       return failure();
1269d9b500d3SAart Bik 
1270b8880f5fSAart Bik     // Make sure element type has runtime support.
1271b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1272d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1273d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1274d9b500d3SAart Bik     Operation *printer;
1275b8880f5fSAart Bik     if (eltType.isF32()) {
1276563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1277b8880f5fSAart Bik     } else if (eltType.isF64()) {
1278563879b6SRahul Joshi       printer = getPrintDouble(printOp);
127954759cefSAart Bik     } else if (eltType.isIndex()) {
1280563879b6SRahul Joshi       printer = getPrintU64(printOp);
1281b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1282b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1283b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1284b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1285b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1286b8880f5fSAart Bik       if (intTy.isUnsigned()) {
128754759cefSAart Bik         if (width <= 64) {
1288b8880f5fSAart Bik           if (width < 64)
1289b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1290563879b6SRahul Joshi           printer = getPrintU64(printOp);
1291b8880f5fSAart Bik         } else {
12923145427dSRiver Riddle           return failure();
1293b8880f5fSAart Bik         }
1294b8880f5fSAart Bik       } else {
1295b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
129654759cefSAart Bik         if (width <= 64) {
1297b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1298b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1299b8880f5fSAart Bik           if (width == 1)
130054759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
130154759cefSAart Bik           else if (width < 64)
1302b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1303563879b6SRahul Joshi           printer = getPrintI64(printOp);
1304b8880f5fSAart Bik         } else {
1305b8880f5fSAart Bik           return failure();
1306b8880f5fSAart Bik         }
1307b8880f5fSAart Bik       }
1308b8880f5fSAart Bik     } else {
1309b8880f5fSAart Bik       return failure();
1310b8880f5fSAart Bik     }
1311d9b500d3SAart Bik 
1312d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1313b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1314563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1315b8880f5fSAart Bik               conversion);
1316563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1317563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13183145427dSRiver Riddle     return success();
1319d9b500d3SAart Bik   }
1320d9b500d3SAart Bik 
1321d9b500d3SAart Bik private:
1322b8880f5fSAart Bik   enum class PrintConversion {
132330e6033bSNicolas Vasilache     // clang-format off
1324b8880f5fSAart Bik     None,
1325b8880f5fSAart Bik     ZeroExt64,
1326b8880f5fSAart Bik     SignExt64
132730e6033bSNicolas Vasilache     // clang-format on
1328b8880f5fSAart Bik   };
1329b8880f5fSAart Bik 
1330d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1331e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1332b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1333d9b500d3SAart Bik     Location loc = op->getLoc();
1334d9b500d3SAart Bik     if (rank == 0) {
1335b8880f5fSAart Bik       switch (conversion) {
1336b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1337b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
13382230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1339b8880f5fSAart Bik         break;
1340b8880f5fSAart Bik       case PrintConversion::SignExt64:
1341b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
13422230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1343b8880f5fSAart Bik         break;
1344b8880f5fSAart Bik       case PrintConversion::None:
1345b8880f5fSAart Bik         break;
1346c9eeeb38Saartbik       }
1347d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1348d9b500d3SAart Bik       return;
1349d9b500d3SAart Bik     }
1350d9b500d3SAart Bik 
1351d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1352d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1353d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1354d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1355d9b500d3SAart Bik       auto reducedType =
1356d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1357dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1358d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1359dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1360dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1361b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1362b8880f5fSAart Bik                 conversion);
1363d9b500d3SAart Bik       if (d != dim - 1)
1364d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1365d9b500d3SAart Bik     }
1366d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1367d9b500d3SAart Bik   }
1368d9b500d3SAart Bik 
1369d9b500d3SAart Bik   // Helper to emit a call.
1370d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1371d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
137208e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1373d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1374d9b500d3SAart Bik   }
1375d9b500d3SAart Bik 
1376d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
13775446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
1378c69c9e0fSAlex Zinenko                              ArrayRef<Type> params) {
1379d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1380d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1381d9b500d3SAart Bik     if (func)
1382d9b500d3SAart Bik       return func;
1383d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1384d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1385d9b500d3SAart Bik         op->getLoc(), name,
13867ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
13877ed9cfc7SAlex Zinenko                                     params));
1388d9b500d3SAart Bik   }
1389d9b500d3SAart Bik 
1390d9b500d3SAart Bik   // Helpers for method names.
1391e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
13922230bf99SAlex Zinenko     return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1393e52414b1Saartbik   }
1394b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
13952230bf99SAlex Zinenko     return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1396b8880f5fSAart Bik   }
1397d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
1398dd5165a9SAlex Zinenko     return getPrint(op, "printF32", Float32Type::get(op->getContext()));
1399d9b500d3SAart Bik   }
1400d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
1401dd5165a9SAlex Zinenko     return getPrint(op, "printF64", Float64Type::get(op->getContext()));
1402d9b500d3SAart Bik   }
1403d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
140454759cefSAart Bik     return getPrint(op, "printOpen", {});
1405d9b500d3SAart Bik   }
1406d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
140754759cefSAart Bik     return getPrint(op, "printClose", {});
1408d9b500d3SAart Bik   }
1409d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
141054759cefSAart Bik     return getPrint(op, "printComma", {});
1411d9b500d3SAart Bik   }
1412d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
141354759cefSAart Bik     return getPrint(op, "printNewline", {});
1414d9b500d3SAart Bik   }
1415d9b500d3SAart Bik };
1416d9b500d3SAart Bik 
1417334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1418c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1419c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1420c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1421334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
142265678d93SNicolas Vasilache public:
1423b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1424b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1425b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1426b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1427b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1428b99bd771SRiver Riddle   }
142965678d93SNicolas Vasilache 
1430334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
143165678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
14329eb3e564SChris Lattner     auto dstType = op.getType();
143365678d93SNicolas Vasilache 
143465678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
143565678d93SNicolas Vasilache 
143665678d93SNicolas Vasilache     int64_t offset =
143765678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
143865678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
143965678d93SNicolas Vasilache     int64_t stride =
144065678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
144165678d93SNicolas Vasilache 
144265678d93SNicolas Vasilache     auto loc = op.getLoc();
144365678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
144435b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1445c3c95b9cSaartbik 
1446c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1447c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1448c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1449c3c95b9cSaartbik       offsets.reserve(size);
1450c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1451c3c95b9cSaartbik            off += stride)
1452c3c95b9cSaartbik         offsets.push_back(off);
1453c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1454c3c95b9cSaartbik                                              op.vector(),
1455c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1456c3c95b9cSaartbik       return success();
1457c3c95b9cSaartbik     }
1458c3c95b9cSaartbik 
1459c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
146065678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
146165678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
146265678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
146365678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
146465678d93SNicolas Vasilache          off += stride, ++idx) {
1465c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1466c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1467c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
146865678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
146965678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
147065678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
147165678d93SNicolas Vasilache     }
1472c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14733145427dSRiver Riddle     return success();
147465678d93SNicolas Vasilache   }
147565678d93SNicolas Vasilache };
147665678d93SNicolas Vasilache 
1477df186507SBenjamin Kramer } // namespace
1478df186507SBenjamin Kramer 
14795c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
14805c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1481ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1482060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
148365678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
14848345b86dSNicolas Vasilache   // clang-format off
1485681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1486681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
14872d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1488c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1489ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1490563879b6SRahul Joshi       converter, reassociateFPReductions);
1491060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1492060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1493060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1494563879b6SRahul Joshi       converter, enableIndexOptimizations);
14958345b86dSNicolas Vasilache   patterns
1496ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
14978345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
14988345b86dSNicolas Vasilache               VectorExtractOpConversion,
14998345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15008345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15018345b86dSNicolas Vasilache               VectorInsertOpConversion,
15028345b86dSNicolas Vasilache               VectorPrintOpConversion,
150319dbb230Saartbik               VectorTypeCastOpConversion,
150439379916Saartbik               VectorMaskedLoadOpConversion,
150539379916Saartbik               VectorMaskedStoreOpConversion,
150619dbb230Saartbik               VectorGatherOpConversion,
1507e8dcf5f8Saartbik               VectorScatterOpConversion,
1508e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1509563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15108345b86dSNicolas Vasilache   // clang-format on
15115c0c51a9SNicolas Vasilache }
15125c0c51a9SNicolas Vasilache 
151363b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
151463b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1515563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1516563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
151763b683a8SNicolas Vasilache }
1518