15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 195c0c51a9SNicolas Vasilache 205c0c51a9SNicolas Vasilache using namespace mlir; 2165678d93SNicolas Vasilache using namespace mlir::vector; 225c0c51a9SNicolas Vasilache 239826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 259826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 269826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 279826fe5cSAart Bik } 289826fe5cSAart Bik 299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 319826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 329826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 339826fe5cSAart Bik } 349826fe5cSAart Bik 351c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 370f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 380f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 390f04384dSAlex Zinenko int64_t pos) { 401c81adf3SAart Bik if (rank == 1) { 411c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 421c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 430f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 441c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 451c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 461c81adf3SAart Bik constant); 471c81adf3SAart Bik } 481c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 491c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 501c81adf3SAart Bik } 511c81adf3SAart Bik 522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 542d515e49SNicolas Vasilache Value into, int64_t offset) { 552d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 562d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 572d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 582d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 592d515e49SNicolas Vasilache loc, vectorType, from, into, 602d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 612d515e49SNicolas Vasilache } 622d515e49SNicolas Vasilache 631c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 650f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 660f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 671c81adf3SAart Bik if (rank == 1) { 681c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 691c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 700f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 711c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 721c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 731c81adf3SAart Bik constant); 741c81adf3SAart Bik } 751c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 761c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 771c81adf3SAart Bik } 781c81adf3SAart Bik 792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 812d515e49SNicolas Vasilache int64_t offset) { 822d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 832d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 842d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 852d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 862d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 872d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 882d515e49SNicolas Vasilache } 892d515e49SNicolas Vasilache 902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 932d515e49SNicolas Vasilache unsigned dropFront = 0, 942d515e49SNicolas Vasilache unsigned dropBack = 0) { 952d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 962d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 972d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 982d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 992d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1002d515e49SNicolas Vasilache it != eit; ++it) 1012d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1022d515e49SNicolas Vasilache return res; 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 106060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107060c9dd1Saartbik // 108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 110060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 111060c9dd1Saartbik // is very conservative on the boundary conditions. 112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 114060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 115060c9dd1Saartbik auto loc = op->getLoc(); 116060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 117060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 118060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 119060c9dd1Saartbik Value indices; 120060c9dd1Saartbik Type idxType; 121060c9dd1Saartbik if (enableIndexOptimizations) { 1220c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1230c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1240c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125060c9dd1Saartbik idxType = rewriter.getI32Type(); 126060c9dd1Saartbik } else { 1270c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1280c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1290c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130060c9dd1Saartbik idxType = rewriter.getI64Type(); 131060c9dd1Saartbik } 132060c9dd1Saartbik // Add in an offset if requested. 133060c9dd1Saartbik if (off) { 134060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 137060c9dd1Saartbik } 138060c9dd1Saartbik // Construct the vector comparison. 139060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142060c9dd1Saartbik } 143060c9dd1Saartbik 14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref. 14519dbb230Saartbik template <typename T> 14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 14719dbb230Saartbik unsigned &align) { 1485f9e0466SNicolas Vasilache Type elementTy = 14919dbb230Saartbik typeConverter.convertType(op.getMemRefType().getElementType()); 1505f9e0466SNicolas Vasilache if (!elementTy) 1515f9e0466SNicolas Vasilache return failure(); 1525f9e0466SNicolas Vasilache 153b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 154b2ab375dSAlex Zinenko // stop depending on translation. 15587a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 15687a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 157b2ab375dSAlex Zinenko .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 158168213f9SAlex Zinenko typeConverter.getDataLayout()); 1595f9e0466SNicolas Vasilache return success(); 1605f9e0466SNicolas Vasilache } 1615f9e0466SNicolas Vasilache 162e8dcf5f8Saartbik // Helper that returns the base address of a memref. 163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 164e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 16519dbb230Saartbik // Inspect stride and offset structure. 16619dbb230Saartbik // 16719dbb230Saartbik // TODO: flat memory only for now, generalize 16819dbb230Saartbik // 16919dbb230Saartbik int64_t offset; 17019dbb230Saartbik SmallVector<int64_t, 4> strides; 17119dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 17219dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 17319dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 17419dbb230Saartbik return failure(); 175e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 176e8dcf5f8Saartbik return success(); 177e8dcf5f8Saartbik } 17819dbb230Saartbik 179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base. 180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 181b98e25b6SBenjamin Kramer Location loc, Value memref, 182b98e25b6SBenjamin Kramer MemRefType memRefType, Value &ptr) { 183e8dcf5f8Saartbik Value base; 184e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 185e8dcf5f8Saartbik return failure(); 1863a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 187e8dcf5f8Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 188e8dcf5f8Saartbik return success(); 189e8dcf5f8Saartbik } 190e8dcf5f8Saartbik 19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base. 192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 193b98e25b6SBenjamin Kramer Location loc, Value memref, 194b98e25b6SBenjamin Kramer MemRefType memRefType, Type type, Value &ptr) { 19539379916Saartbik Value base; 19639379916Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 19739379916Saartbik return failure(); 19839379916Saartbik auto pType = type.template cast<LLVM::LLVMType>().getPointerTo(); 19939379916Saartbik base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 20039379916Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 20139379916Saartbik return success(); 20239379916Saartbik } 20339379916Saartbik 204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index 205e8dcf5f8Saartbik // vector. 206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 207b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 208b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 209b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 210e8dcf5f8Saartbik Value base; 211e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 212e8dcf5f8Saartbik return failure(); 2133a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 214e8dcf5f8Saartbik auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); 2151485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 21619dbb230Saartbik return success(); 21719dbb230Saartbik } 21819dbb230Saartbik 2195f9e0466SNicolas Vasilache static LogicalResult 2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2215f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2225f9e0466SNicolas Vasilache TransferReadOp xferOp, 2235f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 224affbc0cdSNicolas Vasilache unsigned align; 22519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 226affbc0cdSNicolas Vasilache return failure(); 227affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2285f9e0466SNicolas Vasilache return success(); 2295f9e0466SNicolas Vasilache } 2305f9e0466SNicolas Vasilache 2315f9e0466SNicolas Vasilache static LogicalResult 2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2335f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2345f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2355f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2365f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2375f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2385f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2395f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2405f9e0466SNicolas Vasilache 2415f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2425f9e0466SNicolas Vasilache if (!vecTy) 2435f9e0466SNicolas Vasilache return failure(); 2445f9e0466SNicolas Vasilache 2455f9e0466SNicolas Vasilache unsigned align; 24619dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2475f9e0466SNicolas Vasilache return failure(); 2485f9e0466SNicolas Vasilache 2495f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2505f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2515f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2525f9e0466SNicolas Vasilache return success(); 2535f9e0466SNicolas Vasilache } 2545f9e0466SNicolas Vasilache 2555f9e0466SNicolas Vasilache static LogicalResult 2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2575f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2585f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2595f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 260affbc0cdSNicolas Vasilache unsigned align; 26119dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 262affbc0cdSNicolas Vasilache return failure(); 2632d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 264affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 265affbc0cdSNicolas Vasilache align); 2665f9e0466SNicolas Vasilache return success(); 2675f9e0466SNicolas Vasilache } 2685f9e0466SNicolas Vasilache 2695f9e0466SNicolas Vasilache static LogicalResult 2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2715f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2725f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2735f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2745f9e0466SNicolas Vasilache unsigned align; 27519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2765f9e0466SNicolas Vasilache return failure(); 2775f9e0466SNicolas Vasilache 2782d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2795f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2805f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2815f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2825f9e0466SNicolas Vasilache return success(); 2835f9e0466SNicolas Vasilache } 2845f9e0466SNicolas Vasilache 2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2862d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2872d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2885f9e0466SNicolas Vasilache } 2895f9e0466SNicolas Vasilache 2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2912d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2922d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2935f9e0466SNicolas Vasilache } 2945f9e0466SNicolas Vasilache 29590c01357SBenjamin Kramer namespace { 296e83b7b99Saartbik 29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 29963b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern { 30063b683a8SNicolas Vasilache public: 30163b683a8SNicolas Vasilache explicit VectorMatmulOpConversion(MLIRContext *context, 30263b683a8SNicolas Vasilache LLVMTypeConverter &typeConverter) 30363b683a8SNicolas Vasilache : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 30463b683a8SNicolas Vasilache typeConverter) {} 30563b683a8SNicolas Vasilache 3063145427dSRiver Riddle LogicalResult 30763b683a8SNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 30863b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 30963b683a8SNicolas Vasilache auto matmulOp = cast<vector::MatmulOp>(op); 3102d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 31163b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 312*dcec2ca5SChristian Sigg op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(), 31363b683a8SNicolas Vasilache adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 31463b683a8SNicolas Vasilache matmulOp.rhs_columns()); 3153145427dSRiver Riddle return success(); 31663b683a8SNicolas Vasilache } 31763b683a8SNicolas Vasilache }; 31863b683a8SNicolas Vasilache 319c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 320c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 321c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 322c295a65dSaartbik public: 323c295a65dSaartbik explicit VectorFlatTransposeOpConversion(MLIRContext *context, 324c295a65dSaartbik LLVMTypeConverter &typeConverter) 325c295a65dSaartbik : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 326c295a65dSaartbik context, typeConverter) {} 327c295a65dSaartbik 328c295a65dSaartbik LogicalResult 329c295a65dSaartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 330c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 331c295a65dSaartbik auto transOp = cast<vector::FlatTransposeOp>(op); 3322d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 333c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 334*dcec2ca5SChristian Sigg transOp, typeConverter->convertType(transOp.res().getType()), 335c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 336c295a65dSaartbik return success(); 337c295a65dSaartbik } 338c295a65dSaartbik }; 339c295a65dSaartbik 34039379916Saartbik /// Conversion pattern for a vector.maskedload. 34139379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { 34239379916Saartbik public: 34339379916Saartbik explicit VectorMaskedLoadOpConversion(MLIRContext *context, 34439379916Saartbik LLVMTypeConverter &typeConverter) 34539379916Saartbik : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, 34639379916Saartbik typeConverter) {} 34739379916Saartbik 34839379916Saartbik LogicalResult 34939379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 35039379916Saartbik ConversionPatternRewriter &rewriter) const override { 35139379916Saartbik auto loc = op->getLoc(); 35239379916Saartbik auto load = cast<vector::MaskedLoadOp>(op); 35339379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 35439379916Saartbik 35539379916Saartbik // Resolve alignment. 35639379916Saartbik unsigned align; 357*dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), load, align))) 35839379916Saartbik return failure(); 35939379916Saartbik 360*dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(load.getResultVectorType()); 36139379916Saartbik Value ptr; 36239379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 36339379916Saartbik vtype, ptr))) 36439379916Saartbik return failure(); 36539379916Saartbik 36639379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 36739379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 36839379916Saartbik rewriter.getI32IntegerAttr(align)); 36939379916Saartbik return success(); 37039379916Saartbik } 37139379916Saartbik }; 37239379916Saartbik 37339379916Saartbik /// Conversion pattern for a vector.maskedstore. 37439379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { 37539379916Saartbik public: 37639379916Saartbik explicit VectorMaskedStoreOpConversion(MLIRContext *context, 37739379916Saartbik LLVMTypeConverter &typeConverter) 37839379916Saartbik : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, 37939379916Saartbik typeConverter) {} 38039379916Saartbik 38139379916Saartbik LogicalResult 38239379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 38339379916Saartbik ConversionPatternRewriter &rewriter) const override { 38439379916Saartbik auto loc = op->getLoc(); 38539379916Saartbik auto store = cast<vector::MaskedStoreOp>(op); 38639379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 38739379916Saartbik 38839379916Saartbik // Resolve alignment. 38939379916Saartbik unsigned align; 390*dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), store, align))) 39139379916Saartbik return failure(); 39239379916Saartbik 393*dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(store.getValueVectorType()); 39439379916Saartbik Value ptr; 39539379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 39639379916Saartbik vtype, ptr))) 39739379916Saartbik return failure(); 39839379916Saartbik 39939379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 40039379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 40139379916Saartbik rewriter.getI32IntegerAttr(align)); 40239379916Saartbik return success(); 40339379916Saartbik } 40439379916Saartbik }; 40539379916Saartbik 40619dbb230Saartbik /// Conversion pattern for a vector.gather. 40719dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern { 40819dbb230Saartbik public: 40919dbb230Saartbik explicit VectorGatherOpConversion(MLIRContext *context, 41019dbb230Saartbik LLVMTypeConverter &typeConverter) 41119dbb230Saartbik : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, 41219dbb230Saartbik typeConverter) {} 41319dbb230Saartbik 41419dbb230Saartbik LogicalResult 41519dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 41619dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 41719dbb230Saartbik auto loc = op->getLoc(); 41819dbb230Saartbik auto gather = cast<vector::GatherOp>(op); 41919dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 42019dbb230Saartbik 42119dbb230Saartbik // Resolve alignment. 42219dbb230Saartbik unsigned align; 423*dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), gather, align))) 42419dbb230Saartbik return failure(); 42519dbb230Saartbik 42619dbb230Saartbik // Get index ptrs. 42719dbb230Saartbik VectorType vType = gather.getResultVectorType(); 42819dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 42919dbb230Saartbik Value ptrs; 430e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 431e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 43219dbb230Saartbik return failure(); 43319dbb230Saartbik 43419dbb230Saartbik // Replace with the gather intrinsic. 43519dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 436*dcec2ca5SChristian Sigg gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 4370c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 43819dbb230Saartbik return success(); 43919dbb230Saartbik } 44019dbb230Saartbik }; 44119dbb230Saartbik 44219dbb230Saartbik /// Conversion pattern for a vector.scatter. 44319dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern { 44419dbb230Saartbik public: 44519dbb230Saartbik explicit VectorScatterOpConversion(MLIRContext *context, 44619dbb230Saartbik LLVMTypeConverter &typeConverter) 44719dbb230Saartbik : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, 44819dbb230Saartbik typeConverter) {} 44919dbb230Saartbik 45019dbb230Saartbik LogicalResult 45119dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 45219dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 45319dbb230Saartbik auto loc = op->getLoc(); 45419dbb230Saartbik auto scatter = cast<vector::ScatterOp>(op); 45519dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 45619dbb230Saartbik 45719dbb230Saartbik // Resolve alignment. 45819dbb230Saartbik unsigned align; 459*dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align))) 46019dbb230Saartbik return failure(); 46119dbb230Saartbik 46219dbb230Saartbik // Get index ptrs. 46319dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 46419dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 46519dbb230Saartbik Value ptrs; 466e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 467e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 46819dbb230Saartbik return failure(); 46919dbb230Saartbik 47019dbb230Saartbik // Replace with the scatter intrinsic. 47119dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 47219dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 47319dbb230Saartbik rewriter.getI32IntegerAttr(align)); 47419dbb230Saartbik return success(); 47519dbb230Saartbik } 47619dbb230Saartbik }; 47719dbb230Saartbik 478e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 479e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { 480e8dcf5f8Saartbik public: 481e8dcf5f8Saartbik explicit VectorExpandLoadOpConversion(MLIRContext *context, 482e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 483e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, 484e8dcf5f8Saartbik typeConverter) {} 485e8dcf5f8Saartbik 486e8dcf5f8Saartbik LogicalResult 487e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 488e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 489e8dcf5f8Saartbik auto loc = op->getLoc(); 490e8dcf5f8Saartbik auto expand = cast<vector::ExpandLoadOp>(op); 491e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 492e8dcf5f8Saartbik 493e8dcf5f8Saartbik Value ptr; 494e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 495e8dcf5f8Saartbik ptr))) 496e8dcf5f8Saartbik return failure(); 497e8dcf5f8Saartbik 498e8dcf5f8Saartbik auto vType = expand.getResultVectorType(); 499e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 500*dcec2ca5SChristian Sigg op, typeConverter->convertType(vType), ptr, adaptor.mask(), 501e8dcf5f8Saartbik adaptor.pass_thru()); 502e8dcf5f8Saartbik return success(); 503e8dcf5f8Saartbik } 504e8dcf5f8Saartbik }; 505e8dcf5f8Saartbik 506e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 507e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { 508e8dcf5f8Saartbik public: 509e8dcf5f8Saartbik explicit VectorCompressStoreOpConversion(MLIRContext *context, 510e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 511e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), 512e8dcf5f8Saartbik context, typeConverter) {} 513e8dcf5f8Saartbik 514e8dcf5f8Saartbik LogicalResult 515e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 516e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 517e8dcf5f8Saartbik auto loc = op->getLoc(); 518e8dcf5f8Saartbik auto compress = cast<vector::CompressStoreOp>(op); 519e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 520e8dcf5f8Saartbik 521e8dcf5f8Saartbik Value ptr; 522e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), 523e8dcf5f8Saartbik compress.getMemRefType(), ptr))) 524e8dcf5f8Saartbik return failure(); 525e8dcf5f8Saartbik 526e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 527e8dcf5f8Saartbik op, adaptor.value(), ptr, adaptor.mask()); 528e8dcf5f8Saartbik return success(); 529e8dcf5f8Saartbik } 530e8dcf5f8Saartbik }; 531e8dcf5f8Saartbik 53219dbb230Saartbik /// Conversion pattern for all vector reductions. 533870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern { 534e83b7b99Saartbik public: 535e83b7b99Saartbik explicit VectorReductionOpConversion(MLIRContext *context, 536ceb1b327Saartbik LLVMTypeConverter &typeConverter, 537060c9dd1Saartbik bool reassociateFPRed) 538870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 539ceb1b327Saartbik typeConverter), 540060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 541e83b7b99Saartbik 5423145427dSRiver Riddle LogicalResult 543e83b7b99Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 544e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 545e83b7b99Saartbik auto reductionOp = cast<vector::ReductionOp>(op); 546e83b7b99Saartbik auto kind = reductionOp.kind(); 547e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 548*dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(eltType); 549e9628955SAart Bik if (eltType.isIntOrIndex()) { 550e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 551e83b7b99Saartbik if (kind == "add") 552322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 553e83b7b99Saartbik op, llvmType, operands[0]); 554e83b7b99Saartbik else if (kind == "mul") 555322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 556e83b7b99Saartbik op, llvmType, operands[0]); 557e9628955SAart Bik else if (kind == "min" && 558e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 559322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 560e9628955SAart Bik op, llvmType, operands[0]); 561e83b7b99Saartbik else if (kind == "min") 562322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 563e83b7b99Saartbik op, llvmType, operands[0]); 564e9628955SAart Bik else if (kind == "max" && 565e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 566322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 567e9628955SAart Bik op, llvmType, operands[0]); 568e83b7b99Saartbik else if (kind == "max") 569322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 570e83b7b99Saartbik op, llvmType, operands[0]); 571e83b7b99Saartbik else if (kind == "and") 572322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 573e83b7b99Saartbik op, llvmType, operands[0]); 574e83b7b99Saartbik else if (kind == "or") 575322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 576e83b7b99Saartbik op, llvmType, operands[0]); 577e83b7b99Saartbik else if (kind == "xor") 578322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 579e83b7b99Saartbik op, llvmType, operands[0]); 580e83b7b99Saartbik else 5813145427dSRiver Riddle return failure(); 5823145427dSRiver Riddle return success(); 583*dcec2ca5SChristian Sigg } 584e83b7b99Saartbik 585*dcec2ca5SChristian Sigg if (!eltType.isa<FloatType>()) 586*dcec2ca5SChristian Sigg return failure(); 587*dcec2ca5SChristian Sigg 588e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 589e83b7b99Saartbik if (kind == "add") { 5900d924700Saartbik // Optional accumulator (or zero). 5910d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5920d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 5930d924700Saartbik op->getLoc(), llvmType, 5940d924700Saartbik rewriter.getZeroAttr(eltType)); 595322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 596ceb1b327Saartbik op, llvmType, acc, operands[0], 597ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 598e83b7b99Saartbik } else if (kind == "mul") { 5990d924700Saartbik // Optional accumulator (or one). 6000d924700Saartbik Value acc = operands.size() > 1 6010d924700Saartbik ? operands[1] 6020d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 6030d924700Saartbik op->getLoc(), llvmType, 6040d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 605322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 606ceb1b327Saartbik op, llvmType, acc, operands[0], 607ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 608e83b7b99Saartbik } else if (kind == "min") 609*dcec2ca5SChristian Sigg rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType, 610*dcec2ca5SChristian Sigg operands[0]); 611e83b7b99Saartbik else if (kind == "max") 612*dcec2ca5SChristian Sigg rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType, 613*dcec2ca5SChristian Sigg operands[0]); 614e83b7b99Saartbik else 6153145427dSRiver Riddle return failure(); 6163145427dSRiver Riddle return success(); 617e83b7b99Saartbik } 618ceb1b327Saartbik 619ceb1b327Saartbik private: 620ceb1b327Saartbik const bool reassociateFPReductions; 621e83b7b99Saartbik }; 622e83b7b99Saartbik 623060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 624060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { 625060c9dd1Saartbik public: 626060c9dd1Saartbik explicit VectorCreateMaskOpConversion(MLIRContext *context, 627060c9dd1Saartbik LLVMTypeConverter &typeConverter, 628060c9dd1Saartbik bool enableIndexOpt) 629060c9dd1Saartbik : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, 630060c9dd1Saartbik typeConverter), 631060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 632060c9dd1Saartbik 633060c9dd1Saartbik LogicalResult 634060c9dd1Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 635060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 636060c9dd1Saartbik auto dstType = op->getResult(0).getType().cast<VectorType>(); 637060c9dd1Saartbik int64_t rank = dstType.getRank(); 638060c9dd1Saartbik if (rank == 1) { 639060c9dd1Saartbik rewriter.replaceOp( 640060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 641060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 642060c9dd1Saartbik return success(); 643060c9dd1Saartbik } 644060c9dd1Saartbik return failure(); 645060c9dd1Saartbik } 646060c9dd1Saartbik 647060c9dd1Saartbik private: 648060c9dd1Saartbik const bool enableIndexOptimizations; 649060c9dd1Saartbik }; 650060c9dd1Saartbik 651870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern { 6521c81adf3SAart Bik public: 6531c81adf3SAart Bik explicit VectorShuffleOpConversion(MLIRContext *context, 6541c81adf3SAart Bik LLVMTypeConverter &typeConverter) 655870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 6561c81adf3SAart Bik typeConverter) {} 6571c81adf3SAart Bik 6583145427dSRiver Riddle LogicalResult 659e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 6601c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 6611c81adf3SAart Bik auto loc = op->getLoc(); 6622d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6631c81adf3SAart Bik auto shuffleOp = cast<vector::ShuffleOp>(op); 6641c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6651c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6661c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 667*dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(vectorType); 6681c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6691c81adf3SAart Bik 6701c81adf3SAart Bik // Bail if result type cannot be lowered. 6711c81adf3SAart Bik if (!llvmType) 6723145427dSRiver Riddle return failure(); 6731c81adf3SAart Bik 6741c81adf3SAart Bik // Get rank and dimension sizes. 6751c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6761c81adf3SAart Bik assert(v1Type.getRank() == rank); 6771c81adf3SAart Bik assert(v2Type.getRank() == rank); 6781c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6791c81adf3SAart Bik 6801c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6811c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6821c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 683e62a6956SRiver Riddle Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 6841c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 6851c81adf3SAart Bik rewriter.replaceOp(op, shuffle); 6863145427dSRiver Riddle return success(); 687b36aaeafSAart Bik } 688b36aaeafSAart Bik 6891c81adf3SAart Bik // For all other cases, insert the individual values individually. 690e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6911c81adf3SAart Bik int64_t insPos = 0; 6921c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6931c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 694e62a6956SRiver Riddle Value value = adaptor.v1(); 6951c81adf3SAart Bik if (extPos >= v1Dim) { 6961c81adf3SAart Bik extPos -= v1Dim; 6971c81adf3SAart Bik value = adaptor.v2(); 698b36aaeafSAart Bik } 699*dcec2ca5SChristian Sigg Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 700*dcec2ca5SChristian Sigg llvmType, rank, extPos); 701*dcec2ca5SChristian Sigg insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 7020f04384dSAlex Zinenko llvmType, rank, insPos++); 7031c81adf3SAart Bik } 7041c81adf3SAart Bik rewriter.replaceOp(op, insert); 7053145427dSRiver Riddle return success(); 706b36aaeafSAart Bik } 707b36aaeafSAart Bik }; 708b36aaeafSAart Bik 709870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 710cd5dab8aSAart Bik public: 711cd5dab8aSAart Bik explicit VectorExtractElementOpConversion(MLIRContext *context, 712cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 713870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 714870c1fd4SAlex Zinenko context, typeConverter) {} 715cd5dab8aSAart Bik 7163145427dSRiver Riddle LogicalResult 717e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 718cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7192d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 720cd5dab8aSAart Bik auto extractEltOp = cast<vector::ExtractElementOp>(op); 721cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 722*dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType.getElementType()); 723cd5dab8aSAart Bik 724cd5dab8aSAart Bik // Bail if result type cannot be lowered. 725cd5dab8aSAart Bik if (!llvmType) 7263145427dSRiver Riddle return failure(); 727cd5dab8aSAart Bik 728cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 729cd5dab8aSAart Bik op, llvmType, adaptor.vector(), adaptor.position()); 7303145427dSRiver Riddle return success(); 731cd5dab8aSAart Bik } 732cd5dab8aSAart Bik }; 733cd5dab8aSAart Bik 734870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern { 7355c0c51a9SNicolas Vasilache public: 7369826fe5cSAart Bik explicit VectorExtractOpConversion(MLIRContext *context, 7375c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 738870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 7395c0c51a9SNicolas Vasilache typeConverter) {} 7405c0c51a9SNicolas Vasilache 7413145427dSRiver Riddle LogicalResult 742e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 7435c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7445c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 7452d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 746d37f2725SAart Bik auto extractOp = cast<vector::ExtractOp>(op); 7479826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7482bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 749*dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(resultType); 7505c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7519826fe5cSAart Bik 7529826fe5cSAart Bik // Bail if result type cannot be lowered. 7539826fe5cSAart Bik if (!llvmResultType) 7543145427dSRiver Riddle return failure(); 7559826fe5cSAart Bik 7565c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7575c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 758e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7595c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 7605c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7613145427dSRiver Riddle return success(); 7625c0c51a9SNicolas Vasilache } 7635c0c51a9SNicolas Vasilache 7649826fe5cSAart Bik // Potential extraction of 1-D vector from array. 7655c0c51a9SNicolas Vasilache auto *context = op->getContext(); 766e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7675c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7685c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7699826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7705c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7715c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7725c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 773*dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 7745c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7755c0c51a9SNicolas Vasilache } 7765c0c51a9SNicolas Vasilache 7775c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7785c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7795446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 7801d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7815c0c51a9SNicolas Vasilache extracted = 7825c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 7835c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7845c0c51a9SNicolas Vasilache 7853145427dSRiver Riddle return success(); 7865c0c51a9SNicolas Vasilache } 7875c0c51a9SNicolas Vasilache }; 7885c0c51a9SNicolas Vasilache 789681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 790681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 791681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 792681f929fSNicolas Vasilache /// 793681f929fSNicolas Vasilache /// Example: 794681f929fSNicolas Vasilache /// ``` 795681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 796681f929fSNicolas Vasilache /// ``` 797681f929fSNicolas Vasilache /// is converted to: 798681f929fSNicolas Vasilache /// ``` 7993bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 800681f929fSNicolas Vasilache /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 801681f929fSNicolas Vasilache /// -> !llvm<"<8 x float>"> 802681f929fSNicolas Vasilache /// ``` 803870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 804681f929fSNicolas Vasilache public: 805681f929fSNicolas Vasilache explicit VectorFMAOp1DConversion(MLIRContext *context, 806681f929fSNicolas Vasilache LLVMTypeConverter &typeConverter) 807870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 808681f929fSNicolas Vasilache typeConverter) {} 809681f929fSNicolas Vasilache 8103145427dSRiver Riddle LogicalResult 811681f929fSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 812681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 8132d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 814681f929fSNicolas Vasilache vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 815681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 816681f929fSNicolas Vasilache if (vType.getRank() != 1) 8173145427dSRiver Riddle return failure(); 8183bffe602SBenjamin Kramer rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(), 8193bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 8203145427dSRiver Riddle return success(); 821681f929fSNicolas Vasilache } 822681f929fSNicolas Vasilache }; 823681f929fSNicolas Vasilache 824870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 825cd5dab8aSAart Bik public: 826cd5dab8aSAart Bik explicit VectorInsertElementOpConversion(MLIRContext *context, 827cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 828870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 829870c1fd4SAlex Zinenko context, typeConverter) {} 830cd5dab8aSAart Bik 8313145427dSRiver Riddle LogicalResult 832e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 833cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 8342d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 835cd5dab8aSAart Bik auto insertEltOp = cast<vector::InsertElementOp>(op); 836cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 837*dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType); 838cd5dab8aSAart Bik 839cd5dab8aSAart Bik // Bail if result type cannot be lowered. 840cd5dab8aSAart Bik if (!llvmType) 8413145427dSRiver Riddle return failure(); 842cd5dab8aSAart Bik 843cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 844cd5dab8aSAart Bik op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 8453145427dSRiver Riddle return success(); 846cd5dab8aSAart Bik } 847cd5dab8aSAart Bik }; 848cd5dab8aSAart Bik 849870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern { 8509826fe5cSAart Bik public: 8519826fe5cSAart Bik explicit VectorInsertOpConversion(MLIRContext *context, 8529826fe5cSAart Bik LLVMTypeConverter &typeConverter) 853870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 8549826fe5cSAart Bik typeConverter) {} 8559826fe5cSAart Bik 8563145427dSRiver Riddle LogicalResult 857e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 8589826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 8599826fe5cSAart Bik auto loc = op->getLoc(); 8602d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8619826fe5cSAart Bik auto insertOp = cast<vector::InsertOp>(op); 8629826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8639826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 864*dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(destVectorType); 8659826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8669826fe5cSAart Bik 8679826fe5cSAart Bik // Bail if result type cannot be lowered. 8689826fe5cSAart Bik if (!llvmResultType) 8693145427dSRiver Riddle return failure(); 8709826fe5cSAart Bik 8719826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8729826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 873e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8749826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8759826fe5cSAart Bik positionArrayAttr); 8769826fe5cSAart Bik rewriter.replaceOp(op, inserted); 8773145427dSRiver Riddle return success(); 8789826fe5cSAart Bik } 8799826fe5cSAart Bik 8809826fe5cSAart Bik // Potential extraction of 1-D vector from array. 8819826fe5cSAart Bik auto *context = op->getContext(); 882e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8839826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8849826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8859826fe5cSAart Bik auto oneDVectorType = destVectorType; 8869826fe5cSAart Bik if (positionAttrs.size() > 1) { 8879826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8889826fe5cSAart Bik auto nMinusOnePositionAttrs = 8899826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8909826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 891*dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8929826fe5cSAart Bik nMinusOnePositionAttrs); 8939826fe5cSAart Bik } 8949826fe5cSAart Bik 8959826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 8965446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 8971d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 898e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 899*dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 9000f04384dSAlex Zinenko adaptor.source(), constant); 9019826fe5cSAart Bik 9029826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 9039826fe5cSAart Bik if (positionAttrs.size() > 1) { 9049826fe5cSAart Bik auto nMinusOnePositionAttrs = 9059826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 9069826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 9079826fe5cSAart Bik adaptor.dest(), inserted, 9089826fe5cSAart Bik nMinusOnePositionAttrs); 9099826fe5cSAart Bik } 9109826fe5cSAart Bik 9119826fe5cSAart Bik rewriter.replaceOp(op, inserted); 9123145427dSRiver Riddle return success(); 9139826fe5cSAart Bik } 9149826fe5cSAart Bik }; 9159826fe5cSAart Bik 916681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 917681f929fSNicolas Vasilache /// 918681f929fSNicolas Vasilache /// Example: 919681f929fSNicolas Vasilache /// ``` 920681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 921681f929fSNicolas Vasilache /// ``` 922681f929fSNicolas Vasilache /// is rewritten into: 923681f929fSNicolas Vasilache /// ``` 924681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 925681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 926681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 927681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 928681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 929681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 930681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 931681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 932681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 933681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 934681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 935681f929fSNicolas Vasilache /// // %r3 holds the final value. 936681f929fSNicolas Vasilache /// ``` 937681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 938681f929fSNicolas Vasilache public: 939681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 940681f929fSNicolas Vasilache 9413145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 942681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 943681f929fSNicolas Vasilache auto vType = op.getVectorType(); 944681f929fSNicolas Vasilache if (vType.getRank() < 2) 9453145427dSRiver Riddle return failure(); 946681f929fSNicolas Vasilache 947681f929fSNicolas Vasilache auto loc = op.getLoc(); 948681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 949681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 950681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 951681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 952681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 953681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 954681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 955681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 956681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 957681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 958681f929fSNicolas Vasilache } 959681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9603145427dSRiver Riddle return success(); 961681f929fSNicolas Vasilache } 962681f929fSNicolas Vasilache }; 963681f929fSNicolas Vasilache 9642d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9652d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9662d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9672d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9682d515e49SNicolas Vasilache // rank. 9692d515e49SNicolas Vasilache // 9702d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9712d515e49SNicolas Vasilache // have different ranks. In this case: 9722d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9732d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9742d515e49SNicolas Vasilache // destination subvector 9752d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9762d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9772d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9782d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9792d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9802d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9812d515e49SNicolas Vasilache public: 9822d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9832d515e49SNicolas Vasilache 9843145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9852d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9862d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9872d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9882d515e49SNicolas Vasilache 9892d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9903145427dSRiver Riddle return failure(); 9912d515e49SNicolas Vasilache 9922d515e49SNicolas Vasilache auto loc = op.getLoc(); 9932d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9942d515e49SNicolas Vasilache assert(rankDiff >= 0); 9952d515e49SNicolas Vasilache if (rankDiff == 0) 9963145427dSRiver Riddle return failure(); 9972d515e49SNicolas Vasilache 9982d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9992d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 10002d515e49SNicolas Vasilache // on it. 10012d515e49SNicolas Vasilache Value extracted = 10022d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 10032d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 1004*dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 10052d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 10062d515e49SNicolas Vasilache // ranks. 10072d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 10082d515e49SNicolas Vasilache loc, op.source(), extracted, 10092d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 1010c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 10112d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 10122d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 10132d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 1014*dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 10153145427dSRiver Riddle return success(); 10162d515e49SNicolas Vasilache } 10172d515e49SNicolas Vasilache }; 10182d515e49SNicolas Vasilache 10192d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 10202d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 10212d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 10222d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 10232d515e49SNicolas Vasilache // destination subvector 10242d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 10252d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 10262d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 10272d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 10282d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 10292d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 10302d515e49SNicolas Vasilache public: 1031b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 1032b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 1033b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 1034b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 1035b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1036b99bd771SRiver Riddle } 10372d515e49SNicolas Vasilache 10383145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 10392d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10402d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10412d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10422d515e49SNicolas Vasilache 10432d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10443145427dSRiver Riddle return failure(); 10452d515e49SNicolas Vasilache 10462d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10472d515e49SNicolas Vasilache assert(rankDiff >= 0); 10482d515e49SNicolas Vasilache if (rankDiff != 0) 10493145427dSRiver Riddle return failure(); 10502d515e49SNicolas Vasilache 10512d515e49SNicolas Vasilache if (srcType == dstType) { 10522d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10533145427dSRiver Riddle return success(); 10542d515e49SNicolas Vasilache } 10552d515e49SNicolas Vasilache 10562d515e49SNicolas Vasilache int64_t offset = 10572d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10582d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10592d515e49SNicolas Vasilache int64_t stride = 10602d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10612d515e49SNicolas Vasilache 10622d515e49SNicolas Vasilache auto loc = op.getLoc(); 10632d515e49SNicolas Vasilache Value res = op.dest(); 10642d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10652d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10662d515e49SNicolas Vasilache off += stride, ++idx) { 10672d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10682d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10692d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10702d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10712d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10722d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10732d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10742d515e49SNicolas Vasilache // smaller rank. 1075bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10762d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10772d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10782d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10792d515e49SNicolas Vasilache } 10802d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10812d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10822d515e49SNicolas Vasilache } 10832d515e49SNicolas Vasilache 10842d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10853145427dSRiver Riddle return success(); 10862d515e49SNicolas Vasilache } 10872d515e49SNicolas Vasilache }; 10882d515e49SNicolas Vasilache 108930e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 109030e6033bSNicolas Vasilache /// static layout. 109130e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 109230e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10932bf491c7SBenjamin Kramer int64_t offset; 109430e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 109530e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 109630e6033bSNicolas Vasilache return None; 109730e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 109830e6033bSNicolas Vasilache return None; 109930e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 110030e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 110130e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 110230e6033bSNicolas Vasilache return strides; 110330e6033bSNicolas Vasilache 110430e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 110530e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 110630e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 110730e6033bSNicolas Vasilache // layout. 11082bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 11092bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 111030e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 111130e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 111230e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 111330e6033bSNicolas Vasilache return None; 111430e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 111530e6033bSNicolas Vasilache return None; 11162bf491c7SBenjamin Kramer } 111730e6033bSNicolas Vasilache return strides; 11182bf491c7SBenjamin Kramer } 11192bf491c7SBenjamin Kramer 1120870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 11215c0c51a9SNicolas Vasilache public: 11225c0c51a9SNicolas Vasilache explicit VectorTypeCastOpConversion(MLIRContext *context, 11235c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 1124870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 11255c0c51a9SNicolas Vasilache typeConverter) {} 11265c0c51a9SNicolas Vasilache 11273145427dSRiver Riddle LogicalResult 1128e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 11295c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11305c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 11315c0c51a9SNicolas Vasilache vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 11325c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 11332bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 11345c0c51a9SNicolas Vasilache MemRefType targetMemRefType = 11352bdf33ccSRiver Riddle castOp.getResult().getType().cast<MemRefType>(); 11365c0c51a9SNicolas Vasilache 11375c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 11385c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 11395c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 11403145427dSRiver Riddle return failure(); 11415c0c51a9SNicolas Vasilache 11425c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 11432bdf33ccSRiver Riddle operands[0].getType().dyn_cast<LLVM::LLVMType>(); 11445c0c51a9SNicolas Vasilache if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 11453145427dSRiver Riddle return failure(); 11465c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11475c0c51a9SNicolas Vasilache 1148*dcec2ca5SChristian Sigg auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 11495c0c51a9SNicolas Vasilache .dyn_cast_or_null<LLVM::LLVMType>(); 11505c0c51a9SNicolas Vasilache if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 11513145427dSRiver Riddle return failure(); 11525c0c51a9SNicolas Vasilache 115330e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 115430e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 115530e6033bSNicolas Vasilache if (!sourceStrides) 115630e6033bSNicolas Vasilache return failure(); 115730e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 115830e6033bSNicolas Vasilache if (!targetStrides) 115930e6033bSNicolas Vasilache return failure(); 116030e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 116130e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 116230e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 116330e6033bSNicolas Vasilache })) 11643145427dSRiver Riddle return failure(); 11655c0c51a9SNicolas Vasilache 11665446ec85SAlex Zinenko auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 11675c0c51a9SNicolas Vasilache 11685c0c51a9SNicolas Vasilache // Create descriptor. 11695c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11703a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11715c0c51a9SNicolas Vasilache // Set allocated ptr. 1172e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11735c0c51a9SNicolas Vasilache allocated = 11745c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11755c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11765c0c51a9SNicolas Vasilache // Set aligned ptr. 1177e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11785c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11795c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11805c0c51a9SNicolas Vasilache // Fill offset 0. 11815c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11825c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11835c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11845c0c51a9SNicolas Vasilache 11855c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11865c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11875c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11885c0c51a9SNicolas Vasilache auto sizeAttr = 11895c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11905c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11915c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 119230e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 119330e6033bSNicolas Vasilache (*targetStrides)[index]); 11945c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11955c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11965c0c51a9SNicolas Vasilache } 11975c0c51a9SNicolas Vasilache 11985c0c51a9SNicolas Vasilache rewriter.replaceOp(op, {desc}); 11993145427dSRiver Riddle return success(); 12005c0c51a9SNicolas Vasilache } 12015c0c51a9SNicolas Vasilache }; 12025c0c51a9SNicolas Vasilache 12038345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 12048345b86dSNicolas Vasilache /// sequence of: 1205060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1206060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1207060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1208060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1209060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 12108345b86dSNicolas Vasilache template <typename ConcreteOp> 12118345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern { 12128345b86dSNicolas Vasilache public: 12138345b86dSNicolas Vasilache explicit VectorTransferConversion(MLIRContext *context, 1214060c9dd1Saartbik LLVMTypeConverter &typeConv, 1215060c9dd1Saartbik bool enableIndexOpt) 1216060c9dd1Saartbik : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), 1217060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 12188345b86dSNicolas Vasilache 12198345b86dSNicolas Vasilache LogicalResult 12208345b86dSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 12218345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 12228345b86dSNicolas Vasilache auto xferOp = cast<ConcreteOp>(op); 12238345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1224b2c79c50SNicolas Vasilache 1225b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1226b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 12278345b86dSNicolas Vasilache return failure(); 12285f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 12295f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 12305f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 12315f9e0466SNicolas Vasilache op->getContext())) 12328345b86dSNicolas Vasilache return failure(); 12332bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 123430e6033bSNicolas Vasilache auto strides = computeContiguousStrides(xferOp.getMemRefType()); 123530e6033bSNicolas Vasilache if (!strides) 12362bf491c7SBenjamin Kramer return failure(); 12378345b86dSNicolas Vasilache 1238*dcec2ca5SChristian Sigg auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; 12398345b86dSNicolas Vasilache 12408345b86dSNicolas Vasilache Location loc = op->getLoc(); 12418345b86dSNicolas Vasilache MemRefType memRefType = xferOp.getMemRefType(); 12428345b86dSNicolas Vasilache 124368330ee0SThomas Raoux if (auto memrefVectorElementType = 124468330ee0SThomas Raoux memRefType.getElementType().dyn_cast<VectorType>()) { 124568330ee0SThomas Raoux // Memref has vector element type. 124668330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 124768330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 124868330ee0SThomas Raoux return failure(); 12490de60b55SThomas Raoux #ifndef NDEBUG 125068330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 125168330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 125268330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 125368330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 125468330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 125568330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 125668330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 125768330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 125868330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 125968330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 126068330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 126168330ee0SThomas Raoux "result shape."); 12620de60b55SThomas Raoux #endif // ifndef NDEBUG 126368330ee0SThomas Raoux } 126468330ee0SThomas Raoux 12658345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1266be16075bSWen-Heng (Jack) Chung // The vector pointer would always be on address space 0, therefore 1267be16075bSWen-Heng (Jack) Chung // addrspacecast shall be used when source/dst memrefs are not on 1268be16075bSWen-Heng (Jack) Chung // address space 0. 12698345b86dSNicolas Vasilache // TODO: support alignment when possible. 12708b97e17dSChristian Sigg Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 1271d3a98076SAlex Zinenko adaptor.indices(), rewriter); 12728345b86dSNicolas Vasilache auto vecTy = 12738345b86dSNicolas Vasilache toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1274be16075bSWen-Heng (Jack) Chung Value vectorDataPtr; 1275be16075bSWen-Heng (Jack) Chung if (memRefType.getMemorySpace() == 0) 1276be16075bSWen-Heng (Jack) Chung vectorDataPtr = 12778345b86dSNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1278be16075bSWen-Heng (Jack) Chung else 1279be16075bSWen-Heng (Jack) Chung vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1280be16075bSWen-Heng (Jack) Chung loc, vecTy.getPointerTo(), dataPtr); 12818345b86dSNicolas Vasilache 12821870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 1283*dcec2ca5SChristian Sigg return replaceTransferOpWithLoadOrStore( 1284*dcec2ca5SChristian Sigg rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr); 12851870e787SNicolas Vasilache 12868345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12878345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12888345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12898345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1290060c9dd1Saartbik // 1291060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1292060c9dd1Saartbik // dimensions here. 1293060c9dd1Saartbik unsigned vecWidth = vecTy.getVectorNumElements(); 1294060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12950c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 1296b2c79c50SNicolas Vasilache Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1297060c9dd1Saartbik Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, 1298060c9dd1Saartbik vecWidth, dim, &off); 12998345b86dSNicolas Vasilache 13008345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 1301*dcec2ca5SChristian Sigg return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc, 1302*dcec2ca5SChristian Sigg xferOp, operands, vectorDataPtr, mask); 13038345b86dSNicolas Vasilache } 1304060c9dd1Saartbik 1305060c9dd1Saartbik private: 1306060c9dd1Saartbik const bool enableIndexOptimizations; 13078345b86dSNicolas Vasilache }; 13088345b86dSNicolas Vasilache 1309870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern { 1310d9b500d3SAart Bik public: 1311d9b500d3SAart Bik explicit VectorPrintOpConversion(MLIRContext *context, 1312d9b500d3SAart Bik LLVMTypeConverter &typeConverter) 1313870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 1314d9b500d3SAart Bik typeConverter) {} 1315d9b500d3SAart Bik 1316d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1317d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1318d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1319d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1320d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1321d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1322d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1323d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1324d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1325d9b500d3SAart Bik // 13269db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1327d9b500d3SAart Bik // 13283145427dSRiver Riddle LogicalResult 1329e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1330d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 1331d9b500d3SAart Bik auto printOp = cast<vector::PrintOp>(op); 13322d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1333d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1334d9b500d3SAart Bik 1335*dcec2ca5SChristian Sigg if (typeConverter->convertType(printType) == nullptr) 13363145427dSRiver Riddle return failure(); 1337d9b500d3SAart Bik 1338b8880f5fSAart Bik // Make sure element type has runtime support. 1339b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1340d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1341d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1342d9b500d3SAart Bik Operation *printer; 1343b8880f5fSAart Bik if (eltType.isF32()) { 1344d9b500d3SAart Bik printer = getPrintFloat(op); 1345b8880f5fSAart Bik } else if (eltType.isF64()) { 1346d9b500d3SAart Bik printer = getPrintDouble(op); 134754759cefSAart Bik } else if (eltType.isIndex()) { 134854759cefSAart Bik printer = getPrintU64(op); 1349b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1350b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1351b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1352b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1353b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1354b8880f5fSAart Bik if (intTy.isUnsigned()) { 135554759cefSAart Bik if (width <= 64) { 1356b8880f5fSAart Bik if (width < 64) 1357b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1358b8880f5fSAart Bik printer = getPrintU64(op); 1359b8880f5fSAart Bik } else { 13603145427dSRiver Riddle return failure(); 1361b8880f5fSAart Bik } 1362b8880f5fSAart Bik } else { 1363b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 136454759cefSAart Bik if (width <= 64) { 1365b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1366b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1367b8880f5fSAart Bik if (width == 1) 136854759cefSAart Bik conversion = PrintConversion::ZeroExt64; 136954759cefSAart Bik else if (width < 64) 1370b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1371b8880f5fSAart Bik printer = getPrintI64(op); 1372b8880f5fSAart Bik } else { 1373b8880f5fSAart Bik return failure(); 1374b8880f5fSAart Bik } 1375b8880f5fSAart Bik } 1376b8880f5fSAart Bik } else { 1377b8880f5fSAart Bik return failure(); 1378b8880f5fSAart Bik } 1379d9b500d3SAart Bik 1380d9b500d3SAart Bik // Unroll vector into elementary print calls. 1381b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1382b8880f5fSAart Bik emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, 1383b8880f5fSAart Bik conversion); 1384d9b500d3SAart Bik emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1385d9b500d3SAart Bik rewriter.eraseOp(op); 13863145427dSRiver Riddle return success(); 1387d9b500d3SAart Bik } 1388d9b500d3SAart Bik 1389d9b500d3SAart Bik private: 1390b8880f5fSAart Bik enum class PrintConversion { 139130e6033bSNicolas Vasilache // clang-format off 1392b8880f5fSAart Bik None, 1393b8880f5fSAart Bik ZeroExt64, 1394b8880f5fSAart Bik SignExt64 139530e6033bSNicolas Vasilache // clang-format on 1396b8880f5fSAart Bik }; 1397b8880f5fSAart Bik 1398d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1399e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1400b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1401d9b500d3SAart Bik Location loc = op->getLoc(); 1402d9b500d3SAart Bik if (rank == 0) { 1403b8880f5fSAart Bik switch (conversion) { 1404b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1405b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 1406b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1407b8880f5fSAart Bik break; 1408b8880f5fSAart Bik case PrintConversion::SignExt64: 1409b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 1410b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1411b8880f5fSAart Bik break; 1412b8880f5fSAart Bik case PrintConversion::None: 1413b8880f5fSAart Bik break; 1414c9eeeb38Saartbik } 1415d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1416d9b500d3SAart Bik return; 1417d9b500d3SAart Bik } 1418d9b500d3SAart Bik 1419d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1420d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1421d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1422d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1423d9b500d3SAart Bik auto reducedType = 1424d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1425*dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType( 1426d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1427*dcec2ca5SChristian Sigg Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1428*dcec2ca5SChristian Sigg llvmType, rank, d); 1429b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1430b8880f5fSAart Bik conversion); 1431d9b500d3SAart Bik if (d != dim - 1) 1432d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1433d9b500d3SAart Bik } 1434d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1435d9b500d3SAart Bik } 1436d9b500d3SAart Bik 1437d9b500d3SAart Bik // Helper to emit a call. 1438d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1439d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 144008e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1441d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1442d9b500d3SAart Bik } 1443d9b500d3SAart Bik 1444d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 14455446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 14465446ec85SAlex Zinenko ArrayRef<LLVM::LLVMType> params) { 1447d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1448d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1449d9b500d3SAart Bik if (func) 1450d9b500d3SAart Bik return func; 1451d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1452d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1453d9b500d3SAart Bik op->getLoc(), name, 14545446ec85SAlex Zinenko LLVM::LLVMType::getFunctionTy( 14555446ec85SAlex Zinenko LLVM::LLVMType::getVoidTy(op->getContext()), params, 14565446ec85SAlex Zinenko /*isVarArg=*/false)); 1457d9b500d3SAart Bik } 1458d9b500d3SAart Bik 1459d9b500d3SAart Bik // Helpers for method names. 1460e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 146154759cefSAart Bik return getPrint(op, "printI64", 14625446ec85SAlex Zinenko LLVM::LLVMType::getInt64Ty(op->getContext())); 1463e52414b1Saartbik } 1464b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 1465b8880f5fSAart Bik return getPrint(op, "printU64", 1466b8880f5fSAart Bik LLVM::LLVMType::getInt64Ty(op->getContext())); 1467b8880f5fSAart Bik } 1468d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 146954759cefSAart Bik return getPrint(op, "printF32", 14705446ec85SAlex Zinenko LLVM::LLVMType::getFloatTy(op->getContext())); 1471d9b500d3SAart Bik } 1472d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 147354759cefSAart Bik return getPrint(op, "printF64", 14745446ec85SAlex Zinenko LLVM::LLVMType::getDoubleTy(op->getContext())); 1475d9b500d3SAart Bik } 1476d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 147754759cefSAart Bik return getPrint(op, "printOpen", {}); 1478d9b500d3SAart Bik } 1479d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 148054759cefSAart Bik return getPrint(op, "printClose", {}); 1481d9b500d3SAart Bik } 1482d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 148354759cefSAart Bik return getPrint(op, "printComma", {}); 1484d9b500d3SAart Bik } 1485d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 148654759cefSAart Bik return getPrint(op, "printNewline", {}); 1487d9b500d3SAart Bik } 1488d9b500d3SAart Bik }; 1489d9b500d3SAart Bik 1490334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1491c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1492c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1493c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1494334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 149565678d93SNicolas Vasilache public: 1496b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1497b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1498b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1499b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1500b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1501b99bd771SRiver Riddle } 150265678d93SNicolas Vasilache 1503334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 150465678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 150565678d93SNicolas Vasilache auto dstType = op.getResult().getType().cast<VectorType>(); 150665678d93SNicolas Vasilache 150765678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 150865678d93SNicolas Vasilache 150965678d93SNicolas Vasilache int64_t offset = 151065678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 151165678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 151265678d93SNicolas Vasilache int64_t stride = 151365678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 151465678d93SNicolas Vasilache 151565678d93SNicolas Vasilache auto loc = op.getLoc(); 151665678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 151735b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1518c3c95b9cSaartbik 1519c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1520c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1521c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1522c3c95b9cSaartbik offsets.reserve(size); 1523c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1524c3c95b9cSaartbik off += stride) 1525c3c95b9cSaartbik offsets.push_back(off); 1526c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1527c3c95b9cSaartbik op.vector(), 1528c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1529c3c95b9cSaartbik return success(); 1530c3c95b9cSaartbik } 1531c3c95b9cSaartbik 1532c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 153365678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 153465678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 153565678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 153665678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 153765678d93SNicolas Vasilache off += stride, ++idx) { 1538c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1539c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1540c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 154165678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 154265678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 154365678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 154465678d93SNicolas Vasilache } 1545c3c95b9cSaartbik rewriter.replaceOp(op, res); 15463145427dSRiver Riddle return success(); 154765678d93SNicolas Vasilache } 154865678d93SNicolas Vasilache }; 154965678d93SNicolas Vasilache 1550df186507SBenjamin Kramer } // namespace 1551df186507SBenjamin Kramer 15525c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 15535c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1554ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1555060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 155665678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 15578345b86dSNicolas Vasilache // clang-format off 1558681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1559681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 15602d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1561c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1562ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1563ceb1b327Saartbik ctx, converter, reassociateFPReductions); 1564060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1565060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1566060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1567060c9dd1Saartbik ctx, converter, enableIndexOptimizations); 15688345b86dSNicolas Vasilache patterns 1569ceb1b327Saartbik .insert<VectorShuffleOpConversion, 15708345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15718345b86dSNicolas Vasilache VectorExtractOpConversion, 15728345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15738345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15748345b86dSNicolas Vasilache VectorInsertOpConversion, 15758345b86dSNicolas Vasilache VectorPrintOpConversion, 157619dbb230Saartbik VectorTypeCastOpConversion, 157739379916Saartbik VectorMaskedLoadOpConversion, 157839379916Saartbik VectorMaskedStoreOpConversion, 157919dbb230Saartbik VectorGatherOpConversion, 1580e8dcf5f8Saartbik VectorScatterOpConversion, 1581e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1582e8dcf5f8Saartbik VectorCompressStoreOpConversion>(ctx, converter); 15838345b86dSNicolas Vasilache // clang-format on 15845c0c51a9SNicolas Vasilache } 15855c0c51a9SNicolas Vasilache 158663b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 158763b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 158863b683a8SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 158963b683a8SNicolas Vasilache patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1590c295a65dSaartbik patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 159163b683a8SNicolas Vasilache } 1592