15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
111834ad4aSRiver Riddle #include "../PassDetail.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
135c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
178345b86dSNicolas Vasilache #include "mlir/IR/AffineMap.h"
185c0c51a9SNicolas Vasilache #include "mlir/IR/Attributes.h"
195c0c51a9SNicolas Vasilache #include "mlir/IR/Builders.h"
205c0c51a9SNicolas Vasilache #include "mlir/IR/MLIRContext.h"
215c0c51a9SNicolas Vasilache #include "mlir/IR/Module.h"
225c0c51a9SNicolas Vasilache #include "mlir/IR/Operation.h"
235c0c51a9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
245c0c51a9SNicolas Vasilache #include "mlir/IR/StandardTypes.h"
255c0c51a9SNicolas Vasilache #include "mlir/IR/Types.h"
26ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
275c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
285c0c51a9SNicolas Vasilache #include "mlir/Transforms/Passes.h"
295c0c51a9SNicolas Vasilache #include "llvm/IR/DerivedTypes.h"
305c0c51a9SNicolas Vasilache #include "llvm/IR/Module.h"
315c0c51a9SNicolas Vasilache #include "llvm/IR/Type.h"
325c0c51a9SNicolas Vasilache #include "llvm/Support/Allocator.h"
335c0c51a9SNicolas Vasilache #include "llvm/Support/ErrorHandling.h"
345c0c51a9SNicolas Vasilache 
355c0c51a9SNicolas Vasilache using namespace mlir;
3665678d93SNicolas Vasilache using namespace mlir::vector;
375c0c51a9SNicolas Vasilache 
389826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
399826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
409826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
419826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
429826fe5cSAart Bik }
439826fe5cSAart Bik 
449826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
459826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
469826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
479826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
489826fe5cSAart Bik }
499826fe5cSAart Bik 
501c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
51e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
520f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
530f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
540f04384dSAlex Zinenko                        int64_t pos) {
551c81adf3SAart Bik   if (rank == 1) {
561c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
571c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
580f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
591c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
601c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
611c81adf3SAart Bik                                                   constant);
621c81adf3SAart Bik   }
631c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
641c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
651c81adf3SAart Bik }
661c81adf3SAart Bik 
672d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
682d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
692d515e49SNicolas Vasilache                        Value into, int64_t offset) {
702d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
712d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
722d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
732d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
742d515e49SNicolas Vasilache       loc, vectorType, from, into,
752d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
762d515e49SNicolas Vasilache }
772d515e49SNicolas Vasilache 
781c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
79e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
800f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
810f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
821c81adf3SAart Bik   if (rank == 1) {
831c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
841c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
850f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
861c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
871c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
881c81adf3SAart Bik                                                    constant);
891c81adf3SAart Bik   }
901c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
911c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
921c81adf3SAart Bik }
931c81adf3SAart Bik 
942d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
952d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
962d515e49SNicolas Vasilache                         int64_t offset) {
972d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
982d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
992d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
1002d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
1012d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
1022d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
1052d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
1069db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
1072d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
1082d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
1092d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
1102d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
1112d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
1122d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
1132d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
1142d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1152d515e49SNicolas Vasilache        it != eit; ++it)
1162d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1172d515e49SNicolas Vasilache   return res;
1182d515e49SNicolas Vasilache }
1192d515e49SNicolas Vasilache 
120*060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
121*060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
122*060c9dd1Saartbik //
123*060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
124*060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
125*060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
126*060c9dd1Saartbik //       is very conservative on the boundary conditions.
127*060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
128*060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
129*060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
130*060c9dd1Saartbik   auto loc = op->getLoc();
131*060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
132*060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
133*060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
134*060c9dd1Saartbik   Value indices;
135*060c9dd1Saartbik   Type idxType;
136*060c9dd1Saartbik   if (enableIndexOptimizations) {
137*060c9dd1Saartbik     SmallVector<int32_t, 4> values(dim);
138*060c9dd1Saartbik     for (int64_t d = 0; d < dim; d++)
139*060c9dd1Saartbik       values[d] = d;
140*060c9dd1Saartbik     indices =
141*060c9dd1Saartbik         rewriter.create<ConstantOp>(loc, rewriter.getI32VectorAttr(values));
142*060c9dd1Saartbik     idxType = rewriter.getI32Type();
143*060c9dd1Saartbik   } else {
144*060c9dd1Saartbik     SmallVector<int64_t, 4> values(dim);
145*060c9dd1Saartbik     for (int64_t d = 0; d < dim; d++)
146*060c9dd1Saartbik       values[d] = d;
147*060c9dd1Saartbik     indices =
148*060c9dd1Saartbik         rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
149*060c9dd1Saartbik     idxType = rewriter.getI64Type();
150*060c9dd1Saartbik   }
151*060c9dd1Saartbik   // Add in an offset if requested.
152*060c9dd1Saartbik   if (off) {
153*060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
154*060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
155*060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
156*060c9dd1Saartbik   }
157*060c9dd1Saartbik   // Construct the vector comparison.
158*060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
159*060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
160*060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
161*060c9dd1Saartbik }
162*060c9dd1Saartbik 
16319dbb230Saartbik // Helper that returns data layout alignment of an operation with memref.
16419dbb230Saartbik template <typename T>
16519dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
16619dbb230Saartbik                                  unsigned &align) {
1675f9e0466SNicolas Vasilache   Type elementTy =
16819dbb230Saartbik       typeConverter.convertType(op.getMemRefType().getElementType());
1695f9e0466SNicolas Vasilache   if (!elementTy)
1705f9e0466SNicolas Vasilache     return failure();
1715f9e0466SNicolas Vasilache 
172b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
173b2ab375dSAlex Zinenko   // stop depending on translation.
17487a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
17587a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
176b2ab375dSAlex Zinenko               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
177168213f9SAlex Zinenko                                      typeConverter.getDataLayout());
1785f9e0466SNicolas Vasilache   return success();
1795f9e0466SNicolas Vasilache }
1805f9e0466SNicolas Vasilache 
181e8dcf5f8Saartbik // Helper that returns the base address of a memref.
182b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
183e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
18419dbb230Saartbik   // Inspect stride and offset structure.
18519dbb230Saartbik   //
18619dbb230Saartbik   // TODO: flat memory only for now, generalize
18719dbb230Saartbik   //
18819dbb230Saartbik   int64_t offset;
18919dbb230Saartbik   SmallVector<int64_t, 4> strides;
19019dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
19119dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
19219dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
19319dbb230Saartbik     return failure();
194e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
195e8dcf5f8Saartbik   return success();
196e8dcf5f8Saartbik }
19719dbb230Saartbik 
198e8dcf5f8Saartbik // Helper that returns a pointer given a memref base.
199b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
200b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
201b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Value &ptr) {
202e8dcf5f8Saartbik   Value base;
203e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
204e8dcf5f8Saartbik     return failure();
205e8dcf5f8Saartbik   auto pType = MemRefDescriptor(memref).getElementType();
206e8dcf5f8Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
207e8dcf5f8Saartbik   return success();
208e8dcf5f8Saartbik }
209e8dcf5f8Saartbik 
21039379916Saartbik // Helper that returns a bit-casted pointer given a memref base.
211b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
212b98e25b6SBenjamin Kramer                                 Location loc, Value memref,
213b98e25b6SBenjamin Kramer                                 MemRefType memRefType, Type type, Value &ptr) {
21439379916Saartbik   Value base;
21539379916Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
21639379916Saartbik     return failure();
21739379916Saartbik   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
21839379916Saartbik   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
21939379916Saartbik   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
22039379916Saartbik   return success();
22139379916Saartbik }
22239379916Saartbik 
223e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index
224e8dcf5f8Saartbik // vector.
225b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
226b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
227b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
228b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
229e8dcf5f8Saartbik   Value base;
230e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
231e8dcf5f8Saartbik     return failure();
232e8dcf5f8Saartbik   auto pType = MemRefDescriptor(memref).getElementType();
233e8dcf5f8Saartbik   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
2341485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
23519dbb230Saartbik   return success();
23619dbb230Saartbik }
23719dbb230Saartbik 
2385f9e0466SNicolas Vasilache static LogicalResult
2395f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2405f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2415f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2425f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
243affbc0cdSNicolas Vasilache   unsigned align;
24419dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
245affbc0cdSNicolas Vasilache     return failure();
246affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2475f9e0466SNicolas Vasilache   return success();
2485f9e0466SNicolas Vasilache }
2495f9e0466SNicolas Vasilache 
2505f9e0466SNicolas Vasilache static LogicalResult
2515f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2525f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2535f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2545f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2555f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2565f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2575f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2585f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2595f9e0466SNicolas Vasilache 
2605f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2615f9e0466SNicolas Vasilache   if (!vecTy)
2625f9e0466SNicolas Vasilache     return failure();
2635f9e0466SNicolas Vasilache 
2645f9e0466SNicolas Vasilache   unsigned align;
26519dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2665f9e0466SNicolas Vasilache     return failure();
2675f9e0466SNicolas Vasilache 
2685f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2695f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2705f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2715f9e0466SNicolas Vasilache   return success();
2725f9e0466SNicolas Vasilache }
2735f9e0466SNicolas Vasilache 
2745f9e0466SNicolas Vasilache static LogicalResult
2755f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2765f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2775f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2785f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
279affbc0cdSNicolas Vasilache   unsigned align;
28019dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
281affbc0cdSNicolas Vasilache     return failure();
2822d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
283affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
284affbc0cdSNicolas Vasilache                                              align);
2855f9e0466SNicolas Vasilache   return success();
2865f9e0466SNicolas Vasilache }
2875f9e0466SNicolas Vasilache 
2885f9e0466SNicolas Vasilache static LogicalResult
2895f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2905f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2915f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2925f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2935f9e0466SNicolas Vasilache   unsigned align;
29419dbb230Saartbik   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
2955f9e0466SNicolas Vasilache     return failure();
2965f9e0466SNicolas Vasilache 
2972d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2985f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2995f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
3005f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
3015f9e0466SNicolas Vasilache   return success();
3025f9e0466SNicolas Vasilache }
3035f9e0466SNicolas Vasilache 
3042d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
3052d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
3062d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
3075f9e0466SNicolas Vasilache }
3085f9e0466SNicolas Vasilache 
3092d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
3102d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
3112d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
3125f9e0466SNicolas Vasilache }
3135f9e0466SNicolas Vasilache 
31490c01357SBenjamin Kramer namespace {
315e83b7b99Saartbik 
31663b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
31763b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
31863b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern {
31963b683a8SNicolas Vasilache public:
32063b683a8SNicolas Vasilache   explicit VectorMatmulOpConversion(MLIRContext *context,
32163b683a8SNicolas Vasilache                                     LLVMTypeConverter &typeConverter)
32263b683a8SNicolas Vasilache       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
32363b683a8SNicolas Vasilache                              typeConverter) {}
32463b683a8SNicolas Vasilache 
3253145427dSRiver Riddle   LogicalResult
32663b683a8SNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
32763b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
32863b683a8SNicolas Vasilache     auto matmulOp = cast<vector::MatmulOp>(op);
3292d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
33063b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
33163b683a8SNicolas Vasilache         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
33263b683a8SNicolas Vasilache         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
33363b683a8SNicolas Vasilache         matmulOp.rhs_columns());
3343145427dSRiver Riddle     return success();
33563b683a8SNicolas Vasilache   }
33663b683a8SNicolas Vasilache };
33763b683a8SNicolas Vasilache 
338c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
339c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
340c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
341c295a65dSaartbik public:
342c295a65dSaartbik   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
343c295a65dSaartbik                                            LLVMTypeConverter &typeConverter)
344c295a65dSaartbik       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
345c295a65dSaartbik                              context, typeConverter) {}
346c295a65dSaartbik 
347c295a65dSaartbik   LogicalResult
348c295a65dSaartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
349c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
350c295a65dSaartbik     auto transOp = cast<vector::FlatTransposeOp>(op);
3512d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
352c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
353c295a65dSaartbik         transOp, typeConverter.convertType(transOp.res().getType()),
354c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
355c295a65dSaartbik     return success();
356c295a65dSaartbik   }
357c295a65dSaartbik };
358c295a65dSaartbik 
35939379916Saartbik /// Conversion pattern for a vector.maskedload.
36039379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
36139379916Saartbik public:
36239379916Saartbik   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
36339379916Saartbik                                         LLVMTypeConverter &typeConverter)
36439379916Saartbik       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
36539379916Saartbik                              typeConverter) {}
36639379916Saartbik 
36739379916Saartbik   LogicalResult
36839379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
36939379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
37039379916Saartbik     auto loc = op->getLoc();
37139379916Saartbik     auto load = cast<vector::MaskedLoadOp>(op);
37239379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
37339379916Saartbik 
37439379916Saartbik     // Resolve alignment.
37539379916Saartbik     unsigned align;
37639379916Saartbik     if (failed(getMemRefAlignment(typeConverter, load, align)))
37739379916Saartbik       return failure();
37839379916Saartbik 
37939379916Saartbik     auto vtype = typeConverter.convertType(load.getResultVectorType());
38039379916Saartbik     Value ptr;
38139379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
38239379916Saartbik                           vtype, ptr)))
38339379916Saartbik       return failure();
38439379916Saartbik 
38539379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
38639379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
38739379916Saartbik         rewriter.getI32IntegerAttr(align));
38839379916Saartbik     return success();
38939379916Saartbik   }
39039379916Saartbik };
39139379916Saartbik 
39239379916Saartbik /// Conversion pattern for a vector.maskedstore.
39339379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
39439379916Saartbik public:
39539379916Saartbik   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
39639379916Saartbik                                          LLVMTypeConverter &typeConverter)
39739379916Saartbik       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
39839379916Saartbik                              typeConverter) {}
39939379916Saartbik 
40039379916Saartbik   LogicalResult
40139379916Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40239379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
40339379916Saartbik     auto loc = op->getLoc();
40439379916Saartbik     auto store = cast<vector::MaskedStoreOp>(op);
40539379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
40639379916Saartbik 
40739379916Saartbik     // Resolve alignment.
40839379916Saartbik     unsigned align;
40939379916Saartbik     if (failed(getMemRefAlignment(typeConverter, store, align)))
41039379916Saartbik       return failure();
41139379916Saartbik 
41239379916Saartbik     auto vtype = typeConverter.convertType(store.getValueVectorType());
41339379916Saartbik     Value ptr;
41439379916Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
41539379916Saartbik                           vtype, ptr)))
41639379916Saartbik       return failure();
41739379916Saartbik 
41839379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
41939379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
42039379916Saartbik         rewriter.getI32IntegerAttr(align));
42139379916Saartbik     return success();
42239379916Saartbik   }
42339379916Saartbik };
42439379916Saartbik 
42519dbb230Saartbik /// Conversion pattern for a vector.gather.
42619dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern {
42719dbb230Saartbik public:
42819dbb230Saartbik   explicit VectorGatherOpConversion(MLIRContext *context,
42919dbb230Saartbik                                     LLVMTypeConverter &typeConverter)
43019dbb230Saartbik       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
43119dbb230Saartbik                              typeConverter) {}
43219dbb230Saartbik 
43319dbb230Saartbik   LogicalResult
43419dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
43519dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
43619dbb230Saartbik     auto loc = op->getLoc();
43719dbb230Saartbik     auto gather = cast<vector::GatherOp>(op);
43819dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
43919dbb230Saartbik 
44019dbb230Saartbik     // Resolve alignment.
44119dbb230Saartbik     unsigned align;
44219dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, gather, align)))
44319dbb230Saartbik       return failure();
44419dbb230Saartbik 
44519dbb230Saartbik     // Get index ptrs.
44619dbb230Saartbik     VectorType vType = gather.getResultVectorType();
44719dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
44819dbb230Saartbik     Value ptrs;
449e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
450e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
45119dbb230Saartbik       return failure();
45219dbb230Saartbik 
45319dbb230Saartbik     // Replace with the gather intrinsic.
45419dbb230Saartbik     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
45519dbb230Saartbik                                                           : adaptor.pass_thru();
45619dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
45719dbb230Saartbik         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
45819dbb230Saartbik         rewriter.getI32IntegerAttr(align));
45919dbb230Saartbik     return success();
46019dbb230Saartbik   }
46119dbb230Saartbik };
46219dbb230Saartbik 
46319dbb230Saartbik /// Conversion pattern for a vector.scatter.
46419dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern {
46519dbb230Saartbik public:
46619dbb230Saartbik   explicit VectorScatterOpConversion(MLIRContext *context,
46719dbb230Saartbik                                      LLVMTypeConverter &typeConverter)
46819dbb230Saartbik       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
46919dbb230Saartbik                              typeConverter) {}
47019dbb230Saartbik 
47119dbb230Saartbik   LogicalResult
47219dbb230Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
47319dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
47419dbb230Saartbik     auto loc = op->getLoc();
47519dbb230Saartbik     auto scatter = cast<vector::ScatterOp>(op);
47619dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
47719dbb230Saartbik 
47819dbb230Saartbik     // Resolve alignment.
47919dbb230Saartbik     unsigned align;
48019dbb230Saartbik     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
48119dbb230Saartbik       return failure();
48219dbb230Saartbik 
48319dbb230Saartbik     // Get index ptrs.
48419dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
48519dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
48619dbb230Saartbik     Value ptrs;
487e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
488e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
48919dbb230Saartbik       return failure();
49019dbb230Saartbik 
49119dbb230Saartbik     // Replace with the scatter intrinsic.
49219dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
49319dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
49419dbb230Saartbik         rewriter.getI32IntegerAttr(align));
49519dbb230Saartbik     return success();
49619dbb230Saartbik   }
49719dbb230Saartbik };
49819dbb230Saartbik 
499e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
500e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
501e8dcf5f8Saartbik public:
502e8dcf5f8Saartbik   explicit VectorExpandLoadOpConversion(MLIRContext *context,
503e8dcf5f8Saartbik                                         LLVMTypeConverter &typeConverter)
504e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
505e8dcf5f8Saartbik                              typeConverter) {}
506e8dcf5f8Saartbik 
507e8dcf5f8Saartbik   LogicalResult
508e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
509e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
510e8dcf5f8Saartbik     auto loc = op->getLoc();
511e8dcf5f8Saartbik     auto expand = cast<vector::ExpandLoadOp>(op);
512e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
513e8dcf5f8Saartbik 
514e8dcf5f8Saartbik     Value ptr;
515e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
516e8dcf5f8Saartbik                           ptr)))
517e8dcf5f8Saartbik       return failure();
518e8dcf5f8Saartbik 
519e8dcf5f8Saartbik     auto vType = expand.getResultVectorType();
520e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
521e8dcf5f8Saartbik         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
522e8dcf5f8Saartbik         adaptor.pass_thru());
523e8dcf5f8Saartbik     return success();
524e8dcf5f8Saartbik   }
525e8dcf5f8Saartbik };
526e8dcf5f8Saartbik 
527e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
528e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
529e8dcf5f8Saartbik public:
530e8dcf5f8Saartbik   explicit VectorCompressStoreOpConversion(MLIRContext *context,
531e8dcf5f8Saartbik                                            LLVMTypeConverter &typeConverter)
532e8dcf5f8Saartbik       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
533e8dcf5f8Saartbik                              context, typeConverter) {}
534e8dcf5f8Saartbik 
535e8dcf5f8Saartbik   LogicalResult
536e8dcf5f8Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
537e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
538e8dcf5f8Saartbik     auto loc = op->getLoc();
539e8dcf5f8Saartbik     auto compress = cast<vector::CompressStoreOp>(op);
540e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
541e8dcf5f8Saartbik 
542e8dcf5f8Saartbik     Value ptr;
543e8dcf5f8Saartbik     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
544e8dcf5f8Saartbik                           compress.getMemRefType(), ptr)))
545e8dcf5f8Saartbik       return failure();
546e8dcf5f8Saartbik 
547e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
548e8dcf5f8Saartbik         op, adaptor.value(), ptr, adaptor.mask());
549e8dcf5f8Saartbik     return success();
550e8dcf5f8Saartbik   }
551e8dcf5f8Saartbik };
552e8dcf5f8Saartbik 
55319dbb230Saartbik /// Conversion pattern for all vector reductions.
554870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern {
555e83b7b99Saartbik public:
556e83b7b99Saartbik   explicit VectorReductionOpConversion(MLIRContext *context,
557ceb1b327Saartbik                                        LLVMTypeConverter &typeConverter,
558*060c9dd1Saartbik                                        bool reassociateFPRed)
559870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
560ceb1b327Saartbik                              typeConverter),
561*060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
562e83b7b99Saartbik 
5633145427dSRiver Riddle   LogicalResult
564e83b7b99Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
565e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
566e83b7b99Saartbik     auto reductionOp = cast<vector::ReductionOp>(op);
567e83b7b99Saartbik     auto kind = reductionOp.kind();
568e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
5690f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(eltType);
57035b68527SLei Zhang     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
571e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
572e83b7b99Saartbik       if (kind == "add")
573e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
574e83b7b99Saartbik             op, llvmType, operands[0]);
575e83b7b99Saartbik       else if (kind == "mul")
576e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
577e83b7b99Saartbik             op, llvmType, operands[0]);
578e83b7b99Saartbik       else if (kind == "min")
579e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
580e83b7b99Saartbik             op, llvmType, operands[0]);
581e83b7b99Saartbik       else if (kind == "max")
582e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
583e83b7b99Saartbik             op, llvmType, operands[0]);
584e83b7b99Saartbik       else if (kind == "and")
585e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
586e83b7b99Saartbik             op, llvmType, operands[0]);
587e83b7b99Saartbik       else if (kind == "or")
588e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
589e83b7b99Saartbik             op, llvmType, operands[0]);
590e83b7b99Saartbik       else if (kind == "xor")
591e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
592e83b7b99Saartbik             op, llvmType, operands[0]);
593e83b7b99Saartbik       else
5943145427dSRiver Riddle         return failure();
5953145427dSRiver Riddle       return success();
596e83b7b99Saartbik 
597e83b7b99Saartbik     } else if (eltType.isF32() || eltType.isF64()) {
598e83b7b99Saartbik       // Floating-point reductions: add/mul/min/max
599e83b7b99Saartbik       if (kind == "add") {
6000d924700Saartbik         // Optional accumulator (or zero).
6010d924700Saartbik         Value acc = operands.size() > 1 ? operands[1]
6020d924700Saartbik                                         : rewriter.create<LLVM::ConstantOp>(
6030d924700Saartbik                                               op->getLoc(), llvmType,
6040d924700Saartbik                                               rewriter.getZeroAttr(eltType));
605e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
606ceb1b327Saartbik             op, llvmType, acc, operands[0],
607ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
608e83b7b99Saartbik       } else if (kind == "mul") {
6090d924700Saartbik         // Optional accumulator (or one).
6100d924700Saartbik         Value acc = operands.size() > 1
6110d924700Saartbik                         ? operands[1]
6120d924700Saartbik                         : rewriter.create<LLVM::ConstantOp>(
6130d924700Saartbik                               op->getLoc(), llvmType,
6140d924700Saartbik                               rewriter.getFloatAttr(eltType, 1.0));
615e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
616ceb1b327Saartbik             op, llvmType, acc, operands[0],
617ceb1b327Saartbik             rewriter.getBoolAttr(reassociateFPReductions));
618e83b7b99Saartbik       } else if (kind == "min")
619e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
620e83b7b99Saartbik             op, llvmType, operands[0]);
621e83b7b99Saartbik       else if (kind == "max")
622e83b7b99Saartbik         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
623e83b7b99Saartbik             op, llvmType, operands[0]);
624e83b7b99Saartbik       else
6253145427dSRiver Riddle         return failure();
6263145427dSRiver Riddle       return success();
627e83b7b99Saartbik     }
6283145427dSRiver Riddle     return failure();
629e83b7b99Saartbik   }
630ceb1b327Saartbik 
631ceb1b327Saartbik private:
632ceb1b327Saartbik   const bool reassociateFPReductions;
633e83b7b99Saartbik };
634e83b7b99Saartbik 
635*060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
636*060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
637*060c9dd1Saartbik public:
638*060c9dd1Saartbik   explicit VectorCreateMaskOpConversion(MLIRContext *context,
639*060c9dd1Saartbik                                         LLVMTypeConverter &typeConverter,
640*060c9dd1Saartbik                                         bool enableIndexOpt)
641*060c9dd1Saartbik       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
642*060c9dd1Saartbik                              typeConverter),
643*060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
644*060c9dd1Saartbik 
645*060c9dd1Saartbik   LogicalResult
646*060c9dd1Saartbik   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
647*060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
648*060c9dd1Saartbik     auto dstType = op->getResult(0).getType().cast<VectorType>();
649*060c9dd1Saartbik     int64_t rank = dstType.getRank();
650*060c9dd1Saartbik     if (rank == 1) {
651*060c9dd1Saartbik       rewriter.replaceOp(
652*060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
653*060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
654*060c9dd1Saartbik       return success();
655*060c9dd1Saartbik     }
656*060c9dd1Saartbik     return failure();
657*060c9dd1Saartbik   }
658*060c9dd1Saartbik 
659*060c9dd1Saartbik private:
660*060c9dd1Saartbik   const bool enableIndexOptimizations;
661*060c9dd1Saartbik };
662*060c9dd1Saartbik 
663870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern {
6641c81adf3SAart Bik public:
6651c81adf3SAart Bik   explicit VectorShuffleOpConversion(MLIRContext *context,
6661c81adf3SAart Bik                                      LLVMTypeConverter &typeConverter)
667870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
6681c81adf3SAart Bik                              typeConverter) {}
6691c81adf3SAart Bik 
6703145427dSRiver Riddle   LogicalResult
671e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
6721c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
6731c81adf3SAart Bik     auto loc = op->getLoc();
6742d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6751c81adf3SAart Bik     auto shuffleOp = cast<vector::ShuffleOp>(op);
6761c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6771c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6781c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
6790f04384dSAlex Zinenko     Type llvmType = typeConverter.convertType(vectorType);
6801c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6811c81adf3SAart Bik 
6821c81adf3SAart Bik     // Bail if result type cannot be lowered.
6831c81adf3SAart Bik     if (!llvmType)
6843145427dSRiver Riddle       return failure();
6851c81adf3SAart Bik 
6861c81adf3SAart Bik     // Get rank and dimension sizes.
6871c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6881c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6891c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6901c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6911c81adf3SAart Bik 
6921c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6931c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6941c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
695e62a6956SRiver Riddle       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
6961c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
6971c81adf3SAart Bik       rewriter.replaceOp(op, shuffle);
6983145427dSRiver Riddle       return success();
699b36aaeafSAart Bik     }
700b36aaeafSAart Bik 
7011c81adf3SAart Bik     // For all other cases, insert the individual values individually.
702e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
7031c81adf3SAart Bik     int64_t insPos = 0;
7041c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
7051c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
706e62a6956SRiver Riddle       Value value = adaptor.v1();
7071c81adf3SAart Bik       if (extPos >= v1Dim) {
7081c81adf3SAart Bik         extPos -= v1Dim;
7091c81adf3SAart Bik         value = adaptor.v2();
710b36aaeafSAart Bik       }
7110f04384dSAlex Zinenko       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
7120f04384dSAlex Zinenko                                  rank, extPos);
7130f04384dSAlex Zinenko       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
7140f04384dSAlex Zinenko                          llvmType, rank, insPos++);
7151c81adf3SAart Bik     }
7161c81adf3SAart Bik     rewriter.replaceOp(op, insert);
7173145427dSRiver Riddle     return success();
718b36aaeafSAart Bik   }
719b36aaeafSAart Bik };
720b36aaeafSAart Bik 
721870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
722cd5dab8aSAart Bik public:
723cd5dab8aSAart Bik   explicit VectorExtractElementOpConversion(MLIRContext *context,
724cd5dab8aSAart Bik                                             LLVMTypeConverter &typeConverter)
725870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
726870c1fd4SAlex Zinenko                              context, typeConverter) {}
727cd5dab8aSAart Bik 
7283145427dSRiver Riddle   LogicalResult
729e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
730cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
7312d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
732cd5dab8aSAart Bik     auto extractEltOp = cast<vector::ExtractElementOp>(op);
733cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
7340f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType.getElementType());
735cd5dab8aSAart Bik 
736cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
737cd5dab8aSAart Bik     if (!llvmType)
7383145427dSRiver Riddle       return failure();
739cd5dab8aSAart Bik 
740cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
741cd5dab8aSAart Bik         op, llvmType, adaptor.vector(), adaptor.position());
7423145427dSRiver Riddle     return success();
743cd5dab8aSAart Bik   }
744cd5dab8aSAart Bik };
745cd5dab8aSAart Bik 
746870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern {
7475c0c51a9SNicolas Vasilache public:
7489826fe5cSAart Bik   explicit VectorExtractOpConversion(MLIRContext *context,
7495c0c51a9SNicolas Vasilache                                      LLVMTypeConverter &typeConverter)
750870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
7515c0c51a9SNicolas Vasilache                              typeConverter) {}
7525c0c51a9SNicolas Vasilache 
7533145427dSRiver Riddle   LogicalResult
754e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7555c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7565c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
7572d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
758d37f2725SAart Bik     auto extractOp = cast<vector::ExtractOp>(op);
7599826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7602bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
7610f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(resultType);
7625c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7639826fe5cSAart Bik 
7649826fe5cSAart Bik     // Bail if result type cannot be lowered.
7659826fe5cSAart Bik     if (!llvmResultType)
7663145427dSRiver Riddle       return failure();
7679826fe5cSAart Bik 
7685c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7695c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
770e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7715c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
7725c0c51a9SNicolas Vasilache       rewriter.replaceOp(op, extracted);
7733145427dSRiver Riddle       return success();
7745c0c51a9SNicolas Vasilache     }
7755c0c51a9SNicolas Vasilache 
7769826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
7775c0c51a9SNicolas Vasilache     auto *context = op->getContext();
778e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7795c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7805c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7819826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7825c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7835c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7845c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
7850f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
7865c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7875c0c51a9SNicolas Vasilache     }
7885c0c51a9SNicolas Vasilache 
7895c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7905c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7915446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
7921d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7935c0c51a9SNicolas Vasilache     extracted =
7945c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
7955c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, extracted);
7965c0c51a9SNicolas Vasilache 
7973145427dSRiver Riddle     return success();
7985c0c51a9SNicolas Vasilache   }
7995c0c51a9SNicolas Vasilache };
8005c0c51a9SNicolas Vasilache 
801681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
802681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
803681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
804681f929fSNicolas Vasilache ///
805681f929fSNicolas Vasilache /// Example:
806681f929fSNicolas Vasilache /// ```
807681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
808681f929fSNicolas Vasilache /// ```
809681f929fSNicolas Vasilache /// is converted to:
810681f929fSNicolas Vasilache /// ```
8113bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
812681f929fSNicolas Vasilache ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
813681f929fSNicolas Vasilache ///    -> !llvm<"<8 x float>">
814681f929fSNicolas Vasilache /// ```
815870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
816681f929fSNicolas Vasilache public:
817681f929fSNicolas Vasilache   explicit VectorFMAOp1DConversion(MLIRContext *context,
818681f929fSNicolas Vasilache                                    LLVMTypeConverter &typeConverter)
819870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
820681f929fSNicolas Vasilache                              typeConverter) {}
821681f929fSNicolas Vasilache 
8223145427dSRiver Riddle   LogicalResult
823681f929fSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
824681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
8252d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
826681f929fSNicolas Vasilache     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
827681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
828681f929fSNicolas Vasilache     if (vType.getRank() != 1)
8293145427dSRiver Riddle       return failure();
8303bffe602SBenjamin Kramer     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
8313bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
8323145427dSRiver Riddle     return success();
833681f929fSNicolas Vasilache   }
834681f929fSNicolas Vasilache };
835681f929fSNicolas Vasilache 
836870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
837cd5dab8aSAart Bik public:
838cd5dab8aSAart Bik   explicit VectorInsertElementOpConversion(MLIRContext *context,
839cd5dab8aSAart Bik                                            LLVMTypeConverter &typeConverter)
840870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
841870c1fd4SAlex Zinenko                              context, typeConverter) {}
842cd5dab8aSAart Bik 
8433145427dSRiver Riddle   LogicalResult
844e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
845cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8462d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
847cd5dab8aSAart Bik     auto insertEltOp = cast<vector::InsertElementOp>(op);
848cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
8490f04384dSAlex Zinenko     auto llvmType = typeConverter.convertType(vectorType);
850cd5dab8aSAart Bik 
851cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
852cd5dab8aSAart Bik     if (!llvmType)
8533145427dSRiver Riddle       return failure();
854cd5dab8aSAart Bik 
855cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
856cd5dab8aSAart Bik         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
8573145427dSRiver Riddle     return success();
858cd5dab8aSAart Bik   }
859cd5dab8aSAart Bik };
860cd5dab8aSAart Bik 
861870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern {
8629826fe5cSAart Bik public:
8639826fe5cSAart Bik   explicit VectorInsertOpConversion(MLIRContext *context,
8649826fe5cSAart Bik                                     LLVMTypeConverter &typeConverter)
865870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
8669826fe5cSAart Bik                              typeConverter) {}
8679826fe5cSAart Bik 
8683145427dSRiver Riddle   LogicalResult
869e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
8709826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8719826fe5cSAart Bik     auto loc = op->getLoc();
8722d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8739826fe5cSAart Bik     auto insertOp = cast<vector::InsertOp>(op);
8749826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8759826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
8760f04384dSAlex Zinenko     auto llvmResultType = typeConverter.convertType(destVectorType);
8779826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8789826fe5cSAart Bik 
8799826fe5cSAart Bik     // Bail if result type cannot be lowered.
8809826fe5cSAart Bik     if (!llvmResultType)
8813145427dSRiver Riddle       return failure();
8829826fe5cSAart Bik 
8839826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8849826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
885e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8869826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8879826fe5cSAart Bik           positionArrayAttr);
8889826fe5cSAart Bik       rewriter.replaceOp(op, inserted);
8893145427dSRiver Riddle       return success();
8909826fe5cSAart Bik     }
8919826fe5cSAart Bik 
8929826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
8939826fe5cSAart Bik     auto *context = op->getContext();
894e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8959826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8969826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8979826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8989826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8999826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
9009826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9019826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9029826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
9030f04384dSAlex Zinenko           loc, typeConverter.convertType(oneDVectorType), extracted,
9049826fe5cSAart Bik           nMinusOnePositionAttrs);
9059826fe5cSAart Bik     }
9069826fe5cSAart Bik 
9079826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
9085446ec85SAlex Zinenko     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
9091d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
910e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
9110f04384dSAlex Zinenko         loc, typeConverter.convertType(oneDVectorType), extracted,
9120f04384dSAlex Zinenko         adaptor.source(), constant);
9139826fe5cSAart Bik 
9149826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
9159826fe5cSAart Bik     if (positionAttrs.size() > 1) {
9169826fe5cSAart Bik       auto nMinusOnePositionAttrs =
9179826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
9189826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
9199826fe5cSAart Bik                                                       adaptor.dest(), inserted,
9209826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
9219826fe5cSAart Bik     }
9229826fe5cSAart Bik 
9239826fe5cSAart Bik     rewriter.replaceOp(op, inserted);
9243145427dSRiver Riddle     return success();
9259826fe5cSAart Bik   }
9269826fe5cSAart Bik };
9279826fe5cSAart Bik 
928681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
929681f929fSNicolas Vasilache ///
930681f929fSNicolas Vasilache /// Example:
931681f929fSNicolas Vasilache /// ```
932681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
933681f929fSNicolas Vasilache /// ```
934681f929fSNicolas Vasilache /// is rewritten into:
935681f929fSNicolas Vasilache /// ```
936681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
937681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
938681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
939681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
940681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
941681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
942681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
943681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
944681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
945681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
946681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
947681f929fSNicolas Vasilache ///  // %r3 holds the final value.
948681f929fSNicolas Vasilache /// ```
949681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
950681f929fSNicolas Vasilache public:
951681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
952681f929fSNicolas Vasilache 
9533145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
954681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
955681f929fSNicolas Vasilache     auto vType = op.getVectorType();
956681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9573145427dSRiver Riddle       return failure();
958681f929fSNicolas Vasilache 
959681f929fSNicolas Vasilache     auto loc = op.getLoc();
960681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
961681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
962681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
963681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
964681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
965681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
966681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
967681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
968681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
969681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
970681f929fSNicolas Vasilache     }
971681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9723145427dSRiver Riddle     return success();
973681f929fSNicolas Vasilache   }
974681f929fSNicolas Vasilache };
975681f929fSNicolas Vasilache 
9762d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9772d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9782d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9792d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9802d515e49SNicolas Vasilache // rank.
9812d515e49SNicolas Vasilache //
9822d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9832d515e49SNicolas Vasilache // have different ranks. In this case:
9842d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9852d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9862d515e49SNicolas Vasilache //   destination subvector
9872d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9882d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9892d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9902d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9912d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9922d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9932d515e49SNicolas Vasilache public:
9942d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9952d515e49SNicolas Vasilache 
9963145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9972d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9982d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9992d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10002d515e49SNicolas Vasilache 
10012d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10023145427dSRiver Riddle       return failure();
10032d515e49SNicolas Vasilache 
10042d515e49SNicolas Vasilache     auto loc = op.getLoc();
10052d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10062d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10072d515e49SNicolas Vasilache     if (rankDiff == 0)
10083145427dSRiver Riddle       return failure();
10092d515e49SNicolas Vasilache 
10102d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
10112d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
10122d515e49SNicolas Vasilache     // on it.
10132d515e49SNicolas Vasilache     Value extracted =
10142d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
10152d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
10162d515e49SNicolas Vasilache                                                   /*dropFront=*/rankRest));
10172d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
10182d515e49SNicolas Vasilache     // ranks.
10192d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
10202d515e49SNicolas Vasilache         loc, op.source(), extracted,
10212d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1022c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
10232d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
10242d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
10252d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
10262d515e49SNicolas Vasilache                        /*dropFront=*/rankRest));
10273145427dSRiver Riddle     return success();
10282d515e49SNicolas Vasilache   }
10292d515e49SNicolas Vasilache };
10302d515e49SNicolas Vasilache 
10312d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
10322d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
10332d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
10342d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
10352d515e49SNicolas Vasilache //   destination subvector
10362d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
10372d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
10382d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
10392d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
10402d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
10412d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
10422d515e49SNicolas Vasilache public:
10432d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
10442d515e49SNicolas Vasilache 
10453145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10462d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10472d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10482d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10492d515e49SNicolas Vasilache 
10502d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10513145427dSRiver Riddle       return failure();
10522d515e49SNicolas Vasilache 
10532d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10542d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10552d515e49SNicolas Vasilache     if (rankDiff != 0)
10563145427dSRiver Riddle       return failure();
10572d515e49SNicolas Vasilache 
10582d515e49SNicolas Vasilache     if (srcType == dstType) {
10592d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10603145427dSRiver Riddle       return success();
10612d515e49SNicolas Vasilache     }
10622d515e49SNicolas Vasilache 
10632d515e49SNicolas Vasilache     int64_t offset =
10642d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10652d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10662d515e49SNicolas Vasilache     int64_t stride =
10672d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10682d515e49SNicolas Vasilache 
10692d515e49SNicolas Vasilache     auto loc = op.getLoc();
10702d515e49SNicolas Vasilache     Value res = op.dest();
10712d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10722d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10732d515e49SNicolas Vasilache          off += stride, ++idx) {
10742d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10752d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10762d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10772d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10782d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10792d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10802d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10812d515e49SNicolas Vasilache         // smaller rank.
1082bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10832d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10842d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10852d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10862d515e49SNicolas Vasilache       }
10872d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10882d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10892d515e49SNicolas Vasilache     }
10902d515e49SNicolas Vasilache 
10912d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10923145427dSRiver Riddle     return success();
10932d515e49SNicolas Vasilache   }
1094bd1ccfe6SRiver Riddle   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
1095bd1ccfe6SRiver Riddle   /// bounded as the rank is strictly decreasing.
1096bd1ccfe6SRiver Riddle   bool hasBoundedRewriteRecursion() const final { return true; }
10972d515e49SNicolas Vasilache };
10982d515e49SNicolas Vasilache 
10992bf491c7SBenjamin Kramer /// Returns true if the memory underlying `memRefType` has a contiguous layout.
11002bf491c7SBenjamin Kramer /// Strides are written to `strides`.
11012bf491c7SBenjamin Kramer static bool isContiguous(MemRefType memRefType,
11022bf491c7SBenjamin Kramer                          SmallVectorImpl<int64_t> &strides) {
11032bf491c7SBenjamin Kramer   int64_t offset;
11042bf491c7SBenjamin Kramer   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
11052bf491c7SBenjamin Kramer   bool isContiguous = (strides.back() == 1);
11062bf491c7SBenjamin Kramer   if (isContiguous) {
11072bf491c7SBenjamin Kramer     auto sizes = memRefType.getShape();
11082bf491c7SBenjamin Kramer     for (int index = 0, e = strides.size() - 2; index < e; ++index) {
11092bf491c7SBenjamin Kramer       if (strides[index] != strides[index + 1] * sizes[index + 1]) {
11102bf491c7SBenjamin Kramer         isContiguous = false;
11112bf491c7SBenjamin Kramer         break;
11122bf491c7SBenjamin Kramer       }
11132bf491c7SBenjamin Kramer     }
11142bf491c7SBenjamin Kramer   }
11152bf491c7SBenjamin Kramer   return succeeded(successStrides) && isContiguous;
11162bf491c7SBenjamin Kramer }
11172bf491c7SBenjamin Kramer 
1118870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
11195c0c51a9SNicolas Vasilache public:
11205c0c51a9SNicolas Vasilache   explicit VectorTypeCastOpConversion(MLIRContext *context,
11215c0c51a9SNicolas Vasilache                                       LLVMTypeConverter &typeConverter)
1122870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
11235c0c51a9SNicolas Vasilache                              typeConverter) {}
11245c0c51a9SNicolas Vasilache 
11253145427dSRiver Riddle   LogicalResult
1126e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
11275c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11285c0c51a9SNicolas Vasilache     auto loc = op->getLoc();
11295c0c51a9SNicolas Vasilache     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
11305c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
11312bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
11325c0c51a9SNicolas Vasilache     MemRefType targetMemRefType =
11332bdf33ccSRiver Riddle         castOp.getResult().getType().cast<MemRefType>();
11345c0c51a9SNicolas Vasilache 
11355c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11365c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11375c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11383145427dSRiver Riddle       return failure();
11395c0c51a9SNicolas Vasilache 
11405c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11412bdf33ccSRiver Riddle         operands[0].getType().dyn_cast<LLVM::LLVMType>();
11425c0c51a9SNicolas Vasilache     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
11433145427dSRiver Riddle       return failure();
11445c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11455c0c51a9SNicolas Vasilache 
11460f04384dSAlex Zinenko     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
11475c0c51a9SNicolas Vasilache                                       .dyn_cast_or_null<LLVM::LLVMType>();
11485c0c51a9SNicolas Vasilache     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
11493145427dSRiver Riddle       return failure();
11505c0c51a9SNicolas Vasilache 
11515c0c51a9SNicolas Vasilache     // Only contiguous source tensors supported atm.
11522bf491c7SBenjamin Kramer     SmallVector<int64_t, 4> strides;
11532bf491c7SBenjamin Kramer     if (!isContiguous(sourceMemRefType, strides))
11543145427dSRiver Riddle       return failure();
11555c0c51a9SNicolas Vasilache 
11565446ec85SAlex Zinenko     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
11575c0c51a9SNicolas Vasilache 
11585c0c51a9SNicolas Vasilache     // Create descriptor.
11595c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11605c0c51a9SNicolas Vasilache     Type llvmTargetElementTy = desc.getElementType();
11615c0c51a9SNicolas Vasilache     // Set allocated ptr.
1162e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11635c0c51a9SNicolas Vasilache     allocated =
11645c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11655c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11665c0c51a9SNicolas Vasilache     // Set aligned ptr.
1167e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11685c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11695c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11705c0c51a9SNicolas Vasilache     // Fill offset 0.
11715c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11725c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11735c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11745c0c51a9SNicolas Vasilache 
11755c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11765c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11775c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11785c0c51a9SNicolas Vasilache       auto sizeAttr =
11795c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11805c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11815c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
11825c0c51a9SNicolas Vasilache       auto strideAttr =
11835c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
11845c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11855c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11865c0c51a9SNicolas Vasilache     }
11875c0c51a9SNicolas Vasilache 
11885c0c51a9SNicolas Vasilache     rewriter.replaceOp(op, {desc});
11893145427dSRiver Riddle     return success();
11905c0c51a9SNicolas Vasilache   }
11915c0c51a9SNicolas Vasilache };
11925c0c51a9SNicolas Vasilache 
11938345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11948345b86dSNicolas Vasilache /// sequence of:
1195*060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1196*060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1197*060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1198*060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1199*060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
12008345b86dSNicolas Vasilache template <typename ConcreteOp>
12018345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern {
12028345b86dSNicolas Vasilache public:
12038345b86dSNicolas Vasilache   explicit VectorTransferConversion(MLIRContext *context,
1204*060c9dd1Saartbik                                     LLVMTypeConverter &typeConv,
1205*060c9dd1Saartbik                                     bool enableIndexOpt)
1206*060c9dd1Saartbik       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1207*060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
12088345b86dSNicolas Vasilache 
12098345b86dSNicolas Vasilache   LogicalResult
12108345b86dSNicolas Vasilache   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
12118345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
12128345b86dSNicolas Vasilache     auto xferOp = cast<ConcreteOp>(op);
12138345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1214b2c79c50SNicolas Vasilache 
1215b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1216b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
12178345b86dSNicolas Vasilache       return failure();
12185f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
12195f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
12205f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
12215f9e0466SNicolas Vasilache                                        op->getContext()))
12228345b86dSNicolas Vasilache       return failure();
12232bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
12242bf491c7SBenjamin Kramer     SmallVector<int64_t, 4> strides;
12252bf491c7SBenjamin Kramer     if (!isContiguous(xferOp.getMemRefType(), strides))
12262bf491c7SBenjamin Kramer       return failure();
12278345b86dSNicolas Vasilache 
12288345b86dSNicolas Vasilache     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
12298345b86dSNicolas Vasilache 
12308345b86dSNicolas Vasilache     Location loc = op->getLoc();
12318345b86dSNicolas Vasilache     MemRefType memRefType = xferOp.getMemRefType();
12328345b86dSNicolas Vasilache 
123368330ee0SThomas Raoux     if (auto memrefVectorElementType =
123468330ee0SThomas Raoux             memRefType.getElementType().dyn_cast<VectorType>()) {
123568330ee0SThomas Raoux       // Memref has vector element type.
123668330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
123768330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
123868330ee0SThomas Raoux         return failure();
12390de60b55SThomas Raoux #ifndef NDEBUG
124068330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
124168330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
124268330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
124368330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
124468330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
124568330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
124668330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
124768330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
124868330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
124968330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
125068330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
125168330ee0SThomas Raoux                "result shape.");
12520de60b55SThomas Raoux #endif // ifndef NDEBUG
125368330ee0SThomas Raoux     }
125468330ee0SThomas Raoux 
12558345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1256be16075bSWen-Heng (Jack) Chung     //    The vector pointer would always be on address space 0, therefore
1257be16075bSWen-Heng (Jack) Chung     //    addrspacecast shall be used when source/dst memrefs are not on
1258be16075bSWen-Heng (Jack) Chung     //    address space 0.
12598345b86dSNicolas Vasilache     // TODO: support alignment when possible.
12608345b86dSNicolas Vasilache     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1261d3a98076SAlex Zinenko                                adaptor.indices(), rewriter);
12628345b86dSNicolas Vasilache     auto vecTy =
12638345b86dSNicolas Vasilache         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1264be16075bSWen-Heng (Jack) Chung     Value vectorDataPtr;
1265be16075bSWen-Heng (Jack) Chung     if (memRefType.getMemorySpace() == 0)
1266be16075bSWen-Heng (Jack) Chung       vectorDataPtr =
12678345b86dSNicolas Vasilache           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1268be16075bSWen-Heng (Jack) Chung     else
1269be16075bSWen-Heng (Jack) Chung       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1270be16075bSWen-Heng (Jack) Chung           loc, vecTy.getPointerTo(), dataPtr);
12718345b86dSNicolas Vasilache 
12721870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
12731870e787SNicolas Vasilache       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
12741870e787SNicolas Vasilache                                               xferOp, operands, vectorDataPtr);
12751870e787SNicolas Vasilache 
12768345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12778345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12788345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12798345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1280*060c9dd1Saartbik     //
1281*060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1282*060c9dd1Saartbik     //       dimensions here.
1283*060c9dd1Saartbik     unsigned vecWidth = vecTy.getVectorNumElements();
1284*060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1285*060c9dd1Saartbik     Value off = *(xferOp.indices().begin() + lastIndex);
1286b2c79c50SNicolas Vasilache     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1287*060c9dd1Saartbik     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1288*060c9dd1Saartbik                                        vecWidth, dim, &off);
12898345b86dSNicolas Vasilache 
12908345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
12911870e787SNicolas Vasilache     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1292a99f62c4SAlex Zinenko                                        operands, vectorDataPtr, mask);
12938345b86dSNicolas Vasilache   }
1294*060c9dd1Saartbik 
1295*060c9dd1Saartbik private:
1296*060c9dd1Saartbik   const bool enableIndexOptimizations;
12978345b86dSNicolas Vasilache };
12988345b86dSNicolas Vasilache 
1299870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern {
1300d9b500d3SAart Bik public:
1301d9b500d3SAart Bik   explicit VectorPrintOpConversion(MLIRContext *context,
1302d9b500d3SAart Bik                                    LLVMTypeConverter &typeConverter)
1303870c1fd4SAlex Zinenko       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1304d9b500d3SAart Bik                              typeConverter) {}
1305d9b500d3SAart Bik 
1306d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1307d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1308d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1309d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1310d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1311d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1312d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1313d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1314d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1315d9b500d3SAart Bik   //
13169db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1317d9b500d3SAart Bik   //
13183145427dSRiver Riddle   LogicalResult
1319e62a6956SRiver Riddle   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1320d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
1321d9b500d3SAart Bik     auto printOp = cast<vector::PrintOp>(op);
13222d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1323d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1324d9b500d3SAart Bik 
13250f04384dSAlex Zinenko     if (typeConverter.convertType(printType) == nullptr)
13263145427dSRiver Riddle       return failure();
1327d9b500d3SAart Bik 
1328d9b500d3SAart Bik     // Make sure element type has runtime support (currently just Float/Double).
1329d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1330d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1331d9b500d3SAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1332d9b500d3SAart Bik     Operation *printer;
1333c9eeeb38Saartbik     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1334e52414b1Saartbik       printer = getPrintI32(op);
133535b68527SLei Zhang     else if (eltType.isSignlessInteger(64))
1336e52414b1Saartbik       printer = getPrintI64(op);
1337e52414b1Saartbik     else if (eltType.isF32())
1338d9b500d3SAart Bik       printer = getPrintFloat(op);
1339d9b500d3SAart Bik     else if (eltType.isF64())
1340d9b500d3SAart Bik       printer = getPrintDouble(op);
1341d9b500d3SAart Bik     else
13423145427dSRiver Riddle       return failure();
1343d9b500d3SAart Bik 
1344d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1345d9b500d3SAart Bik     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1346d9b500d3SAart Bik     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1347d9b500d3SAart Bik     rewriter.eraseOp(op);
13483145427dSRiver Riddle     return success();
1349d9b500d3SAart Bik   }
1350d9b500d3SAart Bik 
1351d9b500d3SAart Bik private:
1352d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1353e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1354d9b500d3SAart Bik                  int64_t rank) const {
1355d9b500d3SAart Bik     Location loc = op->getLoc();
1356d9b500d3SAart Bik     if (rank == 0) {
13575446ec85SAlex Zinenko       if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
1358c9eeeb38Saartbik         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1359c9eeeb38Saartbik         // This avoids the need for a print_i1 method with an unclear ABI.
13605446ec85SAlex Zinenko         auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
1361c9eeeb38Saartbik         auto trueVal = rewriter.create<ConstantOp>(
1362c9eeeb38Saartbik             loc, i32Type, rewriter.getI32IntegerAttr(1));
1363c9eeeb38Saartbik         auto falseVal = rewriter.create<ConstantOp>(
1364c9eeeb38Saartbik             loc, i32Type, rewriter.getI32IntegerAttr(0));
1365c9eeeb38Saartbik         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1366c9eeeb38Saartbik       }
1367d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1368d9b500d3SAart Bik       return;
1369d9b500d3SAart Bik     }
1370d9b500d3SAart Bik 
1371d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1372d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1373d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1374d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1375d9b500d3SAart Bik       auto reducedType =
1376d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
13770f04384dSAlex Zinenko       auto llvmType = typeConverter.convertType(
1378d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1379e62a6956SRiver Riddle       Value nestedVal =
13800f04384dSAlex Zinenko           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1381d9b500d3SAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1382d9b500d3SAart Bik       if (d != dim - 1)
1383d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1384d9b500d3SAart Bik     }
1385d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1386d9b500d3SAart Bik   }
1387d9b500d3SAart Bik 
1388d9b500d3SAart Bik   // Helper to emit a call.
1389d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1390d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
1391d9b500d3SAart Bik     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1392d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1393d9b500d3SAart Bik   }
1394d9b500d3SAart Bik 
1395d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
13965446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
13975446ec85SAlex Zinenko                              ArrayRef<LLVM::LLVMType> params) {
1398d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1399d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1400d9b500d3SAart Bik     if (func)
1401d9b500d3SAart Bik       return func;
1402d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1403d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1404d9b500d3SAart Bik         op->getLoc(), name,
14055446ec85SAlex Zinenko         LLVM::LLVMType::getFunctionTy(
14065446ec85SAlex Zinenko             LLVM::LLVMType::getVoidTy(op->getContext()), params,
14075446ec85SAlex Zinenko             /*isVarArg=*/false));
1408d9b500d3SAart Bik   }
1409d9b500d3SAart Bik 
1410d9b500d3SAart Bik   // Helpers for method names.
1411e52414b1Saartbik   Operation *getPrintI32(Operation *op) const {
14125446ec85SAlex Zinenko     return getPrint(op, "print_i32",
14135446ec85SAlex Zinenko                     LLVM::LLVMType::getInt32Ty(op->getContext()));
1414e52414b1Saartbik   }
1415e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
14165446ec85SAlex Zinenko     return getPrint(op, "print_i64",
14175446ec85SAlex Zinenko                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1418e52414b1Saartbik   }
1419d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
14205446ec85SAlex Zinenko     return getPrint(op, "print_f32",
14215446ec85SAlex Zinenko                     LLVM::LLVMType::getFloatTy(op->getContext()));
1422d9b500d3SAart Bik   }
1423d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
14245446ec85SAlex Zinenko     return getPrint(op, "print_f64",
14255446ec85SAlex Zinenko                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1426d9b500d3SAart Bik   }
1427d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
14285446ec85SAlex Zinenko     return getPrint(op, "print_open", {});
1429d9b500d3SAart Bik   }
1430d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
14315446ec85SAlex Zinenko     return getPrint(op, "print_close", {});
1432d9b500d3SAart Bik   }
1433d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
14345446ec85SAlex Zinenko     return getPrint(op, "print_comma", {});
1435d9b500d3SAart Bik   }
1436d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
14375446ec85SAlex Zinenko     return getPrint(op, "print_newline", {});
1438d9b500d3SAart Bik   }
1439d9b500d3SAart Bik };
1440d9b500d3SAart Bik 
1441334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1442c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1443c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1444c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1445334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
144665678d93SNicolas Vasilache public:
1447334a4159SReid Tatge   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
144865678d93SNicolas Vasilache 
1449334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
145065678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
145165678d93SNicolas Vasilache     auto dstType = op.getResult().getType().cast<VectorType>();
145265678d93SNicolas Vasilache 
145365678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
145465678d93SNicolas Vasilache 
145565678d93SNicolas Vasilache     int64_t offset =
145665678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
145765678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
145865678d93SNicolas Vasilache     int64_t stride =
145965678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
146065678d93SNicolas Vasilache 
146165678d93SNicolas Vasilache     auto loc = op.getLoc();
146265678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
146335b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1464c3c95b9cSaartbik 
1465c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1466c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1467c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1468c3c95b9cSaartbik       offsets.reserve(size);
1469c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1470c3c95b9cSaartbik            off += stride)
1471c3c95b9cSaartbik         offsets.push_back(off);
1472c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1473c3c95b9cSaartbik                                              op.vector(),
1474c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1475c3c95b9cSaartbik       return success();
1476c3c95b9cSaartbik     }
1477c3c95b9cSaartbik 
1478c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
147965678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
148065678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
148165678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
148265678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
148365678d93SNicolas Vasilache          off += stride, ++idx) {
1484c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1485c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1486c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
148765678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
148865678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
148965678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
149065678d93SNicolas Vasilache     }
1491c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14923145427dSRiver Riddle     return success();
149365678d93SNicolas Vasilache   }
1494334a4159SReid Tatge   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1495bd1ccfe6SRiver Riddle   /// bounded as the rank is strictly decreasing.
1496bd1ccfe6SRiver Riddle   bool hasBoundedRewriteRecursion() const final { return true; }
149765678d93SNicolas Vasilache };
149865678d93SNicolas Vasilache 
1499df186507SBenjamin Kramer } // namespace
1500df186507SBenjamin Kramer 
15015c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
15025c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1503ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1504*060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
150565678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15068345b86dSNicolas Vasilache   // clang-format off
1507681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1508681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15092d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1510c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1511ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1512ceb1b327Saartbik       ctx, converter, reassociateFPReductions);
1513*060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1514*060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1515*060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1516*060c9dd1Saartbik       ctx, converter, enableIndexOptimizations);
15178345b86dSNicolas Vasilache   patterns
1518ceb1b327Saartbik       .insert<VectorShuffleOpConversion,
15198345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15208345b86dSNicolas Vasilache               VectorExtractOpConversion,
15218345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15228345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15238345b86dSNicolas Vasilache               VectorInsertOpConversion,
15248345b86dSNicolas Vasilache               VectorPrintOpConversion,
152519dbb230Saartbik               VectorTypeCastOpConversion,
152639379916Saartbik               VectorMaskedLoadOpConversion,
152739379916Saartbik               VectorMaskedStoreOpConversion,
152819dbb230Saartbik               VectorGatherOpConversion,
1529e8dcf5f8Saartbik               VectorScatterOpConversion,
1530e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1531e8dcf5f8Saartbik               VectorCompressStoreOpConversion>(ctx, converter);
15328345b86dSNicolas Vasilache   // clang-format on
15335c0c51a9SNicolas Vasilache }
15345c0c51a9SNicolas Vasilache 
153563b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
153663b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
153763b683a8SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
153863b683a8SNicolas Vasilache   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1539c295a65dSaartbik   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
154063b683a8SNicolas Vasilache }
154163b683a8SNicolas Vasilache 
15425c0c51a9SNicolas Vasilache namespace {
1543722f909fSRiver Riddle struct LowerVectorToLLVMPass
15441834ad4aSRiver Riddle     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
15451bfdf7c7Saartbik   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
15461bfdf7c7Saartbik     this->reassociateFPReductions = options.reassociateFPReductions;
1547*060c9dd1Saartbik     this->enableIndexOptimizations = options.enableIndexOptimizations;
15481bfdf7c7Saartbik   }
1549722f909fSRiver Riddle   void runOnOperation() override;
15505c0c51a9SNicolas Vasilache };
15515c0c51a9SNicolas Vasilache } // namespace
15525c0c51a9SNicolas Vasilache 
1553722f909fSRiver Riddle void LowerVectorToLLVMPass::runOnOperation() {
1554078776a6Saartbik   // Perform progressive lowering of operations on slices and
1555b21c7999Saartbik   // all contraction operations. Also applies folding and DCE.
1556459cf6e5Saartbik   {
15575c0c51a9SNicolas Vasilache     OwningRewritePatternList patterns;
1558b1c688dbSaartbik     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1559459cf6e5Saartbik     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1560b21c7999Saartbik     populateVectorContractLoweringPatterns(patterns, &getContext());
1561a5b9316bSUday Bondhugula     applyPatternsAndFoldGreedily(getOperation(), patterns);
1562459cf6e5Saartbik   }
1563459cf6e5Saartbik 
1564459cf6e5Saartbik   // Convert to the LLVM IR dialect.
15655c0c51a9SNicolas Vasilache   LLVMTypeConverter converter(&getContext());
1566459cf6e5Saartbik   OwningRewritePatternList patterns;
156763b683a8SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1568*060c9dd1Saartbik   populateVectorToLLVMConversionPatterns(
1569*060c9dd1Saartbik       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
1570bbf3ef85SNicolas Vasilache   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
15715c0c51a9SNicolas Vasilache   populateStdToLLVMConversionPatterns(converter, patterns);
15725c0c51a9SNicolas Vasilache 
15732a00ae39STim Shen   LLVMConversionTarget target(getContext());
1574*060c9dd1Saartbik   if (failed(applyPartialConversion(getOperation(), target, patterns)))
15755c0c51a9SNicolas Vasilache     signalPassFailure();
15765c0c51a9SNicolas Vasilache }
15775c0c51a9SNicolas Vasilache 
15781bfdf7c7Saartbik std::unique_ptr<OperationPass<ModuleOp>>
15791bfdf7c7Saartbik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
15801bfdf7c7Saartbik   return std::make_unique<LowerVectorToLLVMPass>(options);
15815c0c51a9SNicolas Vasilache }
1582