15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
111834ad4aSRiver Riddle #include "../PassDetail.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
135c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
178345b86dSNicolas Vasilache #include "mlir/IR/AffineMap.h"
185c0c51a9SNicolas Vasilache #include "mlir/IR/Attributes.h"
195c0c51a9SNicolas Vasilache #include "mlir/IR/Builders.h"
205c0c51a9SNicolas Vasilache #include "mlir/IR/MLIRContext.h"
215c0c51a9SNicolas Vasilache #include "mlir/IR/Module.h"
225c0c51a9SNicolas Vasilache #include "mlir/IR/Operation.h"
235c0c51a9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
245c0c51a9SNicolas Vasilache #include "mlir/IR/StandardTypes.h"
255c0c51a9SNicolas Vasilache #include "mlir/IR/Types.h"
26ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
275c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
285c0c51a9SNicolas Vasilache #include "mlir/Transforms/Passes.h"
295c0c51a9SNicolas Vasilache #include "llvm/IR/DerivedTypes.h"
305c0c51a9SNicolas Vasilache #include "llvm/IR/Module.h"
315c0c51a9SNicolas Vasilache #include "llvm/IR/Type.h"
325c0c51a9SNicolas Vasilache #include "llvm/Support/Allocator.h"
335c0c51a9SNicolas Vasilache #include "llvm/Support/ErrorHandling.h"
345c0c51a9SNicolas Vasilache 
355c0c51a9SNicolas Vasilache using namespace mlir;
3665678d93SNicolas Vasilache using namespace mlir::vector;
375c0c51a9SNicolas Vasilache 
389826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
399826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
409826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
419826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
429826fe5cSAart Bik }
439826fe5cSAart Bik 
449826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
459826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
469826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
479826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
489826fe5cSAart Bik }
499826fe5cSAart Bik 
501c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
51e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
520f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
530f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
540f04384dSAlex Zinenko                        int64_t pos) {
551c81adf3SAart Bik   if (rank == 1) {
561c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
571c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
580f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
591c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
601c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
611c81adf3SAart Bik                                                   constant);
621c81adf3SAart Bik   }
631c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
641c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
651c81adf3SAart Bik }
661c81adf3SAart Bik 
672d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
682d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
692d515e49SNicolas Vasilache                        Value into, int64_t offset) {
702d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
712d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
722d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
732d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
742d515e49SNicolas Vasilache       loc, vectorType, from, into,
752d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
762d515e49SNicolas Vasilache }
772d515e49SNicolas Vasilache 
781c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
79e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
800f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
810f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
821c81adf3SAart Bik   if (rank == 1) {
831c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
841c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
850f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
861c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
871c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
881c81adf3SAart Bik                                                    constant);
891c81adf3SAart Bik   }
901c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
911c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
921c81adf3SAart Bik }
931c81adf3SAart Bik 
942d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
952d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
962d515e49SNicolas Vasilache                         int64_t offset) {
972d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
982d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
992d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
1002d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
1012d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
1022d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
1052d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
1069db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
1072d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
1082d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
1092d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
1102d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
1112d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
1122d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
1132d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
1142d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1152d515e49SNicolas Vasilache        it != eit; ++it)
1162d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1172d515e49SNicolas Vasilache   return res;
1182d515e49SNicolas Vasilache }
1192d515e49SNicolas Vasilache 
120060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
121060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
122060c9dd1Saartbik //
123060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
124060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
125060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
126060c9dd1Saartbik //       is very conservative on the boundary conditions.
127060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
128060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
129060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
130060c9dd1Saartbik   auto loc = op->getLoc();
131060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
132060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
133060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
134060c9dd1Saartbik   Value indices;
135060c9dd1Saartbik   Type idxType;
136060c9dd1Saartbik   if (enableIndexOptimizations) {
1370c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1380c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1390c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
140060c9dd1Saartbik     idxType = rewriter.getI32Type();
141060c9dd1Saartbik   } else {
1420c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1430c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1440c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
145060c9dd1Saartbik     idxType = rewriter.getI64Type();
146060c9dd1Saartbik   }
147060c9dd1Saartbik   // Add in an offset if requested.
148060c9dd1Saartbik   if (off) {
149060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
150060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
151060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
152060c9dd1Saartbik   }
153060c9dd1Saartbik   // Construct the vector comparison.
154060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
155060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
156060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
157060c9dd1Saartbik }
158060c9dd1Saartbik 
15919dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
16019dbb230Saartbik template <typename T>
16119dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
16219dbb230Saartbik                                  unsigned &align) {
1635f9e0466SNicolas Vasilache   Type elementTy =
16419dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1655f9e0466SNicolas Vasilache   if (!elementTy)
1665f9e0466SNicolas Vasilache     return failure();
1675f9e0466SNicolas Vasilache 
168b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
169b2ab375dSAlex Zinenko   // stop depending on translation.
17087a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
17187a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
172b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
173168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1745f9e0466SNicolas Vasilache   return success();
1755f9e0466SNicolas Vasilache }
1765f9e0466SNicolas Vasilache 
177e8dcf5f8Saartbik // Helper that returns the base address of a memref.
178b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
179e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
18019dbb230Saartbik   // Inspect stride and offset structure.
18119dbb230Saartbik   //
18219dbb230Saartbik   // TODO: flat memory only for now, generalize
18319dbb230Saartbik   //
18419dbb230Saartbik   int64_t offset;
18519dbb230Saartbik   SmallVector<int64_t, 4> strides;
18619dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
18719dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
18819dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
18919dbb230Saartbik     return failure();
190e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
191e8dcf5f8Saartbik   return success();
192e8dcf5f8Saartbik }
19319dbb230Saartbik 
194e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
195b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
196b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
197b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
198e8dcf5f8Saartbik   Value base;
199e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
200e8dcf5f8Saartbik     return failure();
2013a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
202e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
203e8dcf5f8Saartbik   return success();
204e8dcf5f8Saartbik }
205e8dcf5f8Saartbik 
20639379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
207b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
208b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
209b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
21039379916Saartbik   Value base;
21139379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
21239379916Saartbik     return failure();
21339379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
21439379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
21539379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
21639379916Saartbik   return success();
21739379916Saartbik }
21839379916Saartbik 
219e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
220e8dcf5f8Saartbik // vector.
221b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
222b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
223b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
224b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
225e8dcf5f8Saartbik   Value base;
226e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
227e8dcf5f8Saartbik     return failure();
2283a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
229e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2301485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
23119dbb230Saartbik   return success();
23219dbb230Saartbik }
23319dbb230Saartbik 
2345f9e0466SNicolas Vasilache static LogicalResult
2355f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2365f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2375f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2385f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
239affbc0cdSNicolas Vasilache   unsigned align;
24019dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
241affbc0cdSNicolas Vasilache     return failure();
242affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2435f9e0466SNicolas Vasilache   return success();
2445f9e0466SNicolas Vasilache }
2455f9e0466SNicolas Vasilache 
2465f9e0466SNicolas Vasilache static LogicalResult
2475f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2485f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2495f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2505f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2515f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2525f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2535f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2545f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2555f9e0466SNicolas Vasilache 
2565f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2575f9e0466SNicolas Vasilache   if (!vecTy)
2585f9e0466SNicolas Vasilache     return failure();
2595f9e0466SNicolas Vasilache 
2605f9e0466SNicolas Vasilache   unsigned align;
26119dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2625f9e0466SNicolas Vasilache     return failure();
2635f9e0466SNicolas Vasilache 
2645f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2655f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2665f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2675f9e0466SNicolas Vasilache   return success();
2685f9e0466SNicolas Vasilache }
2695f9e0466SNicolas Vasilache 
2705f9e0466SNicolas Vasilache static LogicalResult
2715f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2725f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2735f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2745f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
275affbc0cdSNicolas Vasilache   unsigned align;
27619dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
277affbc0cdSNicolas Vasilache     return failure();
2782d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
279affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
280affbc0cdSNicolas Vasilache                                              align);
2815f9e0466SNicolas Vasilache   return success();
2825f9e0466SNicolas Vasilache }
2835f9e0466SNicolas Vasilache 
2845f9e0466SNicolas Vasilache static LogicalResult
2855f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2865f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2875f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2885f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2895f9e0466SNicolas Vasilache   unsigned align;
29019dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2915f9e0466SNicolas Vasilache     return failure();
2925f9e0466SNicolas Vasilache 
2932d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2945f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2955f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2965f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2975f9e0466SNicolas Vasilache   return success();
2985f9e0466SNicolas Vasilache }
2995f9e0466SNicolas Vasilache 
3002d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
3012d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
3022d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
3035f9e0466SNicolas Vasilache }
3045f9e0466SNicolas Vasilache 
3052d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
3062d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
3072d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
3085f9e0466SNicolas Vasilache }
3095f9e0466SNicolas Vasilache 
31090c01357SBenjamin Kramer namespace {
311e83b7b99Saartbik 
31263b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
31363b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
31463b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern {
31563b683a8SNicolas Vasilache public:
31663b683a8SNicolas Vasilache   explicit VectorMatmulOpConversion(MLIRContext *context,
31763b683a8SNicolas Vasilache                                     LLVMTypeConverter &typeConverter)
31863b683a8SNicolas Vasilache       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
31963b683a8SNicolas Vasilache                              typeConverter) {}
32063b683a8SNicolas Vasilache 
3213145427dSRiver Riddle   LogicalResult
32263b683a8SNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
32363b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
32463b683a8SNicolas Vasilache     auto matmulOp = cast<vector::MatmulOp>(op);
3252d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
32663b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
32763b683a8SNicolas Vasilache         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
32863b683a8SNicolas Vasilache         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
32963b683a8SNicolas Vasilache         matmulOp.rhs_columns());
3303145427dSRiver Riddle     return success();
33163b683a8SNicolas Vasilache   }
33263b683a8SNicolas Vasilache };
33363b683a8SNicolas Vasilache 
334c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
335c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
336c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
337c295a65dSaartbik public:
338c295a65dSaartbik   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
339c295a65dSaartbik                                            LLVMTypeConverter &typeConverter)
340c295a65dSaartbik       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
341c295a65dSaartbik                              context, typeConverter) {}
342c295a65dSaartbik 
343c295a65dSaartbik   LogicalResult
344c295a65dSaartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
345c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
346c295a65dSaartbik     auto transOp = cast<vector::FlatTransposeOp>(op);
3472d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
348c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
349c295a65dSaartbik         transOp, typeConverter.convertType(transOp.res().getType()),
350c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
351c295a65dSaartbik     return success();
352c295a65dSaartbik   }
353c295a65dSaartbik };
354c295a65dSaartbik 
35539379916Saartbik /// Conversion pattern for a vector.maskedload.
35639379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
35739379916Saartbik public:
35839379916Saartbik   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
35939379916Saartbik                                         LLVMTypeConverter &typeConverter)
36039379916Saartbik       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
36139379916Saartbik                              typeConverter) {}
36239379916Saartbik 
36339379916Saartbik   LogicalResult
36439379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
36539379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
36639379916Saartbik     auto loc = op->getLoc();
36739379916Saartbik     auto load = cast<vector::MaskedLoadOp>(op);
36839379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
36939379916Saartbik 
37039379916Saartbik     // Resolve alignment.
37139379916Saartbik     unsigned align;
37239379916Saartbik     if (failed(getMemRefAlignment(typeConverter, load, align)))
37339379916Saartbik       return failure();
37439379916Saartbik 
37539379916Saartbik     auto vtype = typeConverter.convertType(load.getResultVectorType());
37639379916Saartbik     Value ptr;
37739379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
37839379916Saartbik                           vtype, ptr)))
37939379916Saartbik       return failure();
38039379916Saartbik 
38139379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
38239379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
38339379916Saartbik         rewriter.getI32IntegerAttr(align));
38439379916Saartbik     return success();
38539379916Saartbik   }
38639379916Saartbik };
38739379916Saartbik 
38839379916Saartbik /// Conversion pattern for a vector.maskedstore.
38939379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
39039379916Saartbik public:
39139379916Saartbik   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
39239379916Saartbik                                          LLVMTypeConverter &typeConverter)
39339379916Saartbik       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
39439379916Saartbik                              typeConverter) {}
39539379916Saartbik 
39639379916Saartbik   LogicalResult
39739379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
39839379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
39939379916Saartbik     auto loc = op->getLoc();
40039379916Saartbik     auto store = cast<vector::MaskedStoreOp>(op);
40139379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
40239379916Saartbik 
40339379916Saartbik     // Resolve alignment.
40439379916Saartbik     unsigned align;
40539379916Saartbik     if (failed(getMemRefAlignment(typeConverter, store, align)))
40639379916Saartbik       return failure();
40739379916Saartbik 
40839379916Saartbik     auto vtype = typeConverter.convertType(store.getValueVectorType());
40939379916Saartbik     Value ptr;
41039379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
41139379916Saartbik                           vtype, ptr)))
41239379916Saartbik       return failure();
41339379916Saartbik 
41439379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
41539379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
41639379916Saartbik         rewriter.getI32IntegerAttr(align));
41739379916Saartbik     return success();
41839379916Saartbik   }
41939379916Saartbik };
42039379916Saartbik 
42119dbb230Saartbik /// Conversion pattern for a vector.gather.
42219dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern {
42319dbb230Saartbik public:
42419dbb230Saartbik   explicit VectorGatherOpConversion(MLIRContext *context,
42519dbb230Saartbik                                     LLVMTypeConverter &typeConverter)
42619dbb230Saartbik       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
42719dbb230Saartbik                              typeConverter) {}
42819dbb230Saartbik 
42919dbb230Saartbik   LogicalResult
43019dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
43119dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
43219dbb230Saartbik     auto loc = op->getLoc();
43319dbb230Saartbik     auto gather = cast<vector::GatherOp>(op);
43419dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
43519dbb230Saartbik 
43619dbb230Saartbik     // Resolve alignment.
43719dbb230Saartbik     unsigned align;
43819dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, gather, align)))
43919dbb230Saartbik       return failure();
44019dbb230Saartbik 
44119dbb230Saartbik     // Get index ptrs.
44219dbb230Saartbik     VectorType vType = gather.getResultVectorType();
44319dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
44419dbb230Saartbik     Value ptrs;
445e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
446e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
44719dbb230Saartbik       return failure();
44819dbb230Saartbik 
44919dbb230Saartbik     // Replace with the gather intrinsic.
45019dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
4510c2a4d3cSBenjamin Kramer         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
4520c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
45319dbb230Saartbik     return success();
45419dbb230Saartbik   }
45519dbb230Saartbik };
45619dbb230Saartbik 
45719dbb230Saartbik /// Conversion pattern for a vector.scatter.
45819dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern {
45919dbb230Saartbik public:
46019dbb230Saartbik   explicit VectorScatterOpConversion(MLIRContext *context,
46119dbb230Saartbik                                      LLVMTypeConverter &typeConverter)
46219dbb230Saartbik       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
46319dbb230Saartbik                              typeConverter) {}
46419dbb230Saartbik 
46519dbb230Saartbik   LogicalResult
46619dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
46719dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
46819dbb230Saartbik     auto loc = op->getLoc();
46919dbb230Saartbik     auto scatter = cast<vector::ScatterOp>(op);
47019dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
47119dbb230Saartbik 
47219dbb230Saartbik     // Resolve alignment.
47319dbb230Saartbik     unsigned align;
47419dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
47519dbb230Saartbik       return failure();
47619dbb230Saartbik 
47719dbb230Saartbik     // Get index ptrs.
47819dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
47919dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
48019dbb230Saartbik     Value ptrs;
481e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
482e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
48319dbb230Saartbik       return failure();
48419dbb230Saartbik 
48519dbb230Saartbik     // Replace with the scatter intrinsic.
48619dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
48719dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
48819dbb230Saartbik         rewriter.getI32IntegerAttr(align));
48919dbb230Saartbik     return success();
49019dbb230Saartbik   }
49119dbb230Saartbik };
49219dbb230Saartbik 
493e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
494e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
495e8dcf5f8Saartbik public:
496e8dcf5f8Saartbik   explicit VectorExpandLoadOpConversion(MLIRContext *context,
497e8dcf5f8Saartbik                                         LLVMTypeConverter &typeConverter)
498e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
499e8dcf5f8Saartbik                              typeConverter) {}
500e8dcf5f8Saartbik 
501e8dcf5f8Saartbik   LogicalResult
502e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
503e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
504e8dcf5f8Saartbik     auto loc = op->getLoc();
505e8dcf5f8Saartbik     auto expand = cast<vector::ExpandLoadOp>(op);
506e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
507e8dcf5f8Saartbik 
508e8dcf5f8Saartbik     Value ptr;
509e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
510e8dcf5f8Saartbik                           ptr)))
511e8dcf5f8Saartbik       return failure();
512e8dcf5f8Saartbik 
513e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
514e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
515e8dcf5f8Saartbik         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
516e8dcf5f8Saartbik         adaptor.pass_thru());
517e8dcf5f8Saartbik     return success();
518e8dcf5f8Saartbik   }
519e8dcf5f8Saartbik };
520e8dcf5f8Saartbik 
521e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
522e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
523e8dcf5f8Saartbik public:
524e8dcf5f8Saartbik   explicit VectorCompressStoreOpConversion(MLIRContext *context,
525e8dcf5f8Saartbik                                            LLVMTypeConverter &typeConverter)
526e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
527e8dcf5f8Saartbik                              context, typeConverter) {}
528e8dcf5f8Saartbik 
529e8dcf5f8Saartbik   LogicalResult
530e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
531e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
532e8dcf5f8Saartbik     auto loc = op->getLoc();
533e8dcf5f8Saartbik     auto compress = cast<vector::CompressStoreOp>(op);
534e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
535e8dcf5f8Saartbik 
536e8dcf5f8Saartbik     Value ptr;
537e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
538e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
539e8dcf5f8Saartbik       return failure();
540e8dcf5f8Saartbik 
541e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
542e8dcf5f8Saartbik         op, adaptor.value(), ptr, adaptor.mask());
543e8dcf5f8Saartbik     return success();
544e8dcf5f8Saartbik   }
545e8dcf5f8Saartbik };
546e8dcf5f8Saartbik 
54719dbb230Saartbik /// Conversion pattern for all vector reductions.
548870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern {
549e83b7b99Saartbik public:
550e83b7b99Saartbik   explicit VectorReductionOpConversion(MLIRContext *context,
551ceb1b327Saartbik                                        LLVMTypeConverter &typeConverter,
552060c9dd1Saartbik                                        bool reassociateFPRed)
553870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
554ceb1b327Saartbik                              typeConverter),
555060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
556e83b7b99Saartbik 
5573145427dSRiver Riddle   LogicalResult
558e83b7b99Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
559e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
560e83b7b99Saartbik     auto reductionOp = cast<vector::ReductionOp>(op);
561e83b7b99Saartbik     auto kind = reductionOp.kind();
562e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
5630f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(eltType);
564e9628955SAart Bik     if (eltType.isIntOrIndex()) {
565e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
566e83b7b99Saartbik       if (kind == "add")
567322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
568e83b7b99Saartbik             op, llvmType, operands[0]);
569e83b7b99Saartbik       else if (kind == "mul")
570322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
571e83b7b99Saartbik             op, llvmType, operands[0]);
572e9628955SAart Bik       else if (kind == "min" &&
573e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
574322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
575e9628955SAart Bik             op, llvmType, operands[0]);
576e83b7b99Saartbik       else if (kind == "min")
577322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
578e83b7b99Saartbik             op, llvmType, operands[0]);
579e9628955SAart Bik       else if (kind == "max" &&
580e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
581322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
582e9628955SAart Bik             op, llvmType, operands[0]);
583e83b7b99Saartbik       else if (kind == "max")
584322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
585e83b7b99Saartbik             op, llvmType, operands[0]);
586e83b7b99Saartbik       else if (kind == "and")
587322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
588e83b7b99Saartbik             op, llvmType, operands[0]);
589e83b7b99Saartbik       else if (kind == "or")
590322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
591e83b7b99Saartbik             op, llvmType, operands[0]);
592e83b7b99Saartbik       else if (kind == "xor")
593322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
594e83b7b99Saartbik             op, llvmType, operands[0]);
595e83b7b99Saartbik       else
5963145427dSRiver Riddle         return failure();
5973145427dSRiver Riddle       return success();
598e83b7b99Saartbik 
5992d76274bSBenjamin Kramer     } else if (eltType.isa<FloatType>()) {
600e83b7b99Saartbik       // Floating-point reductions: add/mul/min/max
601e83b7b99Saartbik       if (kind == "add") {
6020d924700Saartbik         // Optional accumulator (or zero).
6030d924700Saartbik         Value acc = operands.size() > 1 ? operands[1]
6040d924700Saartbik                                         : rewriter.create<LLVM::ConstantOp>(
6050d924700Saartbik                                               op->getLoc(), llvmType,
6060d924700Saartbik                                               rewriter.getZeroAttr(eltType));
607322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
608ceb1b327Saartbik             op, llvmType, acc, operands[0],
609ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
610e83b7b99Saartbik       } else if (kind == "mul") {
6110d924700Saartbik         // Optional accumulator (or one).
6120d924700Saartbik         Value acc = operands.size() > 1
6130d924700Saartbik                         ? operands[1]
6140d924700Saartbik                         : rewriter.create<LLVM::ConstantOp>(
6150d924700Saartbik                               op->getLoc(), llvmType,
6160d924700Saartbik                               rewriter.getFloatAttr(eltType, 1.0));
617322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
618ceb1b327Saartbik             op, llvmType, acc, operands[0],
619ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
620e83b7b99Saartbik       } else if (kind == "min")
621322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
622e83b7b99Saartbik             op, llvmType, operands[0]);
623e83b7b99Saartbik       else if (kind == "max")
624322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
625e83b7b99Saartbik             op, llvmType, operands[0]);
626e83b7b99Saartbik       else
6273145427dSRiver Riddle         return failure();
6283145427dSRiver Riddle       return success();
629e83b7b99Saartbik     }
6303145427dSRiver Riddle     return failure();
631e83b7b99Saartbik   }
632ceb1b327Saartbik 
633ceb1b327Saartbik private:
634ceb1b327Saartbik   const bool reassociateFPReductions;
635e83b7b99Saartbik };
636e83b7b99Saartbik 
637060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
638060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
639060c9dd1Saartbik public:
640060c9dd1Saartbik   explicit VectorCreateMaskOpConversion(MLIRContext *context,
641060c9dd1Saartbik                                         LLVMTypeConverter &typeConverter,
642060c9dd1Saartbik                                         bool enableIndexOpt)
643060c9dd1Saartbik       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
644060c9dd1Saartbik                              typeConverter),
645060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
646060c9dd1Saartbik 
647060c9dd1Saartbik   LogicalResult
648060c9dd1Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
649060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
650060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
651060c9dd1Saartbik     int64_t rank = dstType.getRank();
652060c9dd1Saartbik     if (rank == 1) {
653060c9dd1Saartbik       rewriter.replaceOp(
654060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
655060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
656060c9dd1Saartbik       return success();
657060c9dd1Saartbik     }
658060c9dd1Saartbik     return failure();
659060c9dd1Saartbik   }
660060c9dd1Saartbik 
661060c9dd1Saartbik private:
662060c9dd1Saartbik   const bool enableIndexOptimizations;
663060c9dd1Saartbik };
664060c9dd1Saartbik 
665870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern {
6661c81adf3SAart Bik public:
6671c81adf3SAart Bik   explicit VectorShuffleOpConversion(MLIRContext *context,
6681c81adf3SAart Bik                                      LLVMTypeConverter &typeConverter)
669870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
6701c81adf3SAart Bik                              typeConverter) {}
6711c81adf3SAart Bik 
6723145427dSRiver Riddle   LogicalResult
673e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6741c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6751c81adf3SAart Bik     auto loc = op->getLoc();
6762d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6771c81adf3SAart Bik     auto shuffleOp = cast<vector::ShuffleOp>(op);
6781c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6791c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6801c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
6810f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(vectorType);
6821c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6831c81adf3SAart Bik 
6841c81adf3SAart Bik     // Bail if result type cannot be lowered.
6851c81adf3SAart Bik     if (!llvmType)
6863145427dSRiver Riddle       return failure();
6871c81adf3SAart Bik 
6881c81adf3SAart Bik     // Get rank and dimension sizes.
6891c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6901c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6911c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6921c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6931c81adf3SAart Bik 
6941c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6951c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6961c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
697e62a6956SRiver Riddle       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
6981c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
6991c81adf3SAart Bik       rewriter.replaceOp(op, shuffle);
7003145427dSRiver Riddle       return success();
701b36aaeafSAart Bik     }
702b36aaeafSAart Bik 
7031c81adf3SAart Bik     // For all other cases, insert the individual values individually.
704e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
7051c81adf3SAart Bik     int64_t insPos = 0;
7061c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
7071c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
708e62a6956SRiver Riddle       Value value = adaptor.v1();
7091c81adf3SAart Bik       if (extPos >= v1Dim) {
7101c81adf3SAart Bik         extPos -= v1Dim;
7111c81adf3SAart Bik         value = adaptor.v2();
712b36aaeafSAart Bik       }
7130f04384dSAlex Zinenko       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
7140f04384dSAlex Zinenko                                  rank, extPos);
7150f04384dSAlex Zinenko       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
7160f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7171c81adf3SAart Bik     }
7181c81adf3SAart Bik     rewriter.replaceOp(op, insert);
7193145427dSRiver Riddle     return success();
720b36aaeafSAart Bik   }
721b36aaeafSAart Bik };
722b36aaeafSAart Bik 
723870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
724cd5dab8aSAart Bik public:
725cd5dab8aSAart Bik   explicit VectorExtractElementOpConversion(MLIRContext *context,
726cd5dab8aSAart Bik                                             LLVMTypeConverter &typeConverter)
727870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
728870c1fd4SAlex Zinenko                              context, typeConverter) {}
729cd5dab8aSAart Bik 
7303145427dSRiver Riddle   LogicalResult
731e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
732cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7332d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
734cd5dab8aSAart Bik     auto extractEltOp = cast<vector::ExtractElementOp>(op);
735cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
7360f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType.getElementType());
737cd5dab8aSAart Bik 
738cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
739cd5dab8aSAart Bik     if (!llvmType)
7403145427dSRiver Riddle       return failure();
741cd5dab8aSAart Bik 
742cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
743cd5dab8aSAart Bik         op, llvmType, adaptor.vector(), adaptor.position());
7443145427dSRiver Riddle     return success();
745cd5dab8aSAart Bik   }
746cd5dab8aSAart Bik };
747cd5dab8aSAart Bik 
748870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern {
7495c0c51a9SNicolas Vasilache public:
7509826fe5cSAart Bik   explicit VectorExtractOpConversion(MLIRContext *context,
7515c0c51a9SNicolas Vasilache                                      LLVMTypeConverter &typeConverter)
752870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
7535c0c51a9SNicolas Vasilache                              typeConverter) {}
7545c0c51a9SNicolas Vasilache 
7553145427dSRiver Riddle   LogicalResult
756e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7575c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7585c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
7592d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
760d37f2725SAart Bik     auto extractOp = cast<vector::ExtractOp>(op);
7619826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7622bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
7630f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(resultType);
7645c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7659826fe5cSAart Bik 
7669826fe5cSAart Bik     // Bail if result type cannot be lowered.
7679826fe5cSAart Bik     if (!llvmResultType)
7683145427dSRiver Riddle       return failure();
7699826fe5cSAart Bik 
7705c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7715c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
772e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7735c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
7745c0c51a9SNicolas Vasilache       rewriter.replaceOp(op, extracted);
7753145427dSRiver Riddle       return success();
7765c0c51a9SNicolas Vasilache     }
7775c0c51a9SNicolas Vasilache 
7789826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
7795c0c51a9SNicolas Vasilache     auto *context = op->getContext();
780e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7815c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7825c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7839826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7845c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7855c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7865c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
7870f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
7885c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7895c0c51a9SNicolas Vasilache     }
7905c0c51a9SNicolas Vasilache 
7915c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7925c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7935446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7941d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7955c0c51a9SNicolas Vasilache     extracted =
7965c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
7975c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, extracted);
7985c0c51a9SNicolas Vasilache 
7993145427dSRiver Riddle     return success();
8005c0c51a9SNicolas Vasilache   }
8015c0c51a9SNicolas Vasilache };
8025c0c51a9SNicolas Vasilache 
803681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
804681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
805681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
806681f929fSNicolas Vasilache ///
807681f929fSNicolas Vasilache /// Example:
808681f929fSNicolas Vasilache /// ```
809681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
810681f929fSNicolas Vasilache /// ```
811681f929fSNicolas Vasilache /// is converted to:
812681f929fSNicolas Vasilache /// ```
8133bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
814681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
815681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
816681f929fSNicolas Vasilache /// ```
817870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
818681f929fSNicolas Vasilache public:
819681f929fSNicolas Vasilache   explicit VectorFMAOp1DConversion(MLIRContext *context,
820681f929fSNicolas Vasilache                                    LLVMTypeConverter &typeConverter)
821870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
822681f929fSNicolas Vasilache                              typeConverter) {}
823681f929fSNicolas Vasilache 
8243145427dSRiver Riddle   LogicalResult
825681f929fSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
826681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8272d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
828681f929fSNicolas Vasilache     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
829681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
830681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8313145427dSRiver Riddle       return failure();
8323bffe602SBenjamin Kramer     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
8333bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8343145427dSRiver Riddle     return success();
835681f929fSNicolas Vasilache   }
836681f929fSNicolas Vasilache };
837681f929fSNicolas Vasilache 
838870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
839cd5dab8aSAart Bik public:
840cd5dab8aSAart Bik   explicit VectorInsertElementOpConversion(MLIRContext *context,
841cd5dab8aSAart Bik                                            LLVMTypeConverter &typeConverter)
842870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
843870c1fd4SAlex Zinenko                              context, typeConverter) {}
844cd5dab8aSAart Bik 
8453145427dSRiver Riddle   LogicalResult
846e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
847cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8482d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
849cd5dab8aSAart Bik     auto insertEltOp = cast<vector::InsertElementOp>(op);
850cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
8510f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType);
852cd5dab8aSAart Bik 
853cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
854cd5dab8aSAart Bik     if (!llvmType)
8553145427dSRiver Riddle       return failure();
856cd5dab8aSAart Bik 
857cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
858cd5dab8aSAart Bik         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
8593145427dSRiver Riddle     return success();
860cd5dab8aSAart Bik   }
861cd5dab8aSAart Bik };
862cd5dab8aSAart Bik 
863870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern {
8649826fe5cSAart Bik public:
8659826fe5cSAart Bik   explicit VectorInsertOpConversion(MLIRContext *context,
8669826fe5cSAart Bik                                     LLVMTypeConverter &typeConverter)
867870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
8689826fe5cSAart Bik                              typeConverter) {}
8699826fe5cSAart Bik 
8703145427dSRiver Riddle   LogicalResult
871e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
8729826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8739826fe5cSAart Bik     auto loc = op->getLoc();
8742d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8759826fe5cSAart Bik     auto insertOp = cast<vector::InsertOp>(op);
8769826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8779826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
8780f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(destVectorType);
8799826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8809826fe5cSAart Bik 
8819826fe5cSAart Bik     // Bail if result type cannot be lowered.
8829826fe5cSAart Bik     if (!llvmResultType)
8833145427dSRiver Riddle       return failure();
8849826fe5cSAart Bik 
8859826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8869826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
887e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8889826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8899826fe5cSAart Bik           positionArrayAttr);
8909826fe5cSAart Bik       rewriter.replaceOp(op, inserted);
8913145427dSRiver Riddle       return success();
8929826fe5cSAart Bik     }
8939826fe5cSAart Bik 
8949826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
8959826fe5cSAart Bik     auto *context = op->getContext();
896e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8979826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8989826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8999826fe5cSAart Bik     auto oneDVectorType = destVectorType;
9009826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9019826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
9029826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9039826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9049826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
9050f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
9069826fe5cSAart Bik           nMinusOnePositionAttrs);
9079826fe5cSAart Bik     }
9089826fe5cSAart Bik 
9099826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
9105446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
9111d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
912e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
9130f04384dSAlex Zinenko         loc, typeConverter.convertType(oneDVectorType), extracted,
9140f04384dSAlex Zinenko         adaptor.source(), constant);
9159826fe5cSAart Bik 
9169826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
9179826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9189826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9199826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9209826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
9219826fe5cSAart Bik                                                       adaptor.dest(), inserted,
9229826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
9239826fe5cSAart Bik     }
9249826fe5cSAart Bik 
9259826fe5cSAart Bik     rewriter.replaceOp(op, inserted);
9263145427dSRiver Riddle     return success();
9279826fe5cSAart Bik   }
9289826fe5cSAart Bik };
9299826fe5cSAart Bik 
930681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
931681f929fSNicolas Vasilache ///
932681f929fSNicolas Vasilache /// Example:
933681f929fSNicolas Vasilache /// ```
934681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
935681f929fSNicolas Vasilache /// ```
936681f929fSNicolas Vasilache /// is rewritten into:
937681f929fSNicolas Vasilache /// ```
938681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
939681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
940681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
941681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
942681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
943681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
944681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
945681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
946681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
947681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
948681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
949681f929fSNicolas Vasilache ///  // %r3 holds the final value.
950681f929fSNicolas Vasilache /// ```
951681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
952681f929fSNicolas Vasilache public:
953681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
954681f929fSNicolas Vasilache 
9553145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
956681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
957681f929fSNicolas Vasilache     auto vType = op.getVectorType();
958681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9593145427dSRiver Riddle       return failure();
960681f929fSNicolas Vasilache 
961681f929fSNicolas Vasilache     auto loc = op.getLoc();
962681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
963681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
964681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
965681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
966681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
967681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
968681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
969681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
970681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
971681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
972681f929fSNicolas Vasilache     }
973681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9743145427dSRiver Riddle     return success();
975681f929fSNicolas Vasilache   }
976681f929fSNicolas Vasilache };
977681f929fSNicolas Vasilache 
9782d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9792d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9802d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9812d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9822d515e49SNicolas Vasilache // rank.
9832d515e49SNicolas Vasilache //
9842d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9852d515e49SNicolas Vasilache // have different ranks. In this case:
9862d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9872d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9882d515e49SNicolas Vasilache //   destination subvector
9892d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9902d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9912d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9922d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9932d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9942d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9952d515e49SNicolas Vasilache public:
9962d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9972d515e49SNicolas Vasilache 
9983145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9992d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10002d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10012d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10022d515e49SNicolas Vasilache 
10032d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10043145427dSRiver Riddle       return failure();
10052d515e49SNicolas Vasilache 
10062d515e49SNicolas Vasilache     auto loc = op.getLoc();
10072d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10082d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10092d515e49SNicolas Vasilache     if (rankDiff == 0)
10103145427dSRiver Riddle       return failure();
10112d515e49SNicolas Vasilache 
10122d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
10132d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
10142d515e49SNicolas Vasilache     // on it.
10152d515e49SNicolas Vasilache     Value extracted =
10162d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
10172d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
10182d515e49SNicolas Vasilache                                                   /*dropFront=*/rankRest));
10192d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
10202d515e49SNicolas Vasilache     // ranks.
10212d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
10222d515e49SNicolas Vasilache         loc, op.source(), extracted,
10232d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1024c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
10252d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
10262d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
10272d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
10282d515e49SNicolas Vasilache                        /*dropFront=*/rankRest));
10293145427dSRiver Riddle     return success();
10302d515e49SNicolas Vasilache   }
10312d515e49SNicolas Vasilache };
10322d515e49SNicolas Vasilache 
10332d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10342d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10352d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10362d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10372d515e49SNicolas Vasilache //   destination subvector
10382d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10392d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10402d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10412d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10422d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10432d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10442d515e49SNicolas Vasilache public:
1045*b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1046*b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1047*b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1048*b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1049*b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1050*b99bd771SRiver Riddle   }
10512d515e49SNicolas Vasilache 
10523145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10532d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10542d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10552d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10562d515e49SNicolas Vasilache 
10572d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10583145427dSRiver Riddle       return failure();
10592d515e49SNicolas Vasilache 
10602d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10612d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10622d515e49SNicolas Vasilache     if (rankDiff != 0)
10633145427dSRiver Riddle       return failure();
10642d515e49SNicolas Vasilache 
10652d515e49SNicolas Vasilache     if (srcType == dstType) {
10662d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10673145427dSRiver Riddle       return success();
10682d515e49SNicolas Vasilache     }
10692d515e49SNicolas Vasilache 
10702d515e49SNicolas Vasilache     int64_t offset =
10712d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10722d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10732d515e49SNicolas Vasilache     int64_t stride =
10742d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10752d515e49SNicolas Vasilache 
10762d515e49SNicolas Vasilache     auto loc = op.getLoc();
10772d515e49SNicolas Vasilache     Value res = op.dest();
10782d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10792d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10802d515e49SNicolas Vasilache          off += stride, ++idx) {
10812d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10822d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10832d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10842d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10852d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10862d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10872d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10882d515e49SNicolas Vasilache         // smaller rank.
1089bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10902d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10912d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10922d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10932d515e49SNicolas Vasilache       }
10942d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10952d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10962d515e49SNicolas Vasilache     }
10972d515e49SNicolas Vasilache 
10982d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10993145427dSRiver Riddle     return success();
11002d515e49SNicolas Vasilache   }
11012d515e49SNicolas Vasilache };
11022d515e49SNicolas Vasilache 
110330e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
110430e6033bSNicolas Vasilache /// static layout.
110530e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
110630e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
11072bf491c7SBenjamin Kramer   int64_t offset;
110830e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
110930e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
111030e6033bSNicolas Vasilache     return None;
111130e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
111230e6033bSNicolas Vasilache     return None;
111330e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
111430e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
111530e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
111630e6033bSNicolas Vasilache     return strides;
111730e6033bSNicolas Vasilache 
111830e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
111930e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
112030e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
112130e6033bSNicolas Vasilache   // layout.
11222bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
11232bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
112430e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
112530e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
112630e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
112730e6033bSNicolas Vasilache       return None;
112830e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
112930e6033bSNicolas Vasilache       return None;
11302bf491c7SBenjamin Kramer   }
113130e6033bSNicolas Vasilache   return strides;
11322bf491c7SBenjamin Kramer }
11332bf491c7SBenjamin Kramer 
1134870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
11355c0c51a9SNicolas Vasilache public:
11365c0c51a9SNicolas Vasilache   explicit VectorTypeCastOpConversion(MLIRContext *context,
11375c0c51a9SNicolas Vasilache                                       LLVMTypeConverter &typeConverter)
1138870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
11395c0c51a9SNicolas Vasilache                              typeConverter) {}
11405c0c51a9SNicolas Vasilache 
11413145427dSRiver Riddle   LogicalResult
1142e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11435c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11445c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
11455c0c51a9SNicolas Vasilache     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
11465c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11472bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11485c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
11492bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
11505c0c51a9SNicolas Vasilache 
11515c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11525c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11535c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11543145427dSRiver Riddle       return failure();
11555c0c51a9SNicolas Vasilache 
11565c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11572bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
11585c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
11593145427dSRiver Riddle       return failure();
11605c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11615c0c51a9SNicolas Vasilache 
11620f04384dSAlex Zinenko     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
11635c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11645c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11653145427dSRiver Riddle       return failure();
11665c0c51a9SNicolas Vasilache 
116730e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
116830e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
116930e6033bSNicolas Vasilache     if (!sourceStrides)
117030e6033bSNicolas Vasilache       return failure();
117130e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
117230e6033bSNicolas Vasilache     if (!targetStrides)
117330e6033bSNicolas Vasilache       return failure();
117430e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
117530e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
117630e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
117730e6033bSNicolas Vasilache         }))
11783145427dSRiver Riddle       return failure();
11795c0c51a9SNicolas Vasilache 
11805446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11815c0c51a9SNicolas Vasilache 
11825c0c51a9SNicolas Vasilache     // Create descriptor.
11835c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11843a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11855c0c51a9SNicolas Vasilache     // Set allocated ptr.
1186e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11875c0c51a9SNicolas Vasilache     allocated =
11885c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11895c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11905c0c51a9SNicolas Vasilache     // Set aligned ptr.
1191e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11925c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11935c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11945c0c51a9SNicolas Vasilache     // Fill offset 0.
11955c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11965c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11975c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11985c0c51a9SNicolas Vasilache 
11995c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
12005c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
12015c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
12025c0c51a9SNicolas Vasilache       auto sizeAttr =
12035c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
12045c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
12055c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
120630e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
120730e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
12085c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
12095c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
12105c0c51a9SNicolas Vasilache     }
12115c0c51a9SNicolas Vasilache 
12125c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, {desc});
12133145427dSRiver Riddle     return success();
12145c0c51a9SNicolas Vasilache   }
12155c0c51a9SNicolas Vasilache };
12165c0c51a9SNicolas Vasilache 
12178345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
12188345b86dSNicolas Vasilache /// sequence of:
1219060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1220060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1221060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1222060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1223060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
12248345b86dSNicolas Vasilache template <typename ConcreteOp>
12258345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern {
12268345b86dSNicolas Vasilache public:
12278345b86dSNicolas Vasilache   explicit VectorTransferConversion(MLIRContext *context,
1228060c9dd1Saartbik                                     LLVMTypeConverter &typeConv,
1229060c9dd1Saartbik                                     bool enableIndexOpt)
1230060c9dd1Saartbik       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1231060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
12328345b86dSNicolas Vasilache 
12338345b86dSNicolas Vasilache   LogicalResult
12348345b86dSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
12358345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12368345b86dSNicolas Vasilache     auto xferOp = cast<ConcreteOp>(op);
12378345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1238b2c79c50SNicolas Vasilache 
1239b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1240b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12418345b86dSNicolas Vasilache       return failure();
12425f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12435f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12445f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
12455f9e0466SNicolas Vasilache                                        op->getContext()))
12468345b86dSNicolas Vasilache       return failure();
12472bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
124830e6033bSNicolas Vasilache     auto strides = computeContiguousStrides(xferOp.getMemRefType());
124930e6033bSNicolas Vasilache     if (!strides)
12502bf491c7SBenjamin Kramer       return failure();
12518345b86dSNicolas Vasilache 
12528345b86dSNicolas Vasilache     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
12538345b86dSNicolas Vasilache 
12548345b86dSNicolas Vasilache     Location loc = op->getLoc();
12558345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
12568345b86dSNicolas Vasilache 
125768330ee0SThomas Raoux     if (auto memrefVectorElementType =
125868330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
125968330ee0SThomas Raoux       // Memref has vector element type.
126068330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
126168330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
126268330ee0SThomas Raoux         return failure();
12630de60b55SThomas Raoux #ifndef NDEBUG
126468330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
126568330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
126668330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
126768330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
126868330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
126968330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
127068330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
127168330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
127268330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
127368330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
127468330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
127568330ee0SThomas Raoux                "result shape.");
12760de60b55SThomas Raoux #endif // ifndef NDEBUG
127768330ee0SThomas Raoux     }
127868330ee0SThomas Raoux 
12798345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1280be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1281be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1282be16075bSWen-Heng (Jack) Chung     //    address space 0.
12838345b86dSNicolas Vasilache     // TODO: support alignment when possible.
12848345b86dSNicolas Vasilache     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1285d3a98076SAlex Zinenko                                adaptor.indices(), rewriter);
12868345b86dSNicolas Vasilache     auto vecTy =
12878345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1288be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1289be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1290be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12918345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1292be16075bSWen-Heng (Jack) Chung     else
1293be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1294be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12958345b86dSNicolas Vasilache 
12961870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
12971870e787SNicolas Vasilache       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
12981870e787SNicolas Vasilache                                               xferOp, operands, vectorDataPtr);
12991870e787SNicolas Vasilache 
13008345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
13018345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
13028345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
13038345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1304060c9dd1Saartbik     //
1305060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1306060c9dd1Saartbik     //       dimensions here.
1307060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1308060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
13090c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
1310b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1311060c9dd1Saartbik     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1312060c9dd1Saartbik                                        vecWidth, dim, &off);
13138345b86dSNicolas Vasilache 
13148345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
13151870e787SNicolas Vasilache     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1316a99f62c4SAlex Zinenko                                        operands, vectorDataPtr, mask);
13178345b86dSNicolas Vasilache   }
1318060c9dd1Saartbik 
1319060c9dd1Saartbik private:
1320060c9dd1Saartbik   const bool enableIndexOptimizations;
13218345b86dSNicolas Vasilache };
13228345b86dSNicolas Vasilache 
1323870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern {
1324d9b500d3SAart Bik public:
1325d9b500d3SAart Bik   explicit VectorPrintOpConversion(MLIRContext *context,
1326d9b500d3SAart Bik                                    LLVMTypeConverter &typeConverter)
1327870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1328d9b500d3SAart Bik                              typeConverter) {}
1329d9b500d3SAart Bik 
1330d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1331d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1332d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1333d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1334d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1335d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1336d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1337d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1338d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1339d9b500d3SAart Bik   //
13409db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1341d9b500d3SAart Bik   //
13423145427dSRiver Riddle   LogicalResult
1343e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1344d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1345d9b500d3SAart Bik     auto printOp = cast<vector::PrintOp>(op);
13462d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1347d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1348d9b500d3SAart Bik 
13490f04384dSAlex Zinenko     if (typeConverter.convertType(printType) == nullptr)
13503145427dSRiver Riddle       return failure();
1351d9b500d3SAart Bik 
1352b8880f5fSAart Bik     // Make sure element type has runtime support.
1353b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1354d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1355d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1356d9b500d3SAart Bik     Operation *printer;
1357b8880f5fSAart Bik     if (eltType.isF32()) {
1358d9b500d3SAart Bik       printer = getPrintFloat(op);
1359b8880f5fSAart Bik     } else if (eltType.isF64()) {
1360d9b500d3SAart Bik       printer = getPrintDouble(op);
136154759cefSAart Bik     } else if (eltType.isIndex()) {
136254759cefSAart Bik       printer = getPrintU64(op);
1363b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1364b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1365b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1366b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1367b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1368b8880f5fSAart Bik       if (intTy.isUnsigned()) {
136954759cefSAart Bik         if (width <= 64) {
1370b8880f5fSAart Bik           if (width < 64)
1371b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1372b8880f5fSAart Bik           printer = getPrintU64(op);
1373b8880f5fSAart Bik         } else {
13743145427dSRiver Riddle           return failure();
1375b8880f5fSAart Bik         }
1376b8880f5fSAart Bik       } else {
1377b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
137854759cefSAart Bik         if (width <= 64) {
1379b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1380b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1381b8880f5fSAart Bik           if (width == 1)
138254759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
138354759cefSAart Bik           else if (width < 64)
1384b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1385b8880f5fSAart Bik           printer = getPrintI64(op);
1386b8880f5fSAart Bik         } else {
1387b8880f5fSAart Bik           return failure();
1388b8880f5fSAart Bik         }
1389b8880f5fSAart Bik       }
1390b8880f5fSAart Bik     } else {
1391b8880f5fSAart Bik       return failure();
1392b8880f5fSAart Bik     }
1393d9b500d3SAart Bik 
1394d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1395b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1396b8880f5fSAart Bik     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1397b8880f5fSAart Bik               conversion);
1398d9b500d3SAart Bik     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1399d9b500d3SAart Bik     rewriter.eraseOp(op);
14003145427dSRiver Riddle     return success();
1401d9b500d3SAart Bik   }
1402d9b500d3SAart Bik 
1403d9b500d3SAart Bik private:
1404b8880f5fSAart Bik   enum class PrintConversion {
140530e6033bSNicolas Vasilache     // clang-format off
1406b8880f5fSAart Bik     None,
1407b8880f5fSAart Bik     ZeroExt64,
1408b8880f5fSAart Bik     SignExt64
140930e6033bSNicolas Vasilache     // clang-format on
1410b8880f5fSAart Bik   };
1411b8880f5fSAart Bik 
1412d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1413e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1414b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1415d9b500d3SAart Bik     Location loc = op->getLoc();
1416d9b500d3SAart Bik     if (rank == 0) {
1417b8880f5fSAart Bik       switch (conversion) {
1418b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1419b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
1420b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1421b8880f5fSAart Bik         break;
1422b8880f5fSAart Bik       case PrintConversion::SignExt64:
1423b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
1424b8880f5fSAart Bik             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1425b8880f5fSAart Bik         break;
1426b8880f5fSAart Bik       case PrintConversion::None:
1427b8880f5fSAart Bik         break;
1428c9eeeb38Saartbik       }
1429d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1430d9b500d3SAart Bik       return;
1431d9b500d3SAart Bik     }
1432d9b500d3SAart Bik 
1433d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1434d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1435d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1436d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1437d9b500d3SAart Bik       auto reducedType =
1438d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
14390f04384dSAlex Zinenko       auto llvmType = typeConverter.convertType(
1440d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1441e62a6956SRiver Riddle       Value nestedVal =
14420f04384dSAlex Zinenko           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1443b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1444b8880f5fSAart Bik                 conversion);
1445d9b500d3SAart Bik       if (d != dim - 1)
1446d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1447d9b500d3SAart Bik     }
1448d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1449d9b500d3SAart Bik   }
1450d9b500d3SAart Bik 
1451d9b500d3SAart Bik   // Helper to emit a call.
1452d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1453d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
145408e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1455d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1456d9b500d3SAart Bik   }
1457d9b500d3SAart Bik 
1458d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
14595446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
14605446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1461d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1462d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1463d9b500d3SAart Bik     if (func)
1464d9b500d3SAart Bik       return func;
1465d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1466d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1467d9b500d3SAart Bik         op->getLoc(), name,
14685446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14695446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14705446ec85SAlex Zinenko             /*isVarArg=*/false));
1471d9b500d3SAart Bik   }
1472d9b500d3SAart Bik 
1473d9b500d3SAart Bik   // Helpers for method names.
1474e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
147554759cefSAart Bik     return getPrint(op, "printI64",
14765446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1477e52414b1Saartbik   }
1478b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
1479b8880f5fSAart Bik     return getPrint(op, "printU64",
1480b8880f5fSAart Bik                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1481b8880f5fSAart Bik   }
1482d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
148354759cefSAart Bik     return getPrint(op, "printF32",
14845446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1485d9b500d3SAart Bik   }
1486d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
148754759cefSAart Bik     return getPrint(op, "printF64",
14885446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1489d9b500d3SAart Bik   }
1490d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
149154759cefSAart Bik     return getPrint(op, "printOpen", {});
1492d9b500d3SAart Bik   }
1493d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
149454759cefSAart Bik     return getPrint(op, "printClose", {});
1495d9b500d3SAart Bik   }
1496d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
149754759cefSAart Bik     return getPrint(op, "printComma", {});
1498d9b500d3SAart Bik   }
1499d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
150054759cefSAart Bik     return getPrint(op, "printNewline", {});
1501d9b500d3SAart Bik   }
1502d9b500d3SAart Bik };
1503d9b500d3SAart Bik 
1504334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1505c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1506c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1507c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1508334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
150965678d93SNicolas Vasilache public:
1510*b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1511*b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1512*b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1513*b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1514*b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1515*b99bd771SRiver Riddle   }
151665678d93SNicolas Vasilache 
1517334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
151865678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
151965678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
152065678d93SNicolas Vasilache 
152165678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
152265678d93SNicolas Vasilache 
152365678d93SNicolas Vasilache     int64_t offset =
152465678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
152565678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
152665678d93SNicolas Vasilache     int64_t stride =
152765678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
152865678d93SNicolas Vasilache 
152965678d93SNicolas Vasilache     auto loc = op.getLoc();
153065678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
153135b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1532c3c95b9cSaartbik 
1533c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1534c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1535c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1536c3c95b9cSaartbik       offsets.reserve(size);
1537c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1538c3c95b9cSaartbik            off += stride)
1539c3c95b9cSaartbik         offsets.push_back(off);
1540c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1541c3c95b9cSaartbik                                              op.vector(),
1542c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1543c3c95b9cSaartbik       return success();
1544c3c95b9cSaartbik     }
1545c3c95b9cSaartbik 
1546c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
154765678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
154865678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
154965678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
155065678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
155165678d93SNicolas Vasilache          off += stride, ++idx) {
1552c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1553c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1554c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
155565678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
155665678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
155765678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
155865678d93SNicolas Vasilache     }
1559c3c95b9cSaartbik     rewriter.replaceOp(op, res);
15603145427dSRiver Riddle     return success();
156165678d93SNicolas Vasilache   }
156265678d93SNicolas Vasilache };
156365678d93SNicolas Vasilache 
1564df186507SBenjamin Kramer } // namespace
1565df186507SBenjamin Kramer 
15665c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15675c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1568ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1569060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
157065678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15718345b86dSNicolas Vasilache   // clang-format off
1572681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1573681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15742d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1575c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1576ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1577ceb1b327Saartbik       ctx, converter, reassociateFPReductions);
1578060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1579060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1580060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1581060c9dd1Saartbik       ctx, converter, enableIndexOptimizations);
15828345b86dSNicolas Vasilache   patterns
1583ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15848345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15858345b86dSNicolas Vasilache               VectorExtractOpConversion,
15868345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15878345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15888345b86dSNicolas Vasilache               VectorInsertOpConversion,
15898345b86dSNicolas Vasilache               VectorPrintOpConversion,
159019dbb230Saartbik               VectorTypeCastOpConversion,
159139379916Saartbik               VectorMaskedLoadOpConversion,
159239379916Saartbik               VectorMaskedStoreOpConversion,
159319dbb230Saartbik               VectorGatherOpConversion,
1594e8dcf5f8Saartbik               VectorScatterOpConversion,
1595e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1596e8dcf5f8Saartbik               VectorCompressStoreOpConversion>(ctx, converter);
15978345b86dSNicolas Vasilache   // clang-format on
15985c0c51a9SNicolas Vasilache }
15995c0c51a9SNicolas Vasilache 
160063b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
160163b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
160263b683a8SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
160363b683a8SNicolas Vasilache   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1604c295a65dSaartbik   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
160563b683a8SNicolas Vasilache }
160663b683a8SNicolas Vasilache 
16075c0c51a9SNicolas Vasilache namespace {
1608722f909fSRiver Riddle struct LowerVectorToLLVMPass
16091834ad4aSRiver Riddle     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
16101bfdf7c7Saartbik   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
16111bfdf7c7Saartbik     this->reassociateFPReductions = options.reassociateFPReductions;
1612060c9dd1Saartbik     this->enableIndexOptimizations = options.enableIndexOptimizations;
16131bfdf7c7Saartbik   }
1614722f909fSRiver Riddle   void runOnOperation() override;
16155c0c51a9SNicolas Vasilache };
16165c0c51a9SNicolas Vasilache } // namespace
16175c0c51a9SNicolas Vasilache 
1618722f909fSRiver Riddle void LowerVectorToLLVMPass::runOnOperation() {
1619078776a6Saartbik   // Perform progressive lowering of operations on slices and
1620b21c7999Saartbik   // all contraction operations. Also applies folding and DCE.
1621459cf6e5Saartbik   {
16225c0c51a9SNicolas Vasilache     OwningRewritePatternList patterns;
1623b1c688dbSaartbik     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1624459cf6e5Saartbik     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1625b21c7999Saartbik     populateVectorContractLoweringPatterns(patterns, &getContext());
1626a5b9316bSUday Bondhugula     applyPatternsAndFoldGreedily(getOperation(), patterns);
1627459cf6e5Saartbik   }
1628459cf6e5Saartbik 
1629459cf6e5Saartbik   // Convert to the LLVM IR dialect.
16305c0c51a9SNicolas Vasilache   LLVMTypeConverter converter(&getContext());
1631459cf6e5Saartbik   OwningRewritePatternList patterns;
163263b683a8SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1633060c9dd1Saartbik   populateVectorToLLVMConversionPatterns(
1634060c9dd1Saartbik       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
1635bbf3ef85SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
16365c0c51a9SNicolas Vasilache   populateStdToLLVMConversionPatterns(converter, patterns);
16375c0c51a9SNicolas Vasilache 
16382a00ae39STim Shen   LLVMConversionTarget target(getContext());
1639060c9dd1Saartbik   if (failed(applyPartialConversion(getOperation(), target, patterns)))
16405c0c51a9SNicolas Vasilache     signalPassFailure();
16415c0c51a9SNicolas Vasilache }
16425c0c51a9SNicolas Vasilache 
16431bfdf7c7Saartbik std::unique_ptr<OperationPass<ModuleOp>>
16441bfdf7c7Saartbik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
16451bfdf7c7Saartbik   return std::make_unique<LowerVectorToLLVMPass>(options);
16465c0c51a9SNicolas Vasilache }
1647