15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13*e332c22cSNicolas Vasilache #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 1709f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 18ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 195c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 205c0c51a9SNicolas Vasilache 215c0c51a9SNicolas Vasilache using namespace mlir; 2265678d93SNicolas Vasilache using namespace mlir::vector; 235c0c51a9SNicolas Vasilache 249826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 259826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 269826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 279826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 289826fe5cSAart Bik } 299826fe5cSAart Bik 309826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 319826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 329826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 339826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 349826fe5cSAart Bik } 359826fe5cSAart Bik 361c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 37e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 380f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 390f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 400f04384dSAlex Zinenko int64_t pos) { 411c81adf3SAart Bik if (rank == 1) { 421c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 431c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 440f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 451c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 461c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 471c81adf3SAart Bik constant); 481c81adf3SAart Bik } 491c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 501c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 511c81adf3SAart Bik } 521c81adf3SAart Bik 532d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 542d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 552d515e49SNicolas Vasilache Value into, int64_t offset) { 562d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 572d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 582d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 592d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 602d515e49SNicolas Vasilache loc, vectorType, from, into, 612d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 622d515e49SNicolas Vasilache } 632d515e49SNicolas Vasilache 641c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 65e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 660f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 670f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 681c81adf3SAart Bik if (rank == 1) { 691c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 701c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 710f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 721c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 731c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 741c81adf3SAart Bik constant); 751c81adf3SAart Bik } 761c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 771c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 781c81adf3SAart Bik } 791c81adf3SAart Bik 802d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 812d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 822d515e49SNicolas Vasilache int64_t offset) { 832d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 842d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 852d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 862d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 872d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 882d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 892d515e49SNicolas Vasilache } 902d515e49SNicolas Vasilache 912d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 929db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 932d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 942d515e49SNicolas Vasilache unsigned dropFront = 0, 952d515e49SNicolas Vasilache unsigned dropBack = 0) { 962d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 972d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 982d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 992d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 1002d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1012d515e49SNicolas Vasilache it != eit; ++it) 1022d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1032d515e49SNicolas Vasilache return res; 1042d515e49SNicolas Vasilache } 1052d515e49SNicolas Vasilache 106ba87f991SAlex Zinenko static Value createCastToIndexLike(ConversionPatternRewriter &rewriter, 107ba87f991SAlex Zinenko Location loc, Type targetType, Value value) { 108ba87f991SAlex Zinenko if (targetType == value.getType()) 109ba87f991SAlex Zinenko return value; 110ba87f991SAlex Zinenko 111ba87f991SAlex Zinenko bool targetIsIndex = targetType.isIndex(); 112ba87f991SAlex Zinenko bool valueIsIndex = value.getType().isIndex(); 113ba87f991SAlex Zinenko if (targetIsIndex ^ valueIsIndex) 114ba87f991SAlex Zinenko return rewriter.create<IndexCastOp>(loc, targetType, value); 115ba87f991SAlex Zinenko 116ba87f991SAlex Zinenko auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 117ba87f991SAlex Zinenko auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 118ba87f991SAlex Zinenko assert(targetIntegerType && valueIntegerType && 119ba87f991SAlex Zinenko "unexpected cast between types other than integers and index"); 120ba87f991SAlex Zinenko assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 121ba87f991SAlex Zinenko 122ba87f991SAlex Zinenko if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 123ba87f991SAlex Zinenko return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value); 124ba87f991SAlex Zinenko return rewriter.create<TruncateIOp>(loc, targetIntegerType, value); 125ba87f991SAlex Zinenko } 126ba87f991SAlex Zinenko 127060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 128060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 129060c9dd1Saartbik // 130060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 131060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 132060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 133060c9dd1Saartbik // is very conservative on the boundary conditions. 134060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 135060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 136060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 137060c9dd1Saartbik auto loc = op->getLoc(); 138060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 139060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 140060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 141060c9dd1Saartbik Value indices; 142060c9dd1Saartbik Type idxType; 143060c9dd1Saartbik if (enableIndexOptimizations) { 1440c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1450c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1460c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 147060c9dd1Saartbik idxType = rewriter.getI32Type(); 148060c9dd1Saartbik } else { 1490c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1500c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1510c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 152060c9dd1Saartbik idxType = rewriter.getI64Type(); 153060c9dd1Saartbik } 154060c9dd1Saartbik // Add in an offset if requested. 155060c9dd1Saartbik if (off) { 156ba87f991SAlex Zinenko Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 157060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 158060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 159060c9dd1Saartbik } 160060c9dd1Saartbik // Construct the vector comparison. 161ba87f991SAlex Zinenko Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 162060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 163060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 164060c9dd1Saartbik } 165060c9dd1Saartbik 16626c8f908SThomas Raoux // Helper that returns data layout alignment of a memref. 16726c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 16826c8f908SThomas Raoux MemRefType memrefType, unsigned &align) { 16926c8f908SThomas Raoux Type elementTy = typeConverter.convertType(memrefType.getElementType()); 1705f9e0466SNicolas Vasilache if (!elementTy) 1715f9e0466SNicolas Vasilache return failure(); 1725f9e0466SNicolas Vasilache 173b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 174b2ab375dSAlex Zinenko // stop depending on translation. 17587a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 17687a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 177c69c9e0fSAlex Zinenko .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 1785f9e0466SNicolas Vasilache return success(); 1795f9e0466SNicolas Vasilache } 1805f9e0466SNicolas Vasilache 181e8dcf5f8Saartbik // Helper that returns the base address of a memref. 182b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 183e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 18419dbb230Saartbik // Inspect stride and offset structure. 18519dbb230Saartbik // 18619dbb230Saartbik // TODO: flat memory only for now, generalize 18719dbb230Saartbik // 18819dbb230Saartbik int64_t offset; 18919dbb230Saartbik SmallVector<int64_t, 4> strides; 19019dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 19119dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 19219dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 19319dbb230Saartbik return failure(); 194e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 195e8dcf5f8Saartbik return success(); 196e8dcf5f8Saartbik } 19719dbb230Saartbik 198a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector. 199b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 200b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 201b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 202b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 203e8dcf5f8Saartbik Value base; 204e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 205e8dcf5f8Saartbik return failure(); 2063a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 207bd30a796SAlex Zinenko auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 2081485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 20919dbb230Saartbik return success(); 21019dbb230Saartbik } 21119dbb230Saartbik 212a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer 213a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be 214a57def30SAart Bik // used when source/dst memrefs are not on address space 0. 215a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 216a57def30SAart Bik Value ptr, MemRefType memRefType, Type vt) { 217bd30a796SAlex Zinenko auto pType = LLVM::LLVMPointerType::get(vt); 218a57def30SAart Bik if (memRefType.getMemorySpace() == 0) 219a57def30SAart Bik return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 220a57def30SAart Bik return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr); 221a57def30SAart Bik } 222a57def30SAart Bik 2235f9e0466SNicolas Vasilache static LogicalResult 2245f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2255f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2265f9e0466SNicolas Vasilache TransferReadOp xferOp, 2275f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 228affbc0cdSNicolas Vasilache unsigned align; 22926c8f908SThomas Raoux if (failed(getMemRefAlignment( 23026c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 231affbc0cdSNicolas Vasilache return failure(); 232affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2335f9e0466SNicolas Vasilache return success(); 2345f9e0466SNicolas Vasilache } 2355f9e0466SNicolas Vasilache 2365f9e0466SNicolas Vasilache static LogicalResult 2375f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2385f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2395f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2405f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2415f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2425f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2435f9e0466SNicolas Vasilache 2445f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2455f9e0466SNicolas Vasilache if (!vecTy) 2465f9e0466SNicolas Vasilache return failure(); 2475f9e0466SNicolas Vasilache 2485f9e0466SNicolas Vasilache unsigned align; 24926c8f908SThomas Raoux if (failed(getMemRefAlignment( 25026c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2515f9e0466SNicolas Vasilache return failure(); 2525f9e0466SNicolas Vasilache 2535f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2545f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2555f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2565f9e0466SNicolas Vasilache return success(); 2575f9e0466SNicolas Vasilache } 2585f9e0466SNicolas Vasilache 2595f9e0466SNicolas Vasilache static LogicalResult 2605f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2615f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2625f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2635f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 264affbc0cdSNicolas Vasilache unsigned align; 26526c8f908SThomas Raoux if (failed(getMemRefAlignment( 26626c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 267affbc0cdSNicolas Vasilache return failure(); 2682d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 269affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 270affbc0cdSNicolas Vasilache align); 2715f9e0466SNicolas Vasilache return success(); 2725f9e0466SNicolas Vasilache } 2735f9e0466SNicolas Vasilache 2745f9e0466SNicolas Vasilache static LogicalResult 2755f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2765f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2775f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2785f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2795f9e0466SNicolas Vasilache unsigned align; 28026c8f908SThomas Raoux if (failed(getMemRefAlignment( 28126c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2825f9e0466SNicolas Vasilache return failure(); 2835f9e0466SNicolas Vasilache 2842d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2855f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2865f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2875f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2885f9e0466SNicolas Vasilache return success(); 2895f9e0466SNicolas Vasilache } 2905f9e0466SNicolas Vasilache 2912d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2922d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2932d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2945f9e0466SNicolas Vasilache } 2955f9e0466SNicolas Vasilache 2962d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2972d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2982d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2995f9e0466SNicolas Vasilache } 3005f9e0466SNicolas Vasilache 30190c01357SBenjamin Kramer namespace { 302e83b7b99Saartbik 303cf5c517cSDiego Caballero /// Conversion pattern for a vector.bitcast. 304cf5c517cSDiego Caballero class VectorBitCastOpConversion 305cf5c517cSDiego Caballero : public ConvertOpToLLVMPattern<vector::BitCastOp> { 306cf5c517cSDiego Caballero public: 307cf5c517cSDiego Caballero using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern; 308cf5c517cSDiego Caballero 309cf5c517cSDiego Caballero LogicalResult 310cf5c517cSDiego Caballero matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands, 311cf5c517cSDiego Caballero ConversionPatternRewriter &rewriter) const override { 312cf5c517cSDiego Caballero // Only 1-D vectors can be lowered to LLVM. 313cf5c517cSDiego Caballero VectorType resultTy = bitCastOp.getType(); 314cf5c517cSDiego Caballero if (resultTy.getRank() != 1) 315cf5c517cSDiego Caballero return failure(); 316cf5c517cSDiego Caballero Type newResultTy = typeConverter->convertType(resultTy); 317cf5c517cSDiego Caballero rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy, 318cf5c517cSDiego Caballero operands[0]); 319cf5c517cSDiego Caballero return success(); 320cf5c517cSDiego Caballero } 321cf5c517cSDiego Caballero }; 322cf5c517cSDiego Caballero 32363b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 32463b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 325563879b6SRahul Joshi class VectorMatmulOpConversion 326563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MatmulOp> { 32763b683a8SNicolas Vasilache public: 328563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 32963b683a8SNicolas Vasilache 3303145427dSRiver Riddle LogicalResult 331563879b6SRahul Joshi matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 33263b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 3332d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 33463b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 335563879b6SRahul Joshi matmulOp, typeConverter->convertType(matmulOp.res().getType()), 336563879b6SRahul Joshi adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 337563879b6SRahul Joshi matmulOp.lhs_columns(), matmulOp.rhs_columns()); 3383145427dSRiver Riddle return success(); 33963b683a8SNicolas Vasilache } 34063b683a8SNicolas Vasilache }; 34163b683a8SNicolas Vasilache 342c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 343c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 344563879b6SRahul Joshi class VectorFlatTransposeOpConversion 345563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 346c295a65dSaartbik public: 347563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 348c295a65dSaartbik 349c295a65dSaartbik LogicalResult 350563879b6SRahul Joshi matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 351c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 3522d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 353c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 354dcec2ca5SChristian Sigg transOp, typeConverter->convertType(transOp.res().getType()), 355c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 356c295a65dSaartbik return success(); 357c295a65dSaartbik } 358c295a65dSaartbik }; 359c295a65dSaartbik 36039379916Saartbik /// Conversion pattern for a vector.maskedload. 361563879b6SRahul Joshi class VectorMaskedLoadOpConversion 362563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 36339379916Saartbik public: 364563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 36539379916Saartbik 36639379916Saartbik LogicalResult 367563879b6SRahul Joshi matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 36839379916Saartbik ConversionPatternRewriter &rewriter) const override { 369563879b6SRahul Joshi auto loc = load->getLoc(); 37039379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 371a57def30SAart Bik MemRefType memRefType = load.getMemRefType(); 37239379916Saartbik 37339379916Saartbik // Resolve alignment. 37439379916Saartbik unsigned align; 375a57def30SAart Bik if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 37639379916Saartbik return failure(); 37739379916Saartbik 378a57def30SAart Bik // Resolve address. 379dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(load.getResultVectorType()); 380a57def30SAart Bik Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 381a57def30SAart Bik adaptor.indices(), rewriter); 382a57def30SAart Bik Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 38339379916Saartbik 38439379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 38539379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 38639379916Saartbik rewriter.getI32IntegerAttr(align)); 38739379916Saartbik return success(); 38839379916Saartbik } 38939379916Saartbik }; 39039379916Saartbik 39139379916Saartbik /// Conversion pattern for a vector.maskedstore. 392563879b6SRahul Joshi class VectorMaskedStoreOpConversion 393563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 39439379916Saartbik public: 395563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 39639379916Saartbik 39739379916Saartbik LogicalResult 398563879b6SRahul Joshi matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 39939379916Saartbik ConversionPatternRewriter &rewriter) const override { 400563879b6SRahul Joshi auto loc = store->getLoc(); 40139379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 402a57def30SAart Bik MemRefType memRefType = store.getMemRefType(); 40339379916Saartbik 40439379916Saartbik // Resolve alignment. 40539379916Saartbik unsigned align; 406a57def30SAart Bik if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 40739379916Saartbik return failure(); 40839379916Saartbik 409a57def30SAart Bik // Resolve address. 410dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(store.getValueVectorType()); 411a57def30SAart Bik Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 412a57def30SAart Bik adaptor.indices(), rewriter); 413a57def30SAart Bik Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 41439379916Saartbik 41539379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 41639379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 41739379916Saartbik rewriter.getI32IntegerAttr(align)); 41839379916Saartbik return success(); 41939379916Saartbik } 42039379916Saartbik }; 42139379916Saartbik 42219dbb230Saartbik /// Conversion pattern for a vector.gather. 423563879b6SRahul Joshi class VectorGatherOpConversion 424563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::GatherOp> { 42519dbb230Saartbik public: 426563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 42719dbb230Saartbik 42819dbb230Saartbik LogicalResult 429563879b6SRahul Joshi matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 43019dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 431563879b6SRahul Joshi auto loc = gather->getLoc(); 43219dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 43319dbb230Saartbik 43419dbb230Saartbik // Resolve alignment. 43519dbb230Saartbik unsigned align; 43626c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 43726c8f908SThomas Raoux align))) 43819dbb230Saartbik return failure(); 43919dbb230Saartbik 44019dbb230Saartbik // Get index ptrs. 44119dbb230Saartbik VectorType vType = gather.getResultVectorType(); 44219dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 44319dbb230Saartbik Value ptrs; 444e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 445e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 44619dbb230Saartbik return failure(); 44719dbb230Saartbik 44819dbb230Saartbik // Replace with the gather intrinsic. 44919dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 450dcec2ca5SChristian Sigg gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 4510c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 45219dbb230Saartbik return success(); 45319dbb230Saartbik } 45419dbb230Saartbik }; 45519dbb230Saartbik 45619dbb230Saartbik /// Conversion pattern for a vector.scatter. 457563879b6SRahul Joshi class VectorScatterOpConversion 458563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ScatterOp> { 45919dbb230Saartbik public: 460563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 46119dbb230Saartbik 46219dbb230Saartbik LogicalResult 463563879b6SRahul Joshi matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 46419dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 465563879b6SRahul Joshi auto loc = scatter->getLoc(); 46619dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 46719dbb230Saartbik 46819dbb230Saartbik // Resolve alignment. 46919dbb230Saartbik unsigned align; 47026c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 47126c8f908SThomas Raoux align))) 47219dbb230Saartbik return failure(); 47319dbb230Saartbik 47419dbb230Saartbik // Get index ptrs. 47519dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 47619dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 47719dbb230Saartbik Value ptrs; 478e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 479e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 48019dbb230Saartbik return failure(); 48119dbb230Saartbik 48219dbb230Saartbik // Replace with the scatter intrinsic. 48319dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 48419dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 48519dbb230Saartbik rewriter.getI32IntegerAttr(align)); 48619dbb230Saartbik return success(); 48719dbb230Saartbik } 48819dbb230Saartbik }; 48919dbb230Saartbik 490e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 491563879b6SRahul Joshi class VectorExpandLoadOpConversion 492563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 493e8dcf5f8Saartbik public: 494563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 495e8dcf5f8Saartbik 496e8dcf5f8Saartbik LogicalResult 497563879b6SRahul Joshi matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 498e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 499563879b6SRahul Joshi auto loc = expand->getLoc(); 500e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 501a57def30SAart Bik MemRefType memRefType = expand.getMemRefType(); 502e8dcf5f8Saartbik 503a57def30SAart Bik // Resolve address. 504a57def30SAart Bik auto vtype = typeConverter->convertType(expand.getResultVectorType()); 505a57def30SAart Bik Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 506a57def30SAart Bik adaptor.indices(), rewriter); 507e8dcf5f8Saartbik 508e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 509a57def30SAart Bik expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 510e8dcf5f8Saartbik return success(); 511e8dcf5f8Saartbik } 512e8dcf5f8Saartbik }; 513e8dcf5f8Saartbik 514e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 515563879b6SRahul Joshi class VectorCompressStoreOpConversion 516563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 517e8dcf5f8Saartbik public: 518563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 519e8dcf5f8Saartbik 520e8dcf5f8Saartbik LogicalResult 521563879b6SRahul Joshi matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 522e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 523563879b6SRahul Joshi auto loc = compress->getLoc(); 524e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 525a57def30SAart Bik MemRefType memRefType = compress.getMemRefType(); 526e8dcf5f8Saartbik 527a57def30SAart Bik // Resolve address. 528a57def30SAart Bik Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 529a57def30SAart Bik adaptor.indices(), rewriter); 530e8dcf5f8Saartbik 531e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 532563879b6SRahul Joshi compress, adaptor.value(), ptr, adaptor.mask()); 533e8dcf5f8Saartbik return success(); 534e8dcf5f8Saartbik } 535e8dcf5f8Saartbik }; 536e8dcf5f8Saartbik 53719dbb230Saartbik /// Conversion pattern for all vector reductions. 538563879b6SRahul Joshi class VectorReductionOpConversion 539563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ReductionOp> { 540e83b7b99Saartbik public: 541563879b6SRahul Joshi explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 542060c9dd1Saartbik bool reassociateFPRed) 543563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 544060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 545e83b7b99Saartbik 5463145427dSRiver Riddle LogicalResult 547563879b6SRahul Joshi matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 548e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 549e83b7b99Saartbik auto kind = reductionOp.kind(); 550e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 551dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(eltType); 552e9628955SAart Bik if (eltType.isIntOrIndex()) { 553e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 554e83b7b99Saartbik if (kind == "add") 555322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 556563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 557e83b7b99Saartbik else if (kind == "mul") 558322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 559563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 560e9628955SAart Bik else if (kind == "min" && 561e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 562322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 563563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 564e83b7b99Saartbik else if (kind == "min") 565322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 566563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 567e9628955SAart Bik else if (kind == "max" && 568e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 569322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 570563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 571e83b7b99Saartbik else if (kind == "max") 572322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 573563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 574e83b7b99Saartbik else if (kind == "and") 575322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 576563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 577e83b7b99Saartbik else if (kind == "or") 578322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 579563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 580e83b7b99Saartbik else if (kind == "xor") 581322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 582563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 583e83b7b99Saartbik else 5843145427dSRiver Riddle return failure(); 5853145427dSRiver Riddle return success(); 586dcec2ca5SChristian Sigg } 587e83b7b99Saartbik 588dcec2ca5SChristian Sigg if (!eltType.isa<FloatType>()) 589dcec2ca5SChristian Sigg return failure(); 590dcec2ca5SChristian Sigg 591e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 592e83b7b99Saartbik if (kind == "add") { 5930d924700Saartbik // Optional accumulator (or zero). 5940d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5950d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 596563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5970d924700Saartbik rewriter.getZeroAttr(eltType)); 598322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 599563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 600ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 601e83b7b99Saartbik } else if (kind == "mul") { 6020d924700Saartbik // Optional accumulator (or one). 6030d924700Saartbik Value acc = operands.size() > 1 6040d924700Saartbik ? operands[1] 6050d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 606563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 6070d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 608322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 609563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 610ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 611e83b7b99Saartbik } else if (kind == "min") 612563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 613563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 614e83b7b99Saartbik else if (kind == "max") 615563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 616563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 617e83b7b99Saartbik else 6183145427dSRiver Riddle return failure(); 6193145427dSRiver Riddle return success(); 620e83b7b99Saartbik } 621ceb1b327Saartbik 622ceb1b327Saartbik private: 623ceb1b327Saartbik const bool reassociateFPReductions; 624e83b7b99Saartbik }; 625e83b7b99Saartbik 626060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 627563879b6SRahul Joshi class VectorCreateMaskOpConversion 628563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 629060c9dd1Saartbik public: 630563879b6SRahul Joshi explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 631060c9dd1Saartbik bool enableIndexOpt) 632563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 633060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 634060c9dd1Saartbik 635060c9dd1Saartbik LogicalResult 636563879b6SRahul Joshi matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 637060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 6389eb3e564SChris Lattner auto dstType = op.getType(); 639060c9dd1Saartbik int64_t rank = dstType.getRank(); 640060c9dd1Saartbik if (rank == 1) { 641060c9dd1Saartbik rewriter.replaceOp( 642060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 643060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 644060c9dd1Saartbik return success(); 645060c9dd1Saartbik } 646060c9dd1Saartbik return failure(); 647060c9dd1Saartbik } 648060c9dd1Saartbik 649060c9dd1Saartbik private: 650060c9dd1Saartbik const bool enableIndexOptimizations; 651060c9dd1Saartbik }; 652060c9dd1Saartbik 653563879b6SRahul Joshi class VectorShuffleOpConversion 654563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 6551c81adf3SAart Bik public: 656563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 6571c81adf3SAart Bik 6583145427dSRiver Riddle LogicalResult 659563879b6SRahul Joshi matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 6601c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 661563879b6SRahul Joshi auto loc = shuffleOp->getLoc(); 6622d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6631c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6641c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6651c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 666dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(vectorType); 6671c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6681c81adf3SAart Bik 6691c81adf3SAart Bik // Bail if result type cannot be lowered. 6701c81adf3SAart Bik if (!llvmType) 6713145427dSRiver Riddle return failure(); 6721c81adf3SAart Bik 6731c81adf3SAart Bik // Get rank and dimension sizes. 6741c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6751c81adf3SAart Bik assert(v1Type.getRank() == rank); 6761c81adf3SAart Bik assert(v2Type.getRank() == rank); 6771c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6781c81adf3SAart Bik 6791c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6801c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6811c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 682563879b6SRahul Joshi Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 6831c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 684563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, llvmShuffleOp); 6853145427dSRiver Riddle return success(); 686b36aaeafSAart Bik } 687b36aaeafSAart Bik 6881c81adf3SAart Bik // For all other cases, insert the individual values individually. 689e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6901c81adf3SAart Bik int64_t insPos = 0; 6911c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6921c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 693e62a6956SRiver Riddle Value value = adaptor.v1(); 6941c81adf3SAart Bik if (extPos >= v1Dim) { 6951c81adf3SAart Bik extPos -= v1Dim; 6961c81adf3SAart Bik value = adaptor.v2(); 697b36aaeafSAart Bik } 698dcec2ca5SChristian Sigg Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 699dcec2ca5SChristian Sigg llvmType, rank, extPos); 700dcec2ca5SChristian Sigg insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 7010f04384dSAlex Zinenko llvmType, rank, insPos++); 7021c81adf3SAart Bik } 703563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, insert); 7043145427dSRiver Riddle return success(); 705b36aaeafSAart Bik } 706b36aaeafSAart Bik }; 707b36aaeafSAart Bik 708563879b6SRahul Joshi class VectorExtractElementOpConversion 709563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 710cd5dab8aSAart Bik public: 711563879b6SRahul Joshi using ConvertOpToLLVMPattern< 712563879b6SRahul Joshi vector::ExtractElementOp>::ConvertOpToLLVMPattern; 713cd5dab8aSAart Bik 7143145427dSRiver Riddle LogicalResult 715563879b6SRahul Joshi matchAndRewrite(vector::ExtractElementOp extractEltOp, 716563879b6SRahul Joshi ArrayRef<Value> operands, 717cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7182d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 719cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 720dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType.getElementType()); 721cd5dab8aSAart Bik 722cd5dab8aSAart Bik // Bail if result type cannot be lowered. 723cd5dab8aSAart Bik if (!llvmType) 7243145427dSRiver Riddle return failure(); 725cd5dab8aSAart Bik 726cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 727563879b6SRahul Joshi extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 7283145427dSRiver Riddle return success(); 729cd5dab8aSAart Bik } 730cd5dab8aSAart Bik }; 731cd5dab8aSAart Bik 732563879b6SRahul Joshi class VectorExtractOpConversion 733563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractOp> { 7345c0c51a9SNicolas Vasilache public: 735563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 7365c0c51a9SNicolas Vasilache 7373145427dSRiver Riddle LogicalResult 738563879b6SRahul Joshi matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 7395c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 740563879b6SRahul Joshi auto loc = extractOp->getLoc(); 7412d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 7429826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7432bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 744dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(resultType); 7455c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7469826fe5cSAart Bik 7479826fe5cSAart Bik // Bail if result type cannot be lowered. 7489826fe5cSAart Bik if (!llvmResultType) 7493145427dSRiver Riddle return failure(); 7509826fe5cSAart Bik 7515c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7525c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 753e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7545c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 755563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7563145427dSRiver Riddle return success(); 7575c0c51a9SNicolas Vasilache } 7585c0c51a9SNicolas Vasilache 7599826fe5cSAart Bik // Potential extraction of 1-D vector from array. 760563879b6SRahul Joshi auto *context = extractOp->getContext(); 761e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7625c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7635c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7649826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7655c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 766c2c83e97STres Popp ArrayAttr::get(context, positionAttrs.drop_back()); 7675c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 768dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 7695c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7705c0c51a9SNicolas Vasilache } 7715c0c51a9SNicolas Vasilache 7725c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7735c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7742230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 7751d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7765c0c51a9SNicolas Vasilache extracted = 7775c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 778563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7795c0c51a9SNicolas Vasilache 7803145427dSRiver Riddle return success(); 7815c0c51a9SNicolas Vasilache } 7825c0c51a9SNicolas Vasilache }; 7835c0c51a9SNicolas Vasilache 784681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 785681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 786681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 787681f929fSNicolas Vasilache /// 788681f929fSNicolas Vasilache /// Example: 789681f929fSNicolas Vasilache /// ``` 790681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 791681f929fSNicolas Vasilache /// ``` 792681f929fSNicolas Vasilache /// is converted to: 793681f929fSNicolas Vasilache /// ``` 7943bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 795dd5165a9SAlex Zinenko /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 796dd5165a9SAlex Zinenko /// -> !llvm."<8 x f32>"> 797681f929fSNicolas Vasilache /// ``` 798563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 799681f929fSNicolas Vasilache public: 800563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 801681f929fSNicolas Vasilache 8023145427dSRiver Riddle LogicalResult 803563879b6SRahul Joshi matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 804681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 8052d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 806681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 807681f929fSNicolas Vasilache if (vType.getRank() != 1) 8083145427dSRiver Riddle return failure(); 809563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 8103bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 8113145427dSRiver Riddle return success(); 812681f929fSNicolas Vasilache } 813681f929fSNicolas Vasilache }; 814681f929fSNicolas Vasilache 815563879b6SRahul Joshi class VectorInsertElementOpConversion 816563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 817cd5dab8aSAart Bik public: 818563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 819cd5dab8aSAart Bik 8203145427dSRiver Riddle LogicalResult 821563879b6SRahul Joshi matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 822cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 8232d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 824cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 825dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType); 826cd5dab8aSAart Bik 827cd5dab8aSAart Bik // Bail if result type cannot be lowered. 828cd5dab8aSAart Bik if (!llvmType) 8293145427dSRiver Riddle return failure(); 830cd5dab8aSAart Bik 831cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 832563879b6SRahul Joshi insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 833563879b6SRahul Joshi adaptor.position()); 8343145427dSRiver Riddle return success(); 835cd5dab8aSAart Bik } 836cd5dab8aSAart Bik }; 837cd5dab8aSAart Bik 838563879b6SRahul Joshi class VectorInsertOpConversion 839563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertOp> { 8409826fe5cSAart Bik public: 841563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 8429826fe5cSAart Bik 8433145427dSRiver Riddle LogicalResult 844563879b6SRahul Joshi matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 8459826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 846563879b6SRahul Joshi auto loc = insertOp->getLoc(); 8472d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8489826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8499826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 850dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(destVectorType); 8519826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8529826fe5cSAart Bik 8539826fe5cSAart Bik // Bail if result type cannot be lowered. 8549826fe5cSAart Bik if (!llvmResultType) 8553145427dSRiver Riddle return failure(); 8569826fe5cSAart Bik 8579826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8589826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 859e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8609826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8619826fe5cSAart Bik positionArrayAttr); 862563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8633145427dSRiver Riddle return success(); 8649826fe5cSAart Bik } 8659826fe5cSAart Bik 8669826fe5cSAart Bik // Potential extraction of 1-D vector from array. 867563879b6SRahul Joshi auto *context = insertOp->getContext(); 868e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8699826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8709826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8719826fe5cSAart Bik auto oneDVectorType = destVectorType; 8729826fe5cSAart Bik if (positionAttrs.size() > 1) { 8739826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8749826fe5cSAart Bik auto nMinusOnePositionAttrs = 875c2c83e97STres Popp ArrayAttr::get(context, positionAttrs.drop_back()); 8769826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 877dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8789826fe5cSAart Bik nMinusOnePositionAttrs); 8799826fe5cSAart Bik } 8809826fe5cSAart Bik 8819826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 8822230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 8831d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 884e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 885dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8860f04384dSAlex Zinenko adaptor.source(), constant); 8879826fe5cSAart Bik 8889826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 8899826fe5cSAart Bik if (positionAttrs.size() > 1) { 8909826fe5cSAart Bik auto nMinusOnePositionAttrs = 891c2c83e97STres Popp ArrayAttr::get(context, positionAttrs.drop_back()); 8929826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 8939826fe5cSAart Bik adaptor.dest(), inserted, 8949826fe5cSAart Bik nMinusOnePositionAttrs); 8959826fe5cSAart Bik } 8969826fe5cSAart Bik 897563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8983145427dSRiver Riddle return success(); 8999826fe5cSAart Bik } 9009826fe5cSAart Bik }; 9019826fe5cSAart Bik 902681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 903681f929fSNicolas Vasilache /// 904681f929fSNicolas Vasilache /// Example: 905681f929fSNicolas Vasilache /// ``` 906681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 907681f929fSNicolas Vasilache /// ``` 908681f929fSNicolas Vasilache /// is rewritten into: 909681f929fSNicolas Vasilache /// ``` 910681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 911681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 912681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 913681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 914681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 915681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 916681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 917681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 918681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 919681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 920681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 921681f929fSNicolas Vasilache /// // %r3 holds the final value. 922681f929fSNicolas Vasilache /// ``` 923681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 924681f929fSNicolas Vasilache public: 925681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 926681f929fSNicolas Vasilache 9273145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 928681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 929681f929fSNicolas Vasilache auto vType = op.getVectorType(); 930681f929fSNicolas Vasilache if (vType.getRank() < 2) 9313145427dSRiver Riddle return failure(); 932681f929fSNicolas Vasilache 933681f929fSNicolas Vasilache auto loc = op.getLoc(); 934681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 935681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 936681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 937681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 938681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 939681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 940681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 941681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 942681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 943681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 944681f929fSNicolas Vasilache } 945681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9463145427dSRiver Riddle return success(); 947681f929fSNicolas Vasilache } 948681f929fSNicolas Vasilache }; 949681f929fSNicolas Vasilache 9502d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9512d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9522d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9532d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9542d515e49SNicolas Vasilache // rank. 9552d515e49SNicolas Vasilache // 9562d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9572d515e49SNicolas Vasilache // have different ranks. In this case: 9582d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9592d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9602d515e49SNicolas Vasilache // destination subvector 9612d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9622d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9632d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9642d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9652d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9662d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9672d515e49SNicolas Vasilache public: 9682d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9692d515e49SNicolas Vasilache 9703145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9712d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9722d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9732d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9742d515e49SNicolas Vasilache 9752d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9763145427dSRiver Riddle return failure(); 9772d515e49SNicolas Vasilache 9782d515e49SNicolas Vasilache auto loc = op.getLoc(); 9792d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9802d515e49SNicolas Vasilache assert(rankDiff >= 0); 9812d515e49SNicolas Vasilache if (rankDiff == 0) 9823145427dSRiver Riddle return failure(); 9832d515e49SNicolas Vasilache 9842d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9852d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 9862d515e49SNicolas Vasilache // on it. 9872d515e49SNicolas Vasilache Value extracted = 9882d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 9892d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 990dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9912d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 9922d515e49SNicolas Vasilache // ranks. 9932d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 9942d515e49SNicolas Vasilache loc, op.source(), extracted, 9952d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 996c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 9972d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 9982d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 9992d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 1000dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 10013145427dSRiver Riddle return success(); 10022d515e49SNicolas Vasilache } 10032d515e49SNicolas Vasilache }; 10042d515e49SNicolas Vasilache 10052d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 10062d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 10072d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 10082d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 10092d515e49SNicolas Vasilache // destination subvector 10102d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 10112d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 10122d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 10132d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 10142d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 10152d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 10162d515e49SNicolas Vasilache public: 1017b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 1018b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 1019b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 1020b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 1021b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1022b99bd771SRiver Riddle } 10232d515e49SNicolas Vasilache 10243145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 10252d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10262d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10272d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10282d515e49SNicolas Vasilache 10292d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10303145427dSRiver Riddle return failure(); 10312d515e49SNicolas Vasilache 10322d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10332d515e49SNicolas Vasilache assert(rankDiff >= 0); 10342d515e49SNicolas Vasilache if (rankDiff != 0) 10353145427dSRiver Riddle return failure(); 10362d515e49SNicolas Vasilache 10372d515e49SNicolas Vasilache if (srcType == dstType) { 10382d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10393145427dSRiver Riddle return success(); 10402d515e49SNicolas Vasilache } 10412d515e49SNicolas Vasilache 10422d515e49SNicolas Vasilache int64_t offset = 10432d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10442d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10452d515e49SNicolas Vasilache int64_t stride = 10462d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10472d515e49SNicolas Vasilache 10482d515e49SNicolas Vasilache auto loc = op.getLoc(); 10492d515e49SNicolas Vasilache Value res = op.dest(); 10502d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10512d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10522d515e49SNicolas Vasilache off += stride, ++idx) { 10532d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10542d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10552d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10562d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10572d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10582d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10592d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10602d515e49SNicolas Vasilache // smaller rank. 1061bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10622d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10632d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10642d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10652d515e49SNicolas Vasilache } 10662d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10672d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10682d515e49SNicolas Vasilache } 10692d515e49SNicolas Vasilache 10702d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10713145427dSRiver Riddle return success(); 10722d515e49SNicolas Vasilache } 10732d515e49SNicolas Vasilache }; 10742d515e49SNicolas Vasilache 107530e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 107630e6033bSNicolas Vasilache /// static layout. 107730e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 107830e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10792bf491c7SBenjamin Kramer int64_t offset; 108030e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 108130e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 108230e6033bSNicolas Vasilache return None; 108330e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 108430e6033bSNicolas Vasilache return None; 108530e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 108630e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 108730e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 108830e6033bSNicolas Vasilache return strides; 108930e6033bSNicolas Vasilache 109030e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 109130e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 109230e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 109330e6033bSNicolas Vasilache // layout. 10942bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 10952bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 109630e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 109730e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 109830e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 109930e6033bSNicolas Vasilache return None; 110030e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 110130e6033bSNicolas Vasilache return None; 11022bf491c7SBenjamin Kramer } 110330e6033bSNicolas Vasilache return strides; 11042bf491c7SBenjamin Kramer } 11052bf491c7SBenjamin Kramer 1106563879b6SRahul Joshi class VectorTypeCastOpConversion 1107563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 11085c0c51a9SNicolas Vasilache public: 1109563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 11105c0c51a9SNicolas Vasilache 11113145427dSRiver Riddle LogicalResult 1112563879b6SRahul Joshi matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 11135c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 1114563879b6SRahul Joshi auto loc = castOp->getLoc(); 11155c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 11162bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 11179eb3e564SChris Lattner MemRefType targetMemRefType = castOp.getType(); 11185c0c51a9SNicolas Vasilache 11195c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 11205c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 11215c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 11223145427dSRiver Riddle return failure(); 11235c0c51a9SNicolas Vasilache 11245c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 11258de43b92SAlex Zinenko operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 11268de43b92SAlex Zinenko if (!llvmSourceDescriptorTy) 11273145427dSRiver Riddle return failure(); 11285c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11295c0c51a9SNicolas Vasilache 1130dcec2ca5SChristian Sigg auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 11318de43b92SAlex Zinenko .dyn_cast_or_null<LLVM::LLVMStructType>(); 11328de43b92SAlex Zinenko if (!llvmTargetDescriptorTy) 11333145427dSRiver Riddle return failure(); 11345c0c51a9SNicolas Vasilache 113530e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 113630e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 113730e6033bSNicolas Vasilache if (!sourceStrides) 113830e6033bSNicolas Vasilache return failure(); 113930e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 114030e6033bSNicolas Vasilache if (!targetStrides) 114130e6033bSNicolas Vasilache return failure(); 114230e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 114330e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 114430e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 114530e6033bSNicolas Vasilache })) 11463145427dSRiver Riddle return failure(); 11475c0c51a9SNicolas Vasilache 11482230bf99SAlex Zinenko auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 11495c0c51a9SNicolas Vasilache 11505c0c51a9SNicolas Vasilache // Create descriptor. 11515c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11523a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11535c0c51a9SNicolas Vasilache // Set allocated ptr. 1154e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11555c0c51a9SNicolas Vasilache allocated = 11565c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11575c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11585c0c51a9SNicolas Vasilache // Set aligned ptr. 1159e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11605c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11615c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11625c0c51a9SNicolas Vasilache // Fill offset 0. 11635c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11645c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11655c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11665c0c51a9SNicolas Vasilache 11675c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11685c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11695c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11705c0c51a9SNicolas Vasilache auto sizeAttr = 11715c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11725c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11735c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 117430e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 117530e6033bSNicolas Vasilache (*targetStrides)[index]); 11765c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11775c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11785c0c51a9SNicolas Vasilache } 11795c0c51a9SNicolas Vasilache 1180563879b6SRahul Joshi rewriter.replaceOp(castOp, {desc}); 11813145427dSRiver Riddle return success(); 11825c0c51a9SNicolas Vasilache } 11835c0c51a9SNicolas Vasilache }; 11845c0c51a9SNicolas Vasilache 11858345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 11868345b86dSNicolas Vasilache /// sequence of: 1187060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1188060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1189060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1190060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1191060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 11928345b86dSNicolas Vasilache template <typename ConcreteOp> 1193563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 11948345b86dSNicolas Vasilache public: 1195563879b6SRahul Joshi explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1196060c9dd1Saartbik bool enableIndexOpt) 1197563879b6SRahul Joshi : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1198060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 11998345b86dSNicolas Vasilache 12008345b86dSNicolas Vasilache LogicalResult 1201563879b6SRahul Joshi matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 12028345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 12038345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1204b2c79c50SNicolas Vasilache 1205b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1206b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 12078345b86dSNicolas Vasilache return failure(); 12085f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 12095f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 12105f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 1211563879b6SRahul Joshi xferOp->getContext())) 12128345b86dSNicolas Vasilache return failure(); 121326c8f908SThomas Raoux auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 121426c8f908SThomas Raoux if (!memRefType) 121526c8f908SThomas Raoux return failure(); 12162bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 121726c8f908SThomas Raoux auto strides = computeContiguousStrides(memRefType); 121830e6033bSNicolas Vasilache if (!strides) 12192bf491c7SBenjamin Kramer return failure(); 12208345b86dSNicolas Vasilache 1221563879b6SRahul Joshi auto toLLVMTy = [&](Type t) { 1222563879b6SRahul Joshi return this->getTypeConverter()->convertType(t); 1223563879b6SRahul Joshi }; 12248345b86dSNicolas Vasilache 1225563879b6SRahul Joshi Location loc = xferOp->getLoc(); 12268345b86dSNicolas Vasilache 122768330ee0SThomas Raoux if (auto memrefVectorElementType = 122826c8f908SThomas Raoux memRefType.getElementType().template dyn_cast<VectorType>()) { 122968330ee0SThomas Raoux // Memref has vector element type. 123068330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 123168330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 123268330ee0SThomas Raoux return failure(); 12330de60b55SThomas Raoux #ifndef NDEBUG 123468330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 123568330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 123668330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 123768330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 123868330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 123968330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 124068330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 124168330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 124268330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 124368330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 124468330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 124568330ee0SThomas Raoux "result shape."); 12460de60b55SThomas Raoux #endif // ifndef NDEBUG 124768330ee0SThomas Raoux } 124868330ee0SThomas Raoux 12498345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1250a57def30SAart Bik VectorType vtp = xferOp.getVectorType(); 1251563879b6SRahul Joshi Value dataPtr = this->getStridedElementPtr( 125226c8f908SThomas Raoux loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1253a57def30SAart Bik Value vectorDataPtr = 1254a57def30SAart Bik castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 12558345b86dSNicolas Vasilache 12561870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 1257563879b6SRahul Joshi return replaceTransferOpWithLoadOrStore(rewriter, 1258563879b6SRahul Joshi *this->getTypeConverter(), loc, 1259563879b6SRahul Joshi xferOp, operands, vectorDataPtr); 12601870e787SNicolas Vasilache 12618345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12628345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12638345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12648345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1265060c9dd1Saartbik // 1266060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1267060c9dd1Saartbik // dimensions here. 1268bd30a796SAlex Zinenko unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); 1269060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12700c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 127126c8f908SThomas Raoux Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1272563879b6SRahul Joshi Value mask = buildVectorComparison( 1273563879b6SRahul Joshi rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 12748345b86dSNicolas Vasilache 12758345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 1276563879b6SRahul Joshi return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1277dcec2ca5SChristian Sigg xferOp, operands, vectorDataPtr, mask); 12788345b86dSNicolas Vasilache } 1279060c9dd1Saartbik 1280060c9dd1Saartbik private: 1281060c9dd1Saartbik const bool enableIndexOptimizations; 12828345b86dSNicolas Vasilache }; 12838345b86dSNicolas Vasilache 1284563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1285d9b500d3SAart Bik public: 1286563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1287d9b500d3SAart Bik 1288d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1289d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1290d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1291d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1292d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1293d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1294d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1295d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1296d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1297d9b500d3SAart Bik // 12989db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1299d9b500d3SAart Bik // 13003145427dSRiver Riddle LogicalResult 1301563879b6SRahul Joshi matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1302d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 13032d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1304d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1305d9b500d3SAart Bik 1306dcec2ca5SChristian Sigg if (typeConverter->convertType(printType) == nullptr) 13073145427dSRiver Riddle return failure(); 1308d9b500d3SAart Bik 1309b8880f5fSAart Bik // Make sure element type has runtime support. 1310b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1311d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1312d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1313d9b500d3SAart Bik Operation *printer; 1314b8880f5fSAart Bik if (eltType.isF32()) { 1315*e332c22cSNicolas Vasilache printer = 1316*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>()); 1317b8880f5fSAart Bik } else if (eltType.isF64()) { 1318*e332c22cSNicolas Vasilache printer = 1319*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>()); 132054759cefSAart Bik } else if (eltType.isIndex()) { 1321*e332c22cSNicolas Vasilache printer = 1322*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>()); 1323b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1324b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1325b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1326b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1327b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1328b8880f5fSAart Bik if (intTy.isUnsigned()) { 132954759cefSAart Bik if (width <= 64) { 1330b8880f5fSAart Bik if (width < 64) 1331b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1332*e332c22cSNicolas Vasilache printer = LLVM::lookupOrCreatePrintU64Fn( 1333*e332c22cSNicolas Vasilache printOp->getParentOfType<ModuleOp>()); 1334b8880f5fSAart Bik } else { 13353145427dSRiver Riddle return failure(); 1336b8880f5fSAart Bik } 1337b8880f5fSAart Bik } else { 1338b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 133954759cefSAart Bik if (width <= 64) { 1340b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1341b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1342b8880f5fSAart Bik if (width == 1) 134354759cefSAart Bik conversion = PrintConversion::ZeroExt64; 134454759cefSAart Bik else if (width < 64) 1345b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1346*e332c22cSNicolas Vasilache printer = LLVM::lookupOrCreatePrintI64Fn( 1347*e332c22cSNicolas Vasilache printOp->getParentOfType<ModuleOp>()); 1348b8880f5fSAart Bik } else { 1349b8880f5fSAart Bik return failure(); 1350b8880f5fSAart Bik } 1351b8880f5fSAart Bik } 1352b8880f5fSAart Bik } else { 1353b8880f5fSAart Bik return failure(); 1354b8880f5fSAart Bik } 1355d9b500d3SAart Bik 1356d9b500d3SAart Bik // Unroll vector into elementary print calls. 1357b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1358563879b6SRahul Joshi emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1359b8880f5fSAart Bik conversion); 1360*e332c22cSNicolas Vasilache emitCall(rewriter, printOp->getLoc(), 1361*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintNewlineFn( 1362*e332c22cSNicolas Vasilache printOp->getParentOfType<ModuleOp>())); 1363563879b6SRahul Joshi rewriter.eraseOp(printOp); 13643145427dSRiver Riddle return success(); 1365d9b500d3SAart Bik } 1366d9b500d3SAart Bik 1367d9b500d3SAart Bik private: 1368b8880f5fSAart Bik enum class PrintConversion { 136930e6033bSNicolas Vasilache // clang-format off 1370b8880f5fSAart Bik None, 1371b8880f5fSAart Bik ZeroExt64, 1372b8880f5fSAart Bik SignExt64 137330e6033bSNicolas Vasilache // clang-format on 1374b8880f5fSAart Bik }; 1375b8880f5fSAart Bik 1376d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1377e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1378b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1379d9b500d3SAart Bik Location loc = op->getLoc(); 1380d9b500d3SAart Bik if (rank == 0) { 1381b8880f5fSAart Bik switch (conversion) { 1382b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1383b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 13842230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1385b8880f5fSAart Bik break; 1386b8880f5fSAart Bik case PrintConversion::SignExt64: 1387b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 13882230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1389b8880f5fSAart Bik break; 1390b8880f5fSAart Bik case PrintConversion::None: 1391b8880f5fSAart Bik break; 1392c9eeeb38Saartbik } 1393d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1394d9b500d3SAart Bik return; 1395d9b500d3SAart Bik } 1396d9b500d3SAart Bik 1397*e332c22cSNicolas Vasilache emitCall(rewriter, loc, 1398*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>())); 1399*e332c22cSNicolas Vasilache Operation *printComma = 1400*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>()); 1401d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1402d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1403d9b500d3SAart Bik auto reducedType = 1404d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1405dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType( 1406d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1407dcec2ca5SChristian Sigg Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1408dcec2ca5SChristian Sigg llvmType, rank, d); 1409b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1410b8880f5fSAart Bik conversion); 1411d9b500d3SAart Bik if (d != dim - 1) 1412d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1413d9b500d3SAart Bik } 1414*e332c22cSNicolas Vasilache emitCall(rewriter, loc, 1415*e332c22cSNicolas Vasilache LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>())); 1416d9b500d3SAart Bik } 1417d9b500d3SAart Bik 1418d9b500d3SAart Bik // Helper to emit a call. 1419d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1420d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 142108e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1422d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1423d9b500d3SAart Bik } 1424d9b500d3SAart Bik }; 1425d9b500d3SAart Bik 1426334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1427c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1428c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1429c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1430334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 143165678d93SNicolas Vasilache public: 1432b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1433b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1434b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1435b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1436b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1437b99bd771SRiver Riddle } 143865678d93SNicolas Vasilache 1439334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 144065678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 14419eb3e564SChris Lattner auto dstType = op.getType(); 144265678d93SNicolas Vasilache 144365678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 144465678d93SNicolas Vasilache 144565678d93SNicolas Vasilache int64_t offset = 144665678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 144765678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 144865678d93SNicolas Vasilache int64_t stride = 144965678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 145065678d93SNicolas Vasilache 145165678d93SNicolas Vasilache auto loc = op.getLoc(); 145265678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 145335b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1454c3c95b9cSaartbik 1455c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1456c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1457c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1458c3c95b9cSaartbik offsets.reserve(size); 1459c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1460c3c95b9cSaartbik off += stride) 1461c3c95b9cSaartbik offsets.push_back(off); 1462c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1463c3c95b9cSaartbik op.vector(), 1464c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1465c3c95b9cSaartbik return success(); 1466c3c95b9cSaartbik } 1467c3c95b9cSaartbik 1468c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 146965678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 147065678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 147165678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 147265678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 147365678d93SNicolas Vasilache off += stride, ++idx) { 1474c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1475c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1476c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 147765678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 147865678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 147965678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 148065678d93SNicolas Vasilache } 1481c3c95b9cSaartbik rewriter.replaceOp(op, res); 14823145427dSRiver Riddle return success(); 148365678d93SNicolas Vasilache } 148465678d93SNicolas Vasilache }; 148565678d93SNicolas Vasilache 1486df186507SBenjamin Kramer } // namespace 1487df186507SBenjamin Kramer 14885c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 14895c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1490ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1491060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 149265678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 14938345b86dSNicolas Vasilache // clang-format off 1494681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1495681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 14962d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1497c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1498ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1499563879b6SRahul Joshi converter, reassociateFPReductions); 1500060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1501060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1502060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1503563879b6SRahul Joshi converter, enableIndexOptimizations); 15048345b86dSNicolas Vasilache patterns 1505cf5c517cSDiego Caballero .insert<VectorBitCastOpConversion, 1506cf5c517cSDiego Caballero VectorShuffleOpConversion, 15078345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15088345b86dSNicolas Vasilache VectorExtractOpConversion, 15098345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15108345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15118345b86dSNicolas Vasilache VectorInsertOpConversion, 15128345b86dSNicolas Vasilache VectorPrintOpConversion, 151319dbb230Saartbik VectorTypeCastOpConversion, 151439379916Saartbik VectorMaskedLoadOpConversion, 151539379916Saartbik VectorMaskedStoreOpConversion, 151619dbb230Saartbik VectorGatherOpConversion, 1517e8dcf5f8Saartbik VectorScatterOpConversion, 1518e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1519563879b6SRahul Joshi VectorCompressStoreOpConversion>(converter); 15208345b86dSNicolas Vasilache // clang-format on 15215c0c51a9SNicolas Vasilache } 15225c0c51a9SNicolas Vasilache 152363b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 152463b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1525563879b6SRahul Joshi patterns.insert<VectorMatmulOpConversion>(converter); 1526563879b6SRahul Joshi patterns.insert<VectorFlatTransposeOpConversion>(converter); 152763b683a8SNicolas Vasilache } 1528