15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
25c0c51a9SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c0c51a9SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c0c51a9SNicolas Vasilache 
965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10870c1fd4SAlex Zinenko 
115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h"
1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h"
185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h"
195c0c51a9SNicolas Vasilache 
205c0c51a9SNicolas Vasilache using namespace mlir;
2165678d93SNicolas Vasilache using namespace mlir::vector;
225c0c51a9SNicolas Vasilache 
239826fe5cSAart Bik // Helper to reduce vector type by one rank at front.
249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) {
259826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
269826fe5cSAart Bik   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
279826fe5cSAart Bik }
289826fe5cSAart Bik 
299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back.
309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) {
319826fe5cSAart Bik   assert((tp.getRank() > 1) && "unlowerable vector type");
329826fe5cSAart Bik   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
339826fe5cSAart Bik }
349826fe5cSAart Bik 
351c81adf3SAart Bik // Helper that picks the proper sequence for inserting.
36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter,
370f04384dSAlex Zinenko                        LLVMTypeConverter &typeConverter, Location loc,
380f04384dSAlex Zinenko                        Value val1, Value val2, Type llvmType, int64_t rank,
390f04384dSAlex Zinenko                        int64_t pos) {
401c81adf3SAart Bik   if (rank == 1) {
411c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
421c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
430f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
441c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
451c81adf3SAart Bik     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
461c81adf3SAart Bik                                                   constant);
471c81adf3SAart Bik   }
481c81adf3SAart Bik   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
491c81adf3SAart Bik                                               rewriter.getI64ArrayAttr(pos));
501c81adf3SAart Bik }
511c81adf3SAart Bik 
522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting.
532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
542d515e49SNicolas Vasilache                        Value into, int64_t offset) {
552d515e49SNicolas Vasilache   auto vectorType = into.getType().cast<VectorType>();
562d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
572d515e49SNicolas Vasilache     return rewriter.create<InsertOp>(loc, from, into, offset);
582d515e49SNicolas Vasilache   return rewriter.create<vector::InsertElementOp>(
592d515e49SNicolas Vasilache       loc, vectorType, from, into,
602d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
612d515e49SNicolas Vasilache }
622d515e49SNicolas Vasilache 
631c81adf3SAart Bik // Helper that picks the proper sequence for extracting.
64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter,
650f04384dSAlex Zinenko                         LLVMTypeConverter &typeConverter, Location loc,
660f04384dSAlex Zinenko                         Value val, Type llvmType, int64_t rank, int64_t pos) {
671c81adf3SAart Bik   if (rank == 1) {
681c81adf3SAart Bik     auto idxType = rewriter.getIndexType();
691c81adf3SAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(
700f04384dSAlex Zinenko         loc, typeConverter.convertType(idxType),
711c81adf3SAart Bik         rewriter.getIntegerAttr(idxType, pos));
721c81adf3SAart Bik     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
731c81adf3SAart Bik                                                    constant);
741c81adf3SAart Bik   }
751c81adf3SAart Bik   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
761c81adf3SAart Bik                                                rewriter.getI64ArrayAttr(pos));
771c81adf3SAart Bik }
781c81adf3SAart Bik 
792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting.
802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
812d515e49SNicolas Vasilache                         int64_t offset) {
822d515e49SNicolas Vasilache   auto vectorType = vector.getType().cast<VectorType>();
832d515e49SNicolas Vasilache   if (vectorType.getRank() > 1)
842d515e49SNicolas Vasilache     return rewriter.create<ExtractOp>(loc, vector, offset);
852d515e49SNicolas Vasilache   return rewriter.create<vector::ExtractElementOp>(
862d515e49SNicolas Vasilache       loc, vectorType.getElementType(), vector,
872d515e49SNicolas Vasilache       rewriter.create<ConstantIndexOp>(loc, offset));
882d515e49SNicolas Vasilache }
892d515e49SNicolas Vasilache 
902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing.
922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
932d515e49SNicolas Vasilache                                               unsigned dropFront = 0,
942d515e49SNicolas Vasilache                                               unsigned dropBack = 0) {
952d515e49SNicolas Vasilache   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
962d515e49SNicolas Vasilache   auto range = arrayAttr.getAsRange<IntegerAttr>();
972d515e49SNicolas Vasilache   SmallVector<int64_t, 4> res;
982d515e49SNicolas Vasilache   res.reserve(arrayAttr.size() - dropFront - dropBack);
992d515e49SNicolas Vasilache   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
1002d515e49SNicolas Vasilache        it != eit; ++it)
1012d515e49SNicolas Vasilache     res.push_back((*it).getValue().getSExtValue());
1022d515e49SNicolas Vasilache   return res;
1032d515e49SNicolas Vasilache }
1042d515e49SNicolas Vasilache 
105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask:
106060c9dd1Saartbik //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107060c9dd1Saartbik //
108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109060c9dd1Saartbik //       much more compact, IR for this operation, but LLVM eventually
110060c9dd1Saartbik //       generates more elaborate instructions for this intrinsic since it
111060c9dd1Saartbik //       is very conservative on the boundary conditions.
112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113060c9dd1Saartbik                                    Operation *op, bool enableIndexOptimizations,
114060c9dd1Saartbik                                    int64_t dim, Value b, Value *off = nullptr) {
115060c9dd1Saartbik   auto loc = op->getLoc();
116060c9dd1Saartbik   // If we can assume all indices fit in 32-bit, we perform the vector
117060c9dd1Saartbik   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118060c9dd1Saartbik   // Otherwise we perform the vector comparison using 64-bit indices.
119060c9dd1Saartbik   Value indices;
120060c9dd1Saartbik   Type idxType;
121060c9dd1Saartbik   if (enableIndexOptimizations) {
1220c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1230c2a4d3cSBenjamin Kramer         loc, rewriter.getI32VectorAttr(
1240c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125060c9dd1Saartbik     idxType = rewriter.getI32Type();
126060c9dd1Saartbik   } else {
1270c2a4d3cSBenjamin Kramer     indices = rewriter.create<ConstantOp>(
1280c2a4d3cSBenjamin Kramer         loc, rewriter.getI64VectorAttr(
1290c2a4d3cSBenjamin Kramer                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130060c9dd1Saartbik     idxType = rewriter.getI64Type();
131060c9dd1Saartbik   }
132060c9dd1Saartbik   // Add in an offset if requested.
133060c9dd1Saartbik   if (off) {
134060c9dd1Saartbik     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135060c9dd1Saartbik     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136060c9dd1Saartbik     indices = rewriter.create<AddIOp>(loc, ov, indices);
137060c9dd1Saartbik   }
138060c9dd1Saartbik   // Construct the vector comparison.
139060c9dd1Saartbik   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140060c9dd1Saartbik   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141060c9dd1Saartbik   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142060c9dd1Saartbik }
143060c9dd1Saartbik 
14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref.
14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
14626c8f908SThomas Raoux                                  MemRefType memrefType, unsigned &align) {
14726c8f908SThomas Raoux   Type elementTy = typeConverter.convertType(memrefType.getElementType());
1485f9e0466SNicolas Vasilache   if (!elementTy)
1495f9e0466SNicolas Vasilache     return failure();
1505f9e0466SNicolas Vasilache 
151b2ab375dSAlex Zinenko   // TODO: this should use the MLIR data layout when it becomes available and
152b2ab375dSAlex Zinenko   // stop depending on translation.
15387a89e0fSAlex Zinenko   llvm::LLVMContext llvmContext;
15487a89e0fSAlex Zinenko   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
155c69c9e0fSAlex Zinenko               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
1565f9e0466SNicolas Vasilache   return success();
1575f9e0466SNicolas Vasilache }
1585f9e0466SNicolas Vasilache 
159e8dcf5f8Saartbik // Helper that returns the base address of a memref.
160b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
161e8dcf5f8Saartbik                              Value memref, MemRefType memRefType, Value &base) {
16219dbb230Saartbik   // Inspect stride and offset structure.
16319dbb230Saartbik   //
16419dbb230Saartbik   // TODO: flat memory only for now, generalize
16519dbb230Saartbik   //
16619dbb230Saartbik   int64_t offset;
16719dbb230Saartbik   SmallVector<int64_t, 4> strides;
16819dbb230Saartbik   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
16919dbb230Saartbik   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
17019dbb230Saartbik       offset != 0 || memRefType.getMemorySpace() != 0)
17119dbb230Saartbik     return failure();
172e8dcf5f8Saartbik   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
173e8dcf5f8Saartbik   return success();
174e8dcf5f8Saartbik }
17519dbb230Saartbik 
176a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector.
177b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
178b98e25b6SBenjamin Kramer                                     Location loc, Value memref, Value indices,
179b98e25b6SBenjamin Kramer                                     MemRefType memRefType, VectorType vType,
180b98e25b6SBenjamin Kramer                                     Type iType, Value &ptrs) {
181e8dcf5f8Saartbik   Value base;
182e8dcf5f8Saartbik   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
183e8dcf5f8Saartbik     return failure();
1843a577f54SChristian Sigg   auto pType = MemRefDescriptor(memref).getElementPtrType();
185bd30a796SAlex Zinenko   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
1861485fd29Saartbik   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
18719dbb230Saartbik   return success();
18819dbb230Saartbik }
18919dbb230Saartbik 
190a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer
191a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be
192a57def30SAart Bik // used when source/dst memrefs are not on address space 0.
193a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
194a57def30SAart Bik                          Value ptr, MemRefType memRefType, Type vt) {
195bd30a796SAlex Zinenko   auto pType = LLVM::LLVMPointerType::get(vt);
196a57def30SAart Bik   if (memRefType.getMemorySpace() == 0)
197a57def30SAart Bik     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
198a57def30SAart Bik   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
199a57def30SAart Bik }
200a57def30SAart Bik 
2015f9e0466SNicolas Vasilache static LogicalResult
2025f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2035f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2045f9e0466SNicolas Vasilache                                  TransferReadOp xferOp,
2055f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
206affbc0cdSNicolas Vasilache   unsigned align;
20726c8f908SThomas Raoux   if (failed(getMemRefAlignment(
20826c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
209affbc0cdSNicolas Vasilache     return failure();
210affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
2115f9e0466SNicolas Vasilache   return success();
2125f9e0466SNicolas Vasilache }
2135f9e0466SNicolas Vasilache 
2145f9e0466SNicolas Vasilache static LogicalResult
2155f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2165f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2175f9e0466SNicolas Vasilache                             TransferReadOp xferOp, ArrayRef<Value> operands,
2185f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2195f9e0466SNicolas Vasilache   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
2205f9e0466SNicolas Vasilache   VectorType fillType = xferOp.getVectorType();
2215f9e0466SNicolas Vasilache   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
2225f9e0466SNicolas Vasilache   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
2235f9e0466SNicolas Vasilache 
2245f9e0466SNicolas Vasilache   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
2255f9e0466SNicolas Vasilache   if (!vecTy)
2265f9e0466SNicolas Vasilache     return failure();
2275f9e0466SNicolas Vasilache 
2285f9e0466SNicolas Vasilache   unsigned align;
22926c8f908SThomas Raoux   if (failed(getMemRefAlignment(
23026c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2315f9e0466SNicolas Vasilache     return failure();
2325f9e0466SNicolas Vasilache 
2335f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
2345f9e0466SNicolas Vasilache       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
2355f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2365f9e0466SNicolas Vasilache   return success();
2375f9e0466SNicolas Vasilache }
2385f9e0466SNicolas Vasilache 
2395f9e0466SNicolas Vasilache static LogicalResult
2405f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
2415f9e0466SNicolas Vasilache                                  LLVMTypeConverter &typeConverter, Location loc,
2425f9e0466SNicolas Vasilache                                  TransferWriteOp xferOp,
2435f9e0466SNicolas Vasilache                                  ArrayRef<Value> operands, Value dataPtr) {
244affbc0cdSNicolas Vasilache   unsigned align;
24526c8f908SThomas Raoux   if (failed(getMemRefAlignment(
24626c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
247affbc0cdSNicolas Vasilache     return failure();
2482d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
249affbc0cdSNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
250affbc0cdSNicolas Vasilache                                              align);
2515f9e0466SNicolas Vasilache   return success();
2525f9e0466SNicolas Vasilache }
2535f9e0466SNicolas Vasilache 
2545f9e0466SNicolas Vasilache static LogicalResult
2555f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
2565f9e0466SNicolas Vasilache                             LLVMTypeConverter &typeConverter, Location loc,
2575f9e0466SNicolas Vasilache                             TransferWriteOp xferOp, ArrayRef<Value> operands,
2585f9e0466SNicolas Vasilache                             Value dataPtr, Value mask) {
2595f9e0466SNicolas Vasilache   unsigned align;
26026c8f908SThomas Raoux   if (failed(getMemRefAlignment(
26126c8f908SThomas Raoux           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
2625f9e0466SNicolas Vasilache     return failure();
2635f9e0466SNicolas Vasilache 
2642d2c73c5SJacques Pienaar   auto adaptor = TransferWriteOpAdaptor(operands);
2655f9e0466SNicolas Vasilache   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
2665f9e0466SNicolas Vasilache       xferOp, adaptor.vector(), dataPtr, mask,
2675f9e0466SNicolas Vasilache       rewriter.getI32IntegerAttr(align));
2685f9e0466SNicolas Vasilache   return success();
2695f9e0466SNicolas Vasilache }
2705f9e0466SNicolas Vasilache 
2712d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
2722d2c73c5SJacques Pienaar                                                   ArrayRef<Value> operands) {
2732d2c73c5SJacques Pienaar   return TransferReadOpAdaptor(operands);
2745f9e0466SNicolas Vasilache }
2755f9e0466SNicolas Vasilache 
2762d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
2772d2c73c5SJacques Pienaar                                                    ArrayRef<Value> operands) {
2782d2c73c5SJacques Pienaar   return TransferWriteOpAdaptor(operands);
2795f9e0466SNicolas Vasilache }
2805f9e0466SNicolas Vasilache 
28190c01357SBenjamin Kramer namespace {
282e83b7b99Saartbik 
283*cf5c517cSDiego Caballero /// Conversion pattern for a vector.bitcast.
284*cf5c517cSDiego Caballero class VectorBitCastOpConversion
285*cf5c517cSDiego Caballero     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
286*cf5c517cSDiego Caballero public:
287*cf5c517cSDiego Caballero   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
288*cf5c517cSDiego Caballero 
289*cf5c517cSDiego Caballero   LogicalResult
290*cf5c517cSDiego Caballero   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
291*cf5c517cSDiego Caballero                   ConversionPatternRewriter &rewriter) const override {
292*cf5c517cSDiego Caballero     // Only 1-D vectors can be lowered to LLVM.
293*cf5c517cSDiego Caballero     VectorType resultTy = bitCastOp.getType();
294*cf5c517cSDiego Caballero     if (resultTy.getRank() != 1)
295*cf5c517cSDiego Caballero       return failure();
296*cf5c517cSDiego Caballero     Type newResultTy = typeConverter->convertType(resultTy);
297*cf5c517cSDiego Caballero     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
298*cf5c517cSDiego Caballero                                                  operands[0]);
299*cf5c517cSDiego Caballero     return success();
300*cf5c517cSDiego Caballero   }
301*cf5c517cSDiego Caballero };
302*cf5c517cSDiego Caballero 
30363b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply.
30463b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply.
305563879b6SRahul Joshi class VectorMatmulOpConversion
306563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
30763b683a8SNicolas Vasilache public:
308563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
30963b683a8SNicolas Vasilache 
3103145427dSRiver Riddle   LogicalResult
311563879b6SRahul Joshi   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
31263b683a8SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
3132d2c73c5SJacques Pienaar     auto adaptor = vector::MatmulOpAdaptor(operands);
31463b683a8SNicolas Vasilache     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
315563879b6SRahul Joshi         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
316563879b6SRahul Joshi         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
317563879b6SRahul Joshi         matmulOp.lhs_columns(), matmulOp.rhs_columns());
3183145427dSRiver Riddle     return success();
31963b683a8SNicolas Vasilache   }
32063b683a8SNicolas Vasilache };
32163b683a8SNicolas Vasilache 
322c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose.
323c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose.
324563879b6SRahul Joshi class VectorFlatTransposeOpConversion
325563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
326c295a65dSaartbik public:
327563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
328c295a65dSaartbik 
329c295a65dSaartbik   LogicalResult
330563879b6SRahul Joshi   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
331c295a65dSaartbik                   ConversionPatternRewriter &rewriter) const override {
3322d2c73c5SJacques Pienaar     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333c295a65dSaartbik     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334dcec2ca5SChristian Sigg         transOp, typeConverter->convertType(transOp.res().getType()),
335c295a65dSaartbik         adaptor.matrix(), transOp.rows(), transOp.columns());
336c295a65dSaartbik     return success();
337c295a65dSaartbik   }
338c295a65dSaartbik };
339c295a65dSaartbik 
34039379916Saartbik /// Conversion pattern for a vector.maskedload.
341563879b6SRahul Joshi class VectorMaskedLoadOpConversion
342563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
34339379916Saartbik public:
344563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
34539379916Saartbik 
34639379916Saartbik   LogicalResult
347563879b6SRahul Joshi   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
34839379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
349563879b6SRahul Joshi     auto loc = load->getLoc();
35039379916Saartbik     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
351a57def30SAart Bik     MemRefType memRefType = load.getMemRefType();
35239379916Saartbik 
35339379916Saartbik     // Resolve alignment.
35439379916Saartbik     unsigned align;
355a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
35639379916Saartbik       return failure();
35739379916Saartbik 
358a57def30SAart Bik     // Resolve address.
359dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(load.getResultVectorType());
360a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
361a57def30SAart Bik                                                adaptor.indices(), rewriter);
362a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
36339379916Saartbik 
36439379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
36539379916Saartbik         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
36639379916Saartbik         rewriter.getI32IntegerAttr(align));
36739379916Saartbik     return success();
36839379916Saartbik   }
36939379916Saartbik };
37039379916Saartbik 
37139379916Saartbik /// Conversion pattern for a vector.maskedstore.
372563879b6SRahul Joshi class VectorMaskedStoreOpConversion
373563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
37439379916Saartbik public:
375563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
37639379916Saartbik 
37739379916Saartbik   LogicalResult
378563879b6SRahul Joshi   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
37939379916Saartbik                   ConversionPatternRewriter &rewriter) const override {
380563879b6SRahul Joshi     auto loc = store->getLoc();
38139379916Saartbik     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
382a57def30SAart Bik     MemRefType memRefType = store.getMemRefType();
38339379916Saartbik 
38439379916Saartbik     // Resolve alignment.
38539379916Saartbik     unsigned align;
386a57def30SAart Bik     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
38739379916Saartbik       return failure();
38839379916Saartbik 
389a57def30SAart Bik     // Resolve address.
390dcec2ca5SChristian Sigg     auto vtype = typeConverter->convertType(store.getValueVectorType());
391a57def30SAart Bik     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
392a57def30SAart Bik                                                adaptor.indices(), rewriter);
393a57def30SAart Bik     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
39439379916Saartbik 
39539379916Saartbik     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
39639379916Saartbik         store, adaptor.value(), ptr, adaptor.mask(),
39739379916Saartbik         rewriter.getI32IntegerAttr(align));
39839379916Saartbik     return success();
39939379916Saartbik   }
40039379916Saartbik };
40139379916Saartbik 
40219dbb230Saartbik /// Conversion pattern for a vector.gather.
403563879b6SRahul Joshi class VectorGatherOpConversion
404563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::GatherOp> {
40519dbb230Saartbik public:
406563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
40719dbb230Saartbik 
40819dbb230Saartbik   LogicalResult
409563879b6SRahul Joshi   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
41019dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
411563879b6SRahul Joshi     auto loc = gather->getLoc();
41219dbb230Saartbik     auto adaptor = vector::GatherOpAdaptor(operands);
41319dbb230Saartbik 
41419dbb230Saartbik     // Resolve alignment.
41519dbb230Saartbik     unsigned align;
41626c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
41726c8f908SThomas Raoux                                   align)))
41819dbb230Saartbik       return failure();
41919dbb230Saartbik 
42019dbb230Saartbik     // Get index ptrs.
42119dbb230Saartbik     VectorType vType = gather.getResultVectorType();
42219dbb230Saartbik     Type iType = gather.getIndicesVectorType().getElementType();
42319dbb230Saartbik     Value ptrs;
424e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
425e8dcf5f8Saartbik                               gather.getMemRefType(), vType, iType, ptrs)))
42619dbb230Saartbik       return failure();
42719dbb230Saartbik 
42819dbb230Saartbik     // Replace with the gather intrinsic.
42919dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
430dcec2ca5SChristian Sigg         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
4310c2a4d3cSBenjamin Kramer         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
43219dbb230Saartbik     return success();
43319dbb230Saartbik   }
43419dbb230Saartbik };
43519dbb230Saartbik 
43619dbb230Saartbik /// Conversion pattern for a vector.scatter.
437563879b6SRahul Joshi class VectorScatterOpConversion
438563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
43919dbb230Saartbik public:
440563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
44119dbb230Saartbik 
44219dbb230Saartbik   LogicalResult
443563879b6SRahul Joshi   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
44419dbb230Saartbik                   ConversionPatternRewriter &rewriter) const override {
445563879b6SRahul Joshi     auto loc = scatter->getLoc();
44619dbb230Saartbik     auto adaptor = vector::ScatterOpAdaptor(operands);
44719dbb230Saartbik 
44819dbb230Saartbik     // Resolve alignment.
44919dbb230Saartbik     unsigned align;
45026c8f908SThomas Raoux     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
45126c8f908SThomas Raoux                                   align)))
45219dbb230Saartbik       return failure();
45319dbb230Saartbik 
45419dbb230Saartbik     // Get index ptrs.
45519dbb230Saartbik     VectorType vType = scatter.getValueVectorType();
45619dbb230Saartbik     Type iType = scatter.getIndicesVectorType().getElementType();
45719dbb230Saartbik     Value ptrs;
458e8dcf5f8Saartbik     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
459e8dcf5f8Saartbik                               scatter.getMemRefType(), vType, iType, ptrs)))
46019dbb230Saartbik       return failure();
46119dbb230Saartbik 
46219dbb230Saartbik     // Replace with the scatter intrinsic.
46319dbb230Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
46419dbb230Saartbik         scatter, adaptor.value(), ptrs, adaptor.mask(),
46519dbb230Saartbik         rewriter.getI32IntegerAttr(align));
46619dbb230Saartbik     return success();
46719dbb230Saartbik   }
46819dbb230Saartbik };
46919dbb230Saartbik 
470e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload.
471563879b6SRahul Joshi class VectorExpandLoadOpConversion
472563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
473e8dcf5f8Saartbik public:
474563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
475e8dcf5f8Saartbik 
476e8dcf5f8Saartbik   LogicalResult
477563879b6SRahul Joshi   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
478e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
479563879b6SRahul Joshi     auto loc = expand->getLoc();
480e8dcf5f8Saartbik     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
481a57def30SAart Bik     MemRefType memRefType = expand.getMemRefType();
482e8dcf5f8Saartbik 
483a57def30SAart Bik     // Resolve address.
484a57def30SAart Bik     auto vtype = typeConverter->convertType(expand.getResultVectorType());
485a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
486a57def30SAart Bik                                            adaptor.indices(), rewriter);
487e8dcf5f8Saartbik 
488e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
489a57def30SAart Bik         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
490e8dcf5f8Saartbik     return success();
491e8dcf5f8Saartbik   }
492e8dcf5f8Saartbik };
493e8dcf5f8Saartbik 
494e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore.
495563879b6SRahul Joshi class VectorCompressStoreOpConversion
496563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
497e8dcf5f8Saartbik public:
498563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
499e8dcf5f8Saartbik 
500e8dcf5f8Saartbik   LogicalResult
501563879b6SRahul Joshi   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
502e8dcf5f8Saartbik                   ConversionPatternRewriter &rewriter) const override {
503563879b6SRahul Joshi     auto loc = compress->getLoc();
504e8dcf5f8Saartbik     auto adaptor = vector::CompressStoreOpAdaptor(operands);
505a57def30SAart Bik     MemRefType memRefType = compress.getMemRefType();
506e8dcf5f8Saartbik 
507a57def30SAart Bik     // Resolve address.
508a57def30SAart Bik     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
509a57def30SAart Bik                                            adaptor.indices(), rewriter);
510e8dcf5f8Saartbik 
511e8dcf5f8Saartbik     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
512563879b6SRahul Joshi         compress, adaptor.value(), ptr, adaptor.mask());
513e8dcf5f8Saartbik     return success();
514e8dcf5f8Saartbik   }
515e8dcf5f8Saartbik };
516e8dcf5f8Saartbik 
51719dbb230Saartbik /// Conversion pattern for all vector reductions.
518563879b6SRahul Joshi class VectorReductionOpConversion
519563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
520e83b7b99Saartbik public:
521563879b6SRahul Joshi   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
522060c9dd1Saartbik                                        bool reassociateFPRed)
523563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
524060c9dd1Saartbik         reassociateFPReductions(reassociateFPRed) {}
525e83b7b99Saartbik 
5263145427dSRiver Riddle   LogicalResult
527563879b6SRahul Joshi   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
528e83b7b99Saartbik                   ConversionPatternRewriter &rewriter) const override {
529e83b7b99Saartbik     auto kind = reductionOp.kind();
530e83b7b99Saartbik     Type eltType = reductionOp.dest().getType();
531dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(eltType);
532e9628955SAart Bik     if (eltType.isIntOrIndex()) {
533e83b7b99Saartbik       // Integer reductions: add/mul/min/max/and/or/xor.
534e83b7b99Saartbik       if (kind == "add")
535322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
536563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
537e83b7b99Saartbik       else if (kind == "mul")
538322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
539563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
540e9628955SAart Bik       else if (kind == "min" &&
541e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
542322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
543563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
544e83b7b99Saartbik       else if (kind == "min")
545322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
546563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
547e9628955SAart Bik       else if (kind == "max" &&
548e9628955SAart Bik                (eltType.isIndex() || eltType.isUnsignedInteger()))
549322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
550563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
551e83b7b99Saartbik       else if (kind == "max")
552322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
553563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
554e83b7b99Saartbik       else if (kind == "and")
555322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
556563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
557e83b7b99Saartbik       else if (kind == "or")
558322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
559563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
560e83b7b99Saartbik       else if (kind == "xor")
561322d0afdSAmara Emerson         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
562563879b6SRahul Joshi             reductionOp, llvmType, operands[0]);
563e83b7b99Saartbik       else
5643145427dSRiver Riddle         return failure();
5653145427dSRiver Riddle       return success();
566dcec2ca5SChristian Sigg     }
567e83b7b99Saartbik 
568dcec2ca5SChristian Sigg     if (!eltType.isa<FloatType>())
569dcec2ca5SChristian Sigg       return failure();
570dcec2ca5SChristian Sigg 
571e83b7b99Saartbik     // Floating-point reductions: add/mul/min/max
572e83b7b99Saartbik     if (kind == "add") {
5730d924700Saartbik       // Optional accumulator (or zero).
5740d924700Saartbik       Value acc = operands.size() > 1 ? operands[1]
5750d924700Saartbik                                       : rewriter.create<LLVM::ConstantOp>(
576563879b6SRahul Joshi                                             reductionOp->getLoc(), llvmType,
5770d924700Saartbik                                             rewriter.getZeroAttr(eltType));
578322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
579563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
580ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
581e83b7b99Saartbik     } else if (kind == "mul") {
5820d924700Saartbik       // Optional accumulator (or one).
5830d924700Saartbik       Value acc = operands.size() > 1
5840d924700Saartbik                       ? operands[1]
5850d924700Saartbik                       : rewriter.create<LLVM::ConstantOp>(
586563879b6SRahul Joshi                             reductionOp->getLoc(), llvmType,
5870d924700Saartbik                             rewriter.getFloatAttr(eltType, 1.0));
588322d0afdSAmara Emerson       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
589563879b6SRahul Joshi           reductionOp, llvmType, acc, operands[0],
590ceb1b327Saartbik           rewriter.getBoolAttr(reassociateFPReductions));
591e83b7b99Saartbik     } else if (kind == "min")
592563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
593563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
594e83b7b99Saartbik     else if (kind == "max")
595563879b6SRahul Joshi       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
596563879b6SRahul Joshi           reductionOp, llvmType, operands[0]);
597e83b7b99Saartbik     else
5983145427dSRiver Riddle       return failure();
5993145427dSRiver Riddle     return success();
600e83b7b99Saartbik   }
601ceb1b327Saartbik 
602ceb1b327Saartbik private:
603ceb1b327Saartbik   const bool reassociateFPReductions;
604e83b7b99Saartbik };
605e83b7b99Saartbik 
606060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only).
607563879b6SRahul Joshi class VectorCreateMaskOpConversion
608563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
609060c9dd1Saartbik public:
610563879b6SRahul Joshi   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
611060c9dd1Saartbik                                         bool enableIndexOpt)
612563879b6SRahul Joshi       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
613060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
614060c9dd1Saartbik 
615060c9dd1Saartbik   LogicalResult
616563879b6SRahul Joshi   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
617060c9dd1Saartbik                   ConversionPatternRewriter &rewriter) const override {
6189eb3e564SChris Lattner     auto dstType = op.getType();
619060c9dd1Saartbik     int64_t rank = dstType.getRank();
620060c9dd1Saartbik     if (rank == 1) {
621060c9dd1Saartbik       rewriter.replaceOp(
622060c9dd1Saartbik           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
623060c9dd1Saartbik                                     dstType.getDimSize(0), operands[0]));
624060c9dd1Saartbik       return success();
625060c9dd1Saartbik     }
626060c9dd1Saartbik     return failure();
627060c9dd1Saartbik   }
628060c9dd1Saartbik 
629060c9dd1Saartbik private:
630060c9dd1Saartbik   const bool enableIndexOptimizations;
631060c9dd1Saartbik };
632060c9dd1Saartbik 
633563879b6SRahul Joshi class VectorShuffleOpConversion
634563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
6351c81adf3SAart Bik public:
636563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
6371c81adf3SAart Bik 
6383145427dSRiver Riddle   LogicalResult
639563879b6SRahul Joshi   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
6401c81adf3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
641563879b6SRahul Joshi     auto loc = shuffleOp->getLoc();
6422d2c73c5SJacques Pienaar     auto adaptor = vector::ShuffleOpAdaptor(operands);
6431c81adf3SAart Bik     auto v1Type = shuffleOp.getV1VectorType();
6441c81adf3SAart Bik     auto v2Type = shuffleOp.getV2VectorType();
6451c81adf3SAart Bik     auto vectorType = shuffleOp.getVectorType();
646dcec2ca5SChristian Sigg     Type llvmType = typeConverter->convertType(vectorType);
6471c81adf3SAart Bik     auto maskArrayAttr = shuffleOp.mask();
6481c81adf3SAart Bik 
6491c81adf3SAart Bik     // Bail if result type cannot be lowered.
6501c81adf3SAart Bik     if (!llvmType)
6513145427dSRiver Riddle       return failure();
6521c81adf3SAart Bik 
6531c81adf3SAart Bik     // Get rank and dimension sizes.
6541c81adf3SAart Bik     int64_t rank = vectorType.getRank();
6551c81adf3SAart Bik     assert(v1Type.getRank() == rank);
6561c81adf3SAart Bik     assert(v2Type.getRank() == rank);
6571c81adf3SAart Bik     int64_t v1Dim = v1Type.getDimSize(0);
6581c81adf3SAart Bik 
6591c81adf3SAart Bik     // For rank 1, where both operands have *exactly* the same vector type,
6601c81adf3SAart Bik     // there is direct shuffle support in LLVM. Use it!
6611c81adf3SAart Bik     if (rank == 1 && v1Type == v2Type) {
662563879b6SRahul Joshi       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
6631c81adf3SAart Bik           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
664563879b6SRahul Joshi       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
6653145427dSRiver Riddle       return success();
666b36aaeafSAart Bik     }
667b36aaeafSAart Bik 
6681c81adf3SAart Bik     // For all other cases, insert the individual values individually.
669e62a6956SRiver Riddle     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
6701c81adf3SAart Bik     int64_t insPos = 0;
6711c81adf3SAart Bik     for (auto en : llvm::enumerate(maskArrayAttr)) {
6721c81adf3SAart Bik       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
673e62a6956SRiver Riddle       Value value = adaptor.v1();
6741c81adf3SAart Bik       if (extPos >= v1Dim) {
6751c81adf3SAart Bik         extPos -= v1Dim;
6761c81adf3SAart Bik         value = adaptor.v2();
677b36aaeafSAart Bik       }
678dcec2ca5SChristian Sigg       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
679dcec2ca5SChristian Sigg                                  llvmType, rank, extPos);
680dcec2ca5SChristian Sigg       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
6810f04384dSAlex Zinenko                          llvmType, rank, insPos++);
6821c81adf3SAart Bik     }
683563879b6SRahul Joshi     rewriter.replaceOp(shuffleOp, insert);
6843145427dSRiver Riddle     return success();
685b36aaeafSAart Bik   }
686b36aaeafSAart Bik };
687b36aaeafSAart Bik 
688563879b6SRahul Joshi class VectorExtractElementOpConversion
689563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
690cd5dab8aSAart Bik public:
691563879b6SRahul Joshi   using ConvertOpToLLVMPattern<
692563879b6SRahul Joshi       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
693cd5dab8aSAart Bik 
6943145427dSRiver Riddle   LogicalResult
695563879b6SRahul Joshi   matchAndRewrite(vector::ExtractElementOp extractEltOp,
696563879b6SRahul Joshi                   ArrayRef<Value> operands,
697cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
6982d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractElementOpAdaptor(operands);
699cd5dab8aSAart Bik     auto vectorType = extractEltOp.getVectorType();
700dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType.getElementType());
701cd5dab8aSAart Bik 
702cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
703cd5dab8aSAart Bik     if (!llvmType)
7043145427dSRiver Riddle       return failure();
705cd5dab8aSAart Bik 
706cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
707563879b6SRahul Joshi         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
7083145427dSRiver Riddle     return success();
709cd5dab8aSAart Bik   }
710cd5dab8aSAart Bik };
711cd5dab8aSAart Bik 
712563879b6SRahul Joshi class VectorExtractOpConversion
713563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
7145c0c51a9SNicolas Vasilache public:
715563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
7165c0c51a9SNicolas Vasilache 
7173145427dSRiver Riddle   LogicalResult
718563879b6SRahul Joshi   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
7195c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
720563879b6SRahul Joshi     auto loc = extractOp->getLoc();
7212d2c73c5SJacques Pienaar     auto adaptor = vector::ExtractOpAdaptor(operands);
7229826fe5cSAart Bik     auto vectorType = extractOp.getVectorType();
7232bdf33ccSRiver Riddle     auto resultType = extractOp.getResult().getType();
724dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(resultType);
7255c0c51a9SNicolas Vasilache     auto positionArrayAttr = extractOp.position();
7269826fe5cSAart Bik 
7279826fe5cSAart Bik     // Bail if result type cannot be lowered.
7289826fe5cSAart Bik     if (!llvmResultType)
7293145427dSRiver Riddle       return failure();
7309826fe5cSAart Bik 
7315c0c51a9SNicolas Vasilache     // One-shot extraction of vector from array (only requires extractvalue).
7325c0c51a9SNicolas Vasilache     if (resultType.isa<VectorType>()) {
733e62a6956SRiver Riddle       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
7345c0c51a9SNicolas Vasilache           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
735563879b6SRahul Joshi       rewriter.replaceOp(extractOp, extracted);
7363145427dSRiver Riddle       return success();
7375c0c51a9SNicolas Vasilache     }
7385c0c51a9SNicolas Vasilache 
7399826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
740563879b6SRahul Joshi     auto *context = extractOp->getContext();
741e62a6956SRiver Riddle     Value extracted = adaptor.vector();
7425c0c51a9SNicolas Vasilache     auto positionAttrs = positionArrayAttr.getValue();
7435c0c51a9SNicolas Vasilache     if (positionAttrs.size() > 1) {
7449826fe5cSAart Bik       auto oneDVectorType = reducedVectorTypeBack(vectorType);
7455c0c51a9SNicolas Vasilache       auto nMinusOnePositionAttrs =
7465c0c51a9SNicolas Vasilache           ArrayAttr::get(positionAttrs.drop_back(), context);
7475c0c51a9SNicolas Vasilache       extracted = rewriter.create<LLVM::ExtractValueOp>(
748dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
7495c0c51a9SNicolas Vasilache           nMinusOnePositionAttrs);
7505c0c51a9SNicolas Vasilache     }
7515c0c51a9SNicolas Vasilache 
7525c0c51a9SNicolas Vasilache     // Remaining extraction of element from 1-D LLVM vector
7535c0c51a9SNicolas Vasilache     auto position = positionAttrs.back().cast<IntegerAttr>();
7542230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
7551d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
7565c0c51a9SNicolas Vasilache     extracted =
7575c0c51a9SNicolas Vasilache         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
758563879b6SRahul Joshi     rewriter.replaceOp(extractOp, extracted);
7595c0c51a9SNicolas Vasilache 
7603145427dSRiver Riddle     return success();
7615c0c51a9SNicolas Vasilache   }
7625c0c51a9SNicolas Vasilache };
7635c0c51a9SNicolas Vasilache 
764681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector
765681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
766681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank.
767681f929fSNicolas Vasilache ///
768681f929fSNicolas Vasilache /// Example:
769681f929fSNicolas Vasilache /// ```
770681f929fSNicolas Vasilache ///  vector.fma %a, %a, %a : vector<8xf32>
771681f929fSNicolas Vasilache /// ```
772681f929fSNicolas Vasilache /// is converted to:
773681f929fSNicolas Vasilache /// ```
7743bffe602SBenjamin Kramer ///  llvm.intr.fmuladd %va, %va, %va:
775dd5165a9SAlex Zinenko ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
776dd5165a9SAlex Zinenko ///    -> !llvm."<8 x f32>">
777681f929fSNicolas Vasilache /// ```
778563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
779681f929fSNicolas Vasilache public:
780563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
781681f929fSNicolas Vasilache 
7823145427dSRiver Riddle   LogicalResult
783563879b6SRahul Joshi   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
784681f929fSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
7852d2c73c5SJacques Pienaar     auto adaptor = vector::FMAOpAdaptor(operands);
786681f929fSNicolas Vasilache     VectorType vType = fmaOp.getVectorType();
787681f929fSNicolas Vasilache     if (vType.getRank() != 1)
7883145427dSRiver Riddle       return failure();
789563879b6SRahul Joshi     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
7903bffe602SBenjamin Kramer                                                  adaptor.rhs(), adaptor.acc());
7913145427dSRiver Riddle     return success();
792681f929fSNicolas Vasilache   }
793681f929fSNicolas Vasilache };
794681f929fSNicolas Vasilache 
795563879b6SRahul Joshi class VectorInsertElementOpConversion
796563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
797cd5dab8aSAart Bik public:
798563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
799cd5dab8aSAart Bik 
8003145427dSRiver Riddle   LogicalResult
801563879b6SRahul Joshi   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
802cd5dab8aSAart Bik                   ConversionPatternRewriter &rewriter) const override {
8032d2c73c5SJacques Pienaar     auto adaptor = vector::InsertElementOpAdaptor(operands);
804cd5dab8aSAart Bik     auto vectorType = insertEltOp.getDestVectorType();
805dcec2ca5SChristian Sigg     auto llvmType = typeConverter->convertType(vectorType);
806cd5dab8aSAart Bik 
807cd5dab8aSAart Bik     // Bail if result type cannot be lowered.
808cd5dab8aSAart Bik     if (!llvmType)
8093145427dSRiver Riddle       return failure();
810cd5dab8aSAart Bik 
811cd5dab8aSAart Bik     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
812563879b6SRahul Joshi         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
813563879b6SRahul Joshi         adaptor.position());
8143145427dSRiver Riddle     return success();
815cd5dab8aSAart Bik   }
816cd5dab8aSAart Bik };
817cd5dab8aSAart Bik 
818563879b6SRahul Joshi class VectorInsertOpConversion
819563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::InsertOp> {
8209826fe5cSAart Bik public:
821563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
8229826fe5cSAart Bik 
8233145427dSRiver Riddle   LogicalResult
824563879b6SRahul Joshi   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
8259826fe5cSAart Bik                   ConversionPatternRewriter &rewriter) const override {
826563879b6SRahul Joshi     auto loc = insertOp->getLoc();
8272d2c73c5SJacques Pienaar     auto adaptor = vector::InsertOpAdaptor(operands);
8289826fe5cSAart Bik     auto sourceType = insertOp.getSourceType();
8299826fe5cSAart Bik     auto destVectorType = insertOp.getDestVectorType();
830dcec2ca5SChristian Sigg     auto llvmResultType = typeConverter->convertType(destVectorType);
8319826fe5cSAart Bik     auto positionArrayAttr = insertOp.position();
8329826fe5cSAart Bik 
8339826fe5cSAart Bik     // Bail if result type cannot be lowered.
8349826fe5cSAart Bik     if (!llvmResultType)
8353145427dSRiver Riddle       return failure();
8369826fe5cSAart Bik 
8379826fe5cSAart Bik     // One-shot insertion of a vector into an array (only requires insertvalue).
8389826fe5cSAart Bik     if (sourceType.isa<VectorType>()) {
839e62a6956SRiver Riddle       Value inserted = rewriter.create<LLVM::InsertValueOp>(
8409826fe5cSAart Bik           loc, llvmResultType, adaptor.dest(), adaptor.source(),
8419826fe5cSAart Bik           positionArrayAttr);
842563879b6SRahul Joshi       rewriter.replaceOp(insertOp, inserted);
8433145427dSRiver Riddle       return success();
8449826fe5cSAart Bik     }
8459826fe5cSAart Bik 
8469826fe5cSAart Bik     // Potential extraction of 1-D vector from array.
847563879b6SRahul Joshi     auto *context = insertOp->getContext();
848e62a6956SRiver Riddle     Value extracted = adaptor.dest();
8499826fe5cSAart Bik     auto positionAttrs = positionArrayAttr.getValue();
8509826fe5cSAart Bik     auto position = positionAttrs.back().cast<IntegerAttr>();
8519826fe5cSAart Bik     auto oneDVectorType = destVectorType;
8529826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8539826fe5cSAart Bik       oneDVectorType = reducedVectorTypeBack(destVectorType);
8549826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8559826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8569826fe5cSAart Bik       extracted = rewriter.create<LLVM::ExtractValueOp>(
857dcec2ca5SChristian Sigg           loc, typeConverter->convertType(oneDVectorType), extracted,
8589826fe5cSAart Bik           nMinusOnePositionAttrs);
8599826fe5cSAart Bik     }
8609826fe5cSAart Bik 
8619826fe5cSAart Bik     // Insertion of an element into a 1-D LLVM vector.
8622230bf99SAlex Zinenko     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
8631d47564aSAart Bik     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
864e62a6956SRiver Riddle     Value inserted = rewriter.create<LLVM::InsertElementOp>(
865dcec2ca5SChristian Sigg         loc, typeConverter->convertType(oneDVectorType), extracted,
8660f04384dSAlex Zinenko         adaptor.source(), constant);
8679826fe5cSAart Bik 
8689826fe5cSAart Bik     // Potential insertion of resulting 1-D vector into array.
8699826fe5cSAart Bik     if (positionAttrs.size() > 1) {
8709826fe5cSAart Bik       auto nMinusOnePositionAttrs =
8719826fe5cSAart Bik           ArrayAttr::get(positionAttrs.drop_back(), context);
8729826fe5cSAart Bik       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
8739826fe5cSAart Bik                                                       adaptor.dest(), inserted,
8749826fe5cSAart Bik                                                       nMinusOnePositionAttrs);
8759826fe5cSAart Bik     }
8769826fe5cSAart Bik 
877563879b6SRahul Joshi     rewriter.replaceOp(insertOp, inserted);
8783145427dSRiver Riddle     return success();
8799826fe5cSAart Bik   }
8809826fe5cSAart Bik };
8819826fe5cSAart Bik 
882681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
883681f929fSNicolas Vasilache ///
884681f929fSNicolas Vasilache /// Example:
885681f929fSNicolas Vasilache /// ```
886681f929fSNicolas Vasilache ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
887681f929fSNicolas Vasilache /// ```
888681f929fSNicolas Vasilache /// is rewritten into:
889681f929fSNicolas Vasilache /// ```
890681f929fSNicolas Vasilache ///  %r = splat %f0: vector<2x4xf32>
891681f929fSNicolas Vasilache ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
892681f929fSNicolas Vasilache ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
893681f929fSNicolas Vasilache ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
894681f929fSNicolas Vasilache ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
895681f929fSNicolas Vasilache ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
896681f929fSNicolas Vasilache ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
897681f929fSNicolas Vasilache ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
898681f929fSNicolas Vasilache ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
899681f929fSNicolas Vasilache ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
900681f929fSNicolas Vasilache ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
901681f929fSNicolas Vasilache ///  // %r3 holds the final value.
902681f929fSNicolas Vasilache /// ```
903681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
904681f929fSNicolas Vasilache public:
905681f929fSNicolas Vasilache   using OpRewritePattern<FMAOp>::OpRewritePattern;
906681f929fSNicolas Vasilache 
9073145427dSRiver Riddle   LogicalResult matchAndRewrite(FMAOp op,
908681f929fSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
909681f929fSNicolas Vasilache     auto vType = op.getVectorType();
910681f929fSNicolas Vasilache     if (vType.getRank() < 2)
9113145427dSRiver Riddle       return failure();
912681f929fSNicolas Vasilache 
913681f929fSNicolas Vasilache     auto loc = op.getLoc();
914681f929fSNicolas Vasilache     auto elemType = vType.getElementType();
915681f929fSNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
916681f929fSNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
917681f929fSNicolas Vasilache     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
918681f929fSNicolas Vasilache     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
919681f929fSNicolas Vasilache       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
920681f929fSNicolas Vasilache       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
921681f929fSNicolas Vasilache       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
922681f929fSNicolas Vasilache       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
923681f929fSNicolas Vasilache       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
924681f929fSNicolas Vasilache     }
925681f929fSNicolas Vasilache     rewriter.replaceOp(op, desc);
9263145427dSRiver Riddle     return success();
927681f929fSNicolas Vasilache   }
928681f929fSNicolas Vasilache };
929681f929fSNicolas Vasilache 
9302d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly
9312d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern
9322d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to
9332d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same
9342d515e49SNicolas Vasilache // rank.
9352d515e49SNicolas Vasilache //
9362d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9372d515e49SNicolas Vasilache // have different ranks. In this case:
9382d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9392d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9402d515e49SNicolas Vasilache //   destination subvector
9412d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9422d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9432d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9442d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9452d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern
9462d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9472d515e49SNicolas Vasilache public:
9482d515e49SNicolas Vasilache   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
9492d515e49SNicolas Vasilache 
9503145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
9512d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
9522d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
9532d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
9542d515e49SNicolas Vasilache 
9552d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
9563145427dSRiver Riddle       return failure();
9572d515e49SNicolas Vasilache 
9582d515e49SNicolas Vasilache     auto loc = op.getLoc();
9592d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
9602d515e49SNicolas Vasilache     assert(rankDiff >= 0);
9612d515e49SNicolas Vasilache     if (rankDiff == 0)
9623145427dSRiver Riddle       return failure();
9632d515e49SNicolas Vasilache 
9642d515e49SNicolas Vasilache     int64_t rankRest = dstType.getRank() - rankDiff;
9652d515e49SNicolas Vasilache     // Extract / insert the subvector of matching rank and InsertStridedSlice
9662d515e49SNicolas Vasilache     // on it.
9672d515e49SNicolas Vasilache     Value extracted =
9682d515e49SNicolas Vasilache         rewriter.create<ExtractOp>(loc, op.dest(),
9692d515e49SNicolas Vasilache                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
970dcec2ca5SChristian Sigg                                                   /*dropBack=*/rankRest));
9712d515e49SNicolas Vasilache     // A different pattern will kick in for InsertStridedSlice with matching
9722d515e49SNicolas Vasilache     // ranks.
9732d515e49SNicolas Vasilache     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
9742d515e49SNicolas Vasilache         loc, op.source(), extracted,
9752d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
976c8fc76a9Saartbik         getI64SubArray(op.strides(), /*dropFront=*/0));
9772d515e49SNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOp>(
9782d515e49SNicolas Vasilache         op, stridedSliceInnerOp.getResult(), op.dest(),
9792d515e49SNicolas Vasilache         getI64SubArray(op.offsets(), /*dropFront=*/0,
980dcec2ca5SChristian Sigg                        /*dropBack=*/rankRest));
9813145427dSRiver Riddle     return success();
9822d515e49SNicolas Vasilache   }
9832d515e49SNicolas Vasilache };
9842d515e49SNicolas Vasilache 
9852d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors
9862d515e49SNicolas Vasilache // have the same rank. In this case, we reduce
9872d515e49SNicolas Vasilache //   1. the proper subvector is extracted from the destination vector
9882d515e49SNicolas Vasilache //   2. a new InsertStridedSlice op is created to insert the source in the
9892d515e49SNicolas Vasilache //   destination subvector
9902d515e49SNicolas Vasilache //   3. the destination subvector is inserted back in the proper place
9912d515e49SNicolas Vasilache //   4. the op is replaced by the result of step 3.
9922d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a
9932d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`.
9942d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern
9952d515e49SNicolas Vasilache     : public OpRewritePattern<InsertStridedSliceOp> {
9962d515e49SNicolas Vasilache public:
997b99bd771SRiver Riddle   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
998b99bd771SRiver Riddle       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
999b99bd771SRiver Riddle     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1000b99bd771SRiver Riddle     // bounded as the rank is strictly decreasing.
1001b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1002b99bd771SRiver Riddle   }
10032d515e49SNicolas Vasilache 
10043145427dSRiver Riddle   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
10052d515e49SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
10062d515e49SNicolas Vasilache     auto srcType = op.getSourceVectorType();
10072d515e49SNicolas Vasilache     auto dstType = op.getDestVectorType();
10082d515e49SNicolas Vasilache 
10092d515e49SNicolas Vasilache     if (op.offsets().getValue().empty())
10103145427dSRiver Riddle       return failure();
10112d515e49SNicolas Vasilache 
10122d515e49SNicolas Vasilache     int64_t rankDiff = dstType.getRank() - srcType.getRank();
10132d515e49SNicolas Vasilache     assert(rankDiff >= 0);
10142d515e49SNicolas Vasilache     if (rankDiff != 0)
10153145427dSRiver Riddle       return failure();
10162d515e49SNicolas Vasilache 
10172d515e49SNicolas Vasilache     if (srcType == dstType) {
10182d515e49SNicolas Vasilache       rewriter.replaceOp(op, op.source());
10193145427dSRiver Riddle       return success();
10202d515e49SNicolas Vasilache     }
10212d515e49SNicolas Vasilache 
10222d515e49SNicolas Vasilache     int64_t offset =
10232d515e49SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
10242d515e49SNicolas Vasilache     int64_t size = srcType.getShape().front();
10252d515e49SNicolas Vasilache     int64_t stride =
10262d515e49SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
10272d515e49SNicolas Vasilache 
10282d515e49SNicolas Vasilache     auto loc = op.getLoc();
10292d515e49SNicolas Vasilache     Value res = op.dest();
10302d515e49SNicolas Vasilache     // For each slice of the source vector along the most major dimension.
10312d515e49SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
10322d515e49SNicolas Vasilache          off += stride, ++idx) {
10332d515e49SNicolas Vasilache       // 1. extract the proper subvector (or element) from source
10342d515e49SNicolas Vasilache       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
10352d515e49SNicolas Vasilache       if (extractedSource.getType().isa<VectorType>()) {
10362d515e49SNicolas Vasilache         // 2. If we have a vector, extract the proper subvector from destination
10372d515e49SNicolas Vasilache         // Otherwise we are at the element level and no need to recurse.
10382d515e49SNicolas Vasilache         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
10392d515e49SNicolas Vasilache         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
10402d515e49SNicolas Vasilache         // smaller rank.
1041bd1ccfe6SRiver Riddle         extractedSource = rewriter.create<InsertStridedSliceOp>(
10422d515e49SNicolas Vasilache             loc, extractedSource, extractedDest,
10432d515e49SNicolas Vasilache             getI64SubArray(op.offsets(), /* dropFront=*/1),
10442d515e49SNicolas Vasilache             getI64SubArray(op.strides(), /* dropFront=*/1));
10452d515e49SNicolas Vasilache       }
10462d515e49SNicolas Vasilache       // 4. Insert the extractedSource into the res vector.
10472d515e49SNicolas Vasilache       res = insertOne(rewriter, loc, extractedSource, res, off);
10482d515e49SNicolas Vasilache     }
10492d515e49SNicolas Vasilache 
10502d515e49SNicolas Vasilache     rewriter.replaceOp(op, res);
10513145427dSRiver Riddle     return success();
10522d515e49SNicolas Vasilache   }
10532d515e49SNicolas Vasilache };
10542d515e49SNicolas Vasilache 
105530e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous
105630e6033bSNicolas Vasilache /// static layout.
105730e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>>
105830e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) {
10592bf491c7SBenjamin Kramer   int64_t offset;
106030e6033bSNicolas Vasilache   SmallVector<int64_t, 4> strides;
106130e6033bSNicolas Vasilache   if (failed(getStridesAndOffset(memRefType, strides, offset)))
106230e6033bSNicolas Vasilache     return None;
106330e6033bSNicolas Vasilache   if (!strides.empty() && strides.back() != 1)
106430e6033bSNicolas Vasilache     return None;
106530e6033bSNicolas Vasilache   // If no layout or identity layout, this is contiguous by definition.
106630e6033bSNicolas Vasilache   if (memRefType.getAffineMaps().empty() ||
106730e6033bSNicolas Vasilache       memRefType.getAffineMaps().front().isIdentity())
106830e6033bSNicolas Vasilache     return strides;
106930e6033bSNicolas Vasilache 
107030e6033bSNicolas Vasilache   // Otherwise, we must determine contiguity form shapes. This can only ever
107130e6033bSNicolas Vasilache   // work in static cases because MemRefType is underspecified to represent
107230e6033bSNicolas Vasilache   // contiguous dynamic shapes in other ways than with just empty/identity
107330e6033bSNicolas Vasilache   // layout.
10742bf491c7SBenjamin Kramer   auto sizes = memRefType.getShape();
10752bf491c7SBenjamin Kramer   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
107630e6033bSNicolas Vasilache     if (ShapedType::isDynamic(sizes[index + 1]) ||
107730e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
107830e6033bSNicolas Vasilache         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
107930e6033bSNicolas Vasilache       return None;
108030e6033bSNicolas Vasilache     if (strides[index] != strides[index + 1] * sizes[index + 1])
108130e6033bSNicolas Vasilache       return None;
10822bf491c7SBenjamin Kramer   }
108330e6033bSNicolas Vasilache   return strides;
10842bf491c7SBenjamin Kramer }
10852bf491c7SBenjamin Kramer 
1086563879b6SRahul Joshi class VectorTypeCastOpConversion
1087563879b6SRahul Joshi     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
10885c0c51a9SNicolas Vasilache public:
1089563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
10905c0c51a9SNicolas Vasilache 
10913145427dSRiver Riddle   LogicalResult
1092563879b6SRahul Joshi   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
10935c0c51a9SNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
1094563879b6SRahul Joshi     auto loc = castOp->getLoc();
10955c0c51a9SNicolas Vasilache     MemRefType sourceMemRefType =
10962bdf33ccSRiver Riddle         castOp.getOperand().getType().cast<MemRefType>();
10979eb3e564SChris Lattner     MemRefType targetMemRefType = castOp.getType();
10985c0c51a9SNicolas Vasilache 
10995c0c51a9SNicolas Vasilache     // Only static shape casts supported atm.
11005c0c51a9SNicolas Vasilache     if (!sourceMemRefType.hasStaticShape() ||
11015c0c51a9SNicolas Vasilache         !targetMemRefType.hasStaticShape())
11023145427dSRiver Riddle       return failure();
11035c0c51a9SNicolas Vasilache 
11045c0c51a9SNicolas Vasilache     auto llvmSourceDescriptorTy =
11058de43b92SAlex Zinenko         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
11068de43b92SAlex Zinenko     if (!llvmSourceDescriptorTy)
11073145427dSRiver Riddle       return failure();
11085c0c51a9SNicolas Vasilache     MemRefDescriptor sourceMemRef(operands[0]);
11095c0c51a9SNicolas Vasilache 
1110dcec2ca5SChristian Sigg     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
11118de43b92SAlex Zinenko                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
11128de43b92SAlex Zinenko     if (!llvmTargetDescriptorTy)
11133145427dSRiver Riddle       return failure();
11145c0c51a9SNicolas Vasilache 
111530e6033bSNicolas Vasilache     // Only contiguous source buffers supported atm.
111630e6033bSNicolas Vasilache     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
111730e6033bSNicolas Vasilache     if (!sourceStrides)
111830e6033bSNicolas Vasilache       return failure();
111930e6033bSNicolas Vasilache     auto targetStrides = computeContiguousStrides(targetMemRefType);
112030e6033bSNicolas Vasilache     if (!targetStrides)
112130e6033bSNicolas Vasilache       return failure();
112230e6033bSNicolas Vasilache     // Only support static strides for now, regardless of contiguity.
112330e6033bSNicolas Vasilache     if (llvm::any_of(*targetStrides, [](int64_t stride) {
112430e6033bSNicolas Vasilache           return ShapedType::isDynamicStrideOrOffset(stride);
112530e6033bSNicolas Vasilache         }))
11263145427dSRiver Riddle       return failure();
11275c0c51a9SNicolas Vasilache 
11282230bf99SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
11295c0c51a9SNicolas Vasilache 
11305c0c51a9SNicolas Vasilache     // Create descriptor.
11315c0c51a9SNicolas Vasilache     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
11323a577f54SChristian Sigg     Type llvmTargetElementTy = desc.getElementPtrType();
11335c0c51a9SNicolas Vasilache     // Set allocated ptr.
1134e62a6956SRiver Riddle     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
11355c0c51a9SNicolas Vasilache     allocated =
11365c0c51a9SNicolas Vasilache         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
11375c0c51a9SNicolas Vasilache     desc.setAllocatedPtr(rewriter, loc, allocated);
11385c0c51a9SNicolas Vasilache     // Set aligned ptr.
1139e62a6956SRiver Riddle     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
11405c0c51a9SNicolas Vasilache     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
11415c0c51a9SNicolas Vasilache     desc.setAlignedPtr(rewriter, loc, ptr);
11425c0c51a9SNicolas Vasilache     // Fill offset 0.
11435c0c51a9SNicolas Vasilache     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
11445c0c51a9SNicolas Vasilache     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
11455c0c51a9SNicolas Vasilache     desc.setOffset(rewriter, loc, zero);
11465c0c51a9SNicolas Vasilache 
11475c0c51a9SNicolas Vasilache     // Fill size and stride descriptors in memref.
11485c0c51a9SNicolas Vasilache     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
11495c0c51a9SNicolas Vasilache       int64_t index = indexedSize.index();
11505c0c51a9SNicolas Vasilache       auto sizeAttr =
11515c0c51a9SNicolas Vasilache           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
11525c0c51a9SNicolas Vasilache       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
11535c0c51a9SNicolas Vasilache       desc.setSize(rewriter, loc, index, size);
115430e6033bSNicolas Vasilache       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
115530e6033bSNicolas Vasilache                                                 (*targetStrides)[index]);
11565c0c51a9SNicolas Vasilache       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
11575c0c51a9SNicolas Vasilache       desc.setStride(rewriter, loc, index, stride);
11585c0c51a9SNicolas Vasilache     }
11595c0c51a9SNicolas Vasilache 
1160563879b6SRahul Joshi     rewriter.replaceOp(castOp, {desc});
11613145427dSRiver Riddle     return success();
11625c0c51a9SNicolas Vasilache   }
11635c0c51a9SNicolas Vasilache };
11645c0c51a9SNicolas Vasilache 
11658345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a
11668345b86dSNicolas Vasilache /// sequence of:
1167060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer.
1168060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1169060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1170060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound.
1171060c9dd1Saartbik /// 5. Rewrite op as a masked read or write.
11728345b86dSNicolas Vasilache template <typename ConcreteOp>
1173563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
11748345b86dSNicolas Vasilache public:
1175563879b6SRahul Joshi   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1176060c9dd1Saartbik                                     bool enableIndexOpt)
1177563879b6SRahul Joshi       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1178060c9dd1Saartbik         enableIndexOptimizations(enableIndexOpt) {}
11798345b86dSNicolas Vasilache 
11808345b86dSNicolas Vasilache   LogicalResult
1181563879b6SRahul Joshi   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
11828345b86dSNicolas Vasilache                   ConversionPatternRewriter &rewriter) const override {
11838345b86dSNicolas Vasilache     auto adaptor = getTransferOpAdapter(xferOp, operands);
1184b2c79c50SNicolas Vasilache 
1185b2c79c50SNicolas Vasilache     if (xferOp.getVectorType().getRank() > 1 ||
1186b2c79c50SNicolas Vasilache         llvm::size(xferOp.indices()) == 0)
11878345b86dSNicolas Vasilache       return failure();
11885f9e0466SNicolas Vasilache     if (xferOp.permutation_map() !=
11895f9e0466SNicolas Vasilache         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
11905f9e0466SNicolas Vasilache                                        xferOp.getVectorType().getRank(),
1191563879b6SRahul Joshi                                        xferOp->getContext()))
11928345b86dSNicolas Vasilache       return failure();
119326c8f908SThomas Raoux     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
119426c8f908SThomas Raoux     if (!memRefType)
119526c8f908SThomas Raoux       return failure();
11962bf491c7SBenjamin Kramer     // Only contiguous source tensors supported atm.
119726c8f908SThomas Raoux     auto strides = computeContiguousStrides(memRefType);
119830e6033bSNicolas Vasilache     if (!strides)
11992bf491c7SBenjamin Kramer       return failure();
12008345b86dSNicolas Vasilache 
1201563879b6SRahul Joshi     auto toLLVMTy = [&](Type t) {
1202563879b6SRahul Joshi       return this->getTypeConverter()->convertType(t);
1203563879b6SRahul Joshi     };
12048345b86dSNicolas Vasilache 
1205563879b6SRahul Joshi     Location loc = xferOp->getLoc();
12068345b86dSNicolas Vasilache 
120768330ee0SThomas Raoux     if (auto memrefVectorElementType =
120826c8f908SThomas Raoux             memRefType.getElementType().template dyn_cast<VectorType>()) {
120968330ee0SThomas Raoux       // Memref has vector element type.
121068330ee0SThomas Raoux       if (memrefVectorElementType.getElementType() !=
121168330ee0SThomas Raoux           xferOp.getVectorType().getElementType())
121268330ee0SThomas Raoux         return failure();
12130de60b55SThomas Raoux #ifndef NDEBUG
121468330ee0SThomas Raoux       // Check that memref vector type is a suffix of 'vectorType.
121568330ee0SThomas Raoux       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
121668330ee0SThomas Raoux       unsigned resultVecRank = xferOp.getVectorType().getRank();
121768330ee0SThomas Raoux       assert(memrefVecEltRank <= resultVecRank);
121868330ee0SThomas Raoux       // TODO: Move this to isSuffix in Vector/Utils.h.
121968330ee0SThomas Raoux       unsigned rankOffset = resultVecRank - memrefVecEltRank;
122068330ee0SThomas Raoux       auto memrefVecEltShape = memrefVectorElementType.getShape();
122168330ee0SThomas Raoux       auto resultVecShape = xferOp.getVectorType().getShape();
122268330ee0SThomas Raoux       for (unsigned i = 0; i < memrefVecEltRank; ++i)
122368330ee0SThomas Raoux         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
122468330ee0SThomas Raoux                "memref vector element shape should match suffix of vector "
122568330ee0SThomas Raoux                "result shape.");
12260de60b55SThomas Raoux #endif // ifndef NDEBUG
122768330ee0SThomas Raoux     }
122868330ee0SThomas Raoux 
12298345b86dSNicolas Vasilache     // 1. Get the source/dst address as an LLVM vector pointer.
1230a57def30SAart Bik     VectorType vtp = xferOp.getVectorType();
1231563879b6SRahul Joshi     Value dataPtr = this->getStridedElementPtr(
123226c8f908SThomas Raoux         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1233a57def30SAart Bik     Value vectorDataPtr =
1234a57def30SAart Bik         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
12358345b86dSNicolas Vasilache 
12361870e787SNicolas Vasilache     if (!xferOp.isMaskedDim(0))
1237563879b6SRahul Joshi       return replaceTransferOpWithLoadOrStore(rewriter,
1238563879b6SRahul Joshi                                               *this->getTypeConverter(), loc,
1239563879b6SRahul Joshi                                               xferOp, operands, vectorDataPtr);
12401870e787SNicolas Vasilache 
12418345b86dSNicolas Vasilache     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
12428345b86dSNicolas Vasilache     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
12438345b86dSNicolas Vasilache     // 4. Let dim the memref dimension, compute the vector comparison mask:
12448345b86dSNicolas Vasilache     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1245060c9dd1Saartbik     //
1246060c9dd1Saartbik     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1247060c9dd1Saartbik     //       dimensions here.
1248bd30a796SAlex Zinenko     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1249060c9dd1Saartbik     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
12500c2a4d3cSBenjamin Kramer     Value off = xferOp.indices()[lastIndex];
125126c8f908SThomas Raoux     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1252563879b6SRahul Joshi     Value mask = buildVectorComparison(
1253563879b6SRahul Joshi         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
12548345b86dSNicolas Vasilache 
12558345b86dSNicolas Vasilache     // 5. Rewrite as a masked read / write.
1256563879b6SRahul Joshi     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1257dcec2ca5SChristian Sigg                                        xferOp, operands, vectorDataPtr, mask);
12588345b86dSNicolas Vasilache   }
1259060c9dd1Saartbik 
1260060c9dd1Saartbik private:
1261060c9dd1Saartbik   const bool enableIndexOptimizations;
12628345b86dSNicolas Vasilache };
12638345b86dSNicolas Vasilache 
1264563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1265d9b500d3SAart Bik public:
1266563879b6SRahul Joshi   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1267d9b500d3SAart Bik 
1268d9b500d3SAart Bik   // Proof-of-concept lowering implementation that relies on a small
1269d9b500d3SAart Bik   // runtime support library, which only needs to provide a few
1270d9b500d3SAart Bik   // printing methods (single value for all data types, opening/closing
1271d9b500d3SAart Bik   // bracket, comma, newline). The lowering fully unrolls a vector
1272d9b500d3SAart Bik   // in terms of these elementary printing operations. The advantage
1273d9b500d3SAart Bik   // of this approach is that the library can remain unaware of all
1274d9b500d3SAart Bik   // low-level implementation details of vectors while still supporting
1275d9b500d3SAart Bik   // output of any shaped and dimensioned vector. Due to full unrolling,
1276d9b500d3SAart Bik   // this approach is less suited for very large vectors though.
1277d9b500d3SAart Bik   //
12789db53a18SRiver Riddle   // TODO: rely solely on libc in future? something else?
1279d9b500d3SAart Bik   //
12803145427dSRiver Riddle   LogicalResult
1281563879b6SRahul Joshi   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1282d9b500d3SAart Bik                   ConversionPatternRewriter &rewriter) const override {
12832d2c73c5SJacques Pienaar     auto adaptor = vector::PrintOpAdaptor(operands);
1284d9b500d3SAart Bik     Type printType = printOp.getPrintType();
1285d9b500d3SAart Bik 
1286dcec2ca5SChristian Sigg     if (typeConverter->convertType(printType) == nullptr)
12873145427dSRiver Riddle       return failure();
1288d9b500d3SAart Bik 
1289b8880f5fSAart Bik     // Make sure element type has runtime support.
1290b8880f5fSAart Bik     PrintConversion conversion = PrintConversion::None;
1291d9b500d3SAart Bik     VectorType vectorType = printType.dyn_cast<VectorType>();
1292d9b500d3SAart Bik     Type eltType = vectorType ? vectorType.getElementType() : printType;
1293d9b500d3SAart Bik     Operation *printer;
1294b8880f5fSAart Bik     if (eltType.isF32()) {
1295563879b6SRahul Joshi       printer = getPrintFloat(printOp);
1296b8880f5fSAart Bik     } else if (eltType.isF64()) {
1297563879b6SRahul Joshi       printer = getPrintDouble(printOp);
129854759cefSAart Bik     } else if (eltType.isIndex()) {
1299563879b6SRahul Joshi       printer = getPrintU64(printOp);
1300b8880f5fSAart Bik     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1301b8880f5fSAart Bik       // Integers need a zero or sign extension on the operand
1302b8880f5fSAart Bik       // (depending on the source type) as well as a signed or
1303b8880f5fSAart Bik       // unsigned print method. Up to 64-bit is supported.
1304b8880f5fSAart Bik       unsigned width = intTy.getWidth();
1305b8880f5fSAart Bik       if (intTy.isUnsigned()) {
130654759cefSAart Bik         if (width <= 64) {
1307b8880f5fSAart Bik           if (width < 64)
1308b8880f5fSAart Bik             conversion = PrintConversion::ZeroExt64;
1309563879b6SRahul Joshi           printer = getPrintU64(printOp);
1310b8880f5fSAart Bik         } else {
13113145427dSRiver Riddle           return failure();
1312b8880f5fSAart Bik         }
1313b8880f5fSAart Bik       } else {
1314b8880f5fSAart Bik         assert(intTy.isSignless() || intTy.isSigned());
131554759cefSAart Bik         if (width <= 64) {
1316b8880f5fSAart Bik           // Note that we *always* zero extend booleans (1-bit integers),
1317b8880f5fSAart Bik           // so that true/false is printed as 1/0 rather than -1/0.
1318b8880f5fSAart Bik           if (width == 1)
131954759cefSAart Bik             conversion = PrintConversion::ZeroExt64;
132054759cefSAart Bik           else if (width < 64)
1321b8880f5fSAart Bik             conversion = PrintConversion::SignExt64;
1322563879b6SRahul Joshi           printer = getPrintI64(printOp);
1323b8880f5fSAart Bik         } else {
1324b8880f5fSAart Bik           return failure();
1325b8880f5fSAart Bik         }
1326b8880f5fSAart Bik       }
1327b8880f5fSAart Bik     } else {
1328b8880f5fSAart Bik       return failure();
1329b8880f5fSAart Bik     }
1330d9b500d3SAart Bik 
1331d9b500d3SAart Bik     // Unroll vector into elementary print calls.
1332b8880f5fSAart Bik     int64_t rank = vectorType ? vectorType.getRank() : 0;
1333563879b6SRahul Joshi     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1334b8880f5fSAart Bik               conversion);
1335563879b6SRahul Joshi     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1336563879b6SRahul Joshi     rewriter.eraseOp(printOp);
13373145427dSRiver Riddle     return success();
1338d9b500d3SAart Bik   }
1339d9b500d3SAart Bik 
1340d9b500d3SAart Bik private:
1341b8880f5fSAart Bik   enum class PrintConversion {
134230e6033bSNicolas Vasilache     // clang-format off
1343b8880f5fSAart Bik     None,
1344b8880f5fSAart Bik     ZeroExt64,
1345b8880f5fSAart Bik     SignExt64
134630e6033bSNicolas Vasilache     // clang-format on
1347b8880f5fSAart Bik   };
1348b8880f5fSAart Bik 
1349d9b500d3SAart Bik   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1350e62a6956SRiver Riddle                  Value value, VectorType vectorType, Operation *printer,
1351b8880f5fSAart Bik                  int64_t rank, PrintConversion conversion) const {
1352d9b500d3SAart Bik     Location loc = op->getLoc();
1353d9b500d3SAart Bik     if (rank == 0) {
1354b8880f5fSAart Bik       switch (conversion) {
1355b8880f5fSAart Bik       case PrintConversion::ZeroExt64:
1356b8880f5fSAart Bik         value = rewriter.create<ZeroExtendIOp>(
13572230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1358b8880f5fSAart Bik         break;
1359b8880f5fSAart Bik       case PrintConversion::SignExt64:
1360b8880f5fSAart Bik         value = rewriter.create<SignExtendIOp>(
13612230bf99SAlex Zinenko             loc, value, IntegerType::get(rewriter.getContext(), 64));
1362b8880f5fSAart Bik         break;
1363b8880f5fSAart Bik       case PrintConversion::None:
1364b8880f5fSAart Bik         break;
1365c9eeeb38Saartbik       }
1366d9b500d3SAart Bik       emitCall(rewriter, loc, printer, value);
1367d9b500d3SAart Bik       return;
1368d9b500d3SAart Bik     }
1369d9b500d3SAart Bik 
1370d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintOpen(op));
1371d9b500d3SAart Bik     Operation *printComma = getPrintComma(op);
1372d9b500d3SAart Bik     int64_t dim = vectorType.getDimSize(0);
1373d9b500d3SAart Bik     for (int64_t d = 0; d < dim; ++d) {
1374d9b500d3SAart Bik       auto reducedType =
1375d9b500d3SAart Bik           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1376dcec2ca5SChristian Sigg       auto llvmType = typeConverter->convertType(
1377d9b500d3SAart Bik           rank > 1 ? reducedType : vectorType.getElementType());
1378dcec2ca5SChristian Sigg       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1379dcec2ca5SChristian Sigg                                    llvmType, rank, d);
1380b8880f5fSAart Bik       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1381b8880f5fSAart Bik                 conversion);
1382d9b500d3SAart Bik       if (d != dim - 1)
1383d9b500d3SAart Bik         emitCall(rewriter, loc, printComma);
1384d9b500d3SAart Bik     }
1385d9b500d3SAart Bik     emitCall(rewriter, loc, getPrintClose(op));
1386d9b500d3SAart Bik   }
1387d9b500d3SAart Bik 
1388d9b500d3SAart Bik   // Helper to emit a call.
1389d9b500d3SAart Bik   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1390d9b500d3SAart Bik                        Operation *ref, ValueRange params = ValueRange()) {
139108e4f078SRahul Joshi     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1392d9b500d3SAart Bik                                   rewriter.getSymbolRefAttr(ref), params);
1393d9b500d3SAart Bik   }
1394d9b500d3SAart Bik 
1395d9b500d3SAart Bik   // Helper for printer method declaration (first hit) and lookup.
13965446ec85SAlex Zinenko   static Operation *getPrint(Operation *op, StringRef name,
1397c69c9e0fSAlex Zinenko                              ArrayRef<Type> params) {
1398d9b500d3SAart Bik     auto module = op->getParentOfType<ModuleOp>();
1399d9b500d3SAart Bik     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1400d9b500d3SAart Bik     if (func)
1401d9b500d3SAart Bik       return func;
1402d9b500d3SAart Bik     OpBuilder moduleBuilder(module.getBodyRegion());
1403d9b500d3SAart Bik     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1404d9b500d3SAart Bik         op->getLoc(), name,
14057ed9cfc7SAlex Zinenko         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
14067ed9cfc7SAlex Zinenko                                     params));
1407d9b500d3SAart Bik   }
1408d9b500d3SAart Bik 
1409d9b500d3SAart Bik   // Helpers for method names.
1410e52414b1Saartbik   Operation *getPrintI64(Operation *op) const {
14112230bf99SAlex Zinenko     return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1412e52414b1Saartbik   }
1413b8880f5fSAart Bik   Operation *getPrintU64(Operation *op) const {
14142230bf99SAlex Zinenko     return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1415b8880f5fSAart Bik   }
1416d9b500d3SAart Bik   Operation *getPrintFloat(Operation *op) const {
1417dd5165a9SAlex Zinenko     return getPrint(op, "printF32", Float32Type::get(op->getContext()));
1418d9b500d3SAart Bik   }
1419d9b500d3SAart Bik   Operation *getPrintDouble(Operation *op) const {
1420dd5165a9SAlex Zinenko     return getPrint(op, "printF64", Float64Type::get(op->getContext()));
1421d9b500d3SAart Bik   }
1422d9b500d3SAart Bik   Operation *getPrintOpen(Operation *op) const {
142354759cefSAart Bik     return getPrint(op, "printOpen", {});
1424d9b500d3SAart Bik   }
1425d9b500d3SAart Bik   Operation *getPrintClose(Operation *op) const {
142654759cefSAart Bik     return getPrint(op, "printClose", {});
1427d9b500d3SAart Bik   }
1428d9b500d3SAart Bik   Operation *getPrintComma(Operation *op) const {
142954759cefSAart Bik     return getPrint(op, "printComma", {});
1430d9b500d3SAart Bik   }
1431d9b500d3SAart Bik   Operation *getPrintNewline(Operation *op) const {
143254759cefSAart Bik     return getPrint(op, "printNewline", {});
1433d9b500d3SAart Bik   }
1434d9b500d3SAart Bik };
1435d9b500d3SAart Bik 
1436334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either:
1437c3c95b9cSaartbik ///   1. express single offset extract as a direct shuffle.
1438c3c95b9cSaartbik ///   2. extract + lower rank strided_slice + insert for the n-D case.
1439c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion
1440334a4159SReid Tatge     : public OpRewritePattern<ExtractStridedSliceOp> {
144165678d93SNicolas Vasilache public:
1442b99bd771SRiver Riddle   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1443b99bd771SRiver Riddle       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1444b99bd771SRiver Riddle     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1445b99bd771SRiver Riddle     // is bounded as the rank is strictly decreasing.
1446b99bd771SRiver Riddle     setHasBoundedRewriteRecursion();
1447b99bd771SRiver Riddle   }
144865678d93SNicolas Vasilache 
1449334a4159SReid Tatge   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
145065678d93SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
14519eb3e564SChris Lattner     auto dstType = op.getType();
145265678d93SNicolas Vasilache 
145365678d93SNicolas Vasilache     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
145465678d93SNicolas Vasilache 
145565678d93SNicolas Vasilache     int64_t offset =
145665678d93SNicolas Vasilache         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
145765678d93SNicolas Vasilache     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
145865678d93SNicolas Vasilache     int64_t stride =
145965678d93SNicolas Vasilache         op.strides().getValue().front().cast<IntegerAttr>().getInt();
146065678d93SNicolas Vasilache 
146165678d93SNicolas Vasilache     auto loc = op.getLoc();
146265678d93SNicolas Vasilache     auto elemType = dstType.getElementType();
146335b68527SLei Zhang     assert(elemType.isSignlessIntOrIndexOrFloat());
1464c3c95b9cSaartbik 
1465c3c95b9cSaartbik     // Single offset can be more efficiently shuffled.
1466c3c95b9cSaartbik     if (op.offsets().getValue().size() == 1) {
1467c3c95b9cSaartbik       SmallVector<int64_t, 4> offsets;
1468c3c95b9cSaartbik       offsets.reserve(size);
1469c3c95b9cSaartbik       for (int64_t off = offset, e = offset + size * stride; off < e;
1470c3c95b9cSaartbik            off += stride)
1471c3c95b9cSaartbik         offsets.push_back(off);
1472c3c95b9cSaartbik       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1473c3c95b9cSaartbik                                              op.vector(),
1474c3c95b9cSaartbik                                              rewriter.getI64ArrayAttr(offsets));
1475c3c95b9cSaartbik       return success();
1476c3c95b9cSaartbik     }
1477c3c95b9cSaartbik 
1478c3c95b9cSaartbik     // Extract/insert on a lower ranked extract strided slice op.
147965678d93SNicolas Vasilache     Value zero = rewriter.create<ConstantOp>(loc, elemType,
148065678d93SNicolas Vasilache                                              rewriter.getZeroAttr(elemType));
148165678d93SNicolas Vasilache     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
148265678d93SNicolas Vasilache     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
148365678d93SNicolas Vasilache          off += stride, ++idx) {
1484c3c95b9cSaartbik       Value one = extractOne(rewriter, loc, op.vector(), off);
1485c3c95b9cSaartbik       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1486c3c95b9cSaartbik           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
148765678d93SNicolas Vasilache           getI64SubArray(op.sizes(), /* dropFront=*/1),
148865678d93SNicolas Vasilache           getI64SubArray(op.strides(), /* dropFront=*/1));
148965678d93SNicolas Vasilache       res = insertOne(rewriter, loc, extracted, res, idx);
149065678d93SNicolas Vasilache     }
1491c3c95b9cSaartbik     rewriter.replaceOp(op, res);
14923145427dSRiver Riddle     return success();
149365678d93SNicolas Vasilache   }
149465678d93SNicolas Vasilache };
149565678d93SNicolas Vasilache 
1496df186507SBenjamin Kramer } // namespace
1497df186507SBenjamin Kramer 
14985c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM.
14995c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns(
1500ceb1b327Saartbik     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1501060c9dd1Saartbik     bool reassociateFPReductions, bool enableIndexOptimizations) {
150265678d93SNicolas Vasilache   MLIRContext *ctx = converter.getDialect()->getContext();
15038345b86dSNicolas Vasilache   // clang-format off
1504681f929fSNicolas Vasilache   patterns.insert<VectorFMAOpNDRewritePattern,
1505681f929fSNicolas Vasilache                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
15062d515e49SNicolas Vasilache                   VectorInsertStridedSliceOpSameRankRewritePattern,
1507c3c95b9cSaartbik                   VectorExtractStridedSliceOpConversion>(ctx);
1508ceb1b327Saartbik   patterns.insert<VectorReductionOpConversion>(
1509563879b6SRahul Joshi       converter, reassociateFPReductions);
1510060c9dd1Saartbik   patterns.insert<VectorCreateMaskOpConversion,
1511060c9dd1Saartbik                   VectorTransferConversion<TransferReadOp>,
1512060c9dd1Saartbik                   VectorTransferConversion<TransferWriteOp>>(
1513563879b6SRahul Joshi       converter, enableIndexOptimizations);
15148345b86dSNicolas Vasilache   patterns
1515*cf5c517cSDiego Caballero       .insert<VectorBitCastOpConversion,
1516*cf5c517cSDiego Caballero               VectorShuffleOpConversion,
15178345b86dSNicolas Vasilache               VectorExtractElementOpConversion,
15188345b86dSNicolas Vasilache               VectorExtractOpConversion,
15198345b86dSNicolas Vasilache               VectorFMAOp1DConversion,
15208345b86dSNicolas Vasilache               VectorInsertElementOpConversion,
15218345b86dSNicolas Vasilache               VectorInsertOpConversion,
15228345b86dSNicolas Vasilache               VectorPrintOpConversion,
152319dbb230Saartbik               VectorTypeCastOpConversion,
152439379916Saartbik               VectorMaskedLoadOpConversion,
152539379916Saartbik               VectorMaskedStoreOpConversion,
152619dbb230Saartbik               VectorGatherOpConversion,
1527e8dcf5f8Saartbik               VectorScatterOpConversion,
1528e8dcf5f8Saartbik               VectorExpandLoadOpConversion,
1529563879b6SRahul Joshi               VectorCompressStoreOpConversion>(converter);
15308345b86dSNicolas Vasilache   // clang-format on
15315c0c51a9SNicolas Vasilache }
15325c0c51a9SNicolas Vasilache 
153363b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns(
153463b683a8SNicolas Vasilache     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1535563879b6SRahul Joshi   patterns.insert<VectorMatmulOpConversion>(converter);
1536563879b6SRahul Joshi   patterns.insert<VectorFlatTransposeOpConversion>(converter);
153763b683a8SNicolas Vasilache }
1538