15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 16*09f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 195c0c51a9SNicolas Vasilache 205c0c51a9SNicolas Vasilache using namespace mlir; 2165678d93SNicolas Vasilache using namespace mlir::vector; 225c0c51a9SNicolas Vasilache 239826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 259826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 269826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 279826fe5cSAart Bik } 289826fe5cSAart Bik 299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 319826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 329826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 339826fe5cSAart Bik } 349826fe5cSAart Bik 351c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 370f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 380f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 390f04384dSAlex Zinenko int64_t pos) { 401c81adf3SAart Bik if (rank == 1) { 411c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 421c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 430f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 441c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 451c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 461c81adf3SAart Bik constant); 471c81adf3SAart Bik } 481c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 491c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 501c81adf3SAart Bik } 511c81adf3SAart Bik 522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 542d515e49SNicolas Vasilache Value into, int64_t offset) { 552d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 562d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 572d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 582d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 592d515e49SNicolas Vasilache loc, vectorType, from, into, 602d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 612d515e49SNicolas Vasilache } 622d515e49SNicolas Vasilache 631c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 650f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 660f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 671c81adf3SAart Bik if (rank == 1) { 681c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 691c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 700f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 711c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 721c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 731c81adf3SAart Bik constant); 741c81adf3SAart Bik } 751c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 761c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 771c81adf3SAart Bik } 781c81adf3SAart Bik 792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 812d515e49SNicolas Vasilache int64_t offset) { 822d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 832d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 842d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 852d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 862d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 872d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 882d515e49SNicolas Vasilache } 892d515e49SNicolas Vasilache 902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 932d515e49SNicolas Vasilache unsigned dropFront = 0, 942d515e49SNicolas Vasilache unsigned dropBack = 0) { 952d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 962d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 972d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 982d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 992d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1002d515e49SNicolas Vasilache it != eit; ++it) 1012d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1022d515e49SNicolas Vasilache return res; 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 106060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107060c9dd1Saartbik // 108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 110060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 111060c9dd1Saartbik // is very conservative on the boundary conditions. 112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 114060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 115060c9dd1Saartbik auto loc = op->getLoc(); 116060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 117060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 118060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 119060c9dd1Saartbik Value indices; 120060c9dd1Saartbik Type idxType; 121060c9dd1Saartbik if (enableIndexOptimizations) { 1220c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1230c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1240c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125060c9dd1Saartbik idxType = rewriter.getI32Type(); 126060c9dd1Saartbik } else { 1270c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1280c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1290c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130060c9dd1Saartbik idxType = rewriter.getI64Type(); 131060c9dd1Saartbik } 132060c9dd1Saartbik // Add in an offset if requested. 133060c9dd1Saartbik if (off) { 134060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 137060c9dd1Saartbik } 138060c9dd1Saartbik // Construct the vector comparison. 139060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142060c9dd1Saartbik } 143060c9dd1Saartbik 14419dbb230Saartbik // Helper that returns data layout alignment of an operation with memref. 14519dbb230Saartbik template <typename T> 14619dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 14719dbb230Saartbik unsigned &align) { 1485f9e0466SNicolas Vasilache Type elementTy = 14919dbb230Saartbik typeConverter.convertType(op.getMemRefType().getElementType()); 1505f9e0466SNicolas Vasilache if (!elementTy) 1515f9e0466SNicolas Vasilache return failure(); 1525f9e0466SNicolas Vasilache 153b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 154b2ab375dSAlex Zinenko // stop depending on translation. 15587a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 15687a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 157b2ab375dSAlex Zinenko .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 158168213f9SAlex Zinenko typeConverter.getDataLayout()); 1595f9e0466SNicolas Vasilache return success(); 1605f9e0466SNicolas Vasilache } 1615f9e0466SNicolas Vasilache 162e8dcf5f8Saartbik // Helper that returns the base address of a memref. 163b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 164e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 16519dbb230Saartbik // Inspect stride and offset structure. 16619dbb230Saartbik // 16719dbb230Saartbik // TODO: flat memory only for now, generalize 16819dbb230Saartbik // 16919dbb230Saartbik int64_t offset; 17019dbb230Saartbik SmallVector<int64_t, 4> strides; 17119dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 17219dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 17319dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 17419dbb230Saartbik return failure(); 175e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 176e8dcf5f8Saartbik return success(); 177e8dcf5f8Saartbik } 17819dbb230Saartbik 179e8dcf5f8Saartbik // Helper that returns a pointer given a memref base. 180b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 181b98e25b6SBenjamin Kramer Location loc, Value memref, 182b98e25b6SBenjamin Kramer MemRefType memRefType, Value &ptr) { 183e8dcf5f8Saartbik Value base; 184e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 185e8dcf5f8Saartbik return failure(); 1863a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 187e8dcf5f8Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 188e8dcf5f8Saartbik return success(); 189e8dcf5f8Saartbik } 190e8dcf5f8Saartbik 19139379916Saartbik // Helper that returns a bit-casted pointer given a memref base. 192b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 193b98e25b6SBenjamin Kramer Location loc, Value memref, 194b98e25b6SBenjamin Kramer MemRefType memRefType, Type type, Value &ptr) { 19539379916Saartbik Value base; 19639379916Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 19739379916Saartbik return failure(); 19839379916Saartbik auto pType = type.template cast<LLVM::LLVMType>().getPointerTo(); 19939379916Saartbik base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 20039379916Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 20139379916Saartbik return success(); 20239379916Saartbik } 20339379916Saartbik 204e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index 205e8dcf5f8Saartbik // vector. 206b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 207b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 208b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 209b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 210e8dcf5f8Saartbik Value base; 211e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 212e8dcf5f8Saartbik return failure(); 2133a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 214e8dcf5f8Saartbik auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); 2151485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 21619dbb230Saartbik return success(); 21719dbb230Saartbik } 21819dbb230Saartbik 2195f9e0466SNicolas Vasilache static LogicalResult 2205f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2215f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2225f9e0466SNicolas Vasilache TransferReadOp xferOp, 2235f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 224affbc0cdSNicolas Vasilache unsigned align; 22519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 226affbc0cdSNicolas Vasilache return failure(); 227affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2285f9e0466SNicolas Vasilache return success(); 2295f9e0466SNicolas Vasilache } 2305f9e0466SNicolas Vasilache 2315f9e0466SNicolas Vasilache static LogicalResult 2325f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2335f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2345f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2355f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2365f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2375f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2385f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2395f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2405f9e0466SNicolas Vasilache 2415f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2425f9e0466SNicolas Vasilache if (!vecTy) 2435f9e0466SNicolas Vasilache return failure(); 2445f9e0466SNicolas Vasilache 2455f9e0466SNicolas Vasilache unsigned align; 24619dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2475f9e0466SNicolas Vasilache return failure(); 2485f9e0466SNicolas Vasilache 2495f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2505f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2515f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2525f9e0466SNicolas Vasilache return success(); 2535f9e0466SNicolas Vasilache } 2545f9e0466SNicolas Vasilache 2555f9e0466SNicolas Vasilache static LogicalResult 2565f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2575f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2585f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2595f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 260affbc0cdSNicolas Vasilache unsigned align; 26119dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 262affbc0cdSNicolas Vasilache return failure(); 2632d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 264affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 265affbc0cdSNicolas Vasilache align); 2665f9e0466SNicolas Vasilache return success(); 2675f9e0466SNicolas Vasilache } 2685f9e0466SNicolas Vasilache 2695f9e0466SNicolas Vasilache static LogicalResult 2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2715f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2725f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2735f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2745f9e0466SNicolas Vasilache unsigned align; 27519dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2765f9e0466SNicolas Vasilache return failure(); 2775f9e0466SNicolas Vasilache 2782d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2795f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2805f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2815f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2825f9e0466SNicolas Vasilache return success(); 2835f9e0466SNicolas Vasilache } 2845f9e0466SNicolas Vasilache 2852d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2862d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2872d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2885f9e0466SNicolas Vasilache } 2895f9e0466SNicolas Vasilache 2902d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2912d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2922d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2935f9e0466SNicolas Vasilache } 2945f9e0466SNicolas Vasilache 29590c01357SBenjamin Kramer namespace { 296e83b7b99Saartbik 29763b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 29863b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 29963b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern { 30063b683a8SNicolas Vasilache public: 30163b683a8SNicolas Vasilache explicit VectorMatmulOpConversion(MLIRContext *context, 30263b683a8SNicolas Vasilache LLVMTypeConverter &typeConverter) 30363b683a8SNicolas Vasilache : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 30463b683a8SNicolas Vasilache typeConverter) {} 30563b683a8SNicolas Vasilache 3063145427dSRiver Riddle LogicalResult 30763b683a8SNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 30863b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 30963b683a8SNicolas Vasilache auto matmulOp = cast<vector::MatmulOp>(op); 3102d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 31163b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 31263b683a8SNicolas Vasilache op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 31363b683a8SNicolas Vasilache adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 31463b683a8SNicolas Vasilache matmulOp.rhs_columns()); 3153145427dSRiver Riddle return success(); 31663b683a8SNicolas Vasilache } 31763b683a8SNicolas Vasilache }; 31863b683a8SNicolas Vasilache 319c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 320c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 321c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 322c295a65dSaartbik public: 323c295a65dSaartbik explicit VectorFlatTransposeOpConversion(MLIRContext *context, 324c295a65dSaartbik LLVMTypeConverter &typeConverter) 325c295a65dSaartbik : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 326c295a65dSaartbik context, typeConverter) {} 327c295a65dSaartbik 328c295a65dSaartbik LogicalResult 329c295a65dSaartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 330c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 331c295a65dSaartbik auto transOp = cast<vector::FlatTransposeOp>(op); 3322d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 333c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 334c295a65dSaartbik transOp, typeConverter.convertType(transOp.res().getType()), 335c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 336c295a65dSaartbik return success(); 337c295a65dSaartbik } 338c295a65dSaartbik }; 339c295a65dSaartbik 34039379916Saartbik /// Conversion pattern for a vector.maskedload. 34139379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { 34239379916Saartbik public: 34339379916Saartbik explicit VectorMaskedLoadOpConversion(MLIRContext *context, 34439379916Saartbik LLVMTypeConverter &typeConverter) 34539379916Saartbik : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, 34639379916Saartbik typeConverter) {} 34739379916Saartbik 34839379916Saartbik LogicalResult 34939379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 35039379916Saartbik ConversionPatternRewriter &rewriter) const override { 35139379916Saartbik auto loc = op->getLoc(); 35239379916Saartbik auto load = cast<vector::MaskedLoadOp>(op); 35339379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 35439379916Saartbik 35539379916Saartbik // Resolve alignment. 35639379916Saartbik unsigned align; 35739379916Saartbik if (failed(getMemRefAlignment(typeConverter, load, align))) 35839379916Saartbik return failure(); 35939379916Saartbik 36039379916Saartbik auto vtype = typeConverter.convertType(load.getResultVectorType()); 36139379916Saartbik Value ptr; 36239379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 36339379916Saartbik vtype, ptr))) 36439379916Saartbik return failure(); 36539379916Saartbik 36639379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 36739379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 36839379916Saartbik rewriter.getI32IntegerAttr(align)); 36939379916Saartbik return success(); 37039379916Saartbik } 37139379916Saartbik }; 37239379916Saartbik 37339379916Saartbik /// Conversion pattern for a vector.maskedstore. 37439379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { 37539379916Saartbik public: 37639379916Saartbik explicit VectorMaskedStoreOpConversion(MLIRContext *context, 37739379916Saartbik LLVMTypeConverter &typeConverter) 37839379916Saartbik : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, 37939379916Saartbik typeConverter) {} 38039379916Saartbik 38139379916Saartbik LogicalResult 38239379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 38339379916Saartbik ConversionPatternRewriter &rewriter) const override { 38439379916Saartbik auto loc = op->getLoc(); 38539379916Saartbik auto store = cast<vector::MaskedStoreOp>(op); 38639379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 38739379916Saartbik 38839379916Saartbik // Resolve alignment. 38939379916Saartbik unsigned align; 39039379916Saartbik if (failed(getMemRefAlignment(typeConverter, store, align))) 39139379916Saartbik return failure(); 39239379916Saartbik 39339379916Saartbik auto vtype = typeConverter.convertType(store.getValueVectorType()); 39439379916Saartbik Value ptr; 39539379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 39639379916Saartbik vtype, ptr))) 39739379916Saartbik return failure(); 39839379916Saartbik 39939379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 40039379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 40139379916Saartbik rewriter.getI32IntegerAttr(align)); 40239379916Saartbik return success(); 40339379916Saartbik } 40439379916Saartbik }; 40539379916Saartbik 40619dbb230Saartbik /// Conversion pattern for a vector.gather. 40719dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern { 40819dbb230Saartbik public: 40919dbb230Saartbik explicit VectorGatherOpConversion(MLIRContext *context, 41019dbb230Saartbik LLVMTypeConverter &typeConverter) 41119dbb230Saartbik : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, 41219dbb230Saartbik typeConverter) {} 41319dbb230Saartbik 41419dbb230Saartbik LogicalResult 41519dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 41619dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 41719dbb230Saartbik auto loc = op->getLoc(); 41819dbb230Saartbik auto gather = cast<vector::GatherOp>(op); 41919dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 42019dbb230Saartbik 42119dbb230Saartbik // Resolve alignment. 42219dbb230Saartbik unsigned align; 42319dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, gather, align))) 42419dbb230Saartbik return failure(); 42519dbb230Saartbik 42619dbb230Saartbik // Get index ptrs. 42719dbb230Saartbik VectorType vType = gather.getResultVectorType(); 42819dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 42919dbb230Saartbik Value ptrs; 430e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 431e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 43219dbb230Saartbik return failure(); 43319dbb230Saartbik 43419dbb230Saartbik // Replace with the gather intrinsic. 43519dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 4360c2a4d3cSBenjamin Kramer gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), 4370c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 43819dbb230Saartbik return success(); 43919dbb230Saartbik } 44019dbb230Saartbik }; 44119dbb230Saartbik 44219dbb230Saartbik /// Conversion pattern for a vector.scatter. 44319dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern { 44419dbb230Saartbik public: 44519dbb230Saartbik explicit VectorScatterOpConversion(MLIRContext *context, 44619dbb230Saartbik LLVMTypeConverter &typeConverter) 44719dbb230Saartbik : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, 44819dbb230Saartbik typeConverter) {} 44919dbb230Saartbik 45019dbb230Saartbik LogicalResult 45119dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 45219dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 45319dbb230Saartbik auto loc = op->getLoc(); 45419dbb230Saartbik auto scatter = cast<vector::ScatterOp>(op); 45519dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 45619dbb230Saartbik 45719dbb230Saartbik // Resolve alignment. 45819dbb230Saartbik unsigned align; 45919dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, scatter, align))) 46019dbb230Saartbik return failure(); 46119dbb230Saartbik 46219dbb230Saartbik // Get index ptrs. 46319dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 46419dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 46519dbb230Saartbik Value ptrs; 466e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 467e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 46819dbb230Saartbik return failure(); 46919dbb230Saartbik 47019dbb230Saartbik // Replace with the scatter intrinsic. 47119dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 47219dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 47319dbb230Saartbik rewriter.getI32IntegerAttr(align)); 47419dbb230Saartbik return success(); 47519dbb230Saartbik } 47619dbb230Saartbik }; 47719dbb230Saartbik 478e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 479e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { 480e8dcf5f8Saartbik public: 481e8dcf5f8Saartbik explicit VectorExpandLoadOpConversion(MLIRContext *context, 482e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 483e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, 484e8dcf5f8Saartbik typeConverter) {} 485e8dcf5f8Saartbik 486e8dcf5f8Saartbik LogicalResult 487e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 488e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 489e8dcf5f8Saartbik auto loc = op->getLoc(); 490e8dcf5f8Saartbik auto expand = cast<vector::ExpandLoadOp>(op); 491e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 492e8dcf5f8Saartbik 493e8dcf5f8Saartbik Value ptr; 494e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 495e8dcf5f8Saartbik ptr))) 496e8dcf5f8Saartbik return failure(); 497e8dcf5f8Saartbik 498e8dcf5f8Saartbik auto vType = expand.getResultVectorType(); 499e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 500e8dcf5f8Saartbik op, typeConverter.convertType(vType), ptr, adaptor.mask(), 501e8dcf5f8Saartbik adaptor.pass_thru()); 502e8dcf5f8Saartbik return success(); 503e8dcf5f8Saartbik } 504e8dcf5f8Saartbik }; 505e8dcf5f8Saartbik 506e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 507e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { 508e8dcf5f8Saartbik public: 509e8dcf5f8Saartbik explicit VectorCompressStoreOpConversion(MLIRContext *context, 510e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 511e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), 512e8dcf5f8Saartbik context, typeConverter) {} 513e8dcf5f8Saartbik 514e8dcf5f8Saartbik LogicalResult 515e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 516e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 517e8dcf5f8Saartbik auto loc = op->getLoc(); 518e8dcf5f8Saartbik auto compress = cast<vector::CompressStoreOp>(op); 519e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 520e8dcf5f8Saartbik 521e8dcf5f8Saartbik Value ptr; 522e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), 523e8dcf5f8Saartbik compress.getMemRefType(), ptr))) 524e8dcf5f8Saartbik return failure(); 525e8dcf5f8Saartbik 526e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 527e8dcf5f8Saartbik op, adaptor.value(), ptr, adaptor.mask()); 528e8dcf5f8Saartbik return success(); 529e8dcf5f8Saartbik } 530e8dcf5f8Saartbik }; 531e8dcf5f8Saartbik 53219dbb230Saartbik /// Conversion pattern for all vector reductions. 533870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern { 534e83b7b99Saartbik public: 535e83b7b99Saartbik explicit VectorReductionOpConversion(MLIRContext *context, 536ceb1b327Saartbik LLVMTypeConverter &typeConverter, 537060c9dd1Saartbik bool reassociateFPRed) 538870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 539ceb1b327Saartbik typeConverter), 540060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 541e83b7b99Saartbik 5423145427dSRiver Riddle LogicalResult 543e83b7b99Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 544e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 545e83b7b99Saartbik auto reductionOp = cast<vector::ReductionOp>(op); 546e83b7b99Saartbik auto kind = reductionOp.kind(); 547e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 5480f04384dSAlex Zinenko Type llvmType = typeConverter.convertType(eltType); 549e9628955SAart Bik if (eltType.isIntOrIndex()) { 550e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 551e83b7b99Saartbik if (kind == "add") 552322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 553e83b7b99Saartbik op, llvmType, operands[0]); 554e83b7b99Saartbik else if (kind == "mul") 555322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 556e83b7b99Saartbik op, llvmType, operands[0]); 557e9628955SAart Bik else if (kind == "min" && 558e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 559322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 560e9628955SAart Bik op, llvmType, operands[0]); 561e83b7b99Saartbik else if (kind == "min") 562322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 563e83b7b99Saartbik op, llvmType, operands[0]); 564e9628955SAart Bik else if (kind == "max" && 565e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 566322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 567e9628955SAart Bik op, llvmType, operands[0]); 568e83b7b99Saartbik else if (kind == "max") 569322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 570e83b7b99Saartbik op, llvmType, operands[0]); 571e83b7b99Saartbik else if (kind == "and") 572322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 573e83b7b99Saartbik op, llvmType, operands[0]); 574e83b7b99Saartbik else if (kind == "or") 575322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 576e83b7b99Saartbik op, llvmType, operands[0]); 577e83b7b99Saartbik else if (kind == "xor") 578322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 579e83b7b99Saartbik op, llvmType, operands[0]); 580e83b7b99Saartbik else 5813145427dSRiver Riddle return failure(); 5823145427dSRiver Riddle return success(); 583e83b7b99Saartbik 5842d76274bSBenjamin Kramer } else if (eltType.isa<FloatType>()) { 585e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 586e83b7b99Saartbik if (kind == "add") { 5870d924700Saartbik // Optional accumulator (or zero). 5880d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5890d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 5900d924700Saartbik op->getLoc(), llvmType, 5910d924700Saartbik rewriter.getZeroAttr(eltType)); 592322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 593ceb1b327Saartbik op, llvmType, acc, operands[0], 594ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 595e83b7b99Saartbik } else if (kind == "mul") { 5960d924700Saartbik // Optional accumulator (or one). 5970d924700Saartbik Value acc = operands.size() > 1 5980d924700Saartbik ? operands[1] 5990d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 6000d924700Saartbik op->getLoc(), llvmType, 6010d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 602322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 603ceb1b327Saartbik op, llvmType, acc, operands[0], 604ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 605e83b7b99Saartbik } else if (kind == "min") 606322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 607e83b7b99Saartbik op, llvmType, operands[0]); 608e83b7b99Saartbik else if (kind == "max") 609322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 610e83b7b99Saartbik op, llvmType, operands[0]); 611e83b7b99Saartbik else 6123145427dSRiver Riddle return failure(); 6133145427dSRiver Riddle return success(); 614e83b7b99Saartbik } 6153145427dSRiver Riddle return failure(); 616e83b7b99Saartbik } 617ceb1b327Saartbik 618ceb1b327Saartbik private: 619ceb1b327Saartbik const bool reassociateFPReductions; 620e83b7b99Saartbik }; 621e83b7b99Saartbik 622060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 623060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { 624060c9dd1Saartbik public: 625060c9dd1Saartbik explicit VectorCreateMaskOpConversion(MLIRContext *context, 626060c9dd1Saartbik LLVMTypeConverter &typeConverter, 627060c9dd1Saartbik bool enableIndexOpt) 628060c9dd1Saartbik : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, 629060c9dd1Saartbik typeConverter), 630060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 631060c9dd1Saartbik 632060c9dd1Saartbik LogicalResult 633060c9dd1Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 634060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 635060c9dd1Saartbik auto dstType = op->getResult(0).getType().cast<VectorType>(); 636060c9dd1Saartbik int64_t rank = dstType.getRank(); 637060c9dd1Saartbik if (rank == 1) { 638060c9dd1Saartbik rewriter.replaceOp( 639060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 640060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 641060c9dd1Saartbik return success(); 642060c9dd1Saartbik } 643060c9dd1Saartbik return failure(); 644060c9dd1Saartbik } 645060c9dd1Saartbik 646060c9dd1Saartbik private: 647060c9dd1Saartbik const bool enableIndexOptimizations; 648060c9dd1Saartbik }; 649060c9dd1Saartbik 650870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern { 6511c81adf3SAart Bik public: 6521c81adf3SAart Bik explicit VectorShuffleOpConversion(MLIRContext *context, 6531c81adf3SAart Bik LLVMTypeConverter &typeConverter) 654870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 6551c81adf3SAart Bik typeConverter) {} 6561c81adf3SAart Bik 6573145427dSRiver Riddle LogicalResult 658e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 6591c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 6601c81adf3SAart Bik auto loc = op->getLoc(); 6612d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6621c81adf3SAart Bik auto shuffleOp = cast<vector::ShuffleOp>(op); 6631c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6641c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6651c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 6660f04384dSAlex Zinenko Type llvmType = typeConverter.convertType(vectorType); 6671c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6681c81adf3SAart Bik 6691c81adf3SAart Bik // Bail if result type cannot be lowered. 6701c81adf3SAart Bik if (!llvmType) 6713145427dSRiver Riddle return failure(); 6721c81adf3SAart Bik 6731c81adf3SAart Bik // Get rank and dimension sizes. 6741c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6751c81adf3SAart Bik assert(v1Type.getRank() == rank); 6761c81adf3SAart Bik assert(v2Type.getRank() == rank); 6771c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6781c81adf3SAart Bik 6791c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6801c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6811c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 682e62a6956SRiver Riddle Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 6831c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 6841c81adf3SAart Bik rewriter.replaceOp(op, shuffle); 6853145427dSRiver Riddle return success(); 686b36aaeafSAart Bik } 687b36aaeafSAart Bik 6881c81adf3SAart Bik // For all other cases, insert the individual values individually. 689e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6901c81adf3SAart Bik int64_t insPos = 0; 6911c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6921c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 693e62a6956SRiver Riddle Value value = adaptor.v1(); 6941c81adf3SAart Bik if (extPos >= v1Dim) { 6951c81adf3SAart Bik extPos -= v1Dim; 6961c81adf3SAart Bik value = adaptor.v2(); 697b36aaeafSAart Bik } 6980f04384dSAlex Zinenko Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 6990f04384dSAlex Zinenko rank, extPos); 7000f04384dSAlex Zinenko insert = insertOne(rewriter, typeConverter, loc, insert, extract, 7010f04384dSAlex Zinenko llvmType, rank, insPos++); 7021c81adf3SAart Bik } 7031c81adf3SAart Bik rewriter.replaceOp(op, insert); 7043145427dSRiver Riddle return success(); 705b36aaeafSAart Bik } 706b36aaeafSAart Bik }; 707b36aaeafSAart Bik 708870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 709cd5dab8aSAart Bik public: 710cd5dab8aSAart Bik explicit VectorExtractElementOpConversion(MLIRContext *context, 711cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 712870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 713870c1fd4SAlex Zinenko context, typeConverter) {} 714cd5dab8aSAart Bik 7153145427dSRiver Riddle LogicalResult 716e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 717cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7182d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 719cd5dab8aSAart Bik auto extractEltOp = cast<vector::ExtractElementOp>(op); 720cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 7210f04384dSAlex Zinenko auto llvmType = typeConverter.convertType(vectorType.getElementType()); 722cd5dab8aSAart Bik 723cd5dab8aSAart Bik // Bail if result type cannot be lowered. 724cd5dab8aSAart Bik if (!llvmType) 7253145427dSRiver Riddle return failure(); 726cd5dab8aSAart Bik 727cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 728cd5dab8aSAart Bik op, llvmType, adaptor.vector(), adaptor.position()); 7293145427dSRiver Riddle return success(); 730cd5dab8aSAart Bik } 731cd5dab8aSAart Bik }; 732cd5dab8aSAart Bik 733870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern { 7345c0c51a9SNicolas Vasilache public: 7359826fe5cSAart Bik explicit VectorExtractOpConversion(MLIRContext *context, 7365c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 737870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 7385c0c51a9SNicolas Vasilache typeConverter) {} 7395c0c51a9SNicolas Vasilache 7403145427dSRiver Riddle LogicalResult 741e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 7425c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7435c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 7442d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 745d37f2725SAart Bik auto extractOp = cast<vector::ExtractOp>(op); 7469826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7472bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 7480f04384dSAlex Zinenko auto llvmResultType = typeConverter.convertType(resultType); 7495c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7509826fe5cSAart Bik 7519826fe5cSAart Bik // Bail if result type cannot be lowered. 7529826fe5cSAart Bik if (!llvmResultType) 7533145427dSRiver Riddle return failure(); 7549826fe5cSAart Bik 7555c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7565c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 757e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7585c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 7595c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7603145427dSRiver Riddle return success(); 7615c0c51a9SNicolas Vasilache } 7625c0c51a9SNicolas Vasilache 7639826fe5cSAart Bik // Potential extraction of 1-D vector from array. 7645c0c51a9SNicolas Vasilache auto *context = op->getContext(); 765e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7665c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7675c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7689826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7695c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7705c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7715c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 7720f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 7735c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7745c0c51a9SNicolas Vasilache } 7755c0c51a9SNicolas Vasilache 7765c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7775c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7785446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 7791d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7805c0c51a9SNicolas Vasilache extracted = 7815c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 7825c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7835c0c51a9SNicolas Vasilache 7843145427dSRiver Riddle return success(); 7855c0c51a9SNicolas Vasilache } 7865c0c51a9SNicolas Vasilache }; 7875c0c51a9SNicolas Vasilache 788681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 789681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 790681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 791681f929fSNicolas Vasilache /// 792681f929fSNicolas Vasilache /// Example: 793681f929fSNicolas Vasilache /// ``` 794681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 795681f929fSNicolas Vasilache /// ``` 796681f929fSNicolas Vasilache /// is converted to: 797681f929fSNicolas Vasilache /// ``` 7983bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 799681f929fSNicolas Vasilache /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 800681f929fSNicolas Vasilache /// -> !llvm<"<8 x float>"> 801681f929fSNicolas Vasilache /// ``` 802870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 803681f929fSNicolas Vasilache public: 804681f929fSNicolas Vasilache explicit VectorFMAOp1DConversion(MLIRContext *context, 805681f929fSNicolas Vasilache LLVMTypeConverter &typeConverter) 806870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 807681f929fSNicolas Vasilache typeConverter) {} 808681f929fSNicolas Vasilache 8093145427dSRiver Riddle LogicalResult 810681f929fSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 8122d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 813681f929fSNicolas Vasilache vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 814681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 815681f929fSNicolas Vasilache if (vType.getRank() != 1) 8163145427dSRiver Riddle return failure(); 8173bffe602SBenjamin Kramer rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(), 8183bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 8193145427dSRiver Riddle return success(); 820681f929fSNicolas Vasilache } 821681f929fSNicolas Vasilache }; 822681f929fSNicolas Vasilache 823870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 824cd5dab8aSAart Bik public: 825cd5dab8aSAart Bik explicit VectorInsertElementOpConversion(MLIRContext *context, 826cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 827870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 828870c1fd4SAlex Zinenko context, typeConverter) {} 829cd5dab8aSAart Bik 8303145427dSRiver Riddle LogicalResult 831e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 832cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 8332d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 834cd5dab8aSAart Bik auto insertEltOp = cast<vector::InsertElementOp>(op); 835cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 8360f04384dSAlex Zinenko auto llvmType = typeConverter.convertType(vectorType); 837cd5dab8aSAart Bik 838cd5dab8aSAart Bik // Bail if result type cannot be lowered. 839cd5dab8aSAart Bik if (!llvmType) 8403145427dSRiver Riddle return failure(); 841cd5dab8aSAart Bik 842cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 843cd5dab8aSAart Bik op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 8443145427dSRiver Riddle return success(); 845cd5dab8aSAart Bik } 846cd5dab8aSAart Bik }; 847cd5dab8aSAart Bik 848870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern { 8499826fe5cSAart Bik public: 8509826fe5cSAart Bik explicit VectorInsertOpConversion(MLIRContext *context, 8519826fe5cSAart Bik LLVMTypeConverter &typeConverter) 852870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 8539826fe5cSAart Bik typeConverter) {} 8549826fe5cSAart Bik 8553145427dSRiver Riddle LogicalResult 856e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 8579826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 8589826fe5cSAart Bik auto loc = op->getLoc(); 8592d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8609826fe5cSAart Bik auto insertOp = cast<vector::InsertOp>(op); 8619826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8629826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 8630f04384dSAlex Zinenko auto llvmResultType = typeConverter.convertType(destVectorType); 8649826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8659826fe5cSAart Bik 8669826fe5cSAart Bik // Bail if result type cannot be lowered. 8679826fe5cSAart Bik if (!llvmResultType) 8683145427dSRiver Riddle return failure(); 8699826fe5cSAart Bik 8709826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8719826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 872e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8739826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8749826fe5cSAart Bik positionArrayAttr); 8759826fe5cSAart Bik rewriter.replaceOp(op, inserted); 8763145427dSRiver Riddle return success(); 8779826fe5cSAart Bik } 8789826fe5cSAart Bik 8799826fe5cSAart Bik // Potential extraction of 1-D vector from array. 8809826fe5cSAart Bik auto *context = op->getContext(); 881e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8829826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8839826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8849826fe5cSAart Bik auto oneDVectorType = destVectorType; 8859826fe5cSAart Bik if (positionAttrs.size() > 1) { 8869826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8879826fe5cSAart Bik auto nMinusOnePositionAttrs = 8889826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8899826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 8900f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 8919826fe5cSAart Bik nMinusOnePositionAttrs); 8929826fe5cSAart Bik } 8939826fe5cSAart Bik 8949826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 8955446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 8961d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 897e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 8980f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 8990f04384dSAlex Zinenko adaptor.source(), constant); 9009826fe5cSAart Bik 9019826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 9029826fe5cSAart Bik if (positionAttrs.size() > 1) { 9039826fe5cSAart Bik auto nMinusOnePositionAttrs = 9049826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 9059826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 9069826fe5cSAart Bik adaptor.dest(), inserted, 9079826fe5cSAart Bik nMinusOnePositionAttrs); 9089826fe5cSAart Bik } 9099826fe5cSAart Bik 9109826fe5cSAart Bik rewriter.replaceOp(op, inserted); 9113145427dSRiver Riddle return success(); 9129826fe5cSAart Bik } 9139826fe5cSAart Bik }; 9149826fe5cSAart Bik 915681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 916681f929fSNicolas Vasilache /// 917681f929fSNicolas Vasilache /// Example: 918681f929fSNicolas Vasilache /// ``` 919681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 920681f929fSNicolas Vasilache /// ``` 921681f929fSNicolas Vasilache /// is rewritten into: 922681f929fSNicolas Vasilache /// ``` 923681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 924681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 925681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 926681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 927681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 928681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 929681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 930681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 931681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 932681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 933681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 934681f929fSNicolas Vasilache /// // %r3 holds the final value. 935681f929fSNicolas Vasilache /// ``` 936681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 937681f929fSNicolas Vasilache public: 938681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 939681f929fSNicolas Vasilache 9403145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 941681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 942681f929fSNicolas Vasilache auto vType = op.getVectorType(); 943681f929fSNicolas Vasilache if (vType.getRank() < 2) 9443145427dSRiver Riddle return failure(); 945681f929fSNicolas Vasilache 946681f929fSNicolas Vasilache auto loc = op.getLoc(); 947681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 948681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 949681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 950681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 951681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 952681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 953681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 954681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 955681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 956681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 957681f929fSNicolas Vasilache } 958681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9593145427dSRiver Riddle return success(); 960681f929fSNicolas Vasilache } 961681f929fSNicolas Vasilache }; 962681f929fSNicolas Vasilache 9632d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9642d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9652d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9662d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9672d515e49SNicolas Vasilache // rank. 9682d515e49SNicolas Vasilache // 9692d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9702d515e49SNicolas Vasilache // have different ranks. In this case: 9712d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9722d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9732d515e49SNicolas Vasilache // destination subvector 9742d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9752d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9762d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9772d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9782d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9792d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9802d515e49SNicolas Vasilache public: 9812d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9822d515e49SNicolas Vasilache 9833145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9842d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9852d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9862d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9872d515e49SNicolas Vasilache 9882d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9893145427dSRiver Riddle return failure(); 9902d515e49SNicolas Vasilache 9912d515e49SNicolas Vasilache auto loc = op.getLoc(); 9922d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9932d515e49SNicolas Vasilache assert(rankDiff >= 0); 9942d515e49SNicolas Vasilache if (rankDiff == 0) 9953145427dSRiver Riddle return failure(); 9962d515e49SNicolas Vasilache 9972d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9982d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 9992d515e49SNicolas Vasilache // on it. 10002d515e49SNicolas Vasilache Value extracted = 10012d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 10022d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 10032d515e49SNicolas Vasilache /*dropFront=*/rankRest)); 10042d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 10052d515e49SNicolas Vasilache // ranks. 10062d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 10072d515e49SNicolas Vasilache loc, op.source(), extracted, 10082d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 1009c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 10102d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 10112d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 10122d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 10132d515e49SNicolas Vasilache /*dropFront=*/rankRest)); 10143145427dSRiver Riddle return success(); 10152d515e49SNicolas Vasilache } 10162d515e49SNicolas Vasilache }; 10172d515e49SNicolas Vasilache 10182d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 10192d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 10202d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 10212d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 10222d515e49SNicolas Vasilache // destination subvector 10232d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 10242d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 10252d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 10262d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 10272d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 10282d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 10292d515e49SNicolas Vasilache public: 1030b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 1031b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 1032b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 1033b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 1034b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1035b99bd771SRiver Riddle } 10362d515e49SNicolas Vasilache 10373145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 10382d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10392d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10402d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10412d515e49SNicolas Vasilache 10422d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10433145427dSRiver Riddle return failure(); 10442d515e49SNicolas Vasilache 10452d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10462d515e49SNicolas Vasilache assert(rankDiff >= 0); 10472d515e49SNicolas Vasilache if (rankDiff != 0) 10483145427dSRiver Riddle return failure(); 10492d515e49SNicolas Vasilache 10502d515e49SNicolas Vasilache if (srcType == dstType) { 10512d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10523145427dSRiver Riddle return success(); 10532d515e49SNicolas Vasilache } 10542d515e49SNicolas Vasilache 10552d515e49SNicolas Vasilache int64_t offset = 10562d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10572d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10582d515e49SNicolas Vasilache int64_t stride = 10592d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10602d515e49SNicolas Vasilache 10612d515e49SNicolas Vasilache auto loc = op.getLoc(); 10622d515e49SNicolas Vasilache Value res = op.dest(); 10632d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10642d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10652d515e49SNicolas Vasilache off += stride, ++idx) { 10662d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10672d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10682d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10692d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10702d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10712d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10722d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10732d515e49SNicolas Vasilache // smaller rank. 1074bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10752d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10762d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10772d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10782d515e49SNicolas Vasilache } 10792d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10802d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10812d515e49SNicolas Vasilache } 10822d515e49SNicolas Vasilache 10832d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10843145427dSRiver Riddle return success(); 10852d515e49SNicolas Vasilache } 10862d515e49SNicolas Vasilache }; 10872d515e49SNicolas Vasilache 108830e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 108930e6033bSNicolas Vasilache /// static layout. 109030e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 109130e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10922bf491c7SBenjamin Kramer int64_t offset; 109330e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 109430e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 109530e6033bSNicolas Vasilache return None; 109630e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 109730e6033bSNicolas Vasilache return None; 109830e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 109930e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 110030e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 110130e6033bSNicolas Vasilache return strides; 110230e6033bSNicolas Vasilache 110330e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 110430e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 110530e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 110630e6033bSNicolas Vasilache // layout. 11072bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 11082bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 110930e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 111030e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 111130e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 111230e6033bSNicolas Vasilache return None; 111330e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 111430e6033bSNicolas Vasilache return None; 11152bf491c7SBenjamin Kramer } 111630e6033bSNicolas Vasilache return strides; 11172bf491c7SBenjamin Kramer } 11182bf491c7SBenjamin Kramer 1119870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 11205c0c51a9SNicolas Vasilache public: 11215c0c51a9SNicolas Vasilache explicit VectorTypeCastOpConversion(MLIRContext *context, 11225c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 1123870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 11245c0c51a9SNicolas Vasilache typeConverter) {} 11255c0c51a9SNicolas Vasilache 11263145427dSRiver Riddle LogicalResult 1127e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 11285c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11295c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 11305c0c51a9SNicolas Vasilache vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 11315c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 11322bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 11335c0c51a9SNicolas Vasilache MemRefType targetMemRefType = 11342bdf33ccSRiver Riddle castOp.getResult().getType().cast<MemRefType>(); 11355c0c51a9SNicolas Vasilache 11365c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 11375c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 11385c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 11393145427dSRiver Riddle return failure(); 11405c0c51a9SNicolas Vasilache 11415c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 11422bdf33ccSRiver Riddle operands[0].getType().dyn_cast<LLVM::LLVMType>(); 11435c0c51a9SNicolas Vasilache if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 11443145427dSRiver Riddle return failure(); 11455c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11465c0c51a9SNicolas Vasilache 11470f04384dSAlex Zinenko auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 11485c0c51a9SNicolas Vasilache .dyn_cast_or_null<LLVM::LLVMType>(); 11495c0c51a9SNicolas Vasilache if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 11503145427dSRiver Riddle return failure(); 11515c0c51a9SNicolas Vasilache 115230e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 115330e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 115430e6033bSNicolas Vasilache if (!sourceStrides) 115530e6033bSNicolas Vasilache return failure(); 115630e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 115730e6033bSNicolas Vasilache if (!targetStrides) 115830e6033bSNicolas Vasilache return failure(); 115930e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 116030e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 116130e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 116230e6033bSNicolas Vasilache })) 11633145427dSRiver Riddle return failure(); 11645c0c51a9SNicolas Vasilache 11655446ec85SAlex Zinenko auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 11665c0c51a9SNicolas Vasilache 11675c0c51a9SNicolas Vasilache // Create descriptor. 11685c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11693a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11705c0c51a9SNicolas Vasilache // Set allocated ptr. 1171e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11725c0c51a9SNicolas Vasilache allocated = 11735c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11745c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11755c0c51a9SNicolas Vasilache // Set aligned ptr. 1176e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11775c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11785c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11795c0c51a9SNicolas Vasilache // Fill offset 0. 11805c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11815c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11825c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11835c0c51a9SNicolas Vasilache 11845c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11855c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11865c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11875c0c51a9SNicolas Vasilache auto sizeAttr = 11885c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11895c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11905c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 119130e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 119230e6033bSNicolas Vasilache (*targetStrides)[index]); 11935c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11945c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11955c0c51a9SNicolas Vasilache } 11965c0c51a9SNicolas Vasilache 11975c0c51a9SNicolas Vasilache rewriter.replaceOp(op, {desc}); 11983145427dSRiver Riddle return success(); 11995c0c51a9SNicolas Vasilache } 12005c0c51a9SNicolas Vasilache }; 12015c0c51a9SNicolas Vasilache 12028345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 12038345b86dSNicolas Vasilache /// sequence of: 1204060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1205060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1206060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1207060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1208060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 12098345b86dSNicolas Vasilache template <typename ConcreteOp> 12108345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern { 12118345b86dSNicolas Vasilache public: 12128345b86dSNicolas Vasilache explicit VectorTransferConversion(MLIRContext *context, 1213060c9dd1Saartbik LLVMTypeConverter &typeConv, 1214060c9dd1Saartbik bool enableIndexOpt) 1215060c9dd1Saartbik : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), 1216060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 12178345b86dSNicolas Vasilache 12188345b86dSNicolas Vasilache LogicalResult 12198345b86dSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 12208345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 12218345b86dSNicolas Vasilache auto xferOp = cast<ConcreteOp>(op); 12228345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1223b2c79c50SNicolas Vasilache 1224b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1225b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 12268345b86dSNicolas Vasilache return failure(); 12275f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 12285f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 12295f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 12305f9e0466SNicolas Vasilache op->getContext())) 12318345b86dSNicolas Vasilache return failure(); 12322bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 123330e6033bSNicolas Vasilache auto strides = computeContiguousStrides(xferOp.getMemRefType()); 123430e6033bSNicolas Vasilache if (!strides) 12352bf491c7SBenjamin Kramer return failure(); 12368345b86dSNicolas Vasilache 12378345b86dSNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 12388345b86dSNicolas Vasilache 12398345b86dSNicolas Vasilache Location loc = op->getLoc(); 12408345b86dSNicolas Vasilache MemRefType memRefType = xferOp.getMemRefType(); 12418345b86dSNicolas Vasilache 124268330ee0SThomas Raoux if (auto memrefVectorElementType = 124368330ee0SThomas Raoux memRefType.getElementType().dyn_cast<VectorType>()) { 124468330ee0SThomas Raoux // Memref has vector element type. 124568330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 124668330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 124768330ee0SThomas Raoux return failure(); 12480de60b55SThomas Raoux #ifndef NDEBUG 124968330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 125068330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 125168330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 125268330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 125368330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 125468330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 125568330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 125668330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 125768330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 125868330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 125968330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 126068330ee0SThomas Raoux "result shape."); 12610de60b55SThomas Raoux #endif // ifndef NDEBUG 126268330ee0SThomas Raoux } 126368330ee0SThomas Raoux 12648345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1265be16075bSWen-Heng (Jack) Chung // The vector pointer would always be on address space 0, therefore 1266be16075bSWen-Heng (Jack) Chung // addrspacecast shall be used when source/dst memrefs are not on 1267be16075bSWen-Heng (Jack) Chung // address space 0. 12688345b86dSNicolas Vasilache // TODO: support alignment when possible. 12698b97e17dSChristian Sigg Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), 1270d3a98076SAlex Zinenko adaptor.indices(), rewriter); 12718345b86dSNicolas Vasilache auto vecTy = 12728345b86dSNicolas Vasilache toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1273be16075bSWen-Heng (Jack) Chung Value vectorDataPtr; 1274be16075bSWen-Heng (Jack) Chung if (memRefType.getMemorySpace() == 0) 1275be16075bSWen-Heng (Jack) Chung vectorDataPtr = 12768345b86dSNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1277be16075bSWen-Heng (Jack) Chung else 1278be16075bSWen-Heng (Jack) Chung vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1279be16075bSWen-Heng (Jack) Chung loc, vecTy.getPointerTo(), dataPtr); 12808345b86dSNicolas Vasilache 12811870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 12821870e787SNicolas Vasilache return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 12831870e787SNicolas Vasilache xferOp, operands, vectorDataPtr); 12841870e787SNicolas Vasilache 12858345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12868345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12878345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12888345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1289060c9dd1Saartbik // 1290060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1291060c9dd1Saartbik // dimensions here. 1292060c9dd1Saartbik unsigned vecWidth = vecTy.getVectorNumElements(); 1293060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12940c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 1295b2c79c50SNicolas Vasilache Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1296060c9dd1Saartbik Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, 1297060c9dd1Saartbik vecWidth, dim, &off); 12988345b86dSNicolas Vasilache 12998345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 13001870e787SNicolas Vasilache return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 1301a99f62c4SAlex Zinenko operands, vectorDataPtr, mask); 13028345b86dSNicolas Vasilache } 1303060c9dd1Saartbik 1304060c9dd1Saartbik private: 1305060c9dd1Saartbik const bool enableIndexOptimizations; 13068345b86dSNicolas Vasilache }; 13078345b86dSNicolas Vasilache 1308870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern { 1309d9b500d3SAart Bik public: 1310d9b500d3SAart Bik explicit VectorPrintOpConversion(MLIRContext *context, 1311d9b500d3SAart Bik LLVMTypeConverter &typeConverter) 1312870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 1313d9b500d3SAart Bik typeConverter) {} 1314d9b500d3SAart Bik 1315d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1316d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1317d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1318d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1319d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1320d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1321d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1322d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1323d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1324d9b500d3SAart Bik // 13259db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1326d9b500d3SAart Bik // 13273145427dSRiver Riddle LogicalResult 1328e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1329d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 1330d9b500d3SAart Bik auto printOp = cast<vector::PrintOp>(op); 13312d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1332d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1333d9b500d3SAart Bik 13340f04384dSAlex Zinenko if (typeConverter.convertType(printType) == nullptr) 13353145427dSRiver Riddle return failure(); 1336d9b500d3SAart Bik 1337b8880f5fSAart Bik // Make sure element type has runtime support. 1338b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1339d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1340d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1341d9b500d3SAart Bik Operation *printer; 1342b8880f5fSAart Bik if (eltType.isF32()) { 1343d9b500d3SAart Bik printer = getPrintFloat(op); 1344b8880f5fSAart Bik } else if (eltType.isF64()) { 1345d9b500d3SAart Bik printer = getPrintDouble(op); 134654759cefSAart Bik } else if (eltType.isIndex()) { 134754759cefSAart Bik printer = getPrintU64(op); 1348b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1349b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1350b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1351b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1352b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1353b8880f5fSAart Bik if (intTy.isUnsigned()) { 135454759cefSAart Bik if (width <= 64) { 1355b8880f5fSAart Bik if (width < 64) 1356b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1357b8880f5fSAart Bik printer = getPrintU64(op); 1358b8880f5fSAart Bik } else { 13593145427dSRiver Riddle return failure(); 1360b8880f5fSAart Bik } 1361b8880f5fSAart Bik } else { 1362b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 136354759cefSAart Bik if (width <= 64) { 1364b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1365b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1366b8880f5fSAart Bik if (width == 1) 136754759cefSAart Bik conversion = PrintConversion::ZeroExt64; 136854759cefSAart Bik else if (width < 64) 1369b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1370b8880f5fSAart Bik printer = getPrintI64(op); 1371b8880f5fSAart Bik } else { 1372b8880f5fSAart Bik return failure(); 1373b8880f5fSAart Bik } 1374b8880f5fSAart Bik } 1375b8880f5fSAart Bik } else { 1376b8880f5fSAart Bik return failure(); 1377b8880f5fSAart Bik } 1378d9b500d3SAart Bik 1379d9b500d3SAart Bik // Unroll vector into elementary print calls. 1380b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1381b8880f5fSAart Bik emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, 1382b8880f5fSAart Bik conversion); 1383d9b500d3SAart Bik emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1384d9b500d3SAart Bik rewriter.eraseOp(op); 13853145427dSRiver Riddle return success(); 1386d9b500d3SAart Bik } 1387d9b500d3SAart Bik 1388d9b500d3SAart Bik private: 1389b8880f5fSAart Bik enum class PrintConversion { 139030e6033bSNicolas Vasilache // clang-format off 1391b8880f5fSAart Bik None, 1392b8880f5fSAart Bik ZeroExt64, 1393b8880f5fSAart Bik SignExt64 139430e6033bSNicolas Vasilache // clang-format on 1395b8880f5fSAart Bik }; 1396b8880f5fSAart Bik 1397d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1398e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1399b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1400d9b500d3SAart Bik Location loc = op->getLoc(); 1401d9b500d3SAart Bik if (rank == 0) { 1402b8880f5fSAart Bik switch (conversion) { 1403b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1404b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 1405b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1406b8880f5fSAart Bik break; 1407b8880f5fSAart Bik case PrintConversion::SignExt64: 1408b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 1409b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1410b8880f5fSAart Bik break; 1411b8880f5fSAart Bik case PrintConversion::None: 1412b8880f5fSAart Bik break; 1413c9eeeb38Saartbik } 1414d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1415d9b500d3SAart Bik return; 1416d9b500d3SAart Bik } 1417d9b500d3SAart Bik 1418d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1419d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1420d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1421d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1422d9b500d3SAart Bik auto reducedType = 1423d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 14240f04384dSAlex Zinenko auto llvmType = typeConverter.convertType( 1425d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1426e62a6956SRiver Riddle Value nestedVal = 14270f04384dSAlex Zinenko extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1428b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1429b8880f5fSAart Bik conversion); 1430d9b500d3SAart Bik if (d != dim - 1) 1431d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1432d9b500d3SAart Bik } 1433d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1434d9b500d3SAart Bik } 1435d9b500d3SAart Bik 1436d9b500d3SAart Bik // Helper to emit a call. 1437d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1438d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 143908e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1440d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1441d9b500d3SAart Bik } 1442d9b500d3SAart Bik 1443d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 14445446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 14455446ec85SAlex Zinenko ArrayRef<LLVM::LLVMType> params) { 1446d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1447d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1448d9b500d3SAart Bik if (func) 1449d9b500d3SAart Bik return func; 1450d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1451d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1452d9b500d3SAart Bik op->getLoc(), name, 14535446ec85SAlex Zinenko LLVM::LLVMType::getFunctionTy( 14545446ec85SAlex Zinenko LLVM::LLVMType::getVoidTy(op->getContext()), params, 14555446ec85SAlex Zinenko /*isVarArg=*/false)); 1456d9b500d3SAart Bik } 1457d9b500d3SAart Bik 1458d9b500d3SAart Bik // Helpers for method names. 1459e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 146054759cefSAart Bik return getPrint(op, "printI64", 14615446ec85SAlex Zinenko LLVM::LLVMType::getInt64Ty(op->getContext())); 1462e52414b1Saartbik } 1463b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 1464b8880f5fSAart Bik return getPrint(op, "printU64", 1465b8880f5fSAart Bik LLVM::LLVMType::getInt64Ty(op->getContext())); 1466b8880f5fSAart Bik } 1467d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 146854759cefSAart Bik return getPrint(op, "printF32", 14695446ec85SAlex Zinenko LLVM::LLVMType::getFloatTy(op->getContext())); 1470d9b500d3SAart Bik } 1471d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 147254759cefSAart Bik return getPrint(op, "printF64", 14735446ec85SAlex Zinenko LLVM::LLVMType::getDoubleTy(op->getContext())); 1474d9b500d3SAart Bik } 1475d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 147654759cefSAart Bik return getPrint(op, "printOpen", {}); 1477d9b500d3SAart Bik } 1478d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 147954759cefSAart Bik return getPrint(op, "printClose", {}); 1480d9b500d3SAart Bik } 1481d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 148254759cefSAart Bik return getPrint(op, "printComma", {}); 1483d9b500d3SAart Bik } 1484d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 148554759cefSAart Bik return getPrint(op, "printNewline", {}); 1486d9b500d3SAart Bik } 1487d9b500d3SAart Bik }; 1488d9b500d3SAart Bik 1489334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1490c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1491c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1492c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1493334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 149465678d93SNicolas Vasilache public: 1495b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1496b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1497b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1498b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1499b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1500b99bd771SRiver Riddle } 150165678d93SNicolas Vasilache 1502334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 150365678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 150465678d93SNicolas Vasilache auto dstType = op.getResult().getType().cast<VectorType>(); 150565678d93SNicolas Vasilache 150665678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 150765678d93SNicolas Vasilache 150865678d93SNicolas Vasilache int64_t offset = 150965678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 151065678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 151165678d93SNicolas Vasilache int64_t stride = 151265678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 151365678d93SNicolas Vasilache 151465678d93SNicolas Vasilache auto loc = op.getLoc(); 151565678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 151635b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1517c3c95b9cSaartbik 1518c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1519c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1520c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1521c3c95b9cSaartbik offsets.reserve(size); 1522c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1523c3c95b9cSaartbik off += stride) 1524c3c95b9cSaartbik offsets.push_back(off); 1525c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1526c3c95b9cSaartbik op.vector(), 1527c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1528c3c95b9cSaartbik return success(); 1529c3c95b9cSaartbik } 1530c3c95b9cSaartbik 1531c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 153265678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 153365678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 153465678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 153565678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 153665678d93SNicolas Vasilache off += stride, ++idx) { 1537c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1538c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1539c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 154065678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 154165678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 154265678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 154365678d93SNicolas Vasilache } 1544c3c95b9cSaartbik rewriter.replaceOp(op, res); 15453145427dSRiver Riddle return success(); 154665678d93SNicolas Vasilache } 154765678d93SNicolas Vasilache }; 154865678d93SNicolas Vasilache 1549df186507SBenjamin Kramer } // namespace 1550df186507SBenjamin Kramer 15515c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 15525c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1553ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1554060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 155565678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 15568345b86dSNicolas Vasilache // clang-format off 1557681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1558681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 15592d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1560c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1561ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1562ceb1b327Saartbik ctx, converter, reassociateFPReductions); 1563060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1564060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1565060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1566060c9dd1Saartbik ctx, converter, enableIndexOptimizations); 15678345b86dSNicolas Vasilache patterns 1568ceb1b327Saartbik .insert<VectorShuffleOpConversion, 15698345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15708345b86dSNicolas Vasilache VectorExtractOpConversion, 15718345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15728345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15738345b86dSNicolas Vasilache VectorInsertOpConversion, 15748345b86dSNicolas Vasilache VectorPrintOpConversion, 157519dbb230Saartbik VectorTypeCastOpConversion, 157639379916Saartbik VectorMaskedLoadOpConversion, 157739379916Saartbik VectorMaskedStoreOpConversion, 157819dbb230Saartbik VectorGatherOpConversion, 1579e8dcf5f8Saartbik VectorScatterOpConversion, 1580e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1581e8dcf5f8Saartbik VectorCompressStoreOpConversion>(ctx, converter); 15828345b86dSNicolas Vasilache // clang-format on 15835c0c51a9SNicolas Vasilache } 15845c0c51a9SNicolas Vasilache 158563b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 158663b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 158763b683a8SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 158863b683a8SNicolas Vasilache patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1589c295a65dSaartbik patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 159063b683a8SNicolas Vasilache } 1591