15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
14626c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
14726c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1485f9e0466SNicolas Vasilache   if (!elementTy)
1495f9e0466SNicolas Vasilache     return failure();
1505f9e0466SNicolas Vasilache 
151b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
152b2ab375dSAlex Zinenko   // stop depending on translation.
15387a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15487a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
155c69c9e0fSAlex Zinenko               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
1565f9e0466SNicolas Vasilache   return success();
1575f9e0466SNicolas Vasilache }
1585f9e0466SNicolas Vasilache 
159e8dcf5f8Saartbik // Helper that returns the base address of a memref.
160b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
161e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16219dbb230Saartbik   // Inspect stride and offset structure.
16319dbb230Saartbik   //
16419dbb230Saartbik   // TODO: flat memory only for now, generalize
16519dbb230Saartbik   //
16619dbb230Saartbik   int64_t offset;
16719dbb230Saartbik   SmallVector<int64_t, 4> strides;
16819dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
16919dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17019dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17119dbb230Saartbik     return failure();
172e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
173e8dcf5f8Saartbik   return success();
174e8dcf5f8Saartbik }
17519dbb230Saartbik 
176e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
177b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
178b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
179b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
180e8dcf5f8Saartbik   Value base;
181e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
182e8dcf5f8Saartbik     return failure();
1833a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
184e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
185e8dcf5f8Saartbik   return success();
186e8dcf5f8Saartbik }
187e8dcf5f8Saartbik 
18839379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
189b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
190b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
191b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
19239379916Saartbik   Value base;
19339379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
19439379916Saartbik     return failure();
195c69c9e0fSAlex Zinenko   auto pType = LLVM::LLVMPointerType::get(type);
19639379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
19739379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
19839379916Saartbik   return success();
19939379916Saartbik }
20039379916Saartbik 
201e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
202e8dcf5f8Saartbik // vector.
203b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
204b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
205b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
206b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
207e8dcf5f8Saartbik   Value base;
208e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
209e8dcf5f8Saartbik     return failure();
2103a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
2117ed9cfc7SAlex Zinenko   auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0));
2121485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
21319dbb230Saartbik   return success();
21419dbb230Saartbik }
21519dbb230Saartbik 
2165f9e0466SNicolas Vasilache static LogicalResult
2175f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2185f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2195f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2205f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
221affbc0cdSNicolas Vasilache   unsigned align;
22226c8f908SThomas Raoux   if (failed(getMemRefAlignment(
22326c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
224affbc0cdSNicolas Vasilache     return failure();
225affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2265f9e0466SNicolas Vasilache   return success();
2275f9e0466SNicolas Vasilache }
2285f9e0466SNicolas Vasilache 
2295f9e0466SNicolas Vasilache static LogicalResult
2305f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2315f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2325f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2335f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2345f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2355f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2365f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2375f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2385f9e0466SNicolas Vasilache 
2395f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2405f9e0466SNicolas Vasilache   if (!vecTy)
2415f9e0466SNicolas Vasilache     return failure();
2425f9e0466SNicolas Vasilache 
2435f9e0466SNicolas Vasilache   unsigned align;
24426c8f908SThomas Raoux   if (failed(getMemRefAlignment(
24526c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2465f9e0466SNicolas Vasilache     return failure();
2475f9e0466SNicolas Vasilache 
2485f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2495f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2505f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2515f9e0466SNicolas Vasilache   return success();
2525f9e0466SNicolas Vasilache }
2535f9e0466SNicolas Vasilache 
2545f9e0466SNicolas Vasilache static LogicalResult
2555f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2565f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2575f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2585f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
259affbc0cdSNicolas Vasilache   unsigned align;
26026c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26126c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
262affbc0cdSNicolas Vasilache     return failure();
2632d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
264affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265affbc0cdSNicolas Vasilache                                              align);
2665f9e0466SNicolas Vasilache   return success();
2675f9e0466SNicolas Vasilache }
2685f9e0466SNicolas Vasilache 
2695f9e0466SNicolas Vasilache static LogicalResult
2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2715f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2725f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2735f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2745f9e0466SNicolas Vasilache   unsigned align;
27526c8f908SThomas Raoux   if (failed(getMemRefAlignment(
27626c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2775f9e0466SNicolas Vasilache     return failure();
2785f9e0466SNicolas Vasilache 
2792d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2805f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2815f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2825f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2835f9e0466SNicolas Vasilache   return success();
2845f9e0466SNicolas Vasilache }
2855f9e0466SNicolas Vasilache 
2862d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2872d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2882d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2895f9e0466SNicolas Vasilache }
2905f9e0466SNicolas Vasilache 
2912d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2922d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2932d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2945f9e0466SNicolas Vasilache }
2955f9e0466SNicolas Vasilache 
29690c01357SBenjamin Kramer namespace {
297e83b7b99Saartbik 
29863b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
29963b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
300563879b6SRahul Joshi class VectorMatmulOpConversion
301563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
30263b683a8SNicolas Vasilache public:
303563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
30463b683a8SNicolas Vasilache 
3053145427dSRiver Riddle   LogicalResult
306563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
30763b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3082d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
30963b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
310563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
311563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
312563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3133145427dSRiver Riddle     return success();
31463b683a8SNicolas Vasilache   }
31563b683a8SNicolas Vasilache };
31663b683a8SNicolas Vasilache 
317c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
318c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
319563879b6SRahul Joshi class VectorFlatTransposeOpConversion
320563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
321c295a65dSaartbik public:
322563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
323c295a65dSaartbik 
324c295a65dSaartbik   LogicalResult
325563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
326c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3272d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
328c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
329dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
330c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
331c295a65dSaartbik     return success();
332c295a65dSaartbik   }
333c295a65dSaartbik };
334c295a65dSaartbik 
33539379916Saartbik /// Conversion pattern for a vector.maskedload.
336563879b6SRahul Joshi class VectorMaskedLoadOpConversion
337563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
33839379916Saartbik public:
339563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
34039379916Saartbik 
34139379916Saartbik   LogicalResult
342563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
34339379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
344563879b6SRahul Joshi     auto loc = load->getLoc();
34539379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
34639379916Saartbik 
34739379916Saartbik     // Resolve alignment.
34839379916Saartbik     unsigned align;
34926c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
35026c8f908SThomas Raoux                                   align)))
35139379916Saartbik       return failure();
35239379916Saartbik 
353dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
35439379916Saartbik     Value ptr;
35539379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
35639379916Saartbik                           vtype, ptr)))
35739379916Saartbik       return failure();
35839379916Saartbik 
35939379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
36039379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
36139379916Saartbik         rewriter.getI32IntegerAttr(align));
36239379916Saartbik     return success();
36339379916Saartbik   }
36439379916Saartbik };
36539379916Saartbik 
36639379916Saartbik /// Conversion pattern for a vector.maskedstore.
367563879b6SRahul Joshi class VectorMaskedStoreOpConversion
368563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
36939379916Saartbik public:
370563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
37139379916Saartbik 
37239379916Saartbik   LogicalResult
373563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
37439379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
375563879b6SRahul Joshi     auto loc = store->getLoc();
37639379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
37739379916Saartbik 
37839379916Saartbik     // Resolve alignment.
37939379916Saartbik     unsigned align;
38026c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
38126c8f908SThomas Raoux                                   align)))
38239379916Saartbik       return failure();
38339379916Saartbik 
384dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
38539379916Saartbik     Value ptr;
38639379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
38739379916Saartbik                           vtype, ptr)))
38839379916Saartbik       return failure();
38939379916Saartbik 
39039379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
39139379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
39239379916Saartbik         rewriter.getI32IntegerAttr(align));
39339379916Saartbik     return success();
39439379916Saartbik   }
39539379916Saartbik };
39639379916Saartbik 
39719dbb230Saartbik /// Conversion pattern for a vector.gather.
398563879b6SRahul Joshi class VectorGatherOpConversion
399563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
40019dbb230Saartbik public:
401563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
40219dbb230Saartbik 
40319dbb230Saartbik   LogicalResult
404563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
40519dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
406563879b6SRahul Joshi     auto loc = gather->getLoc();
40719dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
40819dbb230Saartbik 
40919dbb230Saartbik     // Resolve alignment.
41019dbb230Saartbik     unsigned align;
41126c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
41226c8f908SThomas Raoux                                   align)))
41319dbb230Saartbik       return failure();
41419dbb230Saartbik 
41519dbb230Saartbik     // Get index ptrs.
41619dbb230Saartbik     VectorType vType = gather.getResultVectorType();
41719dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
41819dbb230Saartbik     Value ptrs;
419e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
420e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
42119dbb230Saartbik       return failure();
42219dbb230Saartbik 
42319dbb230Saartbik     // Replace with the gather intrinsic.
42419dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
425dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4260c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
42719dbb230Saartbik     return success();
42819dbb230Saartbik   }
42919dbb230Saartbik };
43019dbb230Saartbik 
43119dbb230Saartbik /// Conversion pattern for a vector.scatter.
432563879b6SRahul Joshi class VectorScatterOpConversion
433563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
43419dbb230Saartbik public:
435563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
43619dbb230Saartbik 
43719dbb230Saartbik   LogicalResult
438563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
43919dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
440563879b6SRahul Joshi     auto loc = scatter->getLoc();
44119dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
44219dbb230Saartbik 
44319dbb230Saartbik     // Resolve alignment.
44419dbb230Saartbik     unsigned align;
44526c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
44626c8f908SThomas Raoux                                   align)))
44719dbb230Saartbik       return failure();
44819dbb230Saartbik 
44919dbb230Saartbik     // Get index ptrs.
45019dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
45119dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
45219dbb230Saartbik     Value ptrs;
453e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
454e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
45519dbb230Saartbik       return failure();
45619dbb230Saartbik 
45719dbb230Saartbik     // Replace with the scatter intrinsic.
45819dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
45919dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
46019dbb230Saartbik         rewriter.getI32IntegerAttr(align));
46119dbb230Saartbik     return success();
46219dbb230Saartbik   }
46319dbb230Saartbik };
46419dbb230Saartbik 
465e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
466563879b6SRahul Joshi class VectorExpandLoadOpConversion
467563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
468e8dcf5f8Saartbik public:
469563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
470e8dcf5f8Saartbik 
471e8dcf5f8Saartbik   LogicalResult
472563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
473e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
474563879b6SRahul Joshi     auto loc = expand->getLoc();
475e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
476e8dcf5f8Saartbik 
477e8dcf5f8Saartbik     Value ptr;
478e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
479e8dcf5f8Saartbik                           ptr)))
480e8dcf5f8Saartbik       return failure();
481e8dcf5f8Saartbik 
482e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
483e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
484563879b6SRahul Joshi         expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
485e8dcf5f8Saartbik         adaptor.pass_thru());
486e8dcf5f8Saartbik     return success();
487e8dcf5f8Saartbik   }
488e8dcf5f8Saartbik };
489e8dcf5f8Saartbik 
490e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
491563879b6SRahul Joshi class VectorCompressStoreOpConversion
492563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
493e8dcf5f8Saartbik public:
494563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
495e8dcf5f8Saartbik 
496e8dcf5f8Saartbik   LogicalResult
497563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
498e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
499563879b6SRahul Joshi     auto loc = compress->getLoc();
500e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
501e8dcf5f8Saartbik 
502e8dcf5f8Saartbik     Value ptr;
503e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
504e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
505e8dcf5f8Saartbik       return failure();
506e8dcf5f8Saartbik 
507e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
508563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
509e8dcf5f8Saartbik     return success();
510e8dcf5f8Saartbik   }
511e8dcf5f8Saartbik };
512e8dcf5f8Saartbik 
51319dbb230Saartbik /// Conversion pattern for all vector reductions.
514563879b6SRahul Joshi class VectorReductionOpConversion
515563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
516e83b7b99Saartbik public:
517563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
518060c9dd1Saartbik                                        bool reassociateFPRed)
519563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
520060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
521e83b7b99Saartbik 
5223145427dSRiver Riddle   LogicalResult
523563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
524e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
525e83b7b99Saartbik     auto kind = reductionOp.kind();
526e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
527dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
528e9628955SAart Bik     if (eltType.isIntOrIndex()) {
529e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
530e83b7b99Saartbik       if (kind == "add")
531322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
532563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
533e83b7b99Saartbik       else if (kind == "mul")
534322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
535563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
536e9628955SAart Bik       else if (kind == "min" &&
537e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
538322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
539563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
540e83b7b99Saartbik       else if (kind == "min")
541322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
542563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
543e9628955SAart Bik       else if (kind == "max" &&
544e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
545322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
546563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
547e83b7b99Saartbik       else if (kind == "max")
548322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
549563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
550e83b7b99Saartbik       else if (kind == "and")
551322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
552563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
553e83b7b99Saartbik       else if (kind == "or")
554322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
555563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
556e83b7b99Saartbik       else if (kind == "xor")
557322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
558563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
559e83b7b99Saartbik       else
5603145427dSRiver Riddle         return failure();
5613145427dSRiver Riddle       return success();
562dcec2ca5SChristian Sigg     }
563e83b7b99Saartbik 
564dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
565dcec2ca5SChristian Sigg       return failure();
566dcec2ca5SChristian Sigg 
567e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
568e83b7b99Saartbik     if (kind == "add") {
5690d924700Saartbik       // Optional accumulator (or zero).
5700d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5710d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
572563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5730d924700Saartbik                                             rewriter.getZeroAttr(eltType));
574322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
575563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
576ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
577e83b7b99Saartbik     } else if (kind == "mul") {
5780d924700Saartbik       // Optional accumulator (or one).
5790d924700Saartbik       Value acc = operands.size() > 1
5800d924700Saartbik                       ? operands[1]
5810d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
582563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
5830d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
584322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
585563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
586ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
587e83b7b99Saartbik     } else if (kind == "min")
588563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
589563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
590e83b7b99Saartbik     else if (kind == "max")
591563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
592563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
593e83b7b99Saartbik     else
5943145427dSRiver Riddle       return failure();
5953145427dSRiver Riddle     return success();
596e83b7b99Saartbik   }
597ceb1b327Saartbik 
598ceb1b327Saartbik private:
599ceb1b327Saartbik   const bool reassociateFPReductions;
600e83b7b99Saartbik };
601e83b7b99Saartbik 
602060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
603563879b6SRahul Joshi class VectorCreateMaskOpConversion
604563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
605060c9dd1Saartbik public:
606563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
607060c9dd1Saartbik                                         bool enableIndexOpt)
608563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
609060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
610060c9dd1Saartbik 
611060c9dd1Saartbik   LogicalResult
612563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
613060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
6149eb3e564SChris Lattner     auto dstType = op.getType();
615060c9dd1Saartbik     int64_t rank = dstType.getRank();
616060c9dd1Saartbik     if (rank == 1) {
617060c9dd1Saartbik       rewriter.replaceOp(
618060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
619060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
620060c9dd1Saartbik       return success();
621060c9dd1Saartbik     }
622060c9dd1Saartbik     return failure();
623060c9dd1Saartbik   }
624060c9dd1Saartbik 
625060c9dd1Saartbik private:
626060c9dd1Saartbik   const bool enableIndexOptimizations;
627060c9dd1Saartbik };
628060c9dd1Saartbik 
629563879b6SRahul Joshi class VectorShuffleOpConversion
630563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6311c81adf3SAart Bik public:
632563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6331c81adf3SAart Bik 
6343145427dSRiver Riddle   LogicalResult
635563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6361c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
637563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6382d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6391c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6401c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6411c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
642dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6431c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6441c81adf3SAart Bik 
6451c81adf3SAart Bik     // Bail if result type cannot be lowered.
6461c81adf3SAart Bik     if (!llvmType)
6473145427dSRiver Riddle       return failure();
6481c81adf3SAart Bik 
6491c81adf3SAart Bik     // Get rank and dimension sizes.
6501c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6511c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6521c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6531c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6541c81adf3SAart Bik 
6551c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6561c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6571c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
658563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6591c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
660563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6613145427dSRiver Riddle       return success();
662b36aaeafSAart Bik     }
663b36aaeafSAart Bik 
6641c81adf3SAart Bik     // For all other cases, insert the individual values individually.
665e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6661c81adf3SAart Bik     int64_t insPos = 0;
6671c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6681c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
669e62a6956SRiver Riddle       Value value = adaptor.v1();
6701c81adf3SAart Bik       if (extPos >= v1Dim) {
6711c81adf3SAart Bik         extPos -= v1Dim;
6721c81adf3SAart Bik         value = adaptor.v2();
673b36aaeafSAart Bik       }
674dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
675dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
676dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
6770f04384dSAlex Zinenko                          llvmType, rank, insPos++);
6781c81adf3SAart Bik     }
679563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
6803145427dSRiver Riddle     return success();
681b36aaeafSAart Bik   }
682b36aaeafSAart Bik };
683b36aaeafSAart Bik 
684563879b6SRahul Joshi class VectorExtractElementOpConversion
685563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
686cd5dab8aSAart Bik public:
687563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
688563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
689cd5dab8aSAart Bik 
6903145427dSRiver Riddle   LogicalResult
691563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
692563879b6SRahul Joshi                   ArrayRef<Value> operands,
693cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
6942d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
695cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
696dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
697cd5dab8aSAart Bik 
698cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
699cd5dab8aSAart Bik     if (!llvmType)
7003145427dSRiver Riddle       return failure();
701cd5dab8aSAart Bik 
702cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
703563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
7043145427dSRiver Riddle     return success();
705cd5dab8aSAart Bik   }
706cd5dab8aSAart Bik };
707cd5dab8aSAart Bik 
708563879b6SRahul Joshi class VectorExtractOpConversion
709563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7105c0c51a9SNicolas Vasilache public:
711563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7125c0c51a9SNicolas Vasilache 
7133145427dSRiver Riddle   LogicalResult
714563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7155c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
716563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7172d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7189826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7192bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
720dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7215c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7229826fe5cSAart Bik 
7239826fe5cSAart Bik     // Bail if result type cannot be lowered.
7249826fe5cSAart Bik     if (!llvmResultType)
7253145427dSRiver Riddle       return failure();
7269826fe5cSAart Bik 
7275c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7285c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
729e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7305c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
731563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7323145427dSRiver Riddle       return success();
7335c0c51a9SNicolas Vasilache     }
7345c0c51a9SNicolas Vasilache 
7359826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
736563879b6SRahul Joshi     auto *context = extractOp->getContext();
737e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7385c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7395c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7409826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7415c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7425c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7435c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
744dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7455c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7465c0c51a9SNicolas Vasilache     }
7475c0c51a9SNicolas Vasilache 
7485c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7495c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
750*2230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
7511d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7525c0c51a9SNicolas Vasilache     extracted =
7535c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
754563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7555c0c51a9SNicolas Vasilache 
7563145427dSRiver Riddle     return success();
7575c0c51a9SNicolas Vasilache   }
7585c0c51a9SNicolas Vasilache };
7595c0c51a9SNicolas Vasilache 
760681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
761681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
762681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
763681f929fSNicolas Vasilache ///
764681f929fSNicolas Vasilache /// Example:
765681f929fSNicolas Vasilache /// ```
766681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
767681f929fSNicolas Vasilache /// ```
768681f929fSNicolas Vasilache /// is converted to:
769681f929fSNicolas Vasilache /// ```
7703bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
771681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
772681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
773681f929fSNicolas Vasilache /// ```
774563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
775681f929fSNicolas Vasilache public:
776563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
777681f929fSNicolas Vasilache 
7783145427dSRiver Riddle   LogicalResult
779563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
780681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7812d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
782681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
783681f929fSNicolas Vasilache     if (vType.getRank() != 1)
7843145427dSRiver Riddle       return failure();
785563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
7863bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
7873145427dSRiver Riddle     return success();
788681f929fSNicolas Vasilache   }
789681f929fSNicolas Vasilache };
790681f929fSNicolas Vasilache 
791563879b6SRahul Joshi class VectorInsertElementOpConversion
792563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
793cd5dab8aSAart Bik public:
794563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
795cd5dab8aSAart Bik 
7963145427dSRiver Riddle   LogicalResult
797563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
798cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7992d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
800cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
801dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
802cd5dab8aSAart Bik 
803cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
804cd5dab8aSAart Bik     if (!llvmType)
8053145427dSRiver Riddle       return failure();
806cd5dab8aSAart Bik 
807cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
808563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
809563879b6SRahul Joshi         adaptor.position());
8103145427dSRiver Riddle     return success();
811cd5dab8aSAart Bik   }
812cd5dab8aSAart Bik };
813cd5dab8aSAart Bik 
814563879b6SRahul Joshi class VectorInsertOpConversion
815563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8169826fe5cSAart Bik public:
817563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8189826fe5cSAart Bik 
8193145427dSRiver Riddle   LogicalResult
820563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8219826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
822563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8232d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8249826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8259826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
826dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8279826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8289826fe5cSAart Bik 
8299826fe5cSAart Bik     // Bail if result type cannot be lowered.
8309826fe5cSAart Bik     if (!llvmResultType)
8313145427dSRiver Riddle       return failure();
8329826fe5cSAart Bik 
8339826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8349826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
835e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8369826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8379826fe5cSAart Bik           positionArrayAttr);
838563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8393145427dSRiver Riddle       return success();
8409826fe5cSAart Bik     }
8419826fe5cSAart Bik 
8429826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
843563879b6SRahul Joshi     auto *context = insertOp->getContext();
844e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8459826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8469826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8479826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8489826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8499826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8509826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8519826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8529826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
853dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8549826fe5cSAart Bik           nMinusOnePositionAttrs);
8559826fe5cSAart Bik     }
8569826fe5cSAart Bik 
8579826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
858*2230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
8591d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
860e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
861dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8620f04384dSAlex Zinenko         adaptor.source(), constant);
8639826fe5cSAart Bik 
8649826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8659826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8669826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8679826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8689826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8699826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8709826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8719826fe5cSAart Bik     }
8729826fe5cSAart Bik 
873563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8743145427dSRiver Riddle     return success();
8759826fe5cSAart Bik   }
8769826fe5cSAart Bik };
8779826fe5cSAart Bik 
878681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
879681f929fSNicolas Vasilache ///
880681f929fSNicolas Vasilache /// Example:
881681f929fSNicolas Vasilache /// ```
882681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
883681f929fSNicolas Vasilache /// ```
884681f929fSNicolas Vasilache /// is rewritten into:
885681f929fSNicolas Vasilache /// ```
886681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
887681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
888681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
889681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
890681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
891681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
892681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
893681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
894681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
895681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
896681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
897681f929fSNicolas Vasilache ///  // %r3 holds the final value.
898681f929fSNicolas Vasilache /// ```
899681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
900681f929fSNicolas Vasilache public:
901681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
902681f929fSNicolas Vasilache 
9033145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
904681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
905681f929fSNicolas Vasilache     auto vType = op.getVectorType();
906681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9073145427dSRiver Riddle       return failure();
908681f929fSNicolas Vasilache 
909681f929fSNicolas Vasilache     auto loc = op.getLoc();
910681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
911681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
912681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
913681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
914681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
915681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
916681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
917681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
918681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
919681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
920681f929fSNicolas Vasilache     }
921681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9223145427dSRiver Riddle     return success();
923681f929fSNicolas Vasilache   }
924681f929fSNicolas Vasilache };
925681f929fSNicolas Vasilache 
9262d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9272d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9282d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9292d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9302d515e49SNicolas Vasilache // rank.
9312d515e49SNicolas Vasilache //
9322d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9332d515e49SNicolas Vasilache // have different ranks. In this case:
9342d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9352d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9362d515e49SNicolas Vasilache //   destination subvector
9372d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9382d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9392d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9402d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9412d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9422d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9432d515e49SNicolas Vasilache public:
9442d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9452d515e49SNicolas Vasilache 
9463145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9472d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9482d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9492d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9502d515e49SNicolas Vasilache 
9512d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9523145427dSRiver Riddle       return failure();
9532d515e49SNicolas Vasilache 
9542d515e49SNicolas Vasilache     auto loc = op.getLoc();
9552d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9562d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9572d515e49SNicolas Vasilache     if (rankDiff == 0)
9583145427dSRiver Riddle       return failure();
9592d515e49SNicolas Vasilache 
9602d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9612d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9622d515e49SNicolas Vasilache     // on it.
9632d515e49SNicolas Vasilache     Value extracted =
9642d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9652d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
966dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9672d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9682d515e49SNicolas Vasilache     // ranks.
9692d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9702d515e49SNicolas Vasilache         loc, op.source(), extracted,
9712d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
972c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9732d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9742d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9752d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
976dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
9773145427dSRiver Riddle     return success();
9782d515e49SNicolas Vasilache   }
9792d515e49SNicolas Vasilache };
9802d515e49SNicolas Vasilache 
9812d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9822d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
9832d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9842d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9852d515e49SNicolas Vasilache //   destination subvector
9862d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9872d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9882d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9892d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9902d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
9912d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9922d515e49SNicolas Vasilache public:
993b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
994b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
995b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
996b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
997b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
998b99bd771SRiver Riddle   }
9992d515e49SNicolas Vasilache 
10003145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10012d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10022d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10032d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10042d515e49SNicolas Vasilache 
10052d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10063145427dSRiver Riddle       return failure();
10072d515e49SNicolas Vasilache 
10082d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10092d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10102d515e49SNicolas Vasilache     if (rankDiff != 0)
10113145427dSRiver Riddle       return failure();
10122d515e49SNicolas Vasilache 
10132d515e49SNicolas Vasilache     if (srcType == dstType) {
10142d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10153145427dSRiver Riddle       return success();
10162d515e49SNicolas Vasilache     }
10172d515e49SNicolas Vasilache 
10182d515e49SNicolas Vasilache     int64_t offset =
10192d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10202d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10212d515e49SNicolas Vasilache     int64_t stride =
10222d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10232d515e49SNicolas Vasilache 
10242d515e49SNicolas Vasilache     auto loc = op.getLoc();
10252d515e49SNicolas Vasilache     Value res = op.dest();
10262d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10272d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10282d515e49SNicolas Vasilache          off += stride, ++idx) {
10292d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10302d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10312d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10322d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10332d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10342d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10352d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10362d515e49SNicolas Vasilache         // smaller rank.
1037bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10382d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10392d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10402d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10412d515e49SNicolas Vasilache       }
10422d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10432d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10442d515e49SNicolas Vasilache     }
10452d515e49SNicolas Vasilache 
10462d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10473145427dSRiver Riddle     return success();
10482d515e49SNicolas Vasilache   }
10492d515e49SNicolas Vasilache };
10502d515e49SNicolas Vasilache 
105130e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
105230e6033bSNicolas Vasilache /// static layout.
105330e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
105430e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10552bf491c7SBenjamin Kramer   int64_t offset;
105630e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
105730e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
105830e6033bSNicolas Vasilache     return None;
105930e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
106030e6033bSNicolas Vasilache     return None;
106130e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
106230e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
106330e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
106430e6033bSNicolas Vasilache     return strides;
106530e6033bSNicolas Vasilache 
106630e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
106730e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
106830e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
106930e6033bSNicolas Vasilache   // layout.
10702bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10712bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
107230e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
107330e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
107430e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
107530e6033bSNicolas Vasilache       return None;
107630e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
107730e6033bSNicolas Vasilache       return None;
10782bf491c7SBenjamin Kramer   }
107930e6033bSNicolas Vasilache   return strides;
10802bf491c7SBenjamin Kramer }
10812bf491c7SBenjamin Kramer 
1082563879b6SRahul Joshi class VectorTypeCastOpConversion
1083563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
10845c0c51a9SNicolas Vasilache public:
1085563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
10865c0c51a9SNicolas Vasilache 
10873145427dSRiver Riddle   LogicalResult
1088563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
10895c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1090563879b6SRahul Joshi     auto loc = castOp->getLoc();
10915c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
10922bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
10939eb3e564SChris Lattner     MemRefType targetMemRefType = castOp.getType();
10945c0c51a9SNicolas Vasilache 
10955c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
10965c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
10975c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
10983145427dSRiver Riddle       return failure();
10995c0c51a9SNicolas Vasilache 
11005c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11018de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
11028de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
11033145427dSRiver Riddle       return failure();
11045c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11055c0c51a9SNicolas Vasilache 
1106dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11078de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
11088de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
11093145427dSRiver Riddle       return failure();
11105c0c51a9SNicolas Vasilache 
111130e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
111230e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
111330e6033bSNicolas Vasilache     if (!sourceStrides)
111430e6033bSNicolas Vasilache       return failure();
111530e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
111630e6033bSNicolas Vasilache     if (!targetStrides)
111730e6033bSNicolas Vasilache       return failure();
111830e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
111930e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
112030e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
112130e6033bSNicolas Vasilache         }))
11223145427dSRiver Riddle       return failure();
11235c0c51a9SNicolas Vasilache 
1124*2230bf99SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
11255c0c51a9SNicolas Vasilache 
11265c0c51a9SNicolas Vasilache     // Create descriptor.
11275c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11283a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11295c0c51a9SNicolas Vasilache     // Set allocated ptr.
1130e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11315c0c51a9SNicolas Vasilache     allocated =
11325c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11335c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11345c0c51a9SNicolas Vasilache     // Set aligned ptr.
1135e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11365c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11375c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11385c0c51a9SNicolas Vasilache     // Fill offset 0.
11395c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11405c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11415c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11425c0c51a9SNicolas Vasilache 
11435c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11445c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11455c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11465c0c51a9SNicolas Vasilache       auto sizeAttr =
11475c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11485c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11495c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
115030e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
115130e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11525c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11535c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11545c0c51a9SNicolas Vasilache     }
11555c0c51a9SNicolas Vasilache 
1156563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11573145427dSRiver Riddle     return success();
11585c0c51a9SNicolas Vasilache   }
11595c0c51a9SNicolas Vasilache };
11605c0c51a9SNicolas Vasilache 
11618345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11628345b86dSNicolas Vasilache /// sequence of:
1163060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1164060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1165060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1166060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1167060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11688345b86dSNicolas Vasilache template <typename ConcreteOp>
1169563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11708345b86dSNicolas Vasilache public:
1171563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1172060c9dd1Saartbik                                     bool enableIndexOpt)
1173563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1174060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11758345b86dSNicolas Vasilache 
11768345b86dSNicolas Vasilache   LogicalResult
1177563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
11788345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11798345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1180b2c79c50SNicolas Vasilache 
1181b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1182b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
11838345b86dSNicolas Vasilache       return failure();
11845f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
11855f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
11865f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1187563879b6SRahul Joshi                                        xferOp->getContext()))
11888345b86dSNicolas Vasilache       return failure();
118926c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
119026c8f908SThomas Raoux     if (!memRefType)
119126c8f908SThomas Raoux       return failure();
11922bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
119326c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
119430e6033bSNicolas Vasilache     if (!strides)
11952bf491c7SBenjamin Kramer       return failure();
11968345b86dSNicolas Vasilache 
1197563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1198563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1199563879b6SRahul Joshi     };
12008345b86dSNicolas Vasilache 
1201563879b6SRahul Joshi     Location loc = xferOp->getLoc();
12028345b86dSNicolas Vasilache 
120368330ee0SThomas Raoux     if (auto memrefVectorElementType =
120426c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
120568330ee0SThomas Raoux       // Memref has vector element type.
120668330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
120768330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
120868330ee0SThomas Raoux         return failure();
12090de60b55SThomas Raoux #ifndef NDEBUG
121068330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
121168330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
121268330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
121368330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
121468330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
121568330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
121668330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
121768330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
121868330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
121968330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
122068330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
122168330ee0SThomas Raoux                "result shape.");
12220de60b55SThomas Raoux #endif // ifndef NDEBUG
122368330ee0SThomas Raoux     }
122468330ee0SThomas Raoux 
12258345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1226be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1227be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1228be16075bSWen-Heng (Jack) Chung     //    address space 0.
12298345b86dSNicolas Vasilache     // TODO: support alignment when possible.
1230563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
123126c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
12328de43b92SAlex Zinenko     auto vecTy = toLLVMTy(xferOp.getVectorType())
12338de43b92SAlex Zinenko                      .template cast<LLVM::LLVMFixedVectorType>();
1234be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1235be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
12368de43b92SAlex Zinenko       vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
12378de43b92SAlex Zinenko           loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
1238be16075bSWen-Heng (Jack) Chung     else
1239be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
12408de43b92SAlex Zinenko           loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
12418345b86dSNicolas Vasilache 
12421870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1243563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1244563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1245563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12461870e787SNicolas Vasilache 
12478345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12488345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12498345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12508345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1251060c9dd1Saartbik     //
1252060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1253060c9dd1Saartbik     //       dimensions here.
12548de43b92SAlex Zinenko     unsigned vecWidth = vecTy.getNumElements();
1255060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12560c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
125726c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1258563879b6SRahul Joshi     Value mask = buildVectorComparison(
1259563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12608345b86dSNicolas Vasilache 
12618345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1262563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1263dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12648345b86dSNicolas Vasilache   }
1265060c9dd1Saartbik 
1266060c9dd1Saartbik private:
1267060c9dd1Saartbik   const bool enableIndexOptimizations;
12688345b86dSNicolas Vasilache };
12698345b86dSNicolas Vasilache 
1270563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1271d9b500d3SAart Bik public:
1272563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1273d9b500d3SAart Bik 
1274d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1275d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1276d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1277d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1278d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1279d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1280d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1281d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1282d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1283d9b500d3SAart Bik   //
12849db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1285d9b500d3SAart Bik   //
12863145427dSRiver Riddle   LogicalResult
1287563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1288d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
12892d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1290d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1291d9b500d3SAart Bik 
1292dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
12933145427dSRiver Riddle       return failure();
1294d9b500d3SAart Bik 
1295b8880f5fSAart Bik     // Make sure element type has runtime support.
1296b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1297d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1298d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1299d9b500d3SAart Bik     Operation *printer;
1300b8880f5fSAart Bik     if (eltType.isF32()) {
1301563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1302b8880f5fSAart Bik     } else if (eltType.isF64()) {
1303563879b6SRahul Joshi       printer = getPrintDouble(printOp);
130454759cefSAart Bik     } else if (eltType.isIndex()) {
1305563879b6SRahul Joshi       printer = getPrintU64(printOp);
1306b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1307b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1308b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1309b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1310b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1311b8880f5fSAart Bik       if (intTy.isUnsigned()) {
131254759cefSAart Bik         if (width <= 64) {
1313b8880f5fSAart Bik           if (width < 64)
1314b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1315563879b6SRahul Joshi           printer = getPrintU64(printOp);
1316b8880f5fSAart Bik         } else {
13173145427dSRiver Riddle           return failure();
1318b8880f5fSAart Bik         }
1319b8880f5fSAart Bik       } else {
1320b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
132154759cefSAart Bik         if (width <= 64) {
1322b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1323b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1324b8880f5fSAart Bik           if (width == 1)
132554759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
132654759cefSAart Bik           else if (width < 64)
1327b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1328563879b6SRahul Joshi           printer = getPrintI64(printOp);
1329b8880f5fSAart Bik         } else {
1330b8880f5fSAart Bik           return failure();
1331b8880f5fSAart Bik         }
1332b8880f5fSAart Bik       }
1333b8880f5fSAart Bik     } else {
1334b8880f5fSAart Bik       return failure();
1335b8880f5fSAart Bik     }
1336d9b500d3SAart Bik 
1337d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1338b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1339563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1340b8880f5fSAart Bik               conversion);
1341563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1342563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13433145427dSRiver Riddle     return success();
1344d9b500d3SAart Bik   }
1345d9b500d3SAart Bik 
1346d9b500d3SAart Bik private:
1347b8880f5fSAart Bik   enum class PrintConversion {
134830e6033bSNicolas Vasilache     // clang-format off
1349b8880f5fSAart Bik     None,
1350b8880f5fSAart Bik     ZeroExt64,
1351b8880f5fSAart Bik     SignExt64
135230e6033bSNicolas Vasilache     // clang-format on
1353b8880f5fSAart Bik   };
1354b8880f5fSAart Bik 
1355d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1356e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1357b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1358d9b500d3SAart Bik     Location loc = op->getLoc();
1359d9b500d3SAart Bik     if (rank == 0) {
1360b8880f5fSAart Bik       switch (conversion) {
1361b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1362b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1363*2230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1364b8880f5fSAart Bik         break;
1365b8880f5fSAart Bik       case PrintConversion::SignExt64:
1366b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1367*2230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1368b8880f5fSAart Bik         break;
1369b8880f5fSAart Bik       case PrintConversion::None:
1370b8880f5fSAart Bik         break;
1371c9eeeb38Saartbik       }
1372d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1373d9b500d3SAart Bik       return;
1374d9b500d3SAart Bik     }
1375d9b500d3SAart Bik 
1376d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1377d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1378d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1379d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1380d9b500d3SAart Bik       auto reducedType =
1381d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1382dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1383d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1384dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1385dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1386b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1387b8880f5fSAart Bik                 conversion);
1388d9b500d3SAart Bik       if (d != dim - 1)
1389d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1390d9b500d3SAart Bik     }
1391d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1392d9b500d3SAart Bik   }
1393d9b500d3SAart Bik 
1394d9b500d3SAart Bik   // Helper to emit a call.
1395d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1396d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
139708e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1398d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1399d9b500d3SAart Bik   }
1400d9b500d3SAart Bik 
1401d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14025446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
1403c69c9e0fSAlex Zinenko                              ArrayRef<Type> params) {
1404d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1405d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1406d9b500d3SAart Bik     if (func)
1407d9b500d3SAart Bik       return func;
1408d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1409d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1410d9b500d3SAart Bik         op->getLoc(), name,
14117ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
14127ed9cfc7SAlex Zinenko                                     params));
1413d9b500d3SAart Bik   }
1414d9b500d3SAart Bik 
1415d9b500d3SAart Bik   // Helpers for method names.
1416e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
1417*2230bf99SAlex Zinenko     return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1418e52414b1Saartbik   }
1419b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1420*2230bf99SAlex Zinenko     return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1421b8880f5fSAart Bik   }
1422d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
14237ed9cfc7SAlex Zinenko     return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext()));
1424d9b500d3SAart Bik   }
1425d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
142654759cefSAart Bik     return getPrint(op, "printF64",
14277ed9cfc7SAlex Zinenko                     LLVM::LLVMDoubleType::get(op->getContext()));
1428d9b500d3SAart Bik   }
1429d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
143054759cefSAart Bik     return getPrint(op, "printOpen", {});
1431d9b500d3SAart Bik   }
1432d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
143354759cefSAart Bik     return getPrint(op, "printClose", {});
1434d9b500d3SAart Bik   }
1435d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
143654759cefSAart Bik     return getPrint(op, "printComma", {});
1437d9b500d3SAart Bik   }
1438d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
143954759cefSAart Bik     return getPrint(op, "printNewline", {});
1440d9b500d3SAart Bik   }
1441d9b500d3SAart Bik };
1442d9b500d3SAart Bik 
1443334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1444c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1445c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1446c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1447334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
144865678d93SNicolas Vasilache public:
1449b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1450b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1451b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1452b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1453b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1454b99bd771SRiver Riddle   }
145565678d93SNicolas Vasilache 
1456334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
145765678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
14589eb3e564SChris Lattner     auto dstType = op.getType();
145965678d93SNicolas Vasilache 
146065678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
146165678d93SNicolas Vasilache 
146265678d93SNicolas Vasilache     int64_t offset =
146365678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
146465678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
146565678d93SNicolas Vasilache     int64_t stride =
146665678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
146765678d93SNicolas Vasilache 
146865678d93SNicolas Vasilache     auto loc = op.getLoc();
146965678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
147035b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1471c3c95b9cSaartbik 
1472c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1473c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1474c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1475c3c95b9cSaartbik       offsets.reserve(size);
1476c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1477c3c95b9cSaartbik            off += stride)
1478c3c95b9cSaartbik         offsets.push_back(off);
1479c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1480c3c95b9cSaartbik                                              op.vector(),
1481c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1482c3c95b9cSaartbik       return success();
1483c3c95b9cSaartbik     }
1484c3c95b9cSaartbik 
1485c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
148665678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
148765678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
148865678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
148965678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
149065678d93SNicolas Vasilache          off += stride, ++idx) {
1491c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1492c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1493c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
149465678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
149565678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
149665678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
149765678d93SNicolas Vasilache     }
1498c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14993145427dSRiver Riddle     return success();
150065678d93SNicolas Vasilache   }
150165678d93SNicolas Vasilache };
150265678d93SNicolas Vasilache 
1503df186507SBenjamin Kramer } // namespace
1504df186507SBenjamin Kramer 
15055c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15065c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1507ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1508060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
150965678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15108345b86dSNicolas Vasilache   // clang-format off
1511681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1512681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15132d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1514c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1515ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1516563879b6SRahul Joshi       converter, reassociateFPReductions);
1517060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1518060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1519060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1520563879b6SRahul Joshi       converter, enableIndexOptimizations);
15218345b86dSNicolas Vasilache   patterns
1522ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15238345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15248345b86dSNicolas Vasilache               VectorExtractOpConversion,
15258345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15268345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15278345b86dSNicolas Vasilache               VectorInsertOpConversion,
15288345b86dSNicolas Vasilache               VectorPrintOpConversion,
152919dbb230Saartbik               VectorTypeCastOpConversion,
153039379916Saartbik               VectorMaskedLoadOpConversion,
153139379916Saartbik               VectorMaskedStoreOpConversion,
153219dbb230Saartbik               VectorGatherOpConversion,
1533e8dcf5f8Saartbik               VectorScatterOpConversion,
1534e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1535563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15368345b86dSNicolas Vasilache   // clang-format on
15375c0c51a9SNicolas Vasilache }
15385c0c51a9SNicolas Vasilache 
153963b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
154063b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1541563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1542563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
154363b683a8SNicolas Vasilache }
1544