15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 195c0c51a9SNicolas Vasilache 205c0c51a9SNicolas Vasilache using namespace mlir; 2165678d93SNicolas Vasilache using namespace mlir::vector; 225c0c51a9SNicolas Vasilache 239826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 259826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 269826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 279826fe5cSAart Bik } 289826fe5cSAart Bik 299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 319826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 329826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 339826fe5cSAart Bik } 349826fe5cSAart Bik 351c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 370f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 380f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 390f04384dSAlex Zinenko int64_t pos) { 401c81adf3SAart Bik if (rank == 1) { 411c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 421c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 430f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 441c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 451c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 461c81adf3SAart Bik constant); 471c81adf3SAart Bik } 481c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 491c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 501c81adf3SAart Bik } 511c81adf3SAart Bik 522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 542d515e49SNicolas Vasilache Value into, int64_t offset) { 552d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 562d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 572d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 582d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 592d515e49SNicolas Vasilache loc, vectorType, from, into, 602d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 612d515e49SNicolas Vasilache } 622d515e49SNicolas Vasilache 631c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 650f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 660f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 671c81adf3SAart Bik if (rank == 1) { 681c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 691c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 700f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 711c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 721c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 731c81adf3SAart Bik constant); 741c81adf3SAart Bik } 751c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 761c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 771c81adf3SAart Bik } 781c81adf3SAart Bik 792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 812d515e49SNicolas Vasilache int64_t offset) { 822d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 832d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 842d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 852d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 862d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 872d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 882d515e49SNicolas Vasilache } 892d515e49SNicolas Vasilache 902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 932d515e49SNicolas Vasilache unsigned dropFront = 0, 942d515e49SNicolas Vasilache unsigned dropBack = 0) { 952d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 962d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 972d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 982d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 992d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1002d515e49SNicolas Vasilache it != eit; ++it) 1012d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1022d515e49SNicolas Vasilache return res; 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 106060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107060c9dd1Saartbik // 108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 110060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 111060c9dd1Saartbik // is very conservative on the boundary conditions. 112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 114060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 115060c9dd1Saartbik auto loc = op->getLoc(); 116060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 117060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 118060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 119060c9dd1Saartbik Value indices; 120060c9dd1Saartbik Type idxType; 121060c9dd1Saartbik if (enableIndexOptimizations) { 1220c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1230c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1240c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125060c9dd1Saartbik idxType = rewriter.getI32Type(); 126060c9dd1Saartbik } else { 1270c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1280c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1290c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130060c9dd1Saartbik idxType = rewriter.getI64Type(); 131060c9dd1Saartbik } 132060c9dd1Saartbik // Add in an offset if requested. 133060c9dd1Saartbik if (off) { 134060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 137060c9dd1Saartbik } 138060c9dd1Saartbik // Construct the vector comparison. 139060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142060c9dd1Saartbik } 143060c9dd1Saartbik 14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref. 14519dbb230Saartbik template <typename T> 14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 14719dbb230Saartbik unsigned &align) { 1485f9e0466SNicolas Vasilache Type elementTy = 14919dbb230Saartbik typeConverter.convertType(op.getMemRefType().getElementType()); 1505f9e0466SNicolas Vasilache if (!elementTy) 1515f9e0466SNicolas Vasilache return failure(); 1525f9e0466SNicolas Vasilache 153b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 154b2ab375dSAlex Zinenko // stop depending on translation. 15587a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 15687a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 157b2ab375dSAlex Zinenko .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 158168213f9SAlex Zinenko typeConverter.getDataLayout()); 1595f9e0466SNicolas Vasilache return success(); 1605f9e0466SNicolas Vasilache } 1615f9e0466SNicolas Vasilache 162e8dcf5f8Saartbik // Helper that returns the base address of a memref. 163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 164e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 16519dbb230Saartbik // Inspect stride and offset structure. 16619dbb230Saartbik // 16719dbb230Saartbik // TODO: flat memory only for now, generalize 16819dbb230Saartbik // 16919dbb230Saartbik int64_t offset; 17019dbb230Saartbik SmallVector<int64_t, 4> strides; 17119dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 17219dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 17319dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 17419dbb230Saartbik return failure(); 175e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 176e8dcf5f8Saartbik return success(); 177e8dcf5f8Saartbik } 17819dbb230Saartbik 179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base. 180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 181b98e25b6SBenjamin Kramer Location loc, Value memref, 182b98e25b6SBenjamin Kramer MemRefType memRefType, Value &ptr) { 183e8dcf5f8Saartbik Value base; 184e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 185e8dcf5f8Saartbik return failure(); 1863a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 187e8dcf5f8Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 188e8dcf5f8Saartbik return success(); 189e8dcf5f8Saartbik } 190e8dcf5f8Saartbik 19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base. 192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 193b98e25b6SBenjamin Kramer Location loc, Value memref, 194b98e25b6SBenjamin Kramer MemRefType memRefType, Type type, Value &ptr) { 19539379916Saartbik Value base; 19639379916Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 19739379916Saartbik return failure(); 19839379916Saartbik auto pType = type.template cast<LLVM::LLVMType>().getPointerTo(); 19939379916Saartbik base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 20039379916Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 20139379916Saartbik return success(); 20239379916Saartbik } 20339379916Saartbik 204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index 205e8dcf5f8Saartbik // vector. 206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 207b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 208b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 209b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 210e8dcf5f8Saartbik Value base; 211e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 212e8dcf5f8Saartbik return failure(); 2133a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 214e8dcf5f8Saartbik auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); 2151485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 21619dbb230Saartbik return success(); 21719dbb230Saartbik } 21819dbb230Saartbik 2195f9e0466SNicolas Vasilache static LogicalResult 2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2215f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2225f9e0466SNicolas Vasilache TransferReadOp xferOp, 2235f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 224affbc0cdSNicolas Vasilache unsigned align; 22519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 226affbc0cdSNicolas Vasilache return failure(); 227affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2285f9e0466SNicolas Vasilache return success(); 2295f9e0466SNicolas Vasilache } 2305f9e0466SNicolas Vasilache 2315f9e0466SNicolas Vasilache static LogicalResult 2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2335f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2345f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2355f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2365f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2375f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2385f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2395f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2405f9e0466SNicolas Vasilache 2415f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2425f9e0466SNicolas Vasilache if (!vecTy) 2435f9e0466SNicolas Vasilache return failure(); 2445f9e0466SNicolas Vasilache 2455f9e0466SNicolas Vasilache unsigned align; 24619dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2475f9e0466SNicolas Vasilache return failure(); 2485f9e0466SNicolas Vasilache 2495f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2505f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2515f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2525f9e0466SNicolas Vasilache return success(); 2535f9e0466SNicolas Vasilache } 2545f9e0466SNicolas Vasilache 2555f9e0466SNicolas Vasilache static LogicalResult 2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2575f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2585f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2595f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 260affbc0cdSNicolas Vasilache unsigned align; 26119dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 262affbc0cdSNicolas Vasilache return failure(); 2632d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 264affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 265affbc0cdSNicolas Vasilache align); 2665f9e0466SNicolas Vasilache return success(); 2675f9e0466SNicolas Vasilache } 2685f9e0466SNicolas Vasilache 2695f9e0466SNicolas Vasilache static LogicalResult 2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2715f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2725f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2735f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2745f9e0466SNicolas Vasilache unsigned align; 27519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2765f9e0466SNicolas Vasilache return failure(); 2775f9e0466SNicolas Vasilache 2782d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2795f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2805f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2815f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2825f9e0466SNicolas Vasilache return success(); 2835f9e0466SNicolas Vasilache } 2845f9e0466SNicolas Vasilache 2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2862d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2872d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2885f9e0466SNicolas Vasilache } 2895f9e0466SNicolas Vasilache 2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2912d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2922d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2935f9e0466SNicolas Vasilache } 2945f9e0466SNicolas Vasilache 29590c01357SBenjamin Kramer namespace { 296e83b7b99Saartbik 29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 299*563879b6SRahul Joshi class VectorMatmulOpConversion 300*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MatmulOp> { 30163b683a8SNicolas Vasilache public: 302*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 30363b683a8SNicolas Vasilache 3043145427dSRiver Riddle LogicalResult 305*563879b6SRahul Joshi matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 30663b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 3072d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 30863b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 309*563879b6SRahul Joshi matmulOp, typeConverter->convertType(matmulOp.res().getType()), 310*563879b6SRahul Joshi adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 311*563879b6SRahul Joshi matmulOp.lhs_columns(), matmulOp.rhs_columns()); 3123145427dSRiver Riddle return success(); 31363b683a8SNicolas Vasilache } 31463b683a8SNicolas Vasilache }; 31563b683a8SNicolas Vasilache 316c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 317c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 318*563879b6SRahul Joshi class VectorFlatTransposeOpConversion 319*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 320c295a65dSaartbik public: 321*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 322c295a65dSaartbik 323c295a65dSaartbik LogicalResult 324*563879b6SRahul Joshi matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 325c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 3262d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 327c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 328dcec2ca5SChristian Sigg transOp, typeConverter->convertType(transOp.res().getType()), 329c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 330c295a65dSaartbik return success(); 331c295a65dSaartbik } 332c295a65dSaartbik }; 333c295a65dSaartbik 33439379916Saartbik /// Conversion pattern for a vector.maskedload. 335*563879b6SRahul Joshi class VectorMaskedLoadOpConversion 336*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 33739379916Saartbik public: 338*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 33939379916Saartbik 34039379916Saartbik LogicalResult 341*563879b6SRahul Joshi matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 34239379916Saartbik ConversionPatternRewriter &rewriter) const override { 343*563879b6SRahul Joshi auto loc = load->getLoc(); 34439379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 34539379916Saartbik 34639379916Saartbik // Resolve alignment. 34739379916Saartbik unsigned align; 348dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), load, align))) 34939379916Saartbik return failure(); 35039379916Saartbik 351dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(load.getResultVectorType()); 35239379916Saartbik Value ptr; 35339379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 35439379916Saartbik vtype, ptr))) 35539379916Saartbik return failure(); 35639379916Saartbik 35739379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 35839379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 35939379916Saartbik rewriter.getI32IntegerAttr(align)); 36039379916Saartbik return success(); 36139379916Saartbik } 36239379916Saartbik }; 36339379916Saartbik 36439379916Saartbik /// Conversion pattern for a vector.maskedstore. 365*563879b6SRahul Joshi class VectorMaskedStoreOpConversion 366*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 36739379916Saartbik public: 368*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 36939379916Saartbik 37039379916Saartbik LogicalResult 371*563879b6SRahul Joshi matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 37239379916Saartbik ConversionPatternRewriter &rewriter) const override { 373*563879b6SRahul Joshi auto loc = store->getLoc(); 37439379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 37539379916Saartbik 37639379916Saartbik // Resolve alignment. 37739379916Saartbik unsigned align; 378dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), store, align))) 37939379916Saartbik return failure(); 38039379916Saartbik 381dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(store.getValueVectorType()); 38239379916Saartbik Value ptr; 38339379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 38439379916Saartbik vtype, ptr))) 38539379916Saartbik return failure(); 38639379916Saartbik 38739379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 38839379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 38939379916Saartbik rewriter.getI32IntegerAttr(align)); 39039379916Saartbik return success(); 39139379916Saartbik } 39239379916Saartbik }; 39339379916Saartbik 39419dbb230Saartbik /// Conversion pattern for a vector.gather. 395*563879b6SRahul Joshi class VectorGatherOpConversion 396*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::GatherOp> { 39719dbb230Saartbik public: 398*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 39919dbb230Saartbik 40019dbb230Saartbik LogicalResult 401*563879b6SRahul Joshi matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 40219dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 403*563879b6SRahul Joshi auto loc = gather->getLoc(); 40419dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 40519dbb230Saartbik 40619dbb230Saartbik // Resolve alignment. 40719dbb230Saartbik unsigned align; 408dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), gather, align))) 40919dbb230Saartbik return failure(); 41019dbb230Saartbik 41119dbb230Saartbik // Get index ptrs. 41219dbb230Saartbik VectorType vType = gather.getResultVectorType(); 41319dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 41419dbb230Saartbik Value ptrs; 415e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 416e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 41719dbb230Saartbik return failure(); 41819dbb230Saartbik 41919dbb230Saartbik // Replace with the gather intrinsic. 42019dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 421dcec2ca5SChristian Sigg gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 4220c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 42319dbb230Saartbik return success(); 42419dbb230Saartbik } 42519dbb230Saartbik }; 42619dbb230Saartbik 42719dbb230Saartbik /// Conversion pattern for a vector.scatter. 428*563879b6SRahul Joshi class VectorScatterOpConversion 429*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ScatterOp> { 43019dbb230Saartbik public: 431*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 43219dbb230Saartbik 43319dbb230Saartbik LogicalResult 434*563879b6SRahul Joshi matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 43519dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 436*563879b6SRahul Joshi auto loc = scatter->getLoc(); 43719dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 43819dbb230Saartbik 43919dbb230Saartbik // Resolve alignment. 44019dbb230Saartbik unsigned align; 441dcec2ca5SChristian Sigg if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align))) 44219dbb230Saartbik return failure(); 44319dbb230Saartbik 44419dbb230Saartbik // Get index ptrs. 44519dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 44619dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 44719dbb230Saartbik Value ptrs; 448e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 449e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 45019dbb230Saartbik return failure(); 45119dbb230Saartbik 45219dbb230Saartbik // Replace with the scatter intrinsic. 45319dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 45419dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 45519dbb230Saartbik rewriter.getI32IntegerAttr(align)); 45619dbb230Saartbik return success(); 45719dbb230Saartbik } 45819dbb230Saartbik }; 45919dbb230Saartbik 460e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 461*563879b6SRahul Joshi class VectorExpandLoadOpConversion 462*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 463e8dcf5f8Saartbik public: 464*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 465e8dcf5f8Saartbik 466e8dcf5f8Saartbik LogicalResult 467*563879b6SRahul Joshi matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 468e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 469*563879b6SRahul Joshi auto loc = expand->getLoc(); 470e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 471e8dcf5f8Saartbik 472e8dcf5f8Saartbik Value ptr; 473e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 474e8dcf5f8Saartbik ptr))) 475e8dcf5f8Saartbik return failure(); 476e8dcf5f8Saartbik 477e8dcf5f8Saartbik auto vType = expand.getResultVectorType(); 478e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 479*563879b6SRahul Joshi expand, typeConverter->convertType(vType), ptr, adaptor.mask(), 480e8dcf5f8Saartbik adaptor.pass_thru()); 481e8dcf5f8Saartbik return success(); 482e8dcf5f8Saartbik } 483e8dcf5f8Saartbik }; 484e8dcf5f8Saartbik 485e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 486*563879b6SRahul Joshi class VectorCompressStoreOpConversion 487*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 488e8dcf5f8Saartbik public: 489*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 490e8dcf5f8Saartbik 491e8dcf5f8Saartbik LogicalResult 492*563879b6SRahul Joshi matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 493e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 494*563879b6SRahul Joshi auto loc = compress->getLoc(); 495e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 496e8dcf5f8Saartbik 497e8dcf5f8Saartbik Value ptr; 498e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), 499e8dcf5f8Saartbik compress.getMemRefType(), ptr))) 500e8dcf5f8Saartbik return failure(); 501e8dcf5f8Saartbik 502e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 503*563879b6SRahul Joshi compress, adaptor.value(), ptr, adaptor.mask()); 504e8dcf5f8Saartbik return success(); 505e8dcf5f8Saartbik } 506e8dcf5f8Saartbik }; 507e8dcf5f8Saartbik 50819dbb230Saartbik /// Conversion pattern for all vector reductions. 509*563879b6SRahul Joshi class VectorReductionOpConversion 510*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ReductionOp> { 511e83b7b99Saartbik public: 512*563879b6SRahul Joshi explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 513060c9dd1Saartbik bool reassociateFPRed) 514*563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 515060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 516e83b7b99Saartbik 5173145427dSRiver Riddle LogicalResult 518*563879b6SRahul Joshi matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 519e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 520e83b7b99Saartbik auto kind = reductionOp.kind(); 521e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 522dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(eltType); 523e9628955SAart Bik if (eltType.isIntOrIndex()) { 524e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 525e83b7b99Saartbik if (kind == "add") 526322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 527*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 528e83b7b99Saartbik else if (kind == "mul") 529322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 530*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 531e9628955SAart Bik else if (kind == "min" && 532e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 533322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 534*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 535e83b7b99Saartbik else if (kind == "min") 536322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 537*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 538e9628955SAart Bik else if (kind == "max" && 539e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 540322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 541*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 542e83b7b99Saartbik else if (kind == "max") 543322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 544*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 545e83b7b99Saartbik else if (kind == "and") 546322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 547*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 548e83b7b99Saartbik else if (kind == "or") 549322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 550*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 551e83b7b99Saartbik else if (kind == "xor") 552322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 553*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 554e83b7b99Saartbik else 5553145427dSRiver Riddle return failure(); 5563145427dSRiver Riddle return success(); 557dcec2ca5SChristian Sigg } 558e83b7b99Saartbik 559dcec2ca5SChristian Sigg if (!eltType.isa<FloatType>()) 560dcec2ca5SChristian Sigg return failure(); 561dcec2ca5SChristian Sigg 562e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 563e83b7b99Saartbik if (kind == "add") { 5640d924700Saartbik // Optional accumulator (or zero). 5650d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5660d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 567*563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5680d924700Saartbik rewriter.getZeroAttr(eltType)); 569322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 570*563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 571ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 572e83b7b99Saartbik } else if (kind == "mul") { 5730d924700Saartbik // Optional accumulator (or one). 5740d924700Saartbik Value acc = operands.size() > 1 5750d924700Saartbik ? operands[1] 5760d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 577*563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5780d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 579322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 580*563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 581ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 582e83b7b99Saartbik } else if (kind == "min") 583*563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 584*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 585e83b7b99Saartbik else if (kind == "max") 586*563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 587*563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 588e83b7b99Saartbik else 5893145427dSRiver Riddle return failure(); 5903145427dSRiver Riddle return success(); 591e83b7b99Saartbik } 592ceb1b327Saartbik 593ceb1b327Saartbik private: 594ceb1b327Saartbik const bool reassociateFPReductions; 595e83b7b99Saartbik }; 596e83b7b99Saartbik 597060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 598*563879b6SRahul Joshi class VectorCreateMaskOpConversion 599*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 600060c9dd1Saartbik public: 601*563879b6SRahul Joshi explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 602060c9dd1Saartbik bool enableIndexOpt) 603*563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 604060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 605060c9dd1Saartbik 606060c9dd1Saartbik LogicalResult 607*563879b6SRahul Joshi matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 608060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 609060c9dd1Saartbik auto dstType = op->getResult(0).getType().cast<VectorType>(); 610060c9dd1Saartbik int64_t rank = dstType.getRank(); 611060c9dd1Saartbik if (rank == 1) { 612060c9dd1Saartbik rewriter.replaceOp( 613060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 614060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 615060c9dd1Saartbik return success(); 616060c9dd1Saartbik } 617060c9dd1Saartbik return failure(); 618060c9dd1Saartbik } 619060c9dd1Saartbik 620060c9dd1Saartbik private: 621060c9dd1Saartbik const bool enableIndexOptimizations; 622060c9dd1Saartbik }; 623060c9dd1Saartbik 624*563879b6SRahul Joshi class VectorShuffleOpConversion 625*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 6261c81adf3SAart Bik public: 627*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 6281c81adf3SAart Bik 6293145427dSRiver Riddle LogicalResult 630*563879b6SRahul Joshi matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 6311c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 632*563879b6SRahul Joshi auto loc = shuffleOp->getLoc(); 6332d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6341c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6351c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6361c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 637dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(vectorType); 6381c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6391c81adf3SAart Bik 6401c81adf3SAart Bik // Bail if result type cannot be lowered. 6411c81adf3SAart Bik if (!llvmType) 6423145427dSRiver Riddle return failure(); 6431c81adf3SAart Bik 6441c81adf3SAart Bik // Get rank and dimension sizes. 6451c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6461c81adf3SAart Bik assert(v1Type.getRank() == rank); 6471c81adf3SAart Bik assert(v2Type.getRank() == rank); 6481c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6491c81adf3SAart Bik 6501c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6511c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6521c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 653*563879b6SRahul Joshi Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 6541c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 655*563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, llvmShuffleOp); 6563145427dSRiver Riddle return success(); 657b36aaeafSAart Bik } 658b36aaeafSAart Bik 6591c81adf3SAart Bik // For all other cases, insert the individual values individually. 660e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6611c81adf3SAart Bik int64_t insPos = 0; 6621c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6631c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 664e62a6956SRiver Riddle Value value = adaptor.v1(); 6651c81adf3SAart Bik if (extPos >= v1Dim) { 6661c81adf3SAart Bik extPos -= v1Dim; 6671c81adf3SAart Bik value = adaptor.v2(); 668b36aaeafSAart Bik } 669dcec2ca5SChristian Sigg Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 670dcec2ca5SChristian Sigg llvmType, rank, extPos); 671dcec2ca5SChristian Sigg insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 6720f04384dSAlex Zinenko llvmType, rank, insPos++); 6731c81adf3SAart Bik } 674*563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, insert); 6753145427dSRiver Riddle return success(); 676b36aaeafSAart Bik } 677b36aaeafSAart Bik }; 678b36aaeafSAart Bik 679*563879b6SRahul Joshi class VectorExtractElementOpConversion 680*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 681cd5dab8aSAart Bik public: 682*563879b6SRahul Joshi using ConvertOpToLLVMPattern< 683*563879b6SRahul Joshi vector::ExtractElementOp>::ConvertOpToLLVMPattern; 684cd5dab8aSAart Bik 6853145427dSRiver Riddle LogicalResult 686*563879b6SRahul Joshi matchAndRewrite(vector::ExtractElementOp extractEltOp, 687*563879b6SRahul Joshi ArrayRef<Value> operands, 688cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 6892d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 690cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 691dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType.getElementType()); 692cd5dab8aSAart Bik 693cd5dab8aSAart Bik // Bail if result type cannot be lowered. 694cd5dab8aSAart Bik if (!llvmType) 6953145427dSRiver Riddle return failure(); 696cd5dab8aSAart Bik 697cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 698*563879b6SRahul Joshi extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 6993145427dSRiver Riddle return success(); 700cd5dab8aSAart Bik } 701cd5dab8aSAart Bik }; 702cd5dab8aSAart Bik 703*563879b6SRahul Joshi class VectorExtractOpConversion 704*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractOp> { 7055c0c51a9SNicolas Vasilache public: 706*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 7075c0c51a9SNicolas Vasilache 7083145427dSRiver Riddle LogicalResult 709*563879b6SRahul Joshi matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 7105c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 711*563879b6SRahul Joshi auto loc = extractOp->getLoc(); 7122d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 7139826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7142bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 715dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(resultType); 7165c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7179826fe5cSAart Bik 7189826fe5cSAart Bik // Bail if result type cannot be lowered. 7199826fe5cSAart Bik if (!llvmResultType) 7203145427dSRiver Riddle return failure(); 7219826fe5cSAart Bik 7225c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7235c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 724e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7255c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 726*563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7273145427dSRiver Riddle return success(); 7285c0c51a9SNicolas Vasilache } 7295c0c51a9SNicolas Vasilache 7309826fe5cSAart Bik // Potential extraction of 1-D vector from array. 731*563879b6SRahul Joshi auto *context = extractOp->getContext(); 732e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7335c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7345c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7359826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7365c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7375c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7385c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 739dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 7405c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7415c0c51a9SNicolas Vasilache } 7425c0c51a9SNicolas Vasilache 7435c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7445c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7455446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 7461d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7475c0c51a9SNicolas Vasilache extracted = 7485c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 749*563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7505c0c51a9SNicolas Vasilache 7513145427dSRiver Riddle return success(); 7525c0c51a9SNicolas Vasilache } 7535c0c51a9SNicolas Vasilache }; 7545c0c51a9SNicolas Vasilache 755681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 756681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 757681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 758681f929fSNicolas Vasilache /// 759681f929fSNicolas Vasilache /// Example: 760681f929fSNicolas Vasilache /// ``` 761681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 762681f929fSNicolas Vasilache /// ``` 763681f929fSNicolas Vasilache /// is converted to: 764681f929fSNicolas Vasilache /// ``` 7653bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 766681f929fSNicolas Vasilache /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 767681f929fSNicolas Vasilache /// -> !llvm<"<8 x float>"> 768681f929fSNicolas Vasilache /// ``` 769*563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 770681f929fSNicolas Vasilache public: 771*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 772681f929fSNicolas Vasilache 7733145427dSRiver Riddle LogicalResult 774*563879b6SRahul Joshi matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 775681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7762d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 777681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 778681f929fSNicolas Vasilache if (vType.getRank() != 1) 7793145427dSRiver Riddle return failure(); 780*563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 7813bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 7823145427dSRiver Riddle return success(); 783681f929fSNicolas Vasilache } 784681f929fSNicolas Vasilache }; 785681f929fSNicolas Vasilache 786*563879b6SRahul Joshi class VectorInsertElementOpConversion 787*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 788cd5dab8aSAart Bik public: 789*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 790cd5dab8aSAart Bik 7913145427dSRiver Riddle LogicalResult 792*563879b6SRahul Joshi matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 793cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7942d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 795cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 796dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType); 797cd5dab8aSAart Bik 798cd5dab8aSAart Bik // Bail if result type cannot be lowered. 799cd5dab8aSAart Bik if (!llvmType) 8003145427dSRiver Riddle return failure(); 801cd5dab8aSAart Bik 802cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 803*563879b6SRahul Joshi insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 804*563879b6SRahul Joshi adaptor.position()); 8053145427dSRiver Riddle return success(); 806cd5dab8aSAart Bik } 807cd5dab8aSAart Bik }; 808cd5dab8aSAart Bik 809*563879b6SRahul Joshi class VectorInsertOpConversion 810*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertOp> { 8119826fe5cSAart Bik public: 812*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 8139826fe5cSAart Bik 8143145427dSRiver Riddle LogicalResult 815*563879b6SRahul Joshi matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 8169826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 817*563879b6SRahul Joshi auto loc = insertOp->getLoc(); 8182d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8199826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8209826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 821dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(destVectorType); 8229826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8239826fe5cSAart Bik 8249826fe5cSAart Bik // Bail if result type cannot be lowered. 8259826fe5cSAart Bik if (!llvmResultType) 8263145427dSRiver Riddle return failure(); 8279826fe5cSAart Bik 8289826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8299826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 830e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8319826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8329826fe5cSAart Bik positionArrayAttr); 833*563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8343145427dSRiver Riddle return success(); 8359826fe5cSAart Bik } 8369826fe5cSAart Bik 8379826fe5cSAart Bik // Potential extraction of 1-D vector from array. 838*563879b6SRahul Joshi auto *context = insertOp->getContext(); 839e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8409826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8419826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8429826fe5cSAart Bik auto oneDVectorType = destVectorType; 8439826fe5cSAart Bik if (positionAttrs.size() > 1) { 8449826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8459826fe5cSAart Bik auto nMinusOnePositionAttrs = 8469826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8479826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 848dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8499826fe5cSAart Bik nMinusOnePositionAttrs); 8509826fe5cSAart Bik } 8519826fe5cSAart Bik 8529826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 8535446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 8541d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 855e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 856dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8570f04384dSAlex Zinenko adaptor.source(), constant); 8589826fe5cSAart Bik 8599826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 8609826fe5cSAart Bik if (positionAttrs.size() > 1) { 8619826fe5cSAart Bik auto nMinusOnePositionAttrs = 8629826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8639826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 8649826fe5cSAart Bik adaptor.dest(), inserted, 8659826fe5cSAart Bik nMinusOnePositionAttrs); 8669826fe5cSAart Bik } 8679826fe5cSAart Bik 868*563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8693145427dSRiver Riddle return success(); 8709826fe5cSAart Bik } 8719826fe5cSAart Bik }; 8729826fe5cSAart Bik 873681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 874681f929fSNicolas Vasilache /// 875681f929fSNicolas Vasilache /// Example: 876681f929fSNicolas Vasilache /// ``` 877681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 878681f929fSNicolas Vasilache /// ``` 879681f929fSNicolas Vasilache /// is rewritten into: 880681f929fSNicolas Vasilache /// ``` 881681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 882681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 883681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 884681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 885681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 886681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 887681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 888681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 889681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 890681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 891681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 892681f929fSNicolas Vasilache /// // %r3 holds the final value. 893681f929fSNicolas Vasilache /// ``` 894681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 895681f929fSNicolas Vasilache public: 896681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 897681f929fSNicolas Vasilache 8983145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 899681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 900681f929fSNicolas Vasilache auto vType = op.getVectorType(); 901681f929fSNicolas Vasilache if (vType.getRank() < 2) 9023145427dSRiver Riddle return failure(); 903681f929fSNicolas Vasilache 904681f929fSNicolas Vasilache auto loc = op.getLoc(); 905681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 906681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 907681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 908681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 909681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 910681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 911681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 912681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 913681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 914681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 915681f929fSNicolas Vasilache } 916681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9173145427dSRiver Riddle return success(); 918681f929fSNicolas Vasilache } 919681f929fSNicolas Vasilache }; 920681f929fSNicolas Vasilache 9212d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9222d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9232d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9242d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9252d515e49SNicolas Vasilache // rank. 9262d515e49SNicolas Vasilache // 9272d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9282d515e49SNicolas Vasilache // have different ranks. In this case: 9292d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9302d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9312d515e49SNicolas Vasilache // destination subvector 9322d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9332d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9342d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9352d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9362d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9372d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9382d515e49SNicolas Vasilache public: 9392d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9402d515e49SNicolas Vasilache 9413145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9422d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9432d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9442d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9452d515e49SNicolas Vasilache 9462d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9473145427dSRiver Riddle return failure(); 9482d515e49SNicolas Vasilache 9492d515e49SNicolas Vasilache auto loc = op.getLoc(); 9502d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9512d515e49SNicolas Vasilache assert(rankDiff >= 0); 9522d515e49SNicolas Vasilache if (rankDiff == 0) 9533145427dSRiver Riddle return failure(); 9542d515e49SNicolas Vasilache 9552d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9562d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 9572d515e49SNicolas Vasilache // on it. 9582d515e49SNicolas Vasilache Value extracted = 9592d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 9602d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 961dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9622d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 9632d515e49SNicolas Vasilache // ranks. 9642d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 9652d515e49SNicolas Vasilache loc, op.source(), extracted, 9662d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 967c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 9682d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 9692d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 9702d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 971dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9723145427dSRiver Riddle return success(); 9732d515e49SNicolas Vasilache } 9742d515e49SNicolas Vasilache }; 9752d515e49SNicolas Vasilache 9762d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9772d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 9782d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9792d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9802d515e49SNicolas Vasilache // destination subvector 9812d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9822d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9832d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9842d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9852d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 9862d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9872d515e49SNicolas Vasilache public: 988b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 989b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 990b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 991b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 992b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 993b99bd771SRiver Riddle } 9942d515e49SNicolas Vasilache 9953145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9962d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9972d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9982d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9992d515e49SNicolas Vasilache 10002d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10013145427dSRiver Riddle return failure(); 10022d515e49SNicolas Vasilache 10032d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10042d515e49SNicolas Vasilache assert(rankDiff >= 0); 10052d515e49SNicolas Vasilache if (rankDiff != 0) 10063145427dSRiver Riddle return failure(); 10072d515e49SNicolas Vasilache 10082d515e49SNicolas Vasilache if (srcType == dstType) { 10092d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10103145427dSRiver Riddle return success(); 10112d515e49SNicolas Vasilache } 10122d515e49SNicolas Vasilache 10132d515e49SNicolas Vasilache int64_t offset = 10142d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10152d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10162d515e49SNicolas Vasilache int64_t stride = 10172d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10182d515e49SNicolas Vasilache 10192d515e49SNicolas Vasilache auto loc = op.getLoc(); 10202d515e49SNicolas Vasilache Value res = op.dest(); 10212d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10222d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10232d515e49SNicolas Vasilache off += stride, ++idx) { 10242d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10252d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10262d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10272d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10282d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10292d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10302d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10312d515e49SNicolas Vasilache // smaller rank. 1032bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10332d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10342d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10352d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10362d515e49SNicolas Vasilache } 10372d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10382d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10392d515e49SNicolas Vasilache } 10402d515e49SNicolas Vasilache 10412d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10423145427dSRiver Riddle return success(); 10432d515e49SNicolas Vasilache } 10442d515e49SNicolas Vasilache }; 10452d515e49SNicolas Vasilache 104630e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 104730e6033bSNicolas Vasilache /// static layout. 104830e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 104930e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10502bf491c7SBenjamin Kramer int64_t offset; 105130e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 105230e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 105330e6033bSNicolas Vasilache return None; 105430e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 105530e6033bSNicolas Vasilache return None; 105630e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 105730e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 105830e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 105930e6033bSNicolas Vasilache return strides; 106030e6033bSNicolas Vasilache 106130e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 106230e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 106330e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 106430e6033bSNicolas Vasilache // layout. 10652bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 10662bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 106730e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 106830e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 106930e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 107030e6033bSNicolas Vasilache return None; 107130e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 107230e6033bSNicolas Vasilache return None; 10732bf491c7SBenjamin Kramer } 107430e6033bSNicolas Vasilache return strides; 10752bf491c7SBenjamin Kramer } 10762bf491c7SBenjamin Kramer 1077*563879b6SRahul Joshi class VectorTypeCastOpConversion 1078*563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 10795c0c51a9SNicolas Vasilache public: 1080*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 10815c0c51a9SNicolas Vasilache 10823145427dSRiver Riddle LogicalResult 1083*563879b6SRahul Joshi matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 10845c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 1085*563879b6SRahul Joshi auto loc = castOp->getLoc(); 10865c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 10872bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 10885c0c51a9SNicolas Vasilache MemRefType targetMemRefType = 10892bdf33ccSRiver Riddle castOp.getResult().getType().cast<MemRefType>(); 10905c0c51a9SNicolas Vasilache 10915c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 10925c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 10935c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 10943145427dSRiver Riddle return failure(); 10955c0c51a9SNicolas Vasilache 10965c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 10972bdf33ccSRiver Riddle operands[0].getType().dyn_cast<LLVM::LLVMType>(); 10985c0c51a9SNicolas Vasilache if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 10993145427dSRiver Riddle return failure(); 11005c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11015c0c51a9SNicolas Vasilache 1102dcec2ca5SChristian Sigg auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 11035c0c51a9SNicolas Vasilache .dyn_cast_or_null<LLVM::LLVMType>(); 11045c0c51a9SNicolas Vasilache if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 11053145427dSRiver Riddle return failure(); 11065c0c51a9SNicolas Vasilache 110730e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 110830e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 110930e6033bSNicolas Vasilache if (!sourceStrides) 111030e6033bSNicolas Vasilache return failure(); 111130e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 111230e6033bSNicolas Vasilache if (!targetStrides) 111330e6033bSNicolas Vasilache return failure(); 111430e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 111530e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 111630e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 111730e6033bSNicolas Vasilache })) 11183145427dSRiver Riddle return failure(); 11195c0c51a9SNicolas Vasilache 11205446ec85SAlex Zinenko auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 11215c0c51a9SNicolas Vasilache 11225c0c51a9SNicolas Vasilache // Create descriptor. 11235c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11243a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11255c0c51a9SNicolas Vasilache // Set allocated ptr. 1126e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11275c0c51a9SNicolas Vasilache allocated = 11285c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11295c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11305c0c51a9SNicolas Vasilache // Set aligned ptr. 1131e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11325c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11335c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11345c0c51a9SNicolas Vasilache // Fill offset 0. 11355c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11365c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11375c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11385c0c51a9SNicolas Vasilache 11395c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11405c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11415c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11425c0c51a9SNicolas Vasilache auto sizeAttr = 11435c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11445c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11455c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 114630e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 114730e6033bSNicolas Vasilache (*targetStrides)[index]); 11485c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11495c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11505c0c51a9SNicolas Vasilache } 11515c0c51a9SNicolas Vasilache 1152*563879b6SRahul Joshi rewriter.replaceOp(castOp, {desc}); 11533145427dSRiver Riddle return success(); 11545c0c51a9SNicolas Vasilache } 11555c0c51a9SNicolas Vasilache }; 11565c0c51a9SNicolas Vasilache 11578345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 11588345b86dSNicolas Vasilache /// sequence of: 1159060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1160060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1161060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1162060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1163060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 11648345b86dSNicolas Vasilache template <typename ConcreteOp> 1165*563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 11668345b86dSNicolas Vasilache public: 1167*563879b6SRahul Joshi explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1168060c9dd1Saartbik bool enableIndexOpt) 1169*563879b6SRahul Joshi : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1170060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 11718345b86dSNicolas Vasilache 11728345b86dSNicolas Vasilache LogicalResult 1173*563879b6SRahul Joshi matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 11748345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11758345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1176b2c79c50SNicolas Vasilache 1177b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1178b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 11798345b86dSNicolas Vasilache return failure(); 11805f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 11815f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 11825f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 1183*563879b6SRahul Joshi xferOp->getContext())) 11848345b86dSNicolas Vasilache return failure(); 11852bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 118630e6033bSNicolas Vasilache auto strides = computeContiguousStrides(xferOp.getMemRefType()); 118730e6033bSNicolas Vasilache if (!strides) 11882bf491c7SBenjamin Kramer return failure(); 11898345b86dSNicolas Vasilache 1190*563879b6SRahul Joshi auto toLLVMTy = [&](Type t) { 1191*563879b6SRahul Joshi return this->getTypeConverter()->convertType(t); 1192*563879b6SRahul Joshi }; 11938345b86dSNicolas Vasilache 1194*563879b6SRahul Joshi Location loc = xferOp->getLoc(); 11958345b86dSNicolas Vasilache MemRefType memRefType = xferOp.getMemRefType(); 11968345b86dSNicolas Vasilache 119768330ee0SThomas Raoux if (auto memrefVectorElementType = 119868330ee0SThomas Raoux memRefType.getElementType().dyn_cast<VectorType>()) { 119968330ee0SThomas Raoux // Memref has vector element type. 120068330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 120168330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 120268330ee0SThomas Raoux return failure(); 12030de60b55SThomas Raoux #ifndef NDEBUG 120468330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 120568330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 120668330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 120768330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 120868330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 120968330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 121068330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 121168330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 121268330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 121368330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 121468330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 121568330ee0SThomas Raoux "result shape."); 12160de60b55SThomas Raoux #endif // ifndef NDEBUG 121768330ee0SThomas Raoux } 121868330ee0SThomas Raoux 12198345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1220be16075bSWen-Heng (Jack) Chung // The vector pointer would always be on address space 0, therefore 1221be16075bSWen-Heng (Jack) Chung // addrspacecast shall be used when source/dst memrefs are not on 1222be16075bSWen-Heng (Jack) Chung // address space 0. 12238345b86dSNicolas Vasilache // TODO: support alignment when possible. 1224*563879b6SRahul Joshi Value dataPtr = this->getStridedElementPtr( 1225*563879b6SRahul Joshi loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); 12268345b86dSNicolas Vasilache auto vecTy = 12278345b86dSNicolas Vasilache toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1228be16075bSWen-Heng (Jack) Chung Value vectorDataPtr; 1229be16075bSWen-Heng (Jack) Chung if (memRefType.getMemorySpace() == 0) 1230be16075bSWen-Heng (Jack) Chung vectorDataPtr = 12318345b86dSNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1232be16075bSWen-Heng (Jack) Chung else 1233be16075bSWen-Heng (Jack) Chung vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1234be16075bSWen-Heng (Jack) Chung loc, vecTy.getPointerTo(), dataPtr); 12358345b86dSNicolas Vasilache 12361870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 1237*563879b6SRahul Joshi return replaceTransferOpWithLoadOrStore(rewriter, 1238*563879b6SRahul Joshi *this->getTypeConverter(), loc, 1239*563879b6SRahul Joshi xferOp, operands, vectorDataPtr); 12401870e787SNicolas Vasilache 12418345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12428345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12438345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12448345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1245060c9dd1Saartbik // 1246060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1247060c9dd1Saartbik // dimensions here. 1248060c9dd1Saartbik unsigned vecWidth = vecTy.getVectorNumElements(); 1249060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12500c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 1251b2c79c50SNicolas Vasilache Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1252*563879b6SRahul Joshi Value mask = buildVectorComparison( 1253*563879b6SRahul Joshi rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 12548345b86dSNicolas Vasilache 12558345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 1256*563879b6SRahul Joshi return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1257dcec2ca5SChristian Sigg xferOp, operands, vectorDataPtr, mask); 12588345b86dSNicolas Vasilache } 1259060c9dd1Saartbik 1260060c9dd1Saartbik private: 1261060c9dd1Saartbik const bool enableIndexOptimizations; 12628345b86dSNicolas Vasilache }; 12638345b86dSNicolas Vasilache 1264*563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1265d9b500d3SAart Bik public: 1266*563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1267d9b500d3SAart Bik 1268d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1269d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1270d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1271d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1272d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1273d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1274d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1275d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1276d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1277d9b500d3SAart Bik // 12789db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1279d9b500d3SAart Bik // 12803145427dSRiver Riddle LogicalResult 1281*563879b6SRahul Joshi matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1282d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 12832d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1284d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1285d9b500d3SAart Bik 1286dcec2ca5SChristian Sigg if (typeConverter->convertType(printType) == nullptr) 12873145427dSRiver Riddle return failure(); 1288d9b500d3SAart Bik 1289b8880f5fSAart Bik // Make sure element type has runtime support. 1290b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1291d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1292d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1293d9b500d3SAart Bik Operation *printer; 1294b8880f5fSAart Bik if (eltType.isF32()) { 1295*563879b6SRahul Joshi printer = getPrintFloat(printOp); 1296b8880f5fSAart Bik } else if (eltType.isF64()) { 1297*563879b6SRahul Joshi printer = getPrintDouble(printOp); 129854759cefSAart Bik } else if (eltType.isIndex()) { 1299*563879b6SRahul Joshi printer = getPrintU64(printOp); 1300b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1301b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1302b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1303b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1304b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1305b8880f5fSAart Bik if (intTy.isUnsigned()) { 130654759cefSAart Bik if (width <= 64) { 1307b8880f5fSAart Bik if (width < 64) 1308b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1309*563879b6SRahul Joshi printer = getPrintU64(printOp); 1310b8880f5fSAart Bik } else { 13113145427dSRiver Riddle return failure(); 1312b8880f5fSAart Bik } 1313b8880f5fSAart Bik } else { 1314b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 131554759cefSAart Bik if (width <= 64) { 1316b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1317b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1318b8880f5fSAart Bik if (width == 1) 131954759cefSAart Bik conversion = PrintConversion::ZeroExt64; 132054759cefSAart Bik else if (width < 64) 1321b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1322*563879b6SRahul Joshi printer = getPrintI64(printOp); 1323b8880f5fSAart Bik } else { 1324b8880f5fSAart Bik return failure(); 1325b8880f5fSAart Bik } 1326b8880f5fSAart Bik } 1327b8880f5fSAart Bik } else { 1328b8880f5fSAart Bik return failure(); 1329b8880f5fSAart Bik } 1330d9b500d3SAart Bik 1331d9b500d3SAart Bik // Unroll vector into elementary print calls. 1332b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1333*563879b6SRahul Joshi emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1334b8880f5fSAart Bik conversion); 1335*563879b6SRahul Joshi emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1336*563879b6SRahul Joshi rewriter.eraseOp(printOp); 13373145427dSRiver Riddle return success(); 1338d9b500d3SAart Bik } 1339d9b500d3SAart Bik 1340d9b500d3SAart Bik private: 1341b8880f5fSAart Bik enum class PrintConversion { 134230e6033bSNicolas Vasilache // clang-format off 1343b8880f5fSAart Bik None, 1344b8880f5fSAart Bik ZeroExt64, 1345b8880f5fSAart Bik SignExt64 134630e6033bSNicolas Vasilache // clang-format on 1347b8880f5fSAart Bik }; 1348b8880f5fSAart Bik 1349d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1350e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1351b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1352d9b500d3SAart Bik Location loc = op->getLoc(); 1353d9b500d3SAart Bik if (rank == 0) { 1354b8880f5fSAart Bik switch (conversion) { 1355b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1356b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 1357b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1358b8880f5fSAart Bik break; 1359b8880f5fSAart Bik case PrintConversion::SignExt64: 1360b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 1361b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1362b8880f5fSAart Bik break; 1363b8880f5fSAart Bik case PrintConversion::None: 1364b8880f5fSAart Bik break; 1365c9eeeb38Saartbik } 1366d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1367d9b500d3SAart Bik return; 1368d9b500d3SAart Bik } 1369d9b500d3SAart Bik 1370d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1371d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1372d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1373d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1374d9b500d3SAart Bik auto reducedType = 1375d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1376dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType( 1377d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1378dcec2ca5SChristian Sigg Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1379dcec2ca5SChristian Sigg llvmType, rank, d); 1380b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1381b8880f5fSAart Bik conversion); 1382d9b500d3SAart Bik if (d != dim - 1) 1383d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1384d9b500d3SAart Bik } 1385d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1386d9b500d3SAart Bik } 1387d9b500d3SAart Bik 1388d9b500d3SAart Bik // Helper to emit a call. 1389d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1390d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 139108e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1392d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1393d9b500d3SAart Bik } 1394d9b500d3SAart Bik 1395d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 13965446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 13975446ec85SAlex Zinenko ArrayRef<LLVM::LLVMType> params) { 1398d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1399d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1400d9b500d3SAart Bik if (func) 1401d9b500d3SAart Bik return func; 1402d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1403d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1404d9b500d3SAart Bik op->getLoc(), name, 14055446ec85SAlex Zinenko LLVM::LLVMType::getFunctionTy( 14065446ec85SAlex Zinenko LLVM::LLVMType::getVoidTy(op->getContext()), params, 14075446ec85SAlex Zinenko /*isVarArg=*/false)); 1408d9b500d3SAart Bik } 1409d9b500d3SAart Bik 1410d9b500d3SAart Bik // Helpers for method names. 1411e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 141254759cefSAart Bik return getPrint(op, "printI64", 14135446ec85SAlex Zinenko LLVM::LLVMType::getInt64Ty(op->getContext())); 1414e52414b1Saartbik } 1415b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 1416b8880f5fSAart Bik return getPrint(op, "printU64", 1417b8880f5fSAart Bik LLVM::LLVMType::getInt64Ty(op->getContext())); 1418b8880f5fSAart Bik } 1419d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 142054759cefSAart Bik return getPrint(op, "printF32", 14215446ec85SAlex Zinenko LLVM::LLVMType::getFloatTy(op->getContext())); 1422d9b500d3SAart Bik } 1423d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 142454759cefSAart Bik return getPrint(op, "printF64", 14255446ec85SAlex Zinenko LLVM::LLVMType::getDoubleTy(op->getContext())); 1426d9b500d3SAart Bik } 1427d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 142854759cefSAart Bik return getPrint(op, "printOpen", {}); 1429d9b500d3SAart Bik } 1430d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 143154759cefSAart Bik return getPrint(op, "printClose", {}); 1432d9b500d3SAart Bik } 1433d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 143454759cefSAart Bik return getPrint(op, "printComma", {}); 1435d9b500d3SAart Bik } 1436d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 143754759cefSAart Bik return getPrint(op, "printNewline", {}); 1438d9b500d3SAart Bik } 1439d9b500d3SAart Bik }; 1440d9b500d3SAart Bik 1441334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1442c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1443c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1444c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1445334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 144665678d93SNicolas Vasilache public: 1447b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1448b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1449b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1450b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1451b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1452b99bd771SRiver Riddle } 145365678d93SNicolas Vasilache 1454334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 145565678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 145665678d93SNicolas Vasilache auto dstType = op.getResult().getType().cast<VectorType>(); 145765678d93SNicolas Vasilache 145865678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 145965678d93SNicolas Vasilache 146065678d93SNicolas Vasilache int64_t offset = 146165678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 146265678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 146365678d93SNicolas Vasilache int64_t stride = 146465678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 146565678d93SNicolas Vasilache 146665678d93SNicolas Vasilache auto loc = op.getLoc(); 146765678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 146835b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1469c3c95b9cSaartbik 1470c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1471c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1472c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1473c3c95b9cSaartbik offsets.reserve(size); 1474c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1475c3c95b9cSaartbik off += stride) 1476c3c95b9cSaartbik offsets.push_back(off); 1477c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1478c3c95b9cSaartbik op.vector(), 1479c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1480c3c95b9cSaartbik return success(); 1481c3c95b9cSaartbik } 1482c3c95b9cSaartbik 1483c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 148465678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 148565678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 148665678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 148765678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 148865678d93SNicolas Vasilache off += stride, ++idx) { 1489c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1490c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1491c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 149265678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 149365678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 149465678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 149565678d93SNicolas Vasilache } 1496c3c95b9cSaartbik rewriter.replaceOp(op, res); 14973145427dSRiver Riddle return success(); 149865678d93SNicolas Vasilache } 149965678d93SNicolas Vasilache }; 150065678d93SNicolas Vasilache 1501df186507SBenjamin Kramer } // namespace 1502df186507SBenjamin Kramer 15035c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 15045c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1505ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1506060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 150765678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 15088345b86dSNicolas Vasilache // clang-format off 1509681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1510681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 15112d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1512c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1513ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1514*563879b6SRahul Joshi converter, reassociateFPReductions); 1515060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1516060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1517060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1518*563879b6SRahul Joshi converter, enableIndexOptimizations); 15198345b86dSNicolas Vasilache patterns 1520ceb1b327Saartbik .insert<VectorShuffleOpConversion, 15218345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15228345b86dSNicolas Vasilache VectorExtractOpConversion, 15238345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15248345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15258345b86dSNicolas Vasilache VectorInsertOpConversion, 15268345b86dSNicolas Vasilache VectorPrintOpConversion, 152719dbb230Saartbik VectorTypeCastOpConversion, 152839379916Saartbik VectorMaskedLoadOpConversion, 152939379916Saartbik VectorMaskedStoreOpConversion, 153019dbb230Saartbik VectorGatherOpConversion, 1531e8dcf5f8Saartbik VectorScatterOpConversion, 1532e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1533*563879b6SRahul Joshi VectorCompressStoreOpConversion>(converter); 15348345b86dSNicolas Vasilache // clang-format on 15355c0c51a9SNicolas Vasilache } 15365c0c51a9SNicolas Vasilache 153763b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 153863b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1539*563879b6SRahul Joshi patterns.insert<VectorMatmulOpConversion>(converter); 1540*563879b6SRahul Joshi patterns.insert<VectorFlatTransposeOpConversion>(converter); 154163b683a8SNicolas Vasilache } 1542