15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 195c0c51a9SNicolas Vasilache 205c0c51a9SNicolas Vasilache using namespace mlir; 2165678d93SNicolas Vasilache using namespace mlir::vector; 225c0c51a9SNicolas Vasilache 239826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 259826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 269826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 279826fe5cSAart Bik } 289826fe5cSAart Bik 299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 319826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 329826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 339826fe5cSAart Bik } 349826fe5cSAart Bik 351c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 370f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 380f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 390f04384dSAlex Zinenko int64_t pos) { 401c81adf3SAart Bik if (rank == 1) { 411c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 421c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 430f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 441c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 451c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 461c81adf3SAart Bik constant); 471c81adf3SAart Bik } 481c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 491c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 501c81adf3SAart Bik } 511c81adf3SAart Bik 522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 542d515e49SNicolas Vasilache Value into, int64_t offset) { 552d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 562d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 572d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 582d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 592d515e49SNicolas Vasilache loc, vectorType, from, into, 602d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 612d515e49SNicolas Vasilache } 622d515e49SNicolas Vasilache 631c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 650f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 660f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 671c81adf3SAart Bik if (rank == 1) { 681c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 691c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 700f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 711c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 721c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 731c81adf3SAart Bik constant); 741c81adf3SAart Bik } 751c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 761c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 771c81adf3SAart Bik } 781c81adf3SAart Bik 792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 812d515e49SNicolas Vasilache int64_t offset) { 822d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 832d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 842d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 852d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 862d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 872d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 882d515e49SNicolas Vasilache } 892d515e49SNicolas Vasilache 902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 932d515e49SNicolas Vasilache unsigned dropFront = 0, 942d515e49SNicolas Vasilache unsigned dropBack = 0) { 952d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 962d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 972d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 982d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 992d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1002d515e49SNicolas Vasilache it != eit; ++it) 1012d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1022d515e49SNicolas Vasilache return res; 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 106060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107060c9dd1Saartbik // 108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 110060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 111060c9dd1Saartbik // is very conservative on the boundary conditions. 112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 114060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 115060c9dd1Saartbik auto loc = op->getLoc(); 116060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 117060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 118060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 119060c9dd1Saartbik Value indices; 120060c9dd1Saartbik Type idxType; 121060c9dd1Saartbik if (enableIndexOptimizations) { 1220c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1230c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1240c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125060c9dd1Saartbik idxType = rewriter.getI32Type(); 126060c9dd1Saartbik } else { 1270c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1280c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1290c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130060c9dd1Saartbik idxType = rewriter.getI64Type(); 131060c9dd1Saartbik } 132060c9dd1Saartbik // Add in an offset if requested. 133060c9dd1Saartbik if (off) { 134060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 137060c9dd1Saartbik } 138060c9dd1Saartbik // Construct the vector comparison. 139060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142060c9dd1Saartbik } 143060c9dd1Saartbik 14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref. 14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 14626c8f908SThomas Raoux MemRefType memrefType, unsigned &align) { 14726c8f908SThomas Raoux Type elementTy = typeConverter.convertType(memrefType.getElementType()); 1485f9e0466SNicolas Vasilache if (!elementTy) 1495f9e0466SNicolas Vasilache return failure(); 1505f9e0466SNicolas Vasilache 151b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 152b2ab375dSAlex Zinenko // stop depending on translation. 15387a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 15487a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 155c69c9e0fSAlex Zinenko .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 1565f9e0466SNicolas Vasilache return success(); 1575f9e0466SNicolas Vasilache } 1585f9e0466SNicolas Vasilache 159e8dcf5f8Saartbik // Helper that returns the base address of a memref. 160b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 161e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 16219dbb230Saartbik // Inspect stride and offset structure. 16319dbb230Saartbik // 16419dbb230Saartbik // TODO: flat memory only for now, generalize 16519dbb230Saartbik // 16619dbb230Saartbik int64_t offset; 16719dbb230Saartbik SmallVector<int64_t, 4> strides; 16819dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 16919dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 17019dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 17119dbb230Saartbik return failure(); 172e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 173e8dcf5f8Saartbik return success(); 174e8dcf5f8Saartbik } 17519dbb230Saartbik 176e8dcf5f8Saartbik // Helper that returns a pointer given a memref base. 177b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 178b98e25b6SBenjamin Kramer Location loc, Value memref, 179b98e25b6SBenjamin Kramer MemRefType memRefType, Value &ptr) { 180e8dcf5f8Saartbik Value base; 181e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 182e8dcf5f8Saartbik return failure(); 1833a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 184e8dcf5f8Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 185e8dcf5f8Saartbik return success(); 186e8dcf5f8Saartbik } 187e8dcf5f8Saartbik 18839379916Saartbik // Helper that returns a bit-casted pointer given a memref base. 189b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 190b98e25b6SBenjamin Kramer Location loc, Value memref, 191b98e25b6SBenjamin Kramer MemRefType memRefType, Type type, Value &ptr) { 19239379916Saartbik Value base; 19339379916Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 19439379916Saartbik return failure(); 195c69c9e0fSAlex Zinenko auto pType = LLVM::LLVMPointerType::get(type); 19639379916Saartbik base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 19739379916Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 19839379916Saartbik return success(); 19939379916Saartbik } 20039379916Saartbik 201e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index 202e8dcf5f8Saartbik // vector. 203b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 204b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 205b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 206b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 207e8dcf5f8Saartbik Value base; 208e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 209e8dcf5f8Saartbik return failure(); 2103a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 2117ed9cfc7SAlex Zinenko auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0)); 2121485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 21319dbb230Saartbik return success(); 21419dbb230Saartbik } 21519dbb230Saartbik 2165f9e0466SNicolas Vasilache static LogicalResult 2175f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2185f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2195f9e0466SNicolas Vasilache TransferReadOp xferOp, 2205f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 221affbc0cdSNicolas Vasilache unsigned align; 22226c8f908SThomas Raoux if (failed(getMemRefAlignment( 22326c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 224affbc0cdSNicolas Vasilache return failure(); 225affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2265f9e0466SNicolas Vasilache return success(); 2275f9e0466SNicolas Vasilache } 2285f9e0466SNicolas Vasilache 2295f9e0466SNicolas Vasilache static LogicalResult 2305f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2315f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2325f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2335f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2345f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2355f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2365f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2375f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2385f9e0466SNicolas Vasilache 2395f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2405f9e0466SNicolas Vasilache if (!vecTy) 2415f9e0466SNicolas Vasilache return failure(); 2425f9e0466SNicolas Vasilache 2435f9e0466SNicolas Vasilache unsigned align; 24426c8f908SThomas Raoux if (failed(getMemRefAlignment( 24526c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2465f9e0466SNicolas Vasilache return failure(); 2475f9e0466SNicolas Vasilache 2485f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2495f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2505f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2515f9e0466SNicolas Vasilache return success(); 2525f9e0466SNicolas Vasilache } 2535f9e0466SNicolas Vasilache 2545f9e0466SNicolas Vasilache static LogicalResult 2555f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2565f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2575f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2585f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 259affbc0cdSNicolas Vasilache unsigned align; 26026c8f908SThomas Raoux if (failed(getMemRefAlignment( 26126c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 262affbc0cdSNicolas Vasilache return failure(); 2632d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 264affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 265affbc0cdSNicolas Vasilache align); 2665f9e0466SNicolas Vasilache return success(); 2675f9e0466SNicolas Vasilache } 2685f9e0466SNicolas Vasilache 2695f9e0466SNicolas Vasilache static LogicalResult 2705f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2715f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2725f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2735f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2745f9e0466SNicolas Vasilache unsigned align; 27526c8f908SThomas Raoux if (failed(getMemRefAlignment( 27626c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2775f9e0466SNicolas Vasilache return failure(); 2785f9e0466SNicolas Vasilache 2792d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2805f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2815f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2825f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2835f9e0466SNicolas Vasilache return success(); 2845f9e0466SNicolas Vasilache } 2855f9e0466SNicolas Vasilache 2862d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2872d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2882d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2895f9e0466SNicolas Vasilache } 2905f9e0466SNicolas Vasilache 2912d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2922d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2932d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2945f9e0466SNicolas Vasilache } 2955f9e0466SNicolas Vasilache 29690c01357SBenjamin Kramer namespace { 297e83b7b99Saartbik 29863b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 29963b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 300563879b6SRahul Joshi class VectorMatmulOpConversion 301563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MatmulOp> { 30263b683a8SNicolas Vasilache public: 303563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 30463b683a8SNicolas Vasilache 3053145427dSRiver Riddle LogicalResult 306563879b6SRahul Joshi matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 30763b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 3082d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 30963b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 310563879b6SRahul Joshi matmulOp, typeConverter->convertType(matmulOp.res().getType()), 311563879b6SRahul Joshi adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 312563879b6SRahul Joshi matmulOp.lhs_columns(), matmulOp.rhs_columns()); 3133145427dSRiver Riddle return success(); 31463b683a8SNicolas Vasilache } 31563b683a8SNicolas Vasilache }; 31663b683a8SNicolas Vasilache 317c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 318c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 319563879b6SRahul Joshi class VectorFlatTransposeOpConversion 320563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 321c295a65dSaartbik public: 322563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 323c295a65dSaartbik 324c295a65dSaartbik LogicalResult 325563879b6SRahul Joshi matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 326c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 3272d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 328c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 329dcec2ca5SChristian Sigg transOp, typeConverter->convertType(transOp.res().getType()), 330c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 331c295a65dSaartbik return success(); 332c295a65dSaartbik } 333c295a65dSaartbik }; 334c295a65dSaartbik 33539379916Saartbik /// Conversion pattern for a vector.maskedload. 336563879b6SRahul Joshi class VectorMaskedLoadOpConversion 337563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 33839379916Saartbik public: 339563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 34039379916Saartbik 34139379916Saartbik LogicalResult 342563879b6SRahul Joshi matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 34339379916Saartbik ConversionPatternRewriter &rewriter) const override { 344563879b6SRahul Joshi auto loc = load->getLoc(); 34539379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 34639379916Saartbik 34739379916Saartbik // Resolve alignment. 34839379916Saartbik unsigned align; 34926c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(), 35026c8f908SThomas Raoux align))) 35139379916Saartbik return failure(); 35239379916Saartbik 353dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(load.getResultVectorType()); 35439379916Saartbik Value ptr; 35539379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 35639379916Saartbik vtype, ptr))) 35739379916Saartbik return failure(); 35839379916Saartbik 35939379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 36039379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 36139379916Saartbik rewriter.getI32IntegerAttr(align)); 36239379916Saartbik return success(); 36339379916Saartbik } 36439379916Saartbik }; 36539379916Saartbik 36639379916Saartbik /// Conversion pattern for a vector.maskedstore. 367563879b6SRahul Joshi class VectorMaskedStoreOpConversion 368563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 36939379916Saartbik public: 370563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 37139379916Saartbik 37239379916Saartbik LogicalResult 373563879b6SRahul Joshi matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 37439379916Saartbik ConversionPatternRewriter &rewriter) const override { 375563879b6SRahul Joshi auto loc = store->getLoc(); 37639379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 37739379916Saartbik 37839379916Saartbik // Resolve alignment. 37939379916Saartbik unsigned align; 38026c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(), 38126c8f908SThomas Raoux align))) 38239379916Saartbik return failure(); 38339379916Saartbik 384dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(store.getValueVectorType()); 38539379916Saartbik Value ptr; 38639379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 38739379916Saartbik vtype, ptr))) 38839379916Saartbik return failure(); 38939379916Saartbik 39039379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 39139379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 39239379916Saartbik rewriter.getI32IntegerAttr(align)); 39339379916Saartbik return success(); 39439379916Saartbik } 39539379916Saartbik }; 39639379916Saartbik 39719dbb230Saartbik /// Conversion pattern for a vector.gather. 398563879b6SRahul Joshi class VectorGatherOpConversion 399563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::GatherOp> { 40019dbb230Saartbik public: 401563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 40219dbb230Saartbik 40319dbb230Saartbik LogicalResult 404563879b6SRahul Joshi matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 40519dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 406563879b6SRahul Joshi auto loc = gather->getLoc(); 40719dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 40819dbb230Saartbik 40919dbb230Saartbik // Resolve alignment. 41019dbb230Saartbik unsigned align; 41126c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 41226c8f908SThomas Raoux align))) 41319dbb230Saartbik return failure(); 41419dbb230Saartbik 41519dbb230Saartbik // Get index ptrs. 41619dbb230Saartbik VectorType vType = gather.getResultVectorType(); 41719dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 41819dbb230Saartbik Value ptrs; 419e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 420e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 42119dbb230Saartbik return failure(); 42219dbb230Saartbik 42319dbb230Saartbik // Replace with the gather intrinsic. 42419dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 425dcec2ca5SChristian Sigg gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 4260c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 42719dbb230Saartbik return success(); 42819dbb230Saartbik } 42919dbb230Saartbik }; 43019dbb230Saartbik 43119dbb230Saartbik /// Conversion pattern for a vector.scatter. 432563879b6SRahul Joshi class VectorScatterOpConversion 433563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ScatterOp> { 43419dbb230Saartbik public: 435563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 43619dbb230Saartbik 43719dbb230Saartbik LogicalResult 438563879b6SRahul Joshi matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 43919dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 440563879b6SRahul Joshi auto loc = scatter->getLoc(); 44119dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 44219dbb230Saartbik 44319dbb230Saartbik // Resolve alignment. 44419dbb230Saartbik unsigned align; 44526c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 44626c8f908SThomas Raoux align))) 44719dbb230Saartbik return failure(); 44819dbb230Saartbik 44919dbb230Saartbik // Get index ptrs. 45019dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 45119dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 45219dbb230Saartbik Value ptrs; 453e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 454e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 45519dbb230Saartbik return failure(); 45619dbb230Saartbik 45719dbb230Saartbik // Replace with the scatter intrinsic. 45819dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 45919dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 46019dbb230Saartbik rewriter.getI32IntegerAttr(align)); 46119dbb230Saartbik return success(); 46219dbb230Saartbik } 46319dbb230Saartbik }; 46419dbb230Saartbik 465e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 466563879b6SRahul Joshi class VectorExpandLoadOpConversion 467563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 468e8dcf5f8Saartbik public: 469563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 470e8dcf5f8Saartbik 471e8dcf5f8Saartbik LogicalResult 472563879b6SRahul Joshi matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 473e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 474563879b6SRahul Joshi auto loc = expand->getLoc(); 475e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 476e8dcf5f8Saartbik 477e8dcf5f8Saartbik Value ptr; 478e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 479e8dcf5f8Saartbik ptr))) 480e8dcf5f8Saartbik return failure(); 481e8dcf5f8Saartbik 482e8dcf5f8Saartbik auto vType = expand.getResultVectorType(); 483e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 484563879b6SRahul Joshi expand, typeConverter->convertType(vType), ptr, adaptor.mask(), 485e8dcf5f8Saartbik adaptor.pass_thru()); 486e8dcf5f8Saartbik return success(); 487e8dcf5f8Saartbik } 488e8dcf5f8Saartbik }; 489e8dcf5f8Saartbik 490e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 491563879b6SRahul Joshi class VectorCompressStoreOpConversion 492563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 493e8dcf5f8Saartbik public: 494563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 495e8dcf5f8Saartbik 496e8dcf5f8Saartbik LogicalResult 497563879b6SRahul Joshi matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 498e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 499563879b6SRahul Joshi auto loc = compress->getLoc(); 500e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 501e8dcf5f8Saartbik 502e8dcf5f8Saartbik Value ptr; 503e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), 504e8dcf5f8Saartbik compress.getMemRefType(), ptr))) 505e8dcf5f8Saartbik return failure(); 506e8dcf5f8Saartbik 507e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 508563879b6SRahul Joshi compress, adaptor.value(), ptr, adaptor.mask()); 509e8dcf5f8Saartbik return success(); 510e8dcf5f8Saartbik } 511e8dcf5f8Saartbik }; 512e8dcf5f8Saartbik 51319dbb230Saartbik /// Conversion pattern for all vector reductions. 514563879b6SRahul Joshi class VectorReductionOpConversion 515563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ReductionOp> { 516e83b7b99Saartbik public: 517563879b6SRahul Joshi explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 518060c9dd1Saartbik bool reassociateFPRed) 519563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 520060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 521e83b7b99Saartbik 5223145427dSRiver Riddle LogicalResult 523563879b6SRahul Joshi matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 524e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 525e83b7b99Saartbik auto kind = reductionOp.kind(); 526e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 527dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(eltType); 528e9628955SAart Bik if (eltType.isIntOrIndex()) { 529e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 530e83b7b99Saartbik if (kind == "add") 531322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 532563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 533e83b7b99Saartbik else if (kind == "mul") 534322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 535563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 536e9628955SAart Bik else if (kind == "min" && 537e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 538322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 539563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 540e83b7b99Saartbik else if (kind == "min") 541322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 542563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 543e9628955SAart Bik else if (kind == "max" && 544e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 545322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 546563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 547e83b7b99Saartbik else if (kind == "max") 548322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 549563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 550e83b7b99Saartbik else if (kind == "and") 551322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 552563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 553e83b7b99Saartbik else if (kind == "or") 554322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 555563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 556e83b7b99Saartbik else if (kind == "xor") 557322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 558563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 559e83b7b99Saartbik else 5603145427dSRiver Riddle return failure(); 5613145427dSRiver Riddle return success(); 562dcec2ca5SChristian Sigg } 563e83b7b99Saartbik 564dcec2ca5SChristian Sigg if (!eltType.isa<FloatType>()) 565dcec2ca5SChristian Sigg return failure(); 566dcec2ca5SChristian Sigg 567e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 568e83b7b99Saartbik if (kind == "add") { 5690d924700Saartbik // Optional accumulator (or zero). 5700d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5710d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 572563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5730d924700Saartbik rewriter.getZeroAttr(eltType)); 574322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 575563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 576ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 577e83b7b99Saartbik } else if (kind == "mul") { 5780d924700Saartbik // Optional accumulator (or one). 5790d924700Saartbik Value acc = operands.size() > 1 5800d924700Saartbik ? operands[1] 5810d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 582563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5830d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 584322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 585563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 586ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 587e83b7b99Saartbik } else if (kind == "min") 588563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 589563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 590e83b7b99Saartbik else if (kind == "max") 591563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 592563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 593e83b7b99Saartbik else 5943145427dSRiver Riddle return failure(); 5953145427dSRiver Riddle return success(); 596e83b7b99Saartbik } 597ceb1b327Saartbik 598ceb1b327Saartbik private: 599ceb1b327Saartbik const bool reassociateFPReductions; 600e83b7b99Saartbik }; 601e83b7b99Saartbik 602060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 603563879b6SRahul Joshi class VectorCreateMaskOpConversion 604563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 605060c9dd1Saartbik public: 606563879b6SRahul Joshi explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 607060c9dd1Saartbik bool enableIndexOpt) 608563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 609060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 610060c9dd1Saartbik 611060c9dd1Saartbik LogicalResult 612563879b6SRahul Joshi matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 613060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 6149eb3e564SChris Lattner auto dstType = op.getType(); 615060c9dd1Saartbik int64_t rank = dstType.getRank(); 616060c9dd1Saartbik if (rank == 1) { 617060c9dd1Saartbik rewriter.replaceOp( 618060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 619060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 620060c9dd1Saartbik return success(); 621060c9dd1Saartbik } 622060c9dd1Saartbik return failure(); 623060c9dd1Saartbik } 624060c9dd1Saartbik 625060c9dd1Saartbik private: 626060c9dd1Saartbik const bool enableIndexOptimizations; 627060c9dd1Saartbik }; 628060c9dd1Saartbik 629563879b6SRahul Joshi class VectorShuffleOpConversion 630563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 6311c81adf3SAart Bik public: 632563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 6331c81adf3SAart Bik 6343145427dSRiver Riddle LogicalResult 635563879b6SRahul Joshi matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 6361c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 637563879b6SRahul Joshi auto loc = shuffleOp->getLoc(); 6382d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6391c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6401c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6411c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 642dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(vectorType); 6431c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6441c81adf3SAart Bik 6451c81adf3SAart Bik // Bail if result type cannot be lowered. 6461c81adf3SAart Bik if (!llvmType) 6473145427dSRiver Riddle return failure(); 6481c81adf3SAart Bik 6491c81adf3SAart Bik // Get rank and dimension sizes. 6501c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6511c81adf3SAart Bik assert(v1Type.getRank() == rank); 6521c81adf3SAart Bik assert(v2Type.getRank() == rank); 6531c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6541c81adf3SAart Bik 6551c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6561c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6571c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 658563879b6SRahul Joshi Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 6591c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 660563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, llvmShuffleOp); 6613145427dSRiver Riddle return success(); 662b36aaeafSAart Bik } 663b36aaeafSAart Bik 6641c81adf3SAart Bik // For all other cases, insert the individual values individually. 665e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6661c81adf3SAart Bik int64_t insPos = 0; 6671c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6681c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 669e62a6956SRiver Riddle Value value = adaptor.v1(); 6701c81adf3SAart Bik if (extPos >= v1Dim) { 6711c81adf3SAart Bik extPos -= v1Dim; 6721c81adf3SAart Bik value = adaptor.v2(); 673b36aaeafSAart Bik } 674dcec2ca5SChristian Sigg Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 675dcec2ca5SChristian Sigg llvmType, rank, extPos); 676dcec2ca5SChristian Sigg insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 6770f04384dSAlex Zinenko llvmType, rank, insPos++); 6781c81adf3SAart Bik } 679563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, insert); 6803145427dSRiver Riddle return success(); 681b36aaeafSAart Bik } 682b36aaeafSAart Bik }; 683b36aaeafSAart Bik 684563879b6SRahul Joshi class VectorExtractElementOpConversion 685563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 686cd5dab8aSAart Bik public: 687563879b6SRahul Joshi using ConvertOpToLLVMPattern< 688563879b6SRahul Joshi vector::ExtractElementOp>::ConvertOpToLLVMPattern; 689cd5dab8aSAart Bik 6903145427dSRiver Riddle LogicalResult 691563879b6SRahul Joshi matchAndRewrite(vector::ExtractElementOp extractEltOp, 692563879b6SRahul Joshi ArrayRef<Value> operands, 693cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 6942d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 695cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 696dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType.getElementType()); 697cd5dab8aSAart Bik 698cd5dab8aSAart Bik // Bail if result type cannot be lowered. 699cd5dab8aSAart Bik if (!llvmType) 7003145427dSRiver Riddle return failure(); 701cd5dab8aSAart Bik 702cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 703563879b6SRahul Joshi extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 7043145427dSRiver Riddle return success(); 705cd5dab8aSAart Bik } 706cd5dab8aSAart Bik }; 707cd5dab8aSAart Bik 708563879b6SRahul Joshi class VectorExtractOpConversion 709563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractOp> { 7105c0c51a9SNicolas Vasilache public: 711563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 7125c0c51a9SNicolas Vasilache 7133145427dSRiver Riddle LogicalResult 714563879b6SRahul Joshi matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 7155c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 716563879b6SRahul Joshi auto loc = extractOp->getLoc(); 7172d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 7189826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7192bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 720dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(resultType); 7215c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7229826fe5cSAart Bik 7239826fe5cSAart Bik // Bail if result type cannot be lowered. 7249826fe5cSAart Bik if (!llvmResultType) 7253145427dSRiver Riddle return failure(); 7269826fe5cSAart Bik 7275c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7285c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 729e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7305c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 731563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7323145427dSRiver Riddle return success(); 7335c0c51a9SNicolas Vasilache } 7345c0c51a9SNicolas Vasilache 7359826fe5cSAart Bik // Potential extraction of 1-D vector from array. 736563879b6SRahul Joshi auto *context = extractOp->getContext(); 737e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7385c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7395c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7409826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7415c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7425c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7435c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 744dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 7455c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7465c0c51a9SNicolas Vasilache } 7475c0c51a9SNicolas Vasilache 7485c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7495c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 750*2230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 7511d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7525c0c51a9SNicolas Vasilache extracted = 7535c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 754563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7555c0c51a9SNicolas Vasilache 7563145427dSRiver Riddle return success(); 7575c0c51a9SNicolas Vasilache } 7585c0c51a9SNicolas Vasilache }; 7595c0c51a9SNicolas Vasilache 760681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 761681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 762681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 763681f929fSNicolas Vasilache /// 764681f929fSNicolas Vasilache /// Example: 765681f929fSNicolas Vasilache /// ``` 766681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 767681f929fSNicolas Vasilache /// ``` 768681f929fSNicolas Vasilache /// is converted to: 769681f929fSNicolas Vasilache /// ``` 7703bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 771681f929fSNicolas Vasilache /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 772681f929fSNicolas Vasilache /// -> !llvm<"<8 x float>"> 773681f929fSNicolas Vasilache /// ``` 774563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 775681f929fSNicolas Vasilache public: 776563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 777681f929fSNicolas Vasilache 7783145427dSRiver Riddle LogicalResult 779563879b6SRahul Joshi matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 780681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7812d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 782681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 783681f929fSNicolas Vasilache if (vType.getRank() != 1) 7843145427dSRiver Riddle return failure(); 785563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 7863bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 7873145427dSRiver Riddle return success(); 788681f929fSNicolas Vasilache } 789681f929fSNicolas Vasilache }; 790681f929fSNicolas Vasilache 791563879b6SRahul Joshi class VectorInsertElementOpConversion 792563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 793cd5dab8aSAart Bik public: 794563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 795cd5dab8aSAart Bik 7963145427dSRiver Riddle LogicalResult 797563879b6SRahul Joshi matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 798cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7992d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 800cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 801dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType); 802cd5dab8aSAart Bik 803cd5dab8aSAart Bik // Bail if result type cannot be lowered. 804cd5dab8aSAart Bik if (!llvmType) 8053145427dSRiver Riddle return failure(); 806cd5dab8aSAart Bik 807cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 808563879b6SRahul Joshi insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 809563879b6SRahul Joshi adaptor.position()); 8103145427dSRiver Riddle return success(); 811cd5dab8aSAart Bik } 812cd5dab8aSAart Bik }; 813cd5dab8aSAart Bik 814563879b6SRahul Joshi class VectorInsertOpConversion 815563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertOp> { 8169826fe5cSAart Bik public: 817563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 8189826fe5cSAart Bik 8193145427dSRiver Riddle LogicalResult 820563879b6SRahul Joshi matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 8219826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 822563879b6SRahul Joshi auto loc = insertOp->getLoc(); 8232d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8249826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8259826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 826dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(destVectorType); 8279826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8289826fe5cSAart Bik 8299826fe5cSAart Bik // Bail if result type cannot be lowered. 8309826fe5cSAart Bik if (!llvmResultType) 8313145427dSRiver Riddle return failure(); 8329826fe5cSAart Bik 8339826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8349826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 835e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8369826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8379826fe5cSAart Bik positionArrayAttr); 838563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8393145427dSRiver Riddle return success(); 8409826fe5cSAart Bik } 8419826fe5cSAart Bik 8429826fe5cSAart Bik // Potential extraction of 1-D vector from array. 843563879b6SRahul Joshi auto *context = insertOp->getContext(); 844e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8459826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8469826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8479826fe5cSAart Bik auto oneDVectorType = destVectorType; 8489826fe5cSAart Bik if (positionAttrs.size() > 1) { 8499826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8509826fe5cSAart Bik auto nMinusOnePositionAttrs = 8519826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8529826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 853dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8549826fe5cSAart Bik nMinusOnePositionAttrs); 8559826fe5cSAart Bik } 8569826fe5cSAart Bik 8579826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 858*2230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 8591d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 860e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 861dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8620f04384dSAlex Zinenko adaptor.source(), constant); 8639826fe5cSAart Bik 8649826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 8659826fe5cSAart Bik if (positionAttrs.size() > 1) { 8669826fe5cSAart Bik auto nMinusOnePositionAttrs = 8679826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8689826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 8699826fe5cSAart Bik adaptor.dest(), inserted, 8709826fe5cSAart Bik nMinusOnePositionAttrs); 8719826fe5cSAart Bik } 8729826fe5cSAart Bik 873563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8743145427dSRiver Riddle return success(); 8759826fe5cSAart Bik } 8769826fe5cSAart Bik }; 8779826fe5cSAart Bik 878681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 879681f929fSNicolas Vasilache /// 880681f929fSNicolas Vasilache /// Example: 881681f929fSNicolas Vasilache /// ``` 882681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 883681f929fSNicolas Vasilache /// ``` 884681f929fSNicolas Vasilache /// is rewritten into: 885681f929fSNicolas Vasilache /// ``` 886681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 887681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 888681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 889681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 890681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 891681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 892681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 893681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 894681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 895681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 896681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 897681f929fSNicolas Vasilache /// // %r3 holds the final value. 898681f929fSNicolas Vasilache /// ``` 899681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 900681f929fSNicolas Vasilache public: 901681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 902681f929fSNicolas Vasilache 9033145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 904681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 905681f929fSNicolas Vasilache auto vType = op.getVectorType(); 906681f929fSNicolas Vasilache if (vType.getRank() < 2) 9073145427dSRiver Riddle return failure(); 908681f929fSNicolas Vasilache 909681f929fSNicolas Vasilache auto loc = op.getLoc(); 910681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 911681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 912681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 913681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 914681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 915681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 916681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 917681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 918681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 919681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 920681f929fSNicolas Vasilache } 921681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9223145427dSRiver Riddle return success(); 923681f929fSNicolas Vasilache } 924681f929fSNicolas Vasilache }; 925681f929fSNicolas Vasilache 9262d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9272d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9282d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9292d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9302d515e49SNicolas Vasilache // rank. 9312d515e49SNicolas Vasilache // 9322d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9332d515e49SNicolas Vasilache // have different ranks. In this case: 9342d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9352d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9362d515e49SNicolas Vasilache // destination subvector 9372d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9382d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9392d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9402d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9412d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9422d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9432d515e49SNicolas Vasilache public: 9442d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9452d515e49SNicolas Vasilache 9463145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9472d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9482d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9492d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9502d515e49SNicolas Vasilache 9512d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9523145427dSRiver Riddle return failure(); 9532d515e49SNicolas Vasilache 9542d515e49SNicolas Vasilache auto loc = op.getLoc(); 9552d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9562d515e49SNicolas Vasilache assert(rankDiff >= 0); 9572d515e49SNicolas Vasilache if (rankDiff == 0) 9583145427dSRiver Riddle return failure(); 9592d515e49SNicolas Vasilache 9602d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9612d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 9622d515e49SNicolas Vasilache // on it. 9632d515e49SNicolas Vasilache Value extracted = 9642d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 9652d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 966dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9672d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 9682d515e49SNicolas Vasilache // ranks. 9692d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 9702d515e49SNicolas Vasilache loc, op.source(), extracted, 9712d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 972c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 9732d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 9742d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 9752d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 976dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9773145427dSRiver Riddle return success(); 9782d515e49SNicolas Vasilache } 9792d515e49SNicolas Vasilache }; 9802d515e49SNicolas Vasilache 9812d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9822d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 9832d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9842d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9852d515e49SNicolas Vasilache // destination subvector 9862d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9872d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9882d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9892d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9902d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 9912d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9922d515e49SNicolas Vasilache public: 993b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 994b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 995b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 996b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 997b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 998b99bd771SRiver Riddle } 9992d515e49SNicolas Vasilache 10003145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 10012d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10022d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10032d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10042d515e49SNicolas Vasilache 10052d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10063145427dSRiver Riddle return failure(); 10072d515e49SNicolas Vasilache 10082d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10092d515e49SNicolas Vasilache assert(rankDiff >= 0); 10102d515e49SNicolas Vasilache if (rankDiff != 0) 10113145427dSRiver Riddle return failure(); 10122d515e49SNicolas Vasilache 10132d515e49SNicolas Vasilache if (srcType == dstType) { 10142d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10153145427dSRiver Riddle return success(); 10162d515e49SNicolas Vasilache } 10172d515e49SNicolas Vasilache 10182d515e49SNicolas Vasilache int64_t offset = 10192d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10202d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10212d515e49SNicolas Vasilache int64_t stride = 10222d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10232d515e49SNicolas Vasilache 10242d515e49SNicolas Vasilache auto loc = op.getLoc(); 10252d515e49SNicolas Vasilache Value res = op.dest(); 10262d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10272d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10282d515e49SNicolas Vasilache off += stride, ++idx) { 10292d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10302d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10312d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10322d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10332d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10342d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10352d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10362d515e49SNicolas Vasilache // smaller rank. 1037bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10382d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10392d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10402d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10412d515e49SNicolas Vasilache } 10422d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10432d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10442d515e49SNicolas Vasilache } 10452d515e49SNicolas Vasilache 10462d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10473145427dSRiver Riddle return success(); 10482d515e49SNicolas Vasilache } 10492d515e49SNicolas Vasilache }; 10502d515e49SNicolas Vasilache 105130e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 105230e6033bSNicolas Vasilache /// static layout. 105330e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 105430e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10552bf491c7SBenjamin Kramer int64_t offset; 105630e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 105730e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 105830e6033bSNicolas Vasilache return None; 105930e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 106030e6033bSNicolas Vasilache return None; 106130e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 106230e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 106330e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 106430e6033bSNicolas Vasilache return strides; 106530e6033bSNicolas Vasilache 106630e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 106730e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 106830e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 106930e6033bSNicolas Vasilache // layout. 10702bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 10712bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 107230e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 107330e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 107430e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 107530e6033bSNicolas Vasilache return None; 107630e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 107730e6033bSNicolas Vasilache return None; 10782bf491c7SBenjamin Kramer } 107930e6033bSNicolas Vasilache return strides; 10802bf491c7SBenjamin Kramer } 10812bf491c7SBenjamin Kramer 1082563879b6SRahul Joshi class VectorTypeCastOpConversion 1083563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 10845c0c51a9SNicolas Vasilache public: 1085563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 10865c0c51a9SNicolas Vasilache 10873145427dSRiver Riddle LogicalResult 1088563879b6SRahul Joshi matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 10895c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 1090563879b6SRahul Joshi auto loc = castOp->getLoc(); 10915c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 10922bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 10939eb3e564SChris Lattner MemRefType targetMemRefType = castOp.getType(); 10945c0c51a9SNicolas Vasilache 10955c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 10965c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 10975c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 10983145427dSRiver Riddle return failure(); 10995c0c51a9SNicolas Vasilache 11005c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 11018de43b92SAlex Zinenko operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 11028de43b92SAlex Zinenko if (!llvmSourceDescriptorTy) 11033145427dSRiver Riddle return failure(); 11045c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11055c0c51a9SNicolas Vasilache 1106dcec2ca5SChristian Sigg auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 11078de43b92SAlex Zinenko .dyn_cast_or_null<LLVM::LLVMStructType>(); 11088de43b92SAlex Zinenko if (!llvmTargetDescriptorTy) 11093145427dSRiver Riddle return failure(); 11105c0c51a9SNicolas Vasilache 111130e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 111230e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 111330e6033bSNicolas Vasilache if (!sourceStrides) 111430e6033bSNicolas Vasilache return failure(); 111530e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 111630e6033bSNicolas Vasilache if (!targetStrides) 111730e6033bSNicolas Vasilache return failure(); 111830e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 111930e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 112030e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 112130e6033bSNicolas Vasilache })) 11223145427dSRiver Riddle return failure(); 11235c0c51a9SNicolas Vasilache 1124*2230bf99SAlex Zinenko auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 11255c0c51a9SNicolas Vasilache 11265c0c51a9SNicolas Vasilache // Create descriptor. 11275c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11283a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11295c0c51a9SNicolas Vasilache // Set allocated ptr. 1130e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11315c0c51a9SNicolas Vasilache allocated = 11325c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11335c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11345c0c51a9SNicolas Vasilache // Set aligned ptr. 1135e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11365c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11375c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11385c0c51a9SNicolas Vasilache // Fill offset 0. 11395c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11405c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11415c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11425c0c51a9SNicolas Vasilache 11435c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11445c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11455c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11465c0c51a9SNicolas Vasilache auto sizeAttr = 11475c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11485c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11495c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 115030e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 115130e6033bSNicolas Vasilache (*targetStrides)[index]); 11525c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11535c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11545c0c51a9SNicolas Vasilache } 11555c0c51a9SNicolas Vasilache 1156563879b6SRahul Joshi rewriter.replaceOp(castOp, {desc}); 11573145427dSRiver Riddle return success(); 11585c0c51a9SNicolas Vasilache } 11595c0c51a9SNicolas Vasilache }; 11605c0c51a9SNicolas Vasilache 11618345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 11628345b86dSNicolas Vasilache /// sequence of: 1163060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1164060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1165060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1166060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1167060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 11688345b86dSNicolas Vasilache template <typename ConcreteOp> 1169563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 11708345b86dSNicolas Vasilache public: 1171563879b6SRahul Joshi explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1172060c9dd1Saartbik bool enableIndexOpt) 1173563879b6SRahul Joshi : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1174060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 11758345b86dSNicolas Vasilache 11768345b86dSNicolas Vasilache LogicalResult 1177563879b6SRahul Joshi matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 11788345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11798345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1180b2c79c50SNicolas Vasilache 1181b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1182b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 11838345b86dSNicolas Vasilache return failure(); 11845f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 11855f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 11865f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 1187563879b6SRahul Joshi xferOp->getContext())) 11888345b86dSNicolas Vasilache return failure(); 118926c8f908SThomas Raoux auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 119026c8f908SThomas Raoux if (!memRefType) 119126c8f908SThomas Raoux return failure(); 11922bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 119326c8f908SThomas Raoux auto strides = computeContiguousStrides(memRefType); 119430e6033bSNicolas Vasilache if (!strides) 11952bf491c7SBenjamin Kramer return failure(); 11968345b86dSNicolas Vasilache 1197563879b6SRahul Joshi auto toLLVMTy = [&](Type t) { 1198563879b6SRahul Joshi return this->getTypeConverter()->convertType(t); 1199563879b6SRahul Joshi }; 12008345b86dSNicolas Vasilache 1201563879b6SRahul Joshi Location loc = xferOp->getLoc(); 12028345b86dSNicolas Vasilache 120368330ee0SThomas Raoux if (auto memrefVectorElementType = 120426c8f908SThomas Raoux memRefType.getElementType().template dyn_cast<VectorType>()) { 120568330ee0SThomas Raoux // Memref has vector element type. 120668330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 120768330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 120868330ee0SThomas Raoux return failure(); 12090de60b55SThomas Raoux #ifndef NDEBUG 121068330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 121168330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 121268330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 121368330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 121468330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 121568330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 121668330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 121768330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 121868330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 121968330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 122068330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 122168330ee0SThomas Raoux "result shape."); 12220de60b55SThomas Raoux #endif // ifndef NDEBUG 122368330ee0SThomas Raoux } 122468330ee0SThomas Raoux 12258345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1226be16075bSWen-Heng (Jack) Chung // The vector pointer would always be on address space 0, therefore 1227be16075bSWen-Heng (Jack) Chung // addrspacecast shall be used when source/dst memrefs are not on 1228be16075bSWen-Heng (Jack) Chung // address space 0. 12298345b86dSNicolas Vasilache // TODO: support alignment when possible. 1230563879b6SRahul Joshi Value dataPtr = this->getStridedElementPtr( 123126c8f908SThomas Raoux loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 12328de43b92SAlex Zinenko auto vecTy = toLLVMTy(xferOp.getVectorType()) 12338de43b92SAlex Zinenko .template cast<LLVM::LLVMFixedVectorType>(); 1234be16075bSWen-Heng (Jack) Chung Value vectorDataPtr; 1235be16075bSWen-Heng (Jack) Chung if (memRefType.getMemorySpace() == 0) 12368de43b92SAlex Zinenko vectorDataPtr = rewriter.create<LLVM::BitcastOp>( 12378de43b92SAlex Zinenko loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 1238be16075bSWen-Heng (Jack) Chung else 1239be16075bSWen-Heng (Jack) Chung vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 12408de43b92SAlex Zinenko loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); 12418345b86dSNicolas Vasilache 12421870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 1243563879b6SRahul Joshi return replaceTransferOpWithLoadOrStore(rewriter, 1244563879b6SRahul Joshi *this->getTypeConverter(), loc, 1245563879b6SRahul Joshi xferOp, operands, vectorDataPtr); 12461870e787SNicolas Vasilache 12478345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12488345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12498345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12508345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1251060c9dd1Saartbik // 1252060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1253060c9dd1Saartbik // dimensions here. 12548de43b92SAlex Zinenko unsigned vecWidth = vecTy.getNumElements(); 1255060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12560c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 125726c8f908SThomas Raoux Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1258563879b6SRahul Joshi Value mask = buildVectorComparison( 1259563879b6SRahul Joshi rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 12608345b86dSNicolas Vasilache 12618345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 1262563879b6SRahul Joshi return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1263dcec2ca5SChristian Sigg xferOp, operands, vectorDataPtr, mask); 12648345b86dSNicolas Vasilache } 1265060c9dd1Saartbik 1266060c9dd1Saartbik private: 1267060c9dd1Saartbik const bool enableIndexOptimizations; 12688345b86dSNicolas Vasilache }; 12698345b86dSNicolas Vasilache 1270563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1271d9b500d3SAart Bik public: 1272563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1273d9b500d3SAart Bik 1274d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1275d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1276d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1277d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1278d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1279d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1280d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1281d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1282d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1283d9b500d3SAart Bik // 12849db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1285d9b500d3SAart Bik // 12863145427dSRiver Riddle LogicalResult 1287563879b6SRahul Joshi matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1288d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 12892d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1290d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1291d9b500d3SAart Bik 1292dcec2ca5SChristian Sigg if (typeConverter->convertType(printType) == nullptr) 12933145427dSRiver Riddle return failure(); 1294d9b500d3SAart Bik 1295b8880f5fSAart Bik // Make sure element type has runtime support. 1296b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1297d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1298d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1299d9b500d3SAart Bik Operation *printer; 1300b8880f5fSAart Bik if (eltType.isF32()) { 1301563879b6SRahul Joshi printer = getPrintFloat(printOp); 1302b8880f5fSAart Bik } else if (eltType.isF64()) { 1303563879b6SRahul Joshi printer = getPrintDouble(printOp); 130454759cefSAart Bik } else if (eltType.isIndex()) { 1305563879b6SRahul Joshi printer = getPrintU64(printOp); 1306b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1307b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1308b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1309b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1310b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1311b8880f5fSAart Bik if (intTy.isUnsigned()) { 131254759cefSAart Bik if (width <= 64) { 1313b8880f5fSAart Bik if (width < 64) 1314b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1315563879b6SRahul Joshi printer = getPrintU64(printOp); 1316b8880f5fSAart Bik } else { 13173145427dSRiver Riddle return failure(); 1318b8880f5fSAart Bik } 1319b8880f5fSAart Bik } else { 1320b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 132154759cefSAart Bik if (width <= 64) { 1322b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1323b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1324b8880f5fSAart Bik if (width == 1) 132554759cefSAart Bik conversion = PrintConversion::ZeroExt64; 132654759cefSAart Bik else if (width < 64) 1327b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1328563879b6SRahul Joshi printer = getPrintI64(printOp); 1329b8880f5fSAart Bik } else { 1330b8880f5fSAart Bik return failure(); 1331b8880f5fSAart Bik } 1332b8880f5fSAart Bik } 1333b8880f5fSAart Bik } else { 1334b8880f5fSAart Bik return failure(); 1335b8880f5fSAart Bik } 1336d9b500d3SAart Bik 1337d9b500d3SAart Bik // Unroll vector into elementary print calls. 1338b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1339563879b6SRahul Joshi emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1340b8880f5fSAart Bik conversion); 1341563879b6SRahul Joshi emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1342563879b6SRahul Joshi rewriter.eraseOp(printOp); 13433145427dSRiver Riddle return success(); 1344d9b500d3SAart Bik } 1345d9b500d3SAart Bik 1346d9b500d3SAart Bik private: 1347b8880f5fSAart Bik enum class PrintConversion { 134830e6033bSNicolas Vasilache // clang-format off 1349b8880f5fSAart Bik None, 1350b8880f5fSAart Bik ZeroExt64, 1351b8880f5fSAart Bik SignExt64 135230e6033bSNicolas Vasilache // clang-format on 1353b8880f5fSAart Bik }; 1354b8880f5fSAart Bik 1355d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1356e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1357b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1358d9b500d3SAart Bik Location loc = op->getLoc(); 1359d9b500d3SAart Bik if (rank == 0) { 1360b8880f5fSAart Bik switch (conversion) { 1361b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1362b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 1363*2230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1364b8880f5fSAart Bik break; 1365b8880f5fSAart Bik case PrintConversion::SignExt64: 1366b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 1367*2230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1368b8880f5fSAart Bik break; 1369b8880f5fSAart Bik case PrintConversion::None: 1370b8880f5fSAart Bik break; 1371c9eeeb38Saartbik } 1372d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1373d9b500d3SAart Bik return; 1374d9b500d3SAart Bik } 1375d9b500d3SAart Bik 1376d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1377d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1378d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1379d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1380d9b500d3SAart Bik auto reducedType = 1381d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1382dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType( 1383d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1384dcec2ca5SChristian Sigg Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1385dcec2ca5SChristian Sigg llvmType, rank, d); 1386b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1387b8880f5fSAart Bik conversion); 1388d9b500d3SAart Bik if (d != dim - 1) 1389d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1390d9b500d3SAart Bik } 1391d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1392d9b500d3SAart Bik } 1393d9b500d3SAart Bik 1394d9b500d3SAart Bik // Helper to emit a call. 1395d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1396d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 139708e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1398d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1399d9b500d3SAart Bik } 1400d9b500d3SAart Bik 1401d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 14025446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 1403c69c9e0fSAlex Zinenko ArrayRef<Type> params) { 1404d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1405d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1406d9b500d3SAart Bik if (func) 1407d9b500d3SAart Bik return func; 1408d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1409d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1410d9b500d3SAart Bik op->getLoc(), name, 14117ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 14127ed9cfc7SAlex Zinenko params)); 1413d9b500d3SAart Bik } 1414d9b500d3SAart Bik 1415d9b500d3SAart Bik // Helpers for method names. 1416e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 1417*2230bf99SAlex Zinenko return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); 1418e52414b1Saartbik } 1419b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 1420*2230bf99SAlex Zinenko return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); 1421b8880f5fSAart Bik } 1422d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 14237ed9cfc7SAlex Zinenko return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext())); 1424d9b500d3SAart Bik } 1425d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 142654759cefSAart Bik return getPrint(op, "printF64", 14277ed9cfc7SAlex Zinenko LLVM::LLVMDoubleType::get(op->getContext())); 1428d9b500d3SAart Bik } 1429d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 143054759cefSAart Bik return getPrint(op, "printOpen", {}); 1431d9b500d3SAart Bik } 1432d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 143354759cefSAart Bik return getPrint(op, "printClose", {}); 1434d9b500d3SAart Bik } 1435d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 143654759cefSAart Bik return getPrint(op, "printComma", {}); 1437d9b500d3SAart Bik } 1438d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 143954759cefSAart Bik return getPrint(op, "printNewline", {}); 1440d9b500d3SAart Bik } 1441d9b500d3SAart Bik }; 1442d9b500d3SAart Bik 1443334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1444c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1445c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1446c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1447334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 144865678d93SNicolas Vasilache public: 1449b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1450b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1451b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1452b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1453b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1454b99bd771SRiver Riddle } 145565678d93SNicolas Vasilache 1456334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 145765678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 14589eb3e564SChris Lattner auto dstType = op.getType(); 145965678d93SNicolas Vasilache 146065678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 146165678d93SNicolas Vasilache 146265678d93SNicolas Vasilache int64_t offset = 146365678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 146465678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 146565678d93SNicolas Vasilache int64_t stride = 146665678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 146765678d93SNicolas Vasilache 146865678d93SNicolas Vasilache auto loc = op.getLoc(); 146965678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 147035b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1471c3c95b9cSaartbik 1472c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1473c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1474c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1475c3c95b9cSaartbik offsets.reserve(size); 1476c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1477c3c95b9cSaartbik off += stride) 1478c3c95b9cSaartbik offsets.push_back(off); 1479c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1480c3c95b9cSaartbik op.vector(), 1481c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1482c3c95b9cSaartbik return success(); 1483c3c95b9cSaartbik } 1484c3c95b9cSaartbik 1485c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 148665678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 148765678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 148865678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 148965678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 149065678d93SNicolas Vasilache off += stride, ++idx) { 1491c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1492c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1493c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 149465678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 149565678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 149665678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 149765678d93SNicolas Vasilache } 1498c3c95b9cSaartbik rewriter.replaceOp(op, res); 14993145427dSRiver Riddle return success(); 150065678d93SNicolas Vasilache } 150165678d93SNicolas Vasilache }; 150265678d93SNicolas Vasilache 1503df186507SBenjamin Kramer } // namespace 1504df186507SBenjamin Kramer 15055c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 15065c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1507ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1508060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 150965678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 15108345b86dSNicolas Vasilache // clang-format off 1511681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1512681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 15132d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1514c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1515ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1516563879b6SRahul Joshi converter, reassociateFPReductions); 1517060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1518060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1519060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1520563879b6SRahul Joshi converter, enableIndexOptimizations); 15218345b86dSNicolas Vasilache patterns 1522ceb1b327Saartbik .insert<VectorShuffleOpConversion, 15238345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15248345b86dSNicolas Vasilache VectorExtractOpConversion, 15258345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15268345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15278345b86dSNicolas Vasilache VectorInsertOpConversion, 15288345b86dSNicolas Vasilache VectorPrintOpConversion, 152919dbb230Saartbik VectorTypeCastOpConversion, 153039379916Saartbik VectorMaskedLoadOpConversion, 153139379916Saartbik VectorMaskedStoreOpConversion, 153219dbb230Saartbik VectorGatherOpConversion, 1533e8dcf5f8Saartbik VectorScatterOpConversion, 1534e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1535563879b6SRahul Joshi VectorCompressStoreOpConversion>(converter); 15368345b86dSNicolas Vasilache // clang-format on 15375c0c51a9SNicolas Vasilache } 15385c0c51a9SNicolas Vasilache 153963b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 154063b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1541563879b6SRahul Joshi patterns.insert<VectorMatmulOpConversion>(converter); 1542563879b6SRahul Joshi patterns.insert<VectorFlatTransposeOpConversion>(converter); 154363b683a8SNicolas Vasilache } 1544