15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 115c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 135c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1469d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 154d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 1609f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 17ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 185c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 195c0c51a9SNicolas Vasilache 205c0c51a9SNicolas Vasilache using namespace mlir; 2165678d93SNicolas Vasilache using namespace mlir::vector; 225c0c51a9SNicolas Vasilache 239826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 249826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 259826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 269826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 279826fe5cSAart Bik } 289826fe5cSAart Bik 299826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 309826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 319826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 329826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 339826fe5cSAart Bik } 349826fe5cSAart Bik 351c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 36e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 370f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 380f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 390f04384dSAlex Zinenko int64_t pos) { 401c81adf3SAart Bik if (rank == 1) { 411c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 421c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 430f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 441c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 451c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 461c81adf3SAart Bik constant); 471c81adf3SAart Bik } 481c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 491c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 501c81adf3SAart Bik } 511c81adf3SAart Bik 522d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 532d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 542d515e49SNicolas Vasilache Value into, int64_t offset) { 552d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 562d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 572d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 582d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 592d515e49SNicolas Vasilache loc, vectorType, from, into, 602d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 612d515e49SNicolas Vasilache } 622d515e49SNicolas Vasilache 631c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 64e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 650f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 660f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 671c81adf3SAart Bik if (rank == 1) { 681c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 691c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 700f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 711c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 721c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 731c81adf3SAart Bik constant); 741c81adf3SAart Bik } 751c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 761c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 771c81adf3SAart Bik } 781c81adf3SAart Bik 792d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 802d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 812d515e49SNicolas Vasilache int64_t offset) { 822d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 832d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 842d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 852d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 862d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 872d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 882d515e49SNicolas Vasilache } 892d515e49SNicolas Vasilache 902d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 919db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 922d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 932d515e49SNicolas Vasilache unsigned dropFront = 0, 942d515e49SNicolas Vasilache unsigned dropBack = 0) { 952d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 962d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 972d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 982d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 992d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1002d515e49SNicolas Vasilache it != eit; ++it) 1012d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1022d515e49SNicolas Vasilache return res; 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 105060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 106060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 107060c9dd1Saartbik // 108060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 109060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 110060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 111060c9dd1Saartbik // is very conservative on the boundary conditions. 112060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 113060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 114060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 115060c9dd1Saartbik auto loc = op->getLoc(); 116060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 117060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 118060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 119060c9dd1Saartbik Value indices; 120060c9dd1Saartbik Type idxType; 121060c9dd1Saartbik if (enableIndexOptimizations) { 1220c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1230c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1240c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 125060c9dd1Saartbik idxType = rewriter.getI32Type(); 126060c9dd1Saartbik } else { 1270c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1280c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1290c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 130060c9dd1Saartbik idxType = rewriter.getI64Type(); 131060c9dd1Saartbik } 132060c9dd1Saartbik // Add in an offset if requested. 133060c9dd1Saartbik if (off) { 134060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 135060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 136060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 137060c9dd1Saartbik } 138060c9dd1Saartbik // Construct the vector comparison. 139060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 140060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 141060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 142060c9dd1Saartbik } 143060c9dd1Saartbik 14426c8f908SThomas Raoux // Helper that returns data layout alignment of a memref. 14526c8f908SThomas Raoux LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, 14626c8f908SThomas Raoux MemRefType memrefType, unsigned &align) { 14726c8f908SThomas Raoux Type elementTy = typeConverter.convertType(memrefType.getElementType()); 1485f9e0466SNicolas Vasilache if (!elementTy) 1495f9e0466SNicolas Vasilache return failure(); 1505f9e0466SNicolas Vasilache 151b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 152b2ab375dSAlex Zinenko // stop depending on translation. 15387a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 15487a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 155c69c9e0fSAlex Zinenko .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); 1565f9e0466SNicolas Vasilache return success(); 1575f9e0466SNicolas Vasilache } 1585f9e0466SNicolas Vasilache 159e8dcf5f8Saartbik // Helper that returns the base address of a memref. 160b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 161e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 16219dbb230Saartbik // Inspect stride and offset structure. 16319dbb230Saartbik // 16419dbb230Saartbik // TODO: flat memory only for now, generalize 16519dbb230Saartbik // 16619dbb230Saartbik int64_t offset; 16719dbb230Saartbik SmallVector<int64_t, 4> strides; 16819dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 16919dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 17019dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 17119dbb230Saartbik return failure(); 172e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 173e8dcf5f8Saartbik return success(); 174e8dcf5f8Saartbik } 17519dbb230Saartbik 176a57def30SAart Bik // Helper that returns vector of pointers given a memref base with index vector. 177b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 178b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 179b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 180b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 181e8dcf5f8Saartbik Value base; 182e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 183e8dcf5f8Saartbik return failure(); 1843a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 185*bd30a796SAlex Zinenko auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); 1861485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 18719dbb230Saartbik return success(); 18819dbb230Saartbik } 18919dbb230Saartbik 190a57def30SAart Bik // Casts a strided element pointer to a vector pointer. The vector pointer 191a57def30SAart Bik // would always be on address space 0, therefore addrspacecast shall be 192a57def30SAart Bik // used when source/dst memrefs are not on address space 0. 193a57def30SAart Bik static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, 194a57def30SAart Bik Value ptr, MemRefType memRefType, Type vt) { 195*bd30a796SAlex Zinenko auto pType = LLVM::LLVMPointerType::get(vt); 196a57def30SAart Bik if (memRefType.getMemorySpace() == 0) 197a57def30SAart Bik return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr); 198a57def30SAart Bik return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr); 199a57def30SAart Bik } 200a57def30SAart Bik 2015f9e0466SNicolas Vasilache static LogicalResult 2025f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2035f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2045f9e0466SNicolas Vasilache TransferReadOp xferOp, 2055f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 206affbc0cdSNicolas Vasilache unsigned align; 20726c8f908SThomas Raoux if (failed(getMemRefAlignment( 20826c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 209affbc0cdSNicolas Vasilache return failure(); 210affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2115f9e0466SNicolas Vasilache return success(); 2125f9e0466SNicolas Vasilache } 2135f9e0466SNicolas Vasilache 2145f9e0466SNicolas Vasilache static LogicalResult 2155f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2165f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2175f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2185f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2195f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2205f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2215f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2225f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2235f9e0466SNicolas Vasilache 2245f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2255f9e0466SNicolas Vasilache if (!vecTy) 2265f9e0466SNicolas Vasilache return failure(); 2275f9e0466SNicolas Vasilache 2285f9e0466SNicolas Vasilache unsigned align; 22926c8f908SThomas Raoux if (failed(getMemRefAlignment( 23026c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2315f9e0466SNicolas Vasilache return failure(); 2325f9e0466SNicolas Vasilache 2335f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2345f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2355f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2365f9e0466SNicolas Vasilache return success(); 2375f9e0466SNicolas Vasilache } 2385f9e0466SNicolas Vasilache 2395f9e0466SNicolas Vasilache static LogicalResult 2405f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2415f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2425f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2435f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 244affbc0cdSNicolas Vasilache unsigned align; 24526c8f908SThomas Raoux if (failed(getMemRefAlignment( 24626c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 247affbc0cdSNicolas Vasilache return failure(); 2482d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 249affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 250affbc0cdSNicolas Vasilache align); 2515f9e0466SNicolas Vasilache return success(); 2525f9e0466SNicolas Vasilache } 2535f9e0466SNicolas Vasilache 2545f9e0466SNicolas Vasilache static LogicalResult 2555f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2565f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2575f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2585f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2595f9e0466SNicolas Vasilache unsigned align; 26026c8f908SThomas Raoux if (failed(getMemRefAlignment( 26126c8f908SThomas Raoux typeConverter, xferOp.getShapedType().cast<MemRefType>(), align))) 2625f9e0466SNicolas Vasilache return failure(); 2635f9e0466SNicolas Vasilache 2642d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2655f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2665f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2675f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2685f9e0466SNicolas Vasilache return success(); 2695f9e0466SNicolas Vasilache } 2705f9e0466SNicolas Vasilache 2712d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 2722d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2732d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 2745f9e0466SNicolas Vasilache } 2755f9e0466SNicolas Vasilache 2762d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 2772d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 2782d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 2795f9e0466SNicolas Vasilache } 2805f9e0466SNicolas Vasilache 28190c01357SBenjamin Kramer namespace { 282e83b7b99Saartbik 28363b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 28463b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 285563879b6SRahul Joshi class VectorMatmulOpConversion 286563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MatmulOp> { 28763b683a8SNicolas Vasilache public: 288563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; 28963b683a8SNicolas Vasilache 2903145427dSRiver Riddle LogicalResult 291563879b6SRahul Joshi matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands, 29263b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 2932d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 29463b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 295563879b6SRahul Joshi matmulOp, typeConverter->convertType(matmulOp.res().getType()), 296563879b6SRahul Joshi adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), 297563879b6SRahul Joshi matmulOp.lhs_columns(), matmulOp.rhs_columns()); 2983145427dSRiver Riddle return success(); 29963b683a8SNicolas Vasilache } 30063b683a8SNicolas Vasilache }; 30163b683a8SNicolas Vasilache 302c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 303c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 304563879b6SRahul Joshi class VectorFlatTransposeOpConversion 305563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { 306c295a65dSaartbik public: 307563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; 308c295a65dSaartbik 309c295a65dSaartbik LogicalResult 310563879b6SRahul Joshi matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands, 311c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 3122d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 313c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 314dcec2ca5SChristian Sigg transOp, typeConverter->convertType(transOp.res().getType()), 315c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 316c295a65dSaartbik return success(); 317c295a65dSaartbik } 318c295a65dSaartbik }; 319c295a65dSaartbik 32039379916Saartbik /// Conversion pattern for a vector.maskedload. 321563879b6SRahul Joshi class VectorMaskedLoadOpConversion 322563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> { 32339379916Saartbik public: 324563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern; 32539379916Saartbik 32639379916Saartbik LogicalResult 327563879b6SRahul Joshi matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands, 32839379916Saartbik ConversionPatternRewriter &rewriter) const override { 329563879b6SRahul Joshi auto loc = load->getLoc(); 33039379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 331a57def30SAart Bik MemRefType memRefType = load.getMemRefType(); 33239379916Saartbik 33339379916Saartbik // Resolve alignment. 33439379916Saartbik unsigned align; 335a57def30SAart Bik if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 33639379916Saartbik return failure(); 33739379916Saartbik 338a57def30SAart Bik // Resolve address. 339dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(load.getResultVectorType()); 340a57def30SAart Bik Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 341a57def30SAart Bik adaptor.indices(), rewriter); 342a57def30SAart Bik Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 34339379916Saartbik 34439379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 34539379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 34639379916Saartbik rewriter.getI32IntegerAttr(align)); 34739379916Saartbik return success(); 34839379916Saartbik } 34939379916Saartbik }; 35039379916Saartbik 35139379916Saartbik /// Conversion pattern for a vector.maskedstore. 352563879b6SRahul Joshi class VectorMaskedStoreOpConversion 353563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> { 35439379916Saartbik public: 355563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern; 35639379916Saartbik 35739379916Saartbik LogicalResult 358563879b6SRahul Joshi matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands, 35939379916Saartbik ConversionPatternRewriter &rewriter) const override { 360563879b6SRahul Joshi auto loc = store->getLoc(); 36139379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 362a57def30SAart Bik MemRefType memRefType = store.getMemRefType(); 36339379916Saartbik 36439379916Saartbik // Resolve alignment. 36539379916Saartbik unsigned align; 366a57def30SAart Bik if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) 36739379916Saartbik return failure(); 36839379916Saartbik 369a57def30SAart Bik // Resolve address. 370dcec2ca5SChristian Sigg auto vtype = typeConverter->convertType(store.getValueVectorType()); 371a57def30SAart Bik Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 372a57def30SAart Bik adaptor.indices(), rewriter); 373a57def30SAart Bik Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); 37439379916Saartbik 37539379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 37639379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 37739379916Saartbik rewriter.getI32IntegerAttr(align)); 37839379916Saartbik return success(); 37939379916Saartbik } 38039379916Saartbik }; 38139379916Saartbik 38219dbb230Saartbik /// Conversion pattern for a vector.gather. 383563879b6SRahul Joshi class VectorGatherOpConversion 384563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::GatherOp> { 38519dbb230Saartbik public: 386563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; 38719dbb230Saartbik 38819dbb230Saartbik LogicalResult 389563879b6SRahul Joshi matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands, 39019dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 391563879b6SRahul Joshi auto loc = gather->getLoc(); 39219dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 39319dbb230Saartbik 39419dbb230Saartbik // Resolve alignment. 39519dbb230Saartbik unsigned align; 39626c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), 39726c8f908SThomas Raoux align))) 39819dbb230Saartbik return failure(); 39919dbb230Saartbik 40019dbb230Saartbik // Get index ptrs. 40119dbb230Saartbik VectorType vType = gather.getResultVectorType(); 40219dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 40319dbb230Saartbik Value ptrs; 404e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 405e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 40619dbb230Saartbik return failure(); 40719dbb230Saartbik 40819dbb230Saartbik // Replace with the gather intrinsic. 40919dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 410dcec2ca5SChristian Sigg gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), 4110c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 41219dbb230Saartbik return success(); 41319dbb230Saartbik } 41419dbb230Saartbik }; 41519dbb230Saartbik 41619dbb230Saartbik /// Conversion pattern for a vector.scatter. 417563879b6SRahul Joshi class VectorScatterOpConversion 418563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ScatterOp> { 41919dbb230Saartbik public: 420563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; 42119dbb230Saartbik 42219dbb230Saartbik LogicalResult 423563879b6SRahul Joshi matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands, 42419dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 425563879b6SRahul Joshi auto loc = scatter->getLoc(); 42619dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 42719dbb230Saartbik 42819dbb230Saartbik // Resolve alignment. 42919dbb230Saartbik unsigned align; 43026c8f908SThomas Raoux if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), 43126c8f908SThomas Raoux align))) 43219dbb230Saartbik return failure(); 43319dbb230Saartbik 43419dbb230Saartbik // Get index ptrs. 43519dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 43619dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 43719dbb230Saartbik Value ptrs; 438e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 439e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 44019dbb230Saartbik return failure(); 44119dbb230Saartbik 44219dbb230Saartbik // Replace with the scatter intrinsic. 44319dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 44419dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 44519dbb230Saartbik rewriter.getI32IntegerAttr(align)); 44619dbb230Saartbik return success(); 44719dbb230Saartbik } 44819dbb230Saartbik }; 44919dbb230Saartbik 450e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 451563879b6SRahul Joshi class VectorExpandLoadOpConversion 452563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> { 453e8dcf5f8Saartbik public: 454563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; 455e8dcf5f8Saartbik 456e8dcf5f8Saartbik LogicalResult 457563879b6SRahul Joshi matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands, 458e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 459563879b6SRahul Joshi auto loc = expand->getLoc(); 460e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 461a57def30SAart Bik MemRefType memRefType = expand.getMemRefType(); 462e8dcf5f8Saartbik 463a57def30SAart Bik // Resolve address. 464a57def30SAart Bik auto vtype = typeConverter->convertType(expand.getResultVectorType()); 465a57def30SAart Bik Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 466a57def30SAart Bik adaptor.indices(), rewriter); 467e8dcf5f8Saartbik 468e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 469a57def30SAart Bik expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); 470e8dcf5f8Saartbik return success(); 471e8dcf5f8Saartbik } 472e8dcf5f8Saartbik }; 473e8dcf5f8Saartbik 474e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 475563879b6SRahul Joshi class VectorCompressStoreOpConversion 476563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CompressStoreOp> { 477e8dcf5f8Saartbik public: 478563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; 479e8dcf5f8Saartbik 480e8dcf5f8Saartbik LogicalResult 481563879b6SRahul Joshi matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands, 482e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 483563879b6SRahul Joshi auto loc = compress->getLoc(); 484e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 485a57def30SAart Bik MemRefType memRefType = compress.getMemRefType(); 486e8dcf5f8Saartbik 487a57def30SAart Bik // Resolve address. 488a57def30SAart Bik Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), 489a57def30SAart Bik adaptor.indices(), rewriter); 490e8dcf5f8Saartbik 491e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 492563879b6SRahul Joshi compress, adaptor.value(), ptr, adaptor.mask()); 493e8dcf5f8Saartbik return success(); 494e8dcf5f8Saartbik } 495e8dcf5f8Saartbik }; 496e8dcf5f8Saartbik 49719dbb230Saartbik /// Conversion pattern for all vector reductions. 498563879b6SRahul Joshi class VectorReductionOpConversion 499563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ReductionOp> { 500e83b7b99Saartbik public: 501563879b6SRahul Joshi explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, 502060c9dd1Saartbik bool reassociateFPRed) 503563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv), 504060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 505e83b7b99Saartbik 5063145427dSRiver Riddle LogicalResult 507563879b6SRahul Joshi matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands, 508e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 509e83b7b99Saartbik auto kind = reductionOp.kind(); 510e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 511dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(eltType); 512e9628955SAart Bik if (eltType.isIntOrIndex()) { 513e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 514e83b7b99Saartbik if (kind == "add") 515322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 516563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 517e83b7b99Saartbik else if (kind == "mul") 518322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 519563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 520e9628955SAart Bik else if (kind == "min" && 521e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 522322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 523563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 524e83b7b99Saartbik else if (kind == "min") 525322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 526563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 527e9628955SAart Bik else if (kind == "max" && 528e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 529322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 530563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 531e83b7b99Saartbik else if (kind == "max") 532322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 533563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 534e83b7b99Saartbik else if (kind == "and") 535322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 536563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 537e83b7b99Saartbik else if (kind == "or") 538322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 539563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 540e83b7b99Saartbik else if (kind == "xor") 541322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 542563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 543e83b7b99Saartbik else 5443145427dSRiver Riddle return failure(); 5453145427dSRiver Riddle return success(); 546dcec2ca5SChristian Sigg } 547e83b7b99Saartbik 548dcec2ca5SChristian Sigg if (!eltType.isa<FloatType>()) 549dcec2ca5SChristian Sigg return failure(); 550dcec2ca5SChristian Sigg 551e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 552e83b7b99Saartbik if (kind == "add") { 5530d924700Saartbik // Optional accumulator (or zero). 5540d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 5550d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 556563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5570d924700Saartbik rewriter.getZeroAttr(eltType)); 558322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 559563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 560ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 561e83b7b99Saartbik } else if (kind == "mul") { 5620d924700Saartbik // Optional accumulator (or one). 5630d924700Saartbik Value acc = operands.size() > 1 5640d924700Saartbik ? operands[1] 5650d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 566563879b6SRahul Joshi reductionOp->getLoc(), llvmType, 5670d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 568322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 569563879b6SRahul Joshi reductionOp, llvmType, acc, operands[0], 570ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 571e83b7b99Saartbik } else if (kind == "min") 572563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 573563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 574e83b7b99Saartbik else if (kind == "max") 575563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 576563879b6SRahul Joshi reductionOp, llvmType, operands[0]); 577e83b7b99Saartbik else 5783145427dSRiver Riddle return failure(); 5793145427dSRiver Riddle return success(); 580e83b7b99Saartbik } 581ceb1b327Saartbik 582ceb1b327Saartbik private: 583ceb1b327Saartbik const bool reassociateFPReductions; 584e83b7b99Saartbik }; 585e83b7b99Saartbik 586060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 587563879b6SRahul Joshi class VectorCreateMaskOpConversion 588563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { 589060c9dd1Saartbik public: 590563879b6SRahul Joshi explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, 591060c9dd1Saartbik bool enableIndexOpt) 592563879b6SRahul Joshi : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), 593060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 594060c9dd1Saartbik 595060c9dd1Saartbik LogicalResult 596563879b6SRahul Joshi matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, 597060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 5989eb3e564SChris Lattner auto dstType = op.getType(); 599060c9dd1Saartbik int64_t rank = dstType.getRank(); 600060c9dd1Saartbik if (rank == 1) { 601060c9dd1Saartbik rewriter.replaceOp( 602060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 603060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 604060c9dd1Saartbik return success(); 605060c9dd1Saartbik } 606060c9dd1Saartbik return failure(); 607060c9dd1Saartbik } 608060c9dd1Saartbik 609060c9dd1Saartbik private: 610060c9dd1Saartbik const bool enableIndexOptimizations; 611060c9dd1Saartbik }; 612060c9dd1Saartbik 613563879b6SRahul Joshi class VectorShuffleOpConversion 614563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ShuffleOp> { 6151c81adf3SAart Bik public: 616563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern; 6171c81adf3SAart Bik 6183145427dSRiver Riddle LogicalResult 619563879b6SRahul Joshi matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands, 6201c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 621563879b6SRahul Joshi auto loc = shuffleOp->getLoc(); 6222d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6231c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6241c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6251c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 626dcec2ca5SChristian Sigg Type llvmType = typeConverter->convertType(vectorType); 6271c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6281c81adf3SAart Bik 6291c81adf3SAart Bik // Bail if result type cannot be lowered. 6301c81adf3SAart Bik if (!llvmType) 6313145427dSRiver Riddle return failure(); 6321c81adf3SAart Bik 6331c81adf3SAart Bik // Get rank and dimension sizes. 6341c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6351c81adf3SAart Bik assert(v1Type.getRank() == rank); 6361c81adf3SAart Bik assert(v2Type.getRank() == rank); 6371c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6381c81adf3SAart Bik 6391c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6401c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6411c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 642563879b6SRahul Joshi Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( 6431c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 644563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, llvmShuffleOp); 6453145427dSRiver Riddle return success(); 646b36aaeafSAart Bik } 647b36aaeafSAart Bik 6481c81adf3SAart Bik // For all other cases, insert the individual values individually. 649e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 6501c81adf3SAart Bik int64_t insPos = 0; 6511c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 6521c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 653e62a6956SRiver Riddle Value value = adaptor.v1(); 6541c81adf3SAart Bik if (extPos >= v1Dim) { 6551c81adf3SAart Bik extPos -= v1Dim; 6561c81adf3SAart Bik value = adaptor.v2(); 657b36aaeafSAart Bik } 658dcec2ca5SChristian Sigg Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, 659dcec2ca5SChristian Sigg llvmType, rank, extPos); 660dcec2ca5SChristian Sigg insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, 6610f04384dSAlex Zinenko llvmType, rank, insPos++); 6621c81adf3SAart Bik } 663563879b6SRahul Joshi rewriter.replaceOp(shuffleOp, insert); 6643145427dSRiver Riddle return success(); 665b36aaeafSAart Bik } 666b36aaeafSAart Bik }; 667b36aaeafSAart Bik 668563879b6SRahul Joshi class VectorExtractElementOpConversion 669563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { 670cd5dab8aSAart Bik public: 671563879b6SRahul Joshi using ConvertOpToLLVMPattern< 672563879b6SRahul Joshi vector::ExtractElementOp>::ConvertOpToLLVMPattern; 673cd5dab8aSAart Bik 6743145427dSRiver Riddle LogicalResult 675563879b6SRahul Joshi matchAndRewrite(vector::ExtractElementOp extractEltOp, 676563879b6SRahul Joshi ArrayRef<Value> operands, 677cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 6782d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 679cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 680dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType.getElementType()); 681cd5dab8aSAart Bik 682cd5dab8aSAart Bik // Bail if result type cannot be lowered. 683cd5dab8aSAart Bik if (!llvmType) 6843145427dSRiver Riddle return failure(); 685cd5dab8aSAart Bik 686cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 687563879b6SRahul Joshi extractEltOp, llvmType, adaptor.vector(), adaptor.position()); 6883145427dSRiver Riddle return success(); 689cd5dab8aSAart Bik } 690cd5dab8aSAart Bik }; 691cd5dab8aSAart Bik 692563879b6SRahul Joshi class VectorExtractOpConversion 693563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::ExtractOp> { 6945c0c51a9SNicolas Vasilache public: 695563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern; 6965c0c51a9SNicolas Vasilache 6973145427dSRiver Riddle LogicalResult 698563879b6SRahul Joshi matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 6995c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 700563879b6SRahul Joshi auto loc = extractOp->getLoc(); 7012d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 7029826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7032bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 704dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(resultType); 7055c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7069826fe5cSAart Bik 7079826fe5cSAart Bik // Bail if result type cannot be lowered. 7089826fe5cSAart Bik if (!llvmResultType) 7093145427dSRiver Riddle return failure(); 7109826fe5cSAart Bik 7115c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7125c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 713e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7145c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 715563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7163145427dSRiver Riddle return success(); 7175c0c51a9SNicolas Vasilache } 7185c0c51a9SNicolas Vasilache 7199826fe5cSAart Bik // Potential extraction of 1-D vector from array. 720563879b6SRahul Joshi auto *context = extractOp->getContext(); 721e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7225c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7235c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7249826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7255c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7265c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7275c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 728dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 7295c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7305c0c51a9SNicolas Vasilache } 7315c0c51a9SNicolas Vasilache 7325c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7335c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7342230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 7351d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7365c0c51a9SNicolas Vasilache extracted = 7375c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 738563879b6SRahul Joshi rewriter.replaceOp(extractOp, extracted); 7395c0c51a9SNicolas Vasilache 7403145427dSRiver Riddle return success(); 7415c0c51a9SNicolas Vasilache } 7425c0c51a9SNicolas Vasilache }; 7435c0c51a9SNicolas Vasilache 744681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 745681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 746681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 747681f929fSNicolas Vasilache /// 748681f929fSNicolas Vasilache /// Example: 749681f929fSNicolas Vasilache /// ``` 750681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 751681f929fSNicolas Vasilache /// ``` 752681f929fSNicolas Vasilache /// is converted to: 753681f929fSNicolas Vasilache /// ``` 7543bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 755dd5165a9SAlex Zinenko /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">) 756dd5165a9SAlex Zinenko /// -> !llvm."<8 x f32>"> 757681f929fSNicolas Vasilache /// ``` 758563879b6SRahul Joshi class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> { 759681f929fSNicolas Vasilache public: 760563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern; 761681f929fSNicolas Vasilache 7623145427dSRiver Riddle LogicalResult 763563879b6SRahul Joshi matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 764681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7652d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 766681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 767681f929fSNicolas Vasilache if (vType.getRank() != 1) 7683145427dSRiver Riddle return failure(); 769563879b6SRahul Joshi rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(), 7703bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 7713145427dSRiver Riddle return success(); 772681f929fSNicolas Vasilache } 773681f929fSNicolas Vasilache }; 774681f929fSNicolas Vasilache 775563879b6SRahul Joshi class VectorInsertElementOpConversion 776563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertElementOp> { 777cd5dab8aSAart Bik public: 778563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; 779cd5dab8aSAart Bik 7803145427dSRiver Riddle LogicalResult 781563879b6SRahul Joshi matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands, 782cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7832d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 784cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 785dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType(vectorType); 786cd5dab8aSAart Bik 787cd5dab8aSAart Bik // Bail if result type cannot be lowered. 788cd5dab8aSAart Bik if (!llvmType) 7893145427dSRiver Riddle return failure(); 790cd5dab8aSAart Bik 791cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 792563879b6SRahul Joshi insertEltOp, llvmType, adaptor.dest(), adaptor.source(), 793563879b6SRahul Joshi adaptor.position()); 7943145427dSRiver Riddle return success(); 795cd5dab8aSAart Bik } 796cd5dab8aSAart Bik }; 797cd5dab8aSAart Bik 798563879b6SRahul Joshi class VectorInsertOpConversion 799563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::InsertOp> { 8009826fe5cSAart Bik public: 801563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern; 8029826fe5cSAart Bik 8033145427dSRiver Riddle LogicalResult 804563879b6SRahul Joshi matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 8059826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 806563879b6SRahul Joshi auto loc = insertOp->getLoc(); 8072d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8089826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8099826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 810dcec2ca5SChristian Sigg auto llvmResultType = typeConverter->convertType(destVectorType); 8119826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8129826fe5cSAart Bik 8139826fe5cSAart Bik // Bail if result type cannot be lowered. 8149826fe5cSAart Bik if (!llvmResultType) 8153145427dSRiver Riddle return failure(); 8169826fe5cSAart Bik 8179826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8189826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 819e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8209826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8219826fe5cSAart Bik positionArrayAttr); 822563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8233145427dSRiver Riddle return success(); 8249826fe5cSAart Bik } 8259826fe5cSAart Bik 8269826fe5cSAart Bik // Potential extraction of 1-D vector from array. 827563879b6SRahul Joshi auto *context = insertOp->getContext(); 828e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8299826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8309826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8319826fe5cSAart Bik auto oneDVectorType = destVectorType; 8329826fe5cSAart Bik if (positionAttrs.size() > 1) { 8339826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 8349826fe5cSAart Bik auto nMinusOnePositionAttrs = 8359826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8369826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 837dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8389826fe5cSAart Bik nMinusOnePositionAttrs); 8399826fe5cSAart Bik } 8409826fe5cSAart Bik 8419826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 8422230bf99SAlex Zinenko auto i64Type = IntegerType::get(rewriter.getContext(), 64); 8431d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 844e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 845dcec2ca5SChristian Sigg loc, typeConverter->convertType(oneDVectorType), extracted, 8460f04384dSAlex Zinenko adaptor.source(), constant); 8479826fe5cSAart Bik 8489826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 8499826fe5cSAart Bik if (positionAttrs.size() > 1) { 8509826fe5cSAart Bik auto nMinusOnePositionAttrs = 8519826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 8529826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 8539826fe5cSAart Bik adaptor.dest(), inserted, 8549826fe5cSAart Bik nMinusOnePositionAttrs); 8559826fe5cSAart Bik } 8569826fe5cSAart Bik 857563879b6SRahul Joshi rewriter.replaceOp(insertOp, inserted); 8583145427dSRiver Riddle return success(); 8599826fe5cSAart Bik } 8609826fe5cSAart Bik }; 8619826fe5cSAart Bik 862681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 863681f929fSNicolas Vasilache /// 864681f929fSNicolas Vasilache /// Example: 865681f929fSNicolas Vasilache /// ``` 866681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 867681f929fSNicolas Vasilache /// ``` 868681f929fSNicolas Vasilache /// is rewritten into: 869681f929fSNicolas Vasilache /// ``` 870681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 871681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 872681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 873681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 874681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 875681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 876681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 877681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 878681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 879681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 880681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 881681f929fSNicolas Vasilache /// // %r3 holds the final value. 882681f929fSNicolas Vasilache /// ``` 883681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 884681f929fSNicolas Vasilache public: 885681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 886681f929fSNicolas Vasilache 8873145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 888681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 889681f929fSNicolas Vasilache auto vType = op.getVectorType(); 890681f929fSNicolas Vasilache if (vType.getRank() < 2) 8913145427dSRiver Riddle return failure(); 892681f929fSNicolas Vasilache 893681f929fSNicolas Vasilache auto loc = op.getLoc(); 894681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 895681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 896681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 897681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 898681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 899681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 900681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 901681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 902681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 903681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 904681f929fSNicolas Vasilache } 905681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9063145427dSRiver Riddle return success(); 907681f929fSNicolas Vasilache } 908681f929fSNicolas Vasilache }; 909681f929fSNicolas Vasilache 9102d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9112d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9122d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9132d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9142d515e49SNicolas Vasilache // rank. 9152d515e49SNicolas Vasilache // 9162d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9172d515e49SNicolas Vasilache // have different ranks. In this case: 9182d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9192d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9202d515e49SNicolas Vasilache // destination subvector 9212d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9222d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9232d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9242d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9252d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9262d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9272d515e49SNicolas Vasilache public: 9282d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9292d515e49SNicolas Vasilache 9303145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9312d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9322d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9332d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9342d515e49SNicolas Vasilache 9352d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9363145427dSRiver Riddle return failure(); 9372d515e49SNicolas Vasilache 9382d515e49SNicolas Vasilache auto loc = op.getLoc(); 9392d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9402d515e49SNicolas Vasilache assert(rankDiff >= 0); 9412d515e49SNicolas Vasilache if (rankDiff == 0) 9423145427dSRiver Riddle return failure(); 9432d515e49SNicolas Vasilache 9442d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 9452d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 9462d515e49SNicolas Vasilache // on it. 9472d515e49SNicolas Vasilache Value extracted = 9482d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 9492d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 950dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9512d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 9522d515e49SNicolas Vasilache // ranks. 9532d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 9542d515e49SNicolas Vasilache loc, op.source(), extracted, 9552d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 956c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 9572d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 9582d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 9592d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 960dcec2ca5SChristian Sigg /*dropBack=*/rankRest)); 9613145427dSRiver Riddle return success(); 9622d515e49SNicolas Vasilache } 9632d515e49SNicolas Vasilache }; 9642d515e49SNicolas Vasilache 9652d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9662d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 9672d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9682d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9692d515e49SNicolas Vasilache // destination subvector 9702d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9712d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9722d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9732d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9742d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 9752d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9762d515e49SNicolas Vasilache public: 977b99bd771SRiver Riddle VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) 978b99bd771SRiver Riddle : OpRewritePattern<InsertStridedSliceOp>(ctx) { 979b99bd771SRiver Riddle // This pattern creates recursive InsertStridedSliceOp, but the recursion is 980b99bd771SRiver Riddle // bounded as the rank is strictly decreasing. 981b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 982b99bd771SRiver Riddle } 9832d515e49SNicolas Vasilache 9843145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9852d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 9862d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 9872d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 9882d515e49SNicolas Vasilache 9892d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 9903145427dSRiver Riddle return failure(); 9912d515e49SNicolas Vasilache 9922d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 9932d515e49SNicolas Vasilache assert(rankDiff >= 0); 9942d515e49SNicolas Vasilache if (rankDiff != 0) 9953145427dSRiver Riddle return failure(); 9962d515e49SNicolas Vasilache 9972d515e49SNicolas Vasilache if (srcType == dstType) { 9982d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 9993145427dSRiver Riddle return success(); 10002d515e49SNicolas Vasilache } 10012d515e49SNicolas Vasilache 10022d515e49SNicolas Vasilache int64_t offset = 10032d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10042d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10052d515e49SNicolas Vasilache int64_t stride = 10062d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10072d515e49SNicolas Vasilache 10082d515e49SNicolas Vasilache auto loc = op.getLoc(); 10092d515e49SNicolas Vasilache Value res = op.dest(); 10102d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10112d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10122d515e49SNicolas Vasilache off += stride, ++idx) { 10132d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10142d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10152d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10162d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10172d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10182d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10192d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10202d515e49SNicolas Vasilache // smaller rank. 1021bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10222d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10232d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10242d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10252d515e49SNicolas Vasilache } 10262d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10272d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10282d515e49SNicolas Vasilache } 10292d515e49SNicolas Vasilache 10302d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10313145427dSRiver Riddle return success(); 10322d515e49SNicolas Vasilache } 10332d515e49SNicolas Vasilache }; 10342d515e49SNicolas Vasilache 103530e6033bSNicolas Vasilache /// Returns the strides if the memory underlying `memRefType` has a contiguous 103630e6033bSNicolas Vasilache /// static layout. 103730e6033bSNicolas Vasilache static llvm::Optional<SmallVector<int64_t, 4>> 103830e6033bSNicolas Vasilache computeContiguousStrides(MemRefType memRefType) { 10392bf491c7SBenjamin Kramer int64_t offset; 104030e6033bSNicolas Vasilache SmallVector<int64_t, 4> strides; 104130e6033bSNicolas Vasilache if (failed(getStridesAndOffset(memRefType, strides, offset))) 104230e6033bSNicolas Vasilache return None; 104330e6033bSNicolas Vasilache if (!strides.empty() && strides.back() != 1) 104430e6033bSNicolas Vasilache return None; 104530e6033bSNicolas Vasilache // If no layout or identity layout, this is contiguous by definition. 104630e6033bSNicolas Vasilache if (memRefType.getAffineMaps().empty() || 104730e6033bSNicolas Vasilache memRefType.getAffineMaps().front().isIdentity()) 104830e6033bSNicolas Vasilache return strides; 104930e6033bSNicolas Vasilache 105030e6033bSNicolas Vasilache // Otherwise, we must determine contiguity form shapes. This can only ever 105130e6033bSNicolas Vasilache // work in static cases because MemRefType is underspecified to represent 105230e6033bSNicolas Vasilache // contiguous dynamic shapes in other ways than with just empty/identity 105330e6033bSNicolas Vasilache // layout. 10542bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 10552bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 105630e6033bSNicolas Vasilache if (ShapedType::isDynamic(sizes[index + 1]) || 105730e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index]) || 105830e6033bSNicolas Vasilache ShapedType::isDynamicStrideOrOffset(strides[index + 1])) 105930e6033bSNicolas Vasilache return None; 106030e6033bSNicolas Vasilache if (strides[index] != strides[index + 1] * sizes[index + 1]) 106130e6033bSNicolas Vasilache return None; 10622bf491c7SBenjamin Kramer } 106330e6033bSNicolas Vasilache return strides; 10642bf491c7SBenjamin Kramer } 10652bf491c7SBenjamin Kramer 1066563879b6SRahul Joshi class VectorTypeCastOpConversion 1067563879b6SRahul Joshi : public ConvertOpToLLVMPattern<vector::TypeCastOp> { 10685c0c51a9SNicolas Vasilache public: 1069563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; 10705c0c51a9SNicolas Vasilache 10713145427dSRiver Riddle LogicalResult 1072563879b6SRahul Joshi matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, 10735c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 1074563879b6SRahul Joshi auto loc = castOp->getLoc(); 10755c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 10762bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 10779eb3e564SChris Lattner MemRefType targetMemRefType = castOp.getType(); 10785c0c51a9SNicolas Vasilache 10795c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 10805c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 10815c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 10823145427dSRiver Riddle return failure(); 10835c0c51a9SNicolas Vasilache 10845c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 10858de43b92SAlex Zinenko operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); 10868de43b92SAlex Zinenko if (!llvmSourceDescriptorTy) 10873145427dSRiver Riddle return failure(); 10885c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 10895c0c51a9SNicolas Vasilache 1090dcec2ca5SChristian Sigg auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) 10918de43b92SAlex Zinenko .dyn_cast_or_null<LLVM::LLVMStructType>(); 10928de43b92SAlex Zinenko if (!llvmTargetDescriptorTy) 10933145427dSRiver Riddle return failure(); 10945c0c51a9SNicolas Vasilache 109530e6033bSNicolas Vasilache // Only contiguous source buffers supported atm. 109630e6033bSNicolas Vasilache auto sourceStrides = computeContiguousStrides(sourceMemRefType); 109730e6033bSNicolas Vasilache if (!sourceStrides) 109830e6033bSNicolas Vasilache return failure(); 109930e6033bSNicolas Vasilache auto targetStrides = computeContiguousStrides(targetMemRefType); 110030e6033bSNicolas Vasilache if (!targetStrides) 110130e6033bSNicolas Vasilache return failure(); 110230e6033bSNicolas Vasilache // Only support static strides for now, regardless of contiguity. 110330e6033bSNicolas Vasilache if (llvm::any_of(*targetStrides, [](int64_t stride) { 110430e6033bSNicolas Vasilache return ShapedType::isDynamicStrideOrOffset(stride); 110530e6033bSNicolas Vasilache })) 11063145427dSRiver Riddle return failure(); 11075c0c51a9SNicolas Vasilache 11082230bf99SAlex Zinenko auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 11095c0c51a9SNicolas Vasilache 11105c0c51a9SNicolas Vasilache // Create descriptor. 11115c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11123a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11135c0c51a9SNicolas Vasilache // Set allocated ptr. 1114e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11155c0c51a9SNicolas Vasilache allocated = 11165c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11175c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11185c0c51a9SNicolas Vasilache // Set aligned ptr. 1119e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11205c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11215c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11225c0c51a9SNicolas Vasilache // Fill offset 0. 11235c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11245c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11255c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11265c0c51a9SNicolas Vasilache 11275c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11285c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11295c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11305c0c51a9SNicolas Vasilache auto sizeAttr = 11315c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11325c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11335c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 113430e6033bSNicolas Vasilache auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), 113530e6033bSNicolas Vasilache (*targetStrides)[index]); 11365c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11375c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11385c0c51a9SNicolas Vasilache } 11395c0c51a9SNicolas Vasilache 1140563879b6SRahul Joshi rewriter.replaceOp(castOp, {desc}); 11413145427dSRiver Riddle return success(); 11425c0c51a9SNicolas Vasilache } 11435c0c51a9SNicolas Vasilache }; 11445c0c51a9SNicolas Vasilache 11458345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 11468345b86dSNicolas Vasilache /// sequence of: 1147060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1148060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1149060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1150060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1151060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 11528345b86dSNicolas Vasilache template <typename ConcreteOp> 1153563879b6SRahul Joshi class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> { 11548345b86dSNicolas Vasilache public: 1155563879b6SRahul Joshi explicit VectorTransferConversion(LLVMTypeConverter &typeConv, 1156060c9dd1Saartbik bool enableIndexOpt) 1157563879b6SRahul Joshi : ConvertOpToLLVMPattern<ConcreteOp>(typeConv), 1158060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 11598345b86dSNicolas Vasilache 11608345b86dSNicolas Vasilache LogicalResult 1161563879b6SRahul Joshi matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands, 11628345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11638345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1164b2c79c50SNicolas Vasilache 1165b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1166b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 11678345b86dSNicolas Vasilache return failure(); 11685f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 11695f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 11705f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 1171563879b6SRahul Joshi xferOp->getContext())) 11728345b86dSNicolas Vasilache return failure(); 117326c8f908SThomas Raoux auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 117426c8f908SThomas Raoux if (!memRefType) 117526c8f908SThomas Raoux return failure(); 11762bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 117726c8f908SThomas Raoux auto strides = computeContiguousStrides(memRefType); 117830e6033bSNicolas Vasilache if (!strides) 11792bf491c7SBenjamin Kramer return failure(); 11808345b86dSNicolas Vasilache 1181563879b6SRahul Joshi auto toLLVMTy = [&](Type t) { 1182563879b6SRahul Joshi return this->getTypeConverter()->convertType(t); 1183563879b6SRahul Joshi }; 11848345b86dSNicolas Vasilache 1185563879b6SRahul Joshi Location loc = xferOp->getLoc(); 11868345b86dSNicolas Vasilache 118768330ee0SThomas Raoux if (auto memrefVectorElementType = 118826c8f908SThomas Raoux memRefType.getElementType().template dyn_cast<VectorType>()) { 118968330ee0SThomas Raoux // Memref has vector element type. 119068330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 119168330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 119268330ee0SThomas Raoux return failure(); 11930de60b55SThomas Raoux #ifndef NDEBUG 119468330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 119568330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 119668330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 119768330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 119868330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 119968330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 120068330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 120168330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 120268330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 120368330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 120468330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 120568330ee0SThomas Raoux "result shape."); 12060de60b55SThomas Raoux #endif // ifndef NDEBUG 120768330ee0SThomas Raoux } 120868330ee0SThomas Raoux 12098345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1210a57def30SAart Bik VectorType vtp = xferOp.getVectorType(); 1211563879b6SRahul Joshi Value dataPtr = this->getStridedElementPtr( 121226c8f908SThomas Raoux loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); 1213a57def30SAart Bik Value vectorDataPtr = 1214a57def30SAart Bik castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); 12158345b86dSNicolas Vasilache 12161870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 1217563879b6SRahul Joshi return replaceTransferOpWithLoadOrStore(rewriter, 1218563879b6SRahul Joshi *this->getTypeConverter(), loc, 1219563879b6SRahul Joshi xferOp, operands, vectorDataPtr); 12201870e787SNicolas Vasilache 12218345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12228345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12238345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12248345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1225060c9dd1Saartbik // 1226060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1227060c9dd1Saartbik // dimensions here. 1228*bd30a796SAlex Zinenko unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue(); 1229060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12300c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 123126c8f908SThomas Raoux Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex); 1232563879b6SRahul Joshi Value mask = buildVectorComparison( 1233563879b6SRahul Joshi rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); 12348345b86dSNicolas Vasilache 12358345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 1236563879b6SRahul Joshi return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, 1237dcec2ca5SChristian Sigg xferOp, operands, vectorDataPtr, mask); 12388345b86dSNicolas Vasilache } 1239060c9dd1Saartbik 1240060c9dd1Saartbik private: 1241060c9dd1Saartbik const bool enableIndexOptimizations; 12428345b86dSNicolas Vasilache }; 12438345b86dSNicolas Vasilache 1244563879b6SRahul Joshi class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> { 1245d9b500d3SAart Bik public: 1246563879b6SRahul Joshi using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern; 1247d9b500d3SAart Bik 1248d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1249d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1250d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1251d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1252d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1253d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1254d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1255d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1256d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1257d9b500d3SAart Bik // 12589db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1259d9b500d3SAart Bik // 12603145427dSRiver Riddle LogicalResult 1261563879b6SRahul Joshi matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands, 1262d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 12632d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1264d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1265d9b500d3SAart Bik 1266dcec2ca5SChristian Sigg if (typeConverter->convertType(printType) == nullptr) 12673145427dSRiver Riddle return failure(); 1268d9b500d3SAart Bik 1269b8880f5fSAart Bik // Make sure element type has runtime support. 1270b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1271d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1272d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1273d9b500d3SAart Bik Operation *printer; 1274b8880f5fSAart Bik if (eltType.isF32()) { 1275563879b6SRahul Joshi printer = getPrintFloat(printOp); 1276b8880f5fSAart Bik } else if (eltType.isF64()) { 1277563879b6SRahul Joshi printer = getPrintDouble(printOp); 127854759cefSAart Bik } else if (eltType.isIndex()) { 1279563879b6SRahul Joshi printer = getPrintU64(printOp); 1280b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1281b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1282b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1283b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1284b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1285b8880f5fSAart Bik if (intTy.isUnsigned()) { 128654759cefSAart Bik if (width <= 64) { 1287b8880f5fSAart Bik if (width < 64) 1288b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1289563879b6SRahul Joshi printer = getPrintU64(printOp); 1290b8880f5fSAart Bik } else { 12913145427dSRiver Riddle return failure(); 1292b8880f5fSAart Bik } 1293b8880f5fSAart Bik } else { 1294b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 129554759cefSAart Bik if (width <= 64) { 1296b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1297b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1298b8880f5fSAart Bik if (width == 1) 129954759cefSAart Bik conversion = PrintConversion::ZeroExt64; 130054759cefSAart Bik else if (width < 64) 1301b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1302563879b6SRahul Joshi printer = getPrintI64(printOp); 1303b8880f5fSAart Bik } else { 1304b8880f5fSAart Bik return failure(); 1305b8880f5fSAart Bik } 1306b8880f5fSAart Bik } 1307b8880f5fSAart Bik } else { 1308b8880f5fSAart Bik return failure(); 1309b8880f5fSAart Bik } 1310d9b500d3SAart Bik 1311d9b500d3SAart Bik // Unroll vector into elementary print calls. 1312b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1313563879b6SRahul Joshi emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, 1314b8880f5fSAart Bik conversion); 1315563879b6SRahul Joshi emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); 1316563879b6SRahul Joshi rewriter.eraseOp(printOp); 13173145427dSRiver Riddle return success(); 1318d9b500d3SAart Bik } 1319d9b500d3SAart Bik 1320d9b500d3SAart Bik private: 1321b8880f5fSAart Bik enum class PrintConversion { 132230e6033bSNicolas Vasilache // clang-format off 1323b8880f5fSAart Bik None, 1324b8880f5fSAart Bik ZeroExt64, 1325b8880f5fSAart Bik SignExt64 132630e6033bSNicolas Vasilache // clang-format on 1327b8880f5fSAart Bik }; 1328b8880f5fSAart Bik 1329d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1330e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1331b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1332d9b500d3SAart Bik Location loc = op->getLoc(); 1333d9b500d3SAart Bik if (rank == 0) { 1334b8880f5fSAart Bik switch (conversion) { 1335b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1336b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 13372230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1338b8880f5fSAart Bik break; 1339b8880f5fSAart Bik case PrintConversion::SignExt64: 1340b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 13412230bf99SAlex Zinenko loc, value, IntegerType::get(rewriter.getContext(), 64)); 1342b8880f5fSAart Bik break; 1343b8880f5fSAart Bik case PrintConversion::None: 1344b8880f5fSAart Bik break; 1345c9eeeb38Saartbik } 1346d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1347d9b500d3SAart Bik return; 1348d9b500d3SAart Bik } 1349d9b500d3SAart Bik 1350d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1351d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1352d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1353d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1354d9b500d3SAart Bik auto reducedType = 1355d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 1356dcec2ca5SChristian Sigg auto llvmType = typeConverter->convertType( 1357d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1358dcec2ca5SChristian Sigg Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, 1359dcec2ca5SChristian Sigg llvmType, rank, d); 1360b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1361b8880f5fSAart Bik conversion); 1362d9b500d3SAart Bik if (d != dim - 1) 1363d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1364d9b500d3SAart Bik } 1365d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1366d9b500d3SAart Bik } 1367d9b500d3SAart Bik 1368d9b500d3SAart Bik // Helper to emit a call. 1369d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1370d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 137108e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1372d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1373d9b500d3SAart Bik } 1374d9b500d3SAart Bik 1375d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 13765446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 1377c69c9e0fSAlex Zinenko ArrayRef<Type> params) { 1378d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1379d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1380d9b500d3SAart Bik if (func) 1381d9b500d3SAart Bik return func; 1382d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1383d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1384d9b500d3SAart Bik op->getLoc(), name, 13857ed9cfc7SAlex Zinenko LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), 13867ed9cfc7SAlex Zinenko params)); 1387d9b500d3SAart Bik } 1388d9b500d3SAart Bik 1389d9b500d3SAart Bik // Helpers for method names. 1390e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 13912230bf99SAlex Zinenko return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64)); 1392e52414b1Saartbik } 1393b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 13942230bf99SAlex Zinenko return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64)); 1395b8880f5fSAart Bik } 1396d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 1397dd5165a9SAlex Zinenko return getPrint(op, "printF32", Float32Type::get(op->getContext())); 1398d9b500d3SAart Bik } 1399d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 1400dd5165a9SAlex Zinenko return getPrint(op, "printF64", Float64Type::get(op->getContext())); 1401d9b500d3SAart Bik } 1402d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 140354759cefSAart Bik return getPrint(op, "printOpen", {}); 1404d9b500d3SAart Bik } 1405d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 140654759cefSAart Bik return getPrint(op, "printClose", {}); 1407d9b500d3SAart Bik } 1408d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 140954759cefSAart Bik return getPrint(op, "printComma", {}); 1410d9b500d3SAart Bik } 1411d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 141254759cefSAart Bik return getPrint(op, "printNewline", {}); 1413d9b500d3SAart Bik } 1414d9b500d3SAart Bik }; 1415d9b500d3SAart Bik 1416334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1417c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1418c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1419c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1420334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 142165678d93SNicolas Vasilache public: 1422b99bd771SRiver Riddle VectorExtractStridedSliceOpConversion(MLIRContext *ctx) 1423b99bd771SRiver Riddle : OpRewritePattern<ExtractStridedSliceOp>(ctx) { 1424b99bd771SRiver Riddle // This pattern creates recursive ExtractStridedSliceOp, but the recursion 1425b99bd771SRiver Riddle // is bounded as the rank is strictly decreasing. 1426b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1427b99bd771SRiver Riddle } 142865678d93SNicolas Vasilache 1429334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 143065678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 14319eb3e564SChris Lattner auto dstType = op.getType(); 143265678d93SNicolas Vasilache 143365678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 143465678d93SNicolas Vasilache 143565678d93SNicolas Vasilache int64_t offset = 143665678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 143765678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 143865678d93SNicolas Vasilache int64_t stride = 143965678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 144065678d93SNicolas Vasilache 144165678d93SNicolas Vasilache auto loc = op.getLoc(); 144265678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 144335b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1444c3c95b9cSaartbik 1445c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1446c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1447c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1448c3c95b9cSaartbik offsets.reserve(size); 1449c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1450c3c95b9cSaartbik off += stride) 1451c3c95b9cSaartbik offsets.push_back(off); 1452c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1453c3c95b9cSaartbik op.vector(), 1454c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1455c3c95b9cSaartbik return success(); 1456c3c95b9cSaartbik } 1457c3c95b9cSaartbik 1458c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 145965678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 146065678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 146165678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 146265678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 146365678d93SNicolas Vasilache off += stride, ++idx) { 1464c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1465c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1466c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 146765678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 146865678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 146965678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 147065678d93SNicolas Vasilache } 1471c3c95b9cSaartbik rewriter.replaceOp(op, res); 14723145427dSRiver Riddle return success(); 147365678d93SNicolas Vasilache } 147465678d93SNicolas Vasilache }; 147565678d93SNicolas Vasilache 1476df186507SBenjamin Kramer } // namespace 1477df186507SBenjamin Kramer 14785c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 14795c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1480ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1481060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 148265678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 14838345b86dSNicolas Vasilache // clang-format off 1484681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1485681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 14862d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1487c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1488ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1489563879b6SRahul Joshi converter, reassociateFPReductions); 1490060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1491060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1492060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1493563879b6SRahul Joshi converter, enableIndexOptimizations); 14948345b86dSNicolas Vasilache patterns 1495ceb1b327Saartbik .insert<VectorShuffleOpConversion, 14968345b86dSNicolas Vasilache VectorExtractElementOpConversion, 14978345b86dSNicolas Vasilache VectorExtractOpConversion, 14988345b86dSNicolas Vasilache VectorFMAOp1DConversion, 14998345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15008345b86dSNicolas Vasilache VectorInsertOpConversion, 15018345b86dSNicolas Vasilache VectorPrintOpConversion, 150219dbb230Saartbik VectorTypeCastOpConversion, 150339379916Saartbik VectorMaskedLoadOpConversion, 150439379916Saartbik VectorMaskedStoreOpConversion, 150519dbb230Saartbik VectorGatherOpConversion, 1506e8dcf5f8Saartbik VectorScatterOpConversion, 1507e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1508563879b6SRahul Joshi VectorCompressStoreOpConversion>(converter); 15098345b86dSNicolas Vasilache // clang-format on 15105c0c51a9SNicolas Vasilache } 15115c0c51a9SNicolas Vasilache 151263b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 151363b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1514563879b6SRahul Joshi patterns.insert<VectorMatmulOpConversion>(converter); 1515563879b6SRahul Joshi patterns.insert<VectorFlatTransposeOpConversion>(converter); 151663b683a8SNicolas Vasilache } 1517