15c0c51a9SNicolas Vasilache //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 25c0c51a9SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65c0c51a9SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 85c0c51a9SNicolas Vasilache 965678d93SNicolas Vasilache #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10870c1fd4SAlex Zinenko 111834ad4aSRiver Riddle #include "../PassDetail.h" 125c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 135c0c51a9SNicolas Vasilache #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 145c0c51a9SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 1569d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h" 164d60f47bSRob Suderman #include "mlir/Dialect/Vector/VectorOps.h" 178345b86dSNicolas Vasilache #include "mlir/IR/AffineMap.h" 185c0c51a9SNicolas Vasilache #include "mlir/IR/Attributes.h" 195c0c51a9SNicolas Vasilache #include "mlir/IR/Builders.h" 205c0c51a9SNicolas Vasilache #include "mlir/IR/MLIRContext.h" 215c0c51a9SNicolas Vasilache #include "mlir/IR/Module.h" 225c0c51a9SNicolas Vasilache #include "mlir/IR/Operation.h" 235c0c51a9SNicolas Vasilache #include "mlir/IR/PatternMatch.h" 245c0c51a9SNicolas Vasilache #include "mlir/IR/StandardTypes.h" 255c0c51a9SNicolas Vasilache #include "mlir/IR/Types.h" 26ec1f4e7cSAlex Zinenko #include "mlir/Target/LLVMIR/TypeTranslation.h" 275c0c51a9SNicolas Vasilache #include "mlir/Transforms/DialectConversion.h" 285c0c51a9SNicolas Vasilache #include "mlir/Transforms/Passes.h" 295c0c51a9SNicolas Vasilache #include "llvm/IR/DerivedTypes.h" 305c0c51a9SNicolas Vasilache #include "llvm/IR/Module.h" 315c0c51a9SNicolas Vasilache #include "llvm/IR/Type.h" 325c0c51a9SNicolas Vasilache #include "llvm/Support/Allocator.h" 335c0c51a9SNicolas Vasilache #include "llvm/Support/ErrorHandling.h" 345c0c51a9SNicolas Vasilache 355c0c51a9SNicolas Vasilache using namespace mlir; 3665678d93SNicolas Vasilache using namespace mlir::vector; 375c0c51a9SNicolas Vasilache 389826fe5cSAart Bik // Helper to reduce vector type by one rank at front. 399826fe5cSAart Bik static VectorType reducedVectorTypeFront(VectorType tp) { 409826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 419826fe5cSAart Bik return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 429826fe5cSAart Bik } 439826fe5cSAart Bik 449826fe5cSAart Bik // Helper to reduce vector type by *all* but one rank at back. 459826fe5cSAart Bik static VectorType reducedVectorTypeBack(VectorType tp) { 469826fe5cSAart Bik assert((tp.getRank() > 1) && "unlowerable vector type"); 479826fe5cSAart Bik return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 489826fe5cSAart Bik } 499826fe5cSAart Bik 501c81adf3SAart Bik // Helper that picks the proper sequence for inserting. 51e62a6956SRiver Riddle static Value insertOne(ConversionPatternRewriter &rewriter, 520f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 530f04384dSAlex Zinenko Value val1, Value val2, Type llvmType, int64_t rank, 540f04384dSAlex Zinenko int64_t pos) { 551c81adf3SAart Bik if (rank == 1) { 561c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 571c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 580f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 591c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 601c81adf3SAart Bik return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 611c81adf3SAart Bik constant); 621c81adf3SAart Bik } 631c81adf3SAart Bik return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 641c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 651c81adf3SAart Bik } 661c81adf3SAart Bik 672d515e49SNicolas Vasilache // Helper that picks the proper sequence for inserting. 682d515e49SNicolas Vasilache static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 692d515e49SNicolas Vasilache Value into, int64_t offset) { 702d515e49SNicolas Vasilache auto vectorType = into.getType().cast<VectorType>(); 712d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 722d515e49SNicolas Vasilache return rewriter.create<InsertOp>(loc, from, into, offset); 732d515e49SNicolas Vasilache return rewriter.create<vector::InsertElementOp>( 742d515e49SNicolas Vasilache loc, vectorType, from, into, 752d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 762d515e49SNicolas Vasilache } 772d515e49SNicolas Vasilache 781c81adf3SAart Bik // Helper that picks the proper sequence for extracting. 79e62a6956SRiver Riddle static Value extractOne(ConversionPatternRewriter &rewriter, 800f04384dSAlex Zinenko LLVMTypeConverter &typeConverter, Location loc, 810f04384dSAlex Zinenko Value val, Type llvmType, int64_t rank, int64_t pos) { 821c81adf3SAart Bik if (rank == 1) { 831c81adf3SAart Bik auto idxType = rewriter.getIndexType(); 841c81adf3SAart Bik auto constant = rewriter.create<LLVM::ConstantOp>( 850f04384dSAlex Zinenko loc, typeConverter.convertType(idxType), 861c81adf3SAart Bik rewriter.getIntegerAttr(idxType, pos)); 871c81adf3SAart Bik return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, 881c81adf3SAart Bik constant); 891c81adf3SAart Bik } 901c81adf3SAart Bik return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val, 911c81adf3SAart Bik rewriter.getI64ArrayAttr(pos)); 921c81adf3SAart Bik } 931c81adf3SAart Bik 942d515e49SNicolas Vasilache // Helper that picks the proper sequence for extracting. 952d515e49SNicolas Vasilache static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 962d515e49SNicolas Vasilache int64_t offset) { 972d515e49SNicolas Vasilache auto vectorType = vector.getType().cast<VectorType>(); 982d515e49SNicolas Vasilache if (vectorType.getRank() > 1) 992d515e49SNicolas Vasilache return rewriter.create<ExtractOp>(loc, vector, offset); 1002d515e49SNicolas Vasilache return rewriter.create<vector::ExtractElementOp>( 1012d515e49SNicolas Vasilache loc, vectorType.getElementType(), vector, 1022d515e49SNicolas Vasilache rewriter.create<ConstantIndexOp>(loc, offset)); 1032d515e49SNicolas Vasilache } 1042d515e49SNicolas Vasilache 1052d515e49SNicolas Vasilache // Helper that returns a subset of `arrayAttr` as a vector of int64_t. 1069db53a18SRiver Riddle // TODO: Better support for attribute subtype forwarding + slicing. 1072d515e49SNicolas Vasilache static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, 1082d515e49SNicolas Vasilache unsigned dropFront = 0, 1092d515e49SNicolas Vasilache unsigned dropBack = 0) { 1102d515e49SNicolas Vasilache assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); 1112d515e49SNicolas Vasilache auto range = arrayAttr.getAsRange<IntegerAttr>(); 1122d515e49SNicolas Vasilache SmallVector<int64_t, 4> res; 1132d515e49SNicolas Vasilache res.reserve(arrayAttr.size() - dropFront - dropBack); 1142d515e49SNicolas Vasilache for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; 1152d515e49SNicolas Vasilache it != eit; ++it) 1162d515e49SNicolas Vasilache res.push_back((*it).getValue().getSExtValue()); 1172d515e49SNicolas Vasilache return res; 1182d515e49SNicolas Vasilache } 1192d515e49SNicolas Vasilache 120060c9dd1Saartbik // Helper that returns a vector comparison that constructs a mask: 121060c9dd1Saartbik // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 122060c9dd1Saartbik // 123060c9dd1Saartbik // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 124060c9dd1Saartbik // much more compact, IR for this operation, but LLVM eventually 125060c9dd1Saartbik // generates more elaborate instructions for this intrinsic since it 126060c9dd1Saartbik // is very conservative on the boundary conditions. 127060c9dd1Saartbik static Value buildVectorComparison(ConversionPatternRewriter &rewriter, 128060c9dd1Saartbik Operation *op, bool enableIndexOptimizations, 129060c9dd1Saartbik int64_t dim, Value b, Value *off = nullptr) { 130060c9dd1Saartbik auto loc = op->getLoc(); 131060c9dd1Saartbik // If we can assume all indices fit in 32-bit, we perform the vector 132060c9dd1Saartbik // comparison in 32-bit to get a higher degree of SIMD parallelism. 133060c9dd1Saartbik // Otherwise we perform the vector comparison using 64-bit indices. 134060c9dd1Saartbik Value indices; 135060c9dd1Saartbik Type idxType; 136060c9dd1Saartbik if (enableIndexOptimizations) { 1370c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1380c2a4d3cSBenjamin Kramer loc, rewriter.getI32VectorAttr( 1390c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)))); 140060c9dd1Saartbik idxType = rewriter.getI32Type(); 141060c9dd1Saartbik } else { 1420c2a4d3cSBenjamin Kramer indices = rewriter.create<ConstantOp>( 1430c2a4d3cSBenjamin Kramer loc, rewriter.getI64VectorAttr( 1440c2a4d3cSBenjamin Kramer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)))); 145060c9dd1Saartbik idxType = rewriter.getI64Type(); 146060c9dd1Saartbik } 147060c9dd1Saartbik // Add in an offset if requested. 148060c9dd1Saartbik if (off) { 149060c9dd1Saartbik Value o = rewriter.create<IndexCastOp>(loc, idxType, *off); 150060c9dd1Saartbik Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o); 151060c9dd1Saartbik indices = rewriter.create<AddIOp>(loc, ov, indices); 152060c9dd1Saartbik } 153060c9dd1Saartbik // Construct the vector comparison. 154060c9dd1Saartbik Value bound = rewriter.create<IndexCastOp>(loc, idxType, b); 155060c9dd1Saartbik Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); 156060c9dd1Saartbik return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds); 157060c9dd1Saartbik } 158060c9dd1Saartbik 15919dbb230Saartbik // Helper that returns data layout alignment of an operation with memref. 16019dbb230Saartbik template <typename T> 16119dbb230Saartbik LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, 16219dbb230Saartbik unsigned &align) { 1635f9e0466SNicolas Vasilache Type elementTy = 16419dbb230Saartbik typeConverter.convertType(op.getMemRefType().getElementType()); 1655f9e0466SNicolas Vasilache if (!elementTy) 1665f9e0466SNicolas Vasilache return failure(); 1675f9e0466SNicolas Vasilache 168b2ab375dSAlex Zinenko // TODO: this should use the MLIR data layout when it becomes available and 169b2ab375dSAlex Zinenko // stop depending on translation. 17087a89e0fSAlex Zinenko llvm::LLVMContext llvmContext; 17187a89e0fSAlex Zinenko align = LLVM::TypeToLLVMIRTranslator(llvmContext) 172b2ab375dSAlex Zinenko .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(), 173168213f9SAlex Zinenko typeConverter.getDataLayout()); 1745f9e0466SNicolas Vasilache return success(); 1755f9e0466SNicolas Vasilache } 1765f9e0466SNicolas Vasilache 177e8dcf5f8Saartbik // Helper that returns the base address of a memref. 178b98e25b6SBenjamin Kramer static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, 179e8dcf5f8Saartbik Value memref, MemRefType memRefType, Value &base) { 18019dbb230Saartbik // Inspect stride and offset structure. 18119dbb230Saartbik // 18219dbb230Saartbik // TODO: flat memory only for now, generalize 18319dbb230Saartbik // 18419dbb230Saartbik int64_t offset; 18519dbb230Saartbik SmallVector<int64_t, 4> strides; 18619dbb230Saartbik auto successStrides = getStridesAndOffset(memRefType, strides, offset); 18719dbb230Saartbik if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || 18819dbb230Saartbik offset != 0 || memRefType.getMemorySpace() != 0) 18919dbb230Saartbik return failure(); 190e8dcf5f8Saartbik base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); 191e8dcf5f8Saartbik return success(); 192e8dcf5f8Saartbik } 19319dbb230Saartbik 194e8dcf5f8Saartbik // Helper that returns a pointer given a memref base. 195b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 196b98e25b6SBenjamin Kramer Location loc, Value memref, 197b98e25b6SBenjamin Kramer MemRefType memRefType, Value &ptr) { 198e8dcf5f8Saartbik Value base; 199e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 200e8dcf5f8Saartbik return failure(); 2013a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 202e8dcf5f8Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 203e8dcf5f8Saartbik return success(); 204e8dcf5f8Saartbik } 205e8dcf5f8Saartbik 20639379916Saartbik // Helper that returns a bit-casted pointer given a memref base. 207b98e25b6SBenjamin Kramer static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, 208b98e25b6SBenjamin Kramer Location loc, Value memref, 209b98e25b6SBenjamin Kramer MemRefType memRefType, Type type, Value &ptr) { 21039379916Saartbik Value base; 21139379916Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 21239379916Saartbik return failure(); 21339379916Saartbik auto pType = type.template cast<LLVM::LLVMType>().getPointerTo(); 21439379916Saartbik base = rewriter.create<LLVM::BitcastOp>(loc, pType, base); 21539379916Saartbik ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base); 21639379916Saartbik return success(); 21739379916Saartbik } 21839379916Saartbik 219e8dcf5f8Saartbik // Helper that returns vector of pointers given a memref base and an index 220e8dcf5f8Saartbik // vector. 221b98e25b6SBenjamin Kramer static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, 222b98e25b6SBenjamin Kramer Location loc, Value memref, Value indices, 223b98e25b6SBenjamin Kramer MemRefType memRefType, VectorType vType, 224b98e25b6SBenjamin Kramer Type iType, Value &ptrs) { 225e8dcf5f8Saartbik Value base; 226e8dcf5f8Saartbik if (failed(getBase(rewriter, loc, memref, memRefType, base))) 227e8dcf5f8Saartbik return failure(); 2283a577f54SChristian Sigg auto pType = MemRefDescriptor(memref).getElementPtrType(); 229e8dcf5f8Saartbik auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); 2301485fd29Saartbik ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices); 23119dbb230Saartbik return success(); 23219dbb230Saartbik } 23319dbb230Saartbik 2345f9e0466SNicolas Vasilache static LogicalResult 2355f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2365f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2375f9e0466SNicolas Vasilache TransferReadOp xferOp, 2385f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 239affbc0cdSNicolas Vasilache unsigned align; 24019dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 241affbc0cdSNicolas Vasilache return failure(); 242affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align); 2435f9e0466SNicolas Vasilache return success(); 2445f9e0466SNicolas Vasilache } 2455f9e0466SNicolas Vasilache 2465f9e0466SNicolas Vasilache static LogicalResult 2475f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2485f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2495f9e0466SNicolas Vasilache TransferReadOp xferOp, ArrayRef<Value> operands, 2505f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2515f9e0466SNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 2525f9e0466SNicolas Vasilache VectorType fillType = xferOp.getVectorType(); 2535f9e0466SNicolas Vasilache Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding()); 2545f9e0466SNicolas Vasilache fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill); 2555f9e0466SNicolas Vasilache 2565f9e0466SNicolas Vasilache Type vecTy = typeConverter.convertType(xferOp.getVectorType()); 2575f9e0466SNicolas Vasilache if (!vecTy) 2585f9e0466SNicolas Vasilache return failure(); 2595f9e0466SNicolas Vasilache 2605f9e0466SNicolas Vasilache unsigned align; 26119dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2625f9e0466SNicolas Vasilache return failure(); 2635f9e0466SNicolas Vasilache 2645f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 2655f9e0466SNicolas Vasilache xferOp, vecTy, dataPtr, mask, ValueRange{fill}, 2665f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2675f9e0466SNicolas Vasilache return success(); 2685f9e0466SNicolas Vasilache } 2695f9e0466SNicolas Vasilache 2705f9e0466SNicolas Vasilache static LogicalResult 2715f9e0466SNicolas Vasilache replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, 2725f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2735f9e0466SNicolas Vasilache TransferWriteOp xferOp, 2745f9e0466SNicolas Vasilache ArrayRef<Value> operands, Value dataPtr) { 275affbc0cdSNicolas Vasilache unsigned align; 27619dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 277affbc0cdSNicolas Vasilache return failure(); 2782d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 279affbc0cdSNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr, 280affbc0cdSNicolas Vasilache align); 2815f9e0466SNicolas Vasilache return success(); 2825f9e0466SNicolas Vasilache } 2835f9e0466SNicolas Vasilache 2845f9e0466SNicolas Vasilache static LogicalResult 2855f9e0466SNicolas Vasilache replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, 2865f9e0466SNicolas Vasilache LLVMTypeConverter &typeConverter, Location loc, 2875f9e0466SNicolas Vasilache TransferWriteOp xferOp, ArrayRef<Value> operands, 2885f9e0466SNicolas Vasilache Value dataPtr, Value mask) { 2895f9e0466SNicolas Vasilache unsigned align; 29019dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, xferOp, align))) 2915f9e0466SNicolas Vasilache return failure(); 2925f9e0466SNicolas Vasilache 2932d2c73c5SJacques Pienaar auto adaptor = TransferWriteOpAdaptor(operands); 2945f9e0466SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 2955f9e0466SNicolas Vasilache xferOp, adaptor.vector(), dataPtr, mask, 2965f9e0466SNicolas Vasilache rewriter.getI32IntegerAttr(align)); 2975f9e0466SNicolas Vasilache return success(); 2985f9e0466SNicolas Vasilache } 2995f9e0466SNicolas Vasilache 3002d2c73c5SJacques Pienaar static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, 3012d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 3022d2c73c5SJacques Pienaar return TransferReadOpAdaptor(operands); 3035f9e0466SNicolas Vasilache } 3045f9e0466SNicolas Vasilache 3052d2c73c5SJacques Pienaar static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, 3062d2c73c5SJacques Pienaar ArrayRef<Value> operands) { 3072d2c73c5SJacques Pienaar return TransferWriteOpAdaptor(operands); 3085f9e0466SNicolas Vasilache } 3095f9e0466SNicolas Vasilache 31090c01357SBenjamin Kramer namespace { 311e83b7b99Saartbik 31263b683a8SNicolas Vasilache /// Conversion pattern for a vector.matrix_multiply. 31363b683a8SNicolas Vasilache /// This is lowered directly to the proper llvm.intr.matrix.multiply. 31463b683a8SNicolas Vasilache class VectorMatmulOpConversion : public ConvertToLLVMPattern { 31563b683a8SNicolas Vasilache public: 31663b683a8SNicolas Vasilache explicit VectorMatmulOpConversion(MLIRContext *context, 31763b683a8SNicolas Vasilache LLVMTypeConverter &typeConverter) 31863b683a8SNicolas Vasilache : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, 31963b683a8SNicolas Vasilache typeConverter) {} 32063b683a8SNicolas Vasilache 3213145427dSRiver Riddle LogicalResult 32263b683a8SNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 32363b683a8SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 32463b683a8SNicolas Vasilache auto matmulOp = cast<vector::MatmulOp>(op); 3252d2c73c5SJacques Pienaar auto adaptor = vector::MatmulOpAdaptor(operands); 32663b683a8SNicolas Vasilache rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( 32763b683a8SNicolas Vasilache op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), 32863b683a8SNicolas Vasilache adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), 32963b683a8SNicolas Vasilache matmulOp.rhs_columns()); 3303145427dSRiver Riddle return success(); 33163b683a8SNicolas Vasilache } 33263b683a8SNicolas Vasilache }; 33363b683a8SNicolas Vasilache 334c295a65dSaartbik /// Conversion pattern for a vector.flat_transpose. 335c295a65dSaartbik /// This is lowered directly to the proper llvm.intr.matrix.transpose. 336c295a65dSaartbik class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { 337c295a65dSaartbik public: 338c295a65dSaartbik explicit VectorFlatTransposeOpConversion(MLIRContext *context, 339c295a65dSaartbik LLVMTypeConverter &typeConverter) 340c295a65dSaartbik : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), 341c295a65dSaartbik context, typeConverter) {} 342c295a65dSaartbik 343c295a65dSaartbik LogicalResult 344c295a65dSaartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345c295a65dSaartbik ConversionPatternRewriter &rewriter) const override { 346c295a65dSaartbik auto transOp = cast<vector::FlatTransposeOp>(op); 3472d2c73c5SJacques Pienaar auto adaptor = vector::FlatTransposeOpAdaptor(operands); 348c295a65dSaartbik rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( 349c295a65dSaartbik transOp, typeConverter.convertType(transOp.res().getType()), 350c295a65dSaartbik adaptor.matrix(), transOp.rows(), transOp.columns()); 351c295a65dSaartbik return success(); 352c295a65dSaartbik } 353c295a65dSaartbik }; 354c295a65dSaartbik 35539379916Saartbik /// Conversion pattern for a vector.maskedload. 35639379916Saartbik class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { 35739379916Saartbik public: 35839379916Saartbik explicit VectorMaskedLoadOpConversion(MLIRContext *context, 35939379916Saartbik LLVMTypeConverter &typeConverter) 36039379916Saartbik : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, 36139379916Saartbik typeConverter) {} 36239379916Saartbik 36339379916Saartbik LogicalResult 36439379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 36539379916Saartbik ConversionPatternRewriter &rewriter) const override { 36639379916Saartbik auto loc = op->getLoc(); 36739379916Saartbik auto load = cast<vector::MaskedLoadOp>(op); 36839379916Saartbik auto adaptor = vector::MaskedLoadOpAdaptor(operands); 36939379916Saartbik 37039379916Saartbik // Resolve alignment. 37139379916Saartbik unsigned align; 37239379916Saartbik if (failed(getMemRefAlignment(typeConverter, load, align))) 37339379916Saartbik return failure(); 37439379916Saartbik 37539379916Saartbik auto vtype = typeConverter.convertType(load.getResultVectorType()); 37639379916Saartbik Value ptr; 37739379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), 37839379916Saartbik vtype, ptr))) 37939379916Saartbik return failure(); 38039379916Saartbik 38139379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( 38239379916Saartbik load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), 38339379916Saartbik rewriter.getI32IntegerAttr(align)); 38439379916Saartbik return success(); 38539379916Saartbik } 38639379916Saartbik }; 38739379916Saartbik 38839379916Saartbik /// Conversion pattern for a vector.maskedstore. 38939379916Saartbik class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { 39039379916Saartbik public: 39139379916Saartbik explicit VectorMaskedStoreOpConversion(MLIRContext *context, 39239379916Saartbik LLVMTypeConverter &typeConverter) 39339379916Saartbik : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, 39439379916Saartbik typeConverter) {} 39539379916Saartbik 39639379916Saartbik LogicalResult 39739379916Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 39839379916Saartbik ConversionPatternRewriter &rewriter) const override { 39939379916Saartbik auto loc = op->getLoc(); 40039379916Saartbik auto store = cast<vector::MaskedStoreOp>(op); 40139379916Saartbik auto adaptor = vector::MaskedStoreOpAdaptor(operands); 40239379916Saartbik 40339379916Saartbik // Resolve alignment. 40439379916Saartbik unsigned align; 40539379916Saartbik if (failed(getMemRefAlignment(typeConverter, store, align))) 40639379916Saartbik return failure(); 40739379916Saartbik 40839379916Saartbik auto vtype = typeConverter.convertType(store.getValueVectorType()); 40939379916Saartbik Value ptr; 41039379916Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), 41139379916Saartbik vtype, ptr))) 41239379916Saartbik return failure(); 41339379916Saartbik 41439379916Saartbik rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( 41539379916Saartbik store, adaptor.value(), ptr, adaptor.mask(), 41639379916Saartbik rewriter.getI32IntegerAttr(align)); 41739379916Saartbik return success(); 41839379916Saartbik } 41939379916Saartbik }; 42039379916Saartbik 42119dbb230Saartbik /// Conversion pattern for a vector.gather. 42219dbb230Saartbik class VectorGatherOpConversion : public ConvertToLLVMPattern { 42319dbb230Saartbik public: 42419dbb230Saartbik explicit VectorGatherOpConversion(MLIRContext *context, 42519dbb230Saartbik LLVMTypeConverter &typeConverter) 42619dbb230Saartbik : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, 42719dbb230Saartbik typeConverter) {} 42819dbb230Saartbik 42919dbb230Saartbik LogicalResult 43019dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 43119dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 43219dbb230Saartbik auto loc = op->getLoc(); 43319dbb230Saartbik auto gather = cast<vector::GatherOp>(op); 43419dbb230Saartbik auto adaptor = vector::GatherOpAdaptor(operands); 43519dbb230Saartbik 43619dbb230Saartbik // Resolve alignment. 43719dbb230Saartbik unsigned align; 43819dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, gather, align))) 43919dbb230Saartbik return failure(); 44019dbb230Saartbik 44119dbb230Saartbik // Get index ptrs. 44219dbb230Saartbik VectorType vType = gather.getResultVectorType(); 44319dbb230Saartbik Type iType = gather.getIndicesVectorType().getElementType(); 44419dbb230Saartbik Value ptrs; 445e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 446e8dcf5f8Saartbik gather.getMemRefType(), vType, iType, ptrs))) 44719dbb230Saartbik return failure(); 44819dbb230Saartbik 44919dbb230Saartbik // Replace with the gather intrinsic. 45019dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_gather>( 4510c2a4d3cSBenjamin Kramer gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), 4520c2a4d3cSBenjamin Kramer adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); 45319dbb230Saartbik return success(); 45419dbb230Saartbik } 45519dbb230Saartbik }; 45619dbb230Saartbik 45719dbb230Saartbik /// Conversion pattern for a vector.scatter. 45819dbb230Saartbik class VectorScatterOpConversion : public ConvertToLLVMPattern { 45919dbb230Saartbik public: 46019dbb230Saartbik explicit VectorScatterOpConversion(MLIRContext *context, 46119dbb230Saartbik LLVMTypeConverter &typeConverter) 46219dbb230Saartbik : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, 46319dbb230Saartbik typeConverter) {} 46419dbb230Saartbik 46519dbb230Saartbik LogicalResult 46619dbb230Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 46719dbb230Saartbik ConversionPatternRewriter &rewriter) const override { 46819dbb230Saartbik auto loc = op->getLoc(); 46919dbb230Saartbik auto scatter = cast<vector::ScatterOp>(op); 47019dbb230Saartbik auto adaptor = vector::ScatterOpAdaptor(operands); 47119dbb230Saartbik 47219dbb230Saartbik // Resolve alignment. 47319dbb230Saartbik unsigned align; 47419dbb230Saartbik if (failed(getMemRefAlignment(typeConverter, scatter, align))) 47519dbb230Saartbik return failure(); 47619dbb230Saartbik 47719dbb230Saartbik // Get index ptrs. 47819dbb230Saartbik VectorType vType = scatter.getValueVectorType(); 47919dbb230Saartbik Type iType = scatter.getIndicesVectorType().getElementType(); 48019dbb230Saartbik Value ptrs; 481e8dcf5f8Saartbik if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), 482e8dcf5f8Saartbik scatter.getMemRefType(), vType, iType, ptrs))) 48319dbb230Saartbik return failure(); 48419dbb230Saartbik 48519dbb230Saartbik // Replace with the scatter intrinsic. 48619dbb230Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( 48719dbb230Saartbik scatter, adaptor.value(), ptrs, adaptor.mask(), 48819dbb230Saartbik rewriter.getI32IntegerAttr(align)); 48919dbb230Saartbik return success(); 49019dbb230Saartbik } 49119dbb230Saartbik }; 49219dbb230Saartbik 493e8dcf5f8Saartbik /// Conversion pattern for a vector.expandload. 494e8dcf5f8Saartbik class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { 495e8dcf5f8Saartbik public: 496e8dcf5f8Saartbik explicit VectorExpandLoadOpConversion(MLIRContext *context, 497e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 498e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, 499e8dcf5f8Saartbik typeConverter) {} 500e8dcf5f8Saartbik 501e8dcf5f8Saartbik LogicalResult 502e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 503e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 504e8dcf5f8Saartbik auto loc = op->getLoc(); 505e8dcf5f8Saartbik auto expand = cast<vector::ExpandLoadOp>(op); 506e8dcf5f8Saartbik auto adaptor = vector::ExpandLoadOpAdaptor(operands); 507e8dcf5f8Saartbik 508e8dcf5f8Saartbik Value ptr; 509e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), 510e8dcf5f8Saartbik ptr))) 511e8dcf5f8Saartbik return failure(); 512e8dcf5f8Saartbik 513e8dcf5f8Saartbik auto vType = expand.getResultVectorType(); 514e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( 515e8dcf5f8Saartbik op, typeConverter.convertType(vType), ptr, adaptor.mask(), 516e8dcf5f8Saartbik adaptor.pass_thru()); 517e8dcf5f8Saartbik return success(); 518e8dcf5f8Saartbik } 519e8dcf5f8Saartbik }; 520e8dcf5f8Saartbik 521e8dcf5f8Saartbik /// Conversion pattern for a vector.compressstore. 522e8dcf5f8Saartbik class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { 523e8dcf5f8Saartbik public: 524e8dcf5f8Saartbik explicit VectorCompressStoreOpConversion(MLIRContext *context, 525e8dcf5f8Saartbik LLVMTypeConverter &typeConverter) 526e8dcf5f8Saartbik : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), 527e8dcf5f8Saartbik context, typeConverter) {} 528e8dcf5f8Saartbik 529e8dcf5f8Saartbik LogicalResult 530e8dcf5f8Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 531e8dcf5f8Saartbik ConversionPatternRewriter &rewriter) const override { 532e8dcf5f8Saartbik auto loc = op->getLoc(); 533e8dcf5f8Saartbik auto compress = cast<vector::CompressStoreOp>(op); 534e8dcf5f8Saartbik auto adaptor = vector::CompressStoreOpAdaptor(operands); 535e8dcf5f8Saartbik 536e8dcf5f8Saartbik Value ptr; 537e8dcf5f8Saartbik if (failed(getBasePtr(rewriter, loc, adaptor.base(), 538e8dcf5f8Saartbik compress.getMemRefType(), ptr))) 539e8dcf5f8Saartbik return failure(); 540e8dcf5f8Saartbik 541e8dcf5f8Saartbik rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( 542e8dcf5f8Saartbik op, adaptor.value(), ptr, adaptor.mask()); 543e8dcf5f8Saartbik return success(); 544e8dcf5f8Saartbik } 545e8dcf5f8Saartbik }; 546e8dcf5f8Saartbik 54719dbb230Saartbik /// Conversion pattern for all vector reductions. 548870c1fd4SAlex Zinenko class VectorReductionOpConversion : public ConvertToLLVMPattern { 549e83b7b99Saartbik public: 550e83b7b99Saartbik explicit VectorReductionOpConversion(MLIRContext *context, 551ceb1b327Saartbik LLVMTypeConverter &typeConverter, 552060c9dd1Saartbik bool reassociateFPRed) 553870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, 554ceb1b327Saartbik typeConverter), 555060c9dd1Saartbik reassociateFPReductions(reassociateFPRed) {} 556e83b7b99Saartbik 5573145427dSRiver Riddle LogicalResult 558e83b7b99Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 559e83b7b99Saartbik ConversionPatternRewriter &rewriter) const override { 560e83b7b99Saartbik auto reductionOp = cast<vector::ReductionOp>(op); 561e83b7b99Saartbik auto kind = reductionOp.kind(); 562e83b7b99Saartbik Type eltType = reductionOp.dest().getType(); 5630f04384dSAlex Zinenko Type llvmType = typeConverter.convertType(eltType); 564e9628955SAart Bik if (eltType.isIntOrIndex()) { 565e83b7b99Saartbik // Integer reductions: add/mul/min/max/and/or/xor. 566e83b7b99Saartbik if (kind == "add") 567*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>( 568e83b7b99Saartbik op, llvmType, operands[0]); 569e83b7b99Saartbik else if (kind == "mul") 570*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>( 571e83b7b99Saartbik op, llvmType, operands[0]); 572e9628955SAart Bik else if (kind == "min" && 573e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 574*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>( 575e9628955SAart Bik op, llvmType, operands[0]); 576e83b7b99Saartbik else if (kind == "min") 577*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>( 578e83b7b99Saartbik op, llvmType, operands[0]); 579e9628955SAart Bik else if (kind == "max" && 580e9628955SAart Bik (eltType.isIndex() || eltType.isUnsignedInteger())) 581*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>( 582e9628955SAart Bik op, llvmType, operands[0]); 583e83b7b99Saartbik else if (kind == "max") 584*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>( 585e83b7b99Saartbik op, llvmType, operands[0]); 586e83b7b99Saartbik else if (kind == "and") 587*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>( 588e83b7b99Saartbik op, llvmType, operands[0]); 589e83b7b99Saartbik else if (kind == "or") 590*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>( 591e83b7b99Saartbik op, llvmType, operands[0]); 592e83b7b99Saartbik else if (kind == "xor") 593*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>( 594e83b7b99Saartbik op, llvmType, operands[0]); 595e83b7b99Saartbik else 5963145427dSRiver Riddle return failure(); 5973145427dSRiver Riddle return success(); 598e83b7b99Saartbik 5992d76274bSBenjamin Kramer } else if (eltType.isa<FloatType>()) { 600e83b7b99Saartbik // Floating-point reductions: add/mul/min/max 601e83b7b99Saartbik if (kind == "add") { 6020d924700Saartbik // Optional accumulator (or zero). 6030d924700Saartbik Value acc = operands.size() > 1 ? operands[1] 6040d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 6050d924700Saartbik op->getLoc(), llvmType, 6060d924700Saartbik rewriter.getZeroAttr(eltType)); 607*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( 608ceb1b327Saartbik op, llvmType, acc, operands[0], 609ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 610e83b7b99Saartbik } else if (kind == "mul") { 6110d924700Saartbik // Optional accumulator (or one). 6120d924700Saartbik Value acc = operands.size() > 1 6130d924700Saartbik ? operands[1] 6140d924700Saartbik : rewriter.create<LLVM::ConstantOp>( 6150d924700Saartbik op->getLoc(), llvmType, 6160d924700Saartbik rewriter.getFloatAttr(eltType, 1.0)); 617*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( 618ceb1b327Saartbik op, llvmType, acc, operands[0], 619ceb1b327Saartbik rewriter.getBoolAttr(reassociateFPReductions)); 620e83b7b99Saartbik } else if (kind == "min") 621*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( 622e83b7b99Saartbik op, llvmType, operands[0]); 623e83b7b99Saartbik else if (kind == "max") 624*322d0afdSAmara Emerson rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( 625e83b7b99Saartbik op, llvmType, operands[0]); 626e83b7b99Saartbik else 6273145427dSRiver Riddle return failure(); 6283145427dSRiver Riddle return success(); 629e83b7b99Saartbik } 6303145427dSRiver Riddle return failure(); 631e83b7b99Saartbik } 632ceb1b327Saartbik 633ceb1b327Saartbik private: 634ceb1b327Saartbik const bool reassociateFPReductions; 635e83b7b99Saartbik }; 636e83b7b99Saartbik 637060c9dd1Saartbik /// Conversion pattern for a vector.create_mask (1-D only). 638060c9dd1Saartbik class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { 639060c9dd1Saartbik public: 640060c9dd1Saartbik explicit VectorCreateMaskOpConversion(MLIRContext *context, 641060c9dd1Saartbik LLVMTypeConverter &typeConverter, 642060c9dd1Saartbik bool enableIndexOpt) 643060c9dd1Saartbik : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, 644060c9dd1Saartbik typeConverter), 645060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 646060c9dd1Saartbik 647060c9dd1Saartbik LogicalResult 648060c9dd1Saartbik matchAndRewrite(Operation *op, ArrayRef<Value> operands, 649060c9dd1Saartbik ConversionPatternRewriter &rewriter) const override { 650060c9dd1Saartbik auto dstType = op->getResult(0).getType().cast<VectorType>(); 651060c9dd1Saartbik int64_t rank = dstType.getRank(); 652060c9dd1Saartbik if (rank == 1) { 653060c9dd1Saartbik rewriter.replaceOp( 654060c9dd1Saartbik op, buildVectorComparison(rewriter, op, enableIndexOptimizations, 655060c9dd1Saartbik dstType.getDimSize(0), operands[0])); 656060c9dd1Saartbik return success(); 657060c9dd1Saartbik } 658060c9dd1Saartbik return failure(); 659060c9dd1Saartbik } 660060c9dd1Saartbik 661060c9dd1Saartbik private: 662060c9dd1Saartbik const bool enableIndexOptimizations; 663060c9dd1Saartbik }; 664060c9dd1Saartbik 665870c1fd4SAlex Zinenko class VectorShuffleOpConversion : public ConvertToLLVMPattern { 6661c81adf3SAart Bik public: 6671c81adf3SAart Bik explicit VectorShuffleOpConversion(MLIRContext *context, 6681c81adf3SAart Bik LLVMTypeConverter &typeConverter) 669870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, 6701c81adf3SAart Bik typeConverter) {} 6711c81adf3SAart Bik 6723145427dSRiver Riddle LogicalResult 673e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 6741c81adf3SAart Bik ConversionPatternRewriter &rewriter) const override { 6751c81adf3SAart Bik auto loc = op->getLoc(); 6762d2c73c5SJacques Pienaar auto adaptor = vector::ShuffleOpAdaptor(operands); 6771c81adf3SAart Bik auto shuffleOp = cast<vector::ShuffleOp>(op); 6781c81adf3SAart Bik auto v1Type = shuffleOp.getV1VectorType(); 6791c81adf3SAart Bik auto v2Type = shuffleOp.getV2VectorType(); 6801c81adf3SAart Bik auto vectorType = shuffleOp.getVectorType(); 6810f04384dSAlex Zinenko Type llvmType = typeConverter.convertType(vectorType); 6821c81adf3SAart Bik auto maskArrayAttr = shuffleOp.mask(); 6831c81adf3SAart Bik 6841c81adf3SAart Bik // Bail if result type cannot be lowered. 6851c81adf3SAart Bik if (!llvmType) 6863145427dSRiver Riddle return failure(); 6871c81adf3SAart Bik 6881c81adf3SAart Bik // Get rank and dimension sizes. 6891c81adf3SAart Bik int64_t rank = vectorType.getRank(); 6901c81adf3SAart Bik assert(v1Type.getRank() == rank); 6911c81adf3SAart Bik assert(v2Type.getRank() == rank); 6921c81adf3SAart Bik int64_t v1Dim = v1Type.getDimSize(0); 6931c81adf3SAart Bik 6941c81adf3SAart Bik // For rank 1, where both operands have *exactly* the same vector type, 6951c81adf3SAart Bik // there is direct shuffle support in LLVM. Use it! 6961c81adf3SAart Bik if (rank == 1 && v1Type == v2Type) { 697e62a6956SRiver Riddle Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>( 6981c81adf3SAart Bik loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); 6991c81adf3SAart Bik rewriter.replaceOp(op, shuffle); 7003145427dSRiver Riddle return success(); 701b36aaeafSAart Bik } 702b36aaeafSAart Bik 7031c81adf3SAart Bik // For all other cases, insert the individual values individually. 704e62a6956SRiver Riddle Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType); 7051c81adf3SAart Bik int64_t insPos = 0; 7061c81adf3SAart Bik for (auto en : llvm::enumerate(maskArrayAttr)) { 7071c81adf3SAart Bik int64_t extPos = en.value().cast<IntegerAttr>().getInt(); 708e62a6956SRiver Riddle Value value = adaptor.v1(); 7091c81adf3SAart Bik if (extPos >= v1Dim) { 7101c81adf3SAart Bik extPos -= v1Dim; 7111c81adf3SAart Bik value = adaptor.v2(); 712b36aaeafSAart Bik } 7130f04384dSAlex Zinenko Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, 7140f04384dSAlex Zinenko rank, extPos); 7150f04384dSAlex Zinenko insert = insertOne(rewriter, typeConverter, loc, insert, extract, 7160f04384dSAlex Zinenko llvmType, rank, insPos++); 7171c81adf3SAart Bik } 7181c81adf3SAart Bik rewriter.replaceOp(op, insert); 7193145427dSRiver Riddle return success(); 720b36aaeafSAart Bik } 721b36aaeafSAart Bik }; 722b36aaeafSAart Bik 723870c1fd4SAlex Zinenko class VectorExtractElementOpConversion : public ConvertToLLVMPattern { 724cd5dab8aSAart Bik public: 725cd5dab8aSAart Bik explicit VectorExtractElementOpConversion(MLIRContext *context, 726cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 727870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), 728870c1fd4SAlex Zinenko context, typeConverter) {} 729cd5dab8aSAart Bik 7303145427dSRiver Riddle LogicalResult 731e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 732cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 7332d2c73c5SJacques Pienaar auto adaptor = vector::ExtractElementOpAdaptor(operands); 734cd5dab8aSAart Bik auto extractEltOp = cast<vector::ExtractElementOp>(op); 735cd5dab8aSAart Bik auto vectorType = extractEltOp.getVectorType(); 7360f04384dSAlex Zinenko auto llvmType = typeConverter.convertType(vectorType.getElementType()); 737cd5dab8aSAart Bik 738cd5dab8aSAart Bik // Bail if result type cannot be lowered. 739cd5dab8aSAart Bik if (!llvmType) 7403145427dSRiver Riddle return failure(); 741cd5dab8aSAart Bik 742cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 743cd5dab8aSAart Bik op, llvmType, adaptor.vector(), adaptor.position()); 7443145427dSRiver Riddle return success(); 745cd5dab8aSAart Bik } 746cd5dab8aSAart Bik }; 747cd5dab8aSAart Bik 748870c1fd4SAlex Zinenko class VectorExtractOpConversion : public ConvertToLLVMPattern { 7495c0c51a9SNicolas Vasilache public: 7509826fe5cSAart Bik explicit VectorExtractOpConversion(MLIRContext *context, 7515c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 752870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, 7535c0c51a9SNicolas Vasilache typeConverter) {} 7545c0c51a9SNicolas Vasilache 7553145427dSRiver Riddle LogicalResult 756e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 7575c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 7585c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 7592d2c73c5SJacques Pienaar auto adaptor = vector::ExtractOpAdaptor(operands); 760d37f2725SAart Bik auto extractOp = cast<vector::ExtractOp>(op); 7619826fe5cSAart Bik auto vectorType = extractOp.getVectorType(); 7622bdf33ccSRiver Riddle auto resultType = extractOp.getResult().getType(); 7630f04384dSAlex Zinenko auto llvmResultType = typeConverter.convertType(resultType); 7645c0c51a9SNicolas Vasilache auto positionArrayAttr = extractOp.position(); 7659826fe5cSAart Bik 7669826fe5cSAart Bik // Bail if result type cannot be lowered. 7679826fe5cSAart Bik if (!llvmResultType) 7683145427dSRiver Riddle return failure(); 7699826fe5cSAart Bik 7705c0c51a9SNicolas Vasilache // One-shot extraction of vector from array (only requires extractvalue). 7715c0c51a9SNicolas Vasilache if (resultType.isa<VectorType>()) { 772e62a6956SRiver Riddle Value extracted = rewriter.create<LLVM::ExtractValueOp>( 7735c0c51a9SNicolas Vasilache loc, llvmResultType, adaptor.vector(), positionArrayAttr); 7745c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7753145427dSRiver Riddle return success(); 7765c0c51a9SNicolas Vasilache } 7775c0c51a9SNicolas Vasilache 7789826fe5cSAart Bik // Potential extraction of 1-D vector from array. 7795c0c51a9SNicolas Vasilache auto *context = op->getContext(); 780e62a6956SRiver Riddle Value extracted = adaptor.vector(); 7815c0c51a9SNicolas Vasilache auto positionAttrs = positionArrayAttr.getValue(); 7825c0c51a9SNicolas Vasilache if (positionAttrs.size() > 1) { 7839826fe5cSAart Bik auto oneDVectorType = reducedVectorTypeBack(vectorType); 7845c0c51a9SNicolas Vasilache auto nMinusOnePositionAttrs = 7855c0c51a9SNicolas Vasilache ArrayAttr::get(positionAttrs.drop_back(), context); 7865c0c51a9SNicolas Vasilache extracted = rewriter.create<LLVM::ExtractValueOp>( 7870f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 7885c0c51a9SNicolas Vasilache nMinusOnePositionAttrs); 7895c0c51a9SNicolas Vasilache } 7905c0c51a9SNicolas Vasilache 7915c0c51a9SNicolas Vasilache // Remaining extraction of element from 1-D LLVM vector 7925c0c51a9SNicolas Vasilache auto position = positionAttrs.back().cast<IntegerAttr>(); 7935446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 7941d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 7955c0c51a9SNicolas Vasilache extracted = 7965c0c51a9SNicolas Vasilache rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 7975c0c51a9SNicolas Vasilache rewriter.replaceOp(op, extracted); 7985c0c51a9SNicolas Vasilache 7993145427dSRiver Riddle return success(); 8005c0c51a9SNicolas Vasilache } 8015c0c51a9SNicolas Vasilache }; 8025c0c51a9SNicolas Vasilache 803681f929fSNicolas Vasilache /// Conversion pattern that turns a vector.fma on a 1-D vector 804681f929fSNicolas Vasilache /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. 805681f929fSNicolas Vasilache /// This does not match vectors of n >= 2 rank. 806681f929fSNicolas Vasilache /// 807681f929fSNicolas Vasilache /// Example: 808681f929fSNicolas Vasilache /// ``` 809681f929fSNicolas Vasilache /// vector.fma %a, %a, %a : vector<8xf32> 810681f929fSNicolas Vasilache /// ``` 811681f929fSNicolas Vasilache /// is converted to: 812681f929fSNicolas Vasilache /// ``` 8133bffe602SBenjamin Kramer /// llvm.intr.fmuladd %va, %va, %va: 814681f929fSNicolas Vasilache /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) 815681f929fSNicolas Vasilache /// -> !llvm<"<8 x float>"> 816681f929fSNicolas Vasilache /// ``` 817870c1fd4SAlex Zinenko class VectorFMAOp1DConversion : public ConvertToLLVMPattern { 818681f929fSNicolas Vasilache public: 819681f929fSNicolas Vasilache explicit VectorFMAOp1DConversion(MLIRContext *context, 820681f929fSNicolas Vasilache LLVMTypeConverter &typeConverter) 821870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, 822681f929fSNicolas Vasilache typeConverter) {} 823681f929fSNicolas Vasilache 8243145427dSRiver Riddle LogicalResult 825681f929fSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 826681f929fSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 8272d2c73c5SJacques Pienaar auto adaptor = vector::FMAOpAdaptor(operands); 828681f929fSNicolas Vasilache vector::FMAOp fmaOp = cast<vector::FMAOp>(op); 829681f929fSNicolas Vasilache VectorType vType = fmaOp.getVectorType(); 830681f929fSNicolas Vasilache if (vType.getRank() != 1) 8313145427dSRiver Riddle return failure(); 8323bffe602SBenjamin Kramer rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(), 8333bffe602SBenjamin Kramer adaptor.rhs(), adaptor.acc()); 8343145427dSRiver Riddle return success(); 835681f929fSNicolas Vasilache } 836681f929fSNicolas Vasilache }; 837681f929fSNicolas Vasilache 838870c1fd4SAlex Zinenko class VectorInsertElementOpConversion : public ConvertToLLVMPattern { 839cd5dab8aSAart Bik public: 840cd5dab8aSAart Bik explicit VectorInsertElementOpConversion(MLIRContext *context, 841cd5dab8aSAart Bik LLVMTypeConverter &typeConverter) 842870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), 843870c1fd4SAlex Zinenko context, typeConverter) {} 844cd5dab8aSAart Bik 8453145427dSRiver Riddle LogicalResult 846e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 847cd5dab8aSAart Bik ConversionPatternRewriter &rewriter) const override { 8482d2c73c5SJacques Pienaar auto adaptor = vector::InsertElementOpAdaptor(operands); 849cd5dab8aSAart Bik auto insertEltOp = cast<vector::InsertElementOp>(op); 850cd5dab8aSAart Bik auto vectorType = insertEltOp.getDestVectorType(); 8510f04384dSAlex Zinenko auto llvmType = typeConverter.convertType(vectorType); 852cd5dab8aSAart Bik 853cd5dab8aSAart Bik // Bail if result type cannot be lowered. 854cd5dab8aSAart Bik if (!llvmType) 8553145427dSRiver Riddle return failure(); 856cd5dab8aSAart Bik 857cd5dab8aSAart Bik rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 858cd5dab8aSAart Bik op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); 8593145427dSRiver Riddle return success(); 860cd5dab8aSAart Bik } 861cd5dab8aSAart Bik }; 862cd5dab8aSAart Bik 863870c1fd4SAlex Zinenko class VectorInsertOpConversion : public ConvertToLLVMPattern { 8649826fe5cSAart Bik public: 8659826fe5cSAart Bik explicit VectorInsertOpConversion(MLIRContext *context, 8669826fe5cSAart Bik LLVMTypeConverter &typeConverter) 867870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, 8689826fe5cSAart Bik typeConverter) {} 8699826fe5cSAart Bik 8703145427dSRiver Riddle LogicalResult 871e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 8729826fe5cSAart Bik ConversionPatternRewriter &rewriter) const override { 8739826fe5cSAart Bik auto loc = op->getLoc(); 8742d2c73c5SJacques Pienaar auto adaptor = vector::InsertOpAdaptor(operands); 8759826fe5cSAart Bik auto insertOp = cast<vector::InsertOp>(op); 8769826fe5cSAart Bik auto sourceType = insertOp.getSourceType(); 8779826fe5cSAart Bik auto destVectorType = insertOp.getDestVectorType(); 8780f04384dSAlex Zinenko auto llvmResultType = typeConverter.convertType(destVectorType); 8799826fe5cSAart Bik auto positionArrayAttr = insertOp.position(); 8809826fe5cSAart Bik 8819826fe5cSAart Bik // Bail if result type cannot be lowered. 8829826fe5cSAart Bik if (!llvmResultType) 8833145427dSRiver Riddle return failure(); 8849826fe5cSAart Bik 8859826fe5cSAart Bik // One-shot insertion of a vector into an array (only requires insertvalue). 8869826fe5cSAart Bik if (sourceType.isa<VectorType>()) { 887e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertValueOp>( 8889826fe5cSAart Bik loc, llvmResultType, adaptor.dest(), adaptor.source(), 8899826fe5cSAart Bik positionArrayAttr); 8909826fe5cSAart Bik rewriter.replaceOp(op, inserted); 8913145427dSRiver Riddle return success(); 8929826fe5cSAart Bik } 8939826fe5cSAart Bik 8949826fe5cSAart Bik // Potential extraction of 1-D vector from array. 8959826fe5cSAart Bik auto *context = op->getContext(); 896e62a6956SRiver Riddle Value extracted = adaptor.dest(); 8979826fe5cSAart Bik auto positionAttrs = positionArrayAttr.getValue(); 8989826fe5cSAart Bik auto position = positionAttrs.back().cast<IntegerAttr>(); 8999826fe5cSAart Bik auto oneDVectorType = destVectorType; 9009826fe5cSAart Bik if (positionAttrs.size() > 1) { 9019826fe5cSAart Bik oneDVectorType = reducedVectorTypeBack(destVectorType); 9029826fe5cSAart Bik auto nMinusOnePositionAttrs = 9039826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 9049826fe5cSAart Bik extracted = rewriter.create<LLVM::ExtractValueOp>( 9050f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 9069826fe5cSAart Bik nMinusOnePositionAttrs); 9079826fe5cSAart Bik } 9089826fe5cSAart Bik 9099826fe5cSAart Bik // Insertion of an element into a 1-D LLVM vector. 9105446ec85SAlex Zinenko auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 9111d47564aSAart Bik auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); 912e62a6956SRiver Riddle Value inserted = rewriter.create<LLVM::InsertElementOp>( 9130f04384dSAlex Zinenko loc, typeConverter.convertType(oneDVectorType), extracted, 9140f04384dSAlex Zinenko adaptor.source(), constant); 9159826fe5cSAart Bik 9169826fe5cSAart Bik // Potential insertion of resulting 1-D vector into array. 9179826fe5cSAart Bik if (positionAttrs.size() > 1) { 9189826fe5cSAart Bik auto nMinusOnePositionAttrs = 9199826fe5cSAart Bik ArrayAttr::get(positionAttrs.drop_back(), context); 9209826fe5cSAart Bik inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 9219826fe5cSAart Bik adaptor.dest(), inserted, 9229826fe5cSAart Bik nMinusOnePositionAttrs); 9239826fe5cSAart Bik } 9249826fe5cSAart Bik 9259826fe5cSAart Bik rewriter.replaceOp(op, inserted); 9263145427dSRiver Riddle return success(); 9279826fe5cSAart Bik } 9289826fe5cSAart Bik }; 9299826fe5cSAart Bik 930681f929fSNicolas Vasilache /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. 931681f929fSNicolas Vasilache /// 932681f929fSNicolas Vasilache /// Example: 933681f929fSNicolas Vasilache /// ``` 934681f929fSNicolas Vasilache /// %d = vector.fma %a, %b, %c : vector<2x4xf32> 935681f929fSNicolas Vasilache /// ``` 936681f929fSNicolas Vasilache /// is rewritten into: 937681f929fSNicolas Vasilache /// ``` 938681f929fSNicolas Vasilache /// %r = splat %f0: vector<2x4xf32> 939681f929fSNicolas Vasilache /// %va = vector.extractvalue %a[0] : vector<2x4xf32> 940681f929fSNicolas Vasilache /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> 941681f929fSNicolas Vasilache /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> 942681f929fSNicolas Vasilache /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> 943681f929fSNicolas Vasilache /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> 944681f929fSNicolas Vasilache /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> 945681f929fSNicolas Vasilache /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> 946681f929fSNicolas Vasilache /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> 947681f929fSNicolas Vasilache /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> 948681f929fSNicolas Vasilache /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> 949681f929fSNicolas Vasilache /// // %r3 holds the final value. 950681f929fSNicolas Vasilache /// ``` 951681f929fSNicolas Vasilache class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> { 952681f929fSNicolas Vasilache public: 953681f929fSNicolas Vasilache using OpRewritePattern<FMAOp>::OpRewritePattern; 954681f929fSNicolas Vasilache 9553145427dSRiver Riddle LogicalResult matchAndRewrite(FMAOp op, 956681f929fSNicolas Vasilache PatternRewriter &rewriter) const override { 957681f929fSNicolas Vasilache auto vType = op.getVectorType(); 958681f929fSNicolas Vasilache if (vType.getRank() < 2) 9593145427dSRiver Riddle return failure(); 960681f929fSNicolas Vasilache 961681f929fSNicolas Vasilache auto loc = op.getLoc(); 962681f929fSNicolas Vasilache auto elemType = vType.getElementType(); 963681f929fSNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 964681f929fSNicolas Vasilache rewriter.getZeroAttr(elemType)); 965681f929fSNicolas Vasilache Value desc = rewriter.create<SplatOp>(loc, vType, zero); 966681f929fSNicolas Vasilache for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { 967681f929fSNicolas Vasilache Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i); 968681f929fSNicolas Vasilache Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i); 969681f929fSNicolas Vasilache Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i); 970681f929fSNicolas Vasilache Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); 971681f929fSNicolas Vasilache desc = rewriter.create<InsertOp>(loc, fma, desc, i); 972681f929fSNicolas Vasilache } 973681f929fSNicolas Vasilache rewriter.replaceOp(op, desc); 9743145427dSRiver Riddle return success(); 975681f929fSNicolas Vasilache } 976681f929fSNicolas Vasilache }; 977681f929fSNicolas Vasilache 9782d515e49SNicolas Vasilache // When ranks are different, InsertStridedSlice needs to extract a properly 9792d515e49SNicolas Vasilache // ranked vector from the destination vector into which to insert. This pattern 9802d515e49SNicolas Vasilache // only takes care of this part and forwards the rest of the conversion to 9812d515e49SNicolas Vasilache // another pattern that converts InsertStridedSlice for operands of the same 9822d515e49SNicolas Vasilache // rank. 9832d515e49SNicolas Vasilache // 9842d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 9852d515e49SNicolas Vasilache // have different ranks. In this case: 9862d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 9872d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 9882d515e49SNicolas Vasilache // destination subvector 9892d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 9902d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 9912d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 9922d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 9932d515e49SNicolas Vasilache class VectorInsertStridedSliceOpDifferentRankRewritePattern 9942d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 9952d515e49SNicolas Vasilache public: 9962d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 9972d515e49SNicolas Vasilache 9983145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 9992d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10002d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10012d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10022d515e49SNicolas Vasilache 10032d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10043145427dSRiver Riddle return failure(); 10052d515e49SNicolas Vasilache 10062d515e49SNicolas Vasilache auto loc = op.getLoc(); 10072d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10082d515e49SNicolas Vasilache assert(rankDiff >= 0); 10092d515e49SNicolas Vasilache if (rankDiff == 0) 10103145427dSRiver Riddle return failure(); 10112d515e49SNicolas Vasilache 10122d515e49SNicolas Vasilache int64_t rankRest = dstType.getRank() - rankDiff; 10132d515e49SNicolas Vasilache // Extract / insert the subvector of matching rank and InsertStridedSlice 10142d515e49SNicolas Vasilache // on it. 10152d515e49SNicolas Vasilache Value extracted = 10162d515e49SNicolas Vasilache rewriter.create<ExtractOp>(loc, op.dest(), 10172d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 10182d515e49SNicolas Vasilache /*dropFront=*/rankRest)); 10192d515e49SNicolas Vasilache // A different pattern will kick in for InsertStridedSlice with matching 10202d515e49SNicolas Vasilache // ranks. 10212d515e49SNicolas Vasilache auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 10222d515e49SNicolas Vasilache loc, op.source(), extracted, 10232d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 1024c8fc76a9Saartbik getI64SubArray(op.strides(), /*dropFront=*/0)); 10252d515e49SNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOp>( 10262d515e49SNicolas Vasilache op, stridedSliceInnerOp.getResult(), op.dest(), 10272d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /*dropFront=*/0, 10282d515e49SNicolas Vasilache /*dropFront=*/rankRest)); 10293145427dSRiver Riddle return success(); 10302d515e49SNicolas Vasilache } 10312d515e49SNicolas Vasilache }; 10322d515e49SNicolas Vasilache 10332d515e49SNicolas Vasilache // RewritePattern for InsertStridedSliceOp where source and destination vectors 10342d515e49SNicolas Vasilache // have the same rank. In this case, we reduce 10352d515e49SNicolas Vasilache // 1. the proper subvector is extracted from the destination vector 10362d515e49SNicolas Vasilache // 2. a new InsertStridedSlice op is created to insert the source in the 10372d515e49SNicolas Vasilache // destination subvector 10382d515e49SNicolas Vasilache // 3. the destination subvector is inserted back in the proper place 10392d515e49SNicolas Vasilache // 4. the op is replaced by the result of step 3. 10402d515e49SNicolas Vasilache // The new InsertStridedSlice from step 2. will be picked up by a 10412d515e49SNicolas Vasilache // `VectorInsertStridedSliceOpSameRankRewritePattern`. 10422d515e49SNicolas Vasilache class VectorInsertStridedSliceOpSameRankRewritePattern 10432d515e49SNicolas Vasilache : public OpRewritePattern<InsertStridedSliceOp> { 10442d515e49SNicolas Vasilache public: 10452d515e49SNicolas Vasilache using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 10462d515e49SNicolas Vasilache 10473145427dSRiver Riddle LogicalResult matchAndRewrite(InsertStridedSliceOp op, 10482d515e49SNicolas Vasilache PatternRewriter &rewriter) const override { 10492d515e49SNicolas Vasilache auto srcType = op.getSourceVectorType(); 10502d515e49SNicolas Vasilache auto dstType = op.getDestVectorType(); 10512d515e49SNicolas Vasilache 10522d515e49SNicolas Vasilache if (op.offsets().getValue().empty()) 10533145427dSRiver Riddle return failure(); 10542d515e49SNicolas Vasilache 10552d515e49SNicolas Vasilache int64_t rankDiff = dstType.getRank() - srcType.getRank(); 10562d515e49SNicolas Vasilache assert(rankDiff >= 0); 10572d515e49SNicolas Vasilache if (rankDiff != 0) 10583145427dSRiver Riddle return failure(); 10592d515e49SNicolas Vasilache 10602d515e49SNicolas Vasilache if (srcType == dstType) { 10612d515e49SNicolas Vasilache rewriter.replaceOp(op, op.source()); 10623145427dSRiver Riddle return success(); 10632d515e49SNicolas Vasilache } 10642d515e49SNicolas Vasilache 10652d515e49SNicolas Vasilache int64_t offset = 10662d515e49SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 10672d515e49SNicolas Vasilache int64_t size = srcType.getShape().front(); 10682d515e49SNicolas Vasilache int64_t stride = 10692d515e49SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 10702d515e49SNicolas Vasilache 10712d515e49SNicolas Vasilache auto loc = op.getLoc(); 10722d515e49SNicolas Vasilache Value res = op.dest(); 10732d515e49SNicolas Vasilache // For each slice of the source vector along the most major dimension. 10742d515e49SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 10752d515e49SNicolas Vasilache off += stride, ++idx) { 10762d515e49SNicolas Vasilache // 1. extract the proper subvector (or element) from source 10772d515e49SNicolas Vasilache Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 10782d515e49SNicolas Vasilache if (extractedSource.getType().isa<VectorType>()) { 10792d515e49SNicolas Vasilache // 2. If we have a vector, extract the proper subvector from destination 10802d515e49SNicolas Vasilache // Otherwise we are at the element level and no need to recurse. 10812d515e49SNicolas Vasilache Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 10822d515e49SNicolas Vasilache // 3. Reduce the problem to lowering a new InsertStridedSlice op with 10832d515e49SNicolas Vasilache // smaller rank. 1084bd1ccfe6SRiver Riddle extractedSource = rewriter.create<InsertStridedSliceOp>( 10852d515e49SNicolas Vasilache loc, extractedSource, extractedDest, 10862d515e49SNicolas Vasilache getI64SubArray(op.offsets(), /* dropFront=*/1), 10872d515e49SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 10882d515e49SNicolas Vasilache } 10892d515e49SNicolas Vasilache // 4. Insert the extractedSource into the res vector. 10902d515e49SNicolas Vasilache res = insertOne(rewriter, loc, extractedSource, res, off); 10912d515e49SNicolas Vasilache } 10922d515e49SNicolas Vasilache 10932d515e49SNicolas Vasilache rewriter.replaceOp(op, res); 10943145427dSRiver Riddle return success(); 10952d515e49SNicolas Vasilache } 1096bd1ccfe6SRiver Riddle /// This pattern creates recursive InsertStridedSliceOp, but the recursion is 1097bd1ccfe6SRiver Riddle /// bounded as the rank is strictly decreasing. 1098bd1ccfe6SRiver Riddle bool hasBoundedRewriteRecursion() const final { return true; } 10992d515e49SNicolas Vasilache }; 11002d515e49SNicolas Vasilache 11012bf491c7SBenjamin Kramer /// Returns true if the memory underlying `memRefType` has a contiguous layout. 11022bf491c7SBenjamin Kramer /// Strides are written to `strides`. 11032bf491c7SBenjamin Kramer static bool isContiguous(MemRefType memRefType, 11042bf491c7SBenjamin Kramer SmallVectorImpl<int64_t> &strides) { 11052bf491c7SBenjamin Kramer int64_t offset; 11062bf491c7SBenjamin Kramer auto successStrides = getStridesAndOffset(memRefType, strides, offset); 1107239eff50SBenjamin Kramer bool isContiguous = strides.empty() || strides.back() == 1; 11082bf491c7SBenjamin Kramer if (isContiguous) { 11092bf491c7SBenjamin Kramer auto sizes = memRefType.getShape(); 11102bf491c7SBenjamin Kramer for (int index = 0, e = strides.size() - 2; index < e; ++index) { 11112bf491c7SBenjamin Kramer if (strides[index] != strides[index + 1] * sizes[index + 1]) { 11122bf491c7SBenjamin Kramer isContiguous = false; 11132bf491c7SBenjamin Kramer break; 11142bf491c7SBenjamin Kramer } 11152bf491c7SBenjamin Kramer } 11162bf491c7SBenjamin Kramer } 11172bf491c7SBenjamin Kramer return succeeded(successStrides) && isContiguous; 11182bf491c7SBenjamin Kramer } 11192bf491c7SBenjamin Kramer 1120870c1fd4SAlex Zinenko class VectorTypeCastOpConversion : public ConvertToLLVMPattern { 11215c0c51a9SNicolas Vasilache public: 11225c0c51a9SNicolas Vasilache explicit VectorTypeCastOpConversion(MLIRContext *context, 11235c0c51a9SNicolas Vasilache LLVMTypeConverter &typeConverter) 1124870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, 11255c0c51a9SNicolas Vasilache typeConverter) {} 11265c0c51a9SNicolas Vasilache 11273145427dSRiver Riddle LogicalResult 1128e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 11295c0c51a9SNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 11305c0c51a9SNicolas Vasilache auto loc = op->getLoc(); 11315c0c51a9SNicolas Vasilache vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 11325c0c51a9SNicolas Vasilache MemRefType sourceMemRefType = 11332bdf33ccSRiver Riddle castOp.getOperand().getType().cast<MemRefType>(); 11345c0c51a9SNicolas Vasilache MemRefType targetMemRefType = 11352bdf33ccSRiver Riddle castOp.getResult().getType().cast<MemRefType>(); 11365c0c51a9SNicolas Vasilache 11375c0c51a9SNicolas Vasilache // Only static shape casts supported atm. 11385c0c51a9SNicolas Vasilache if (!sourceMemRefType.hasStaticShape() || 11395c0c51a9SNicolas Vasilache !targetMemRefType.hasStaticShape()) 11403145427dSRiver Riddle return failure(); 11415c0c51a9SNicolas Vasilache 11425c0c51a9SNicolas Vasilache auto llvmSourceDescriptorTy = 11432bdf33ccSRiver Riddle operands[0].getType().dyn_cast<LLVM::LLVMType>(); 11445c0c51a9SNicolas Vasilache if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 11453145427dSRiver Riddle return failure(); 11465c0c51a9SNicolas Vasilache MemRefDescriptor sourceMemRef(operands[0]); 11475c0c51a9SNicolas Vasilache 11480f04384dSAlex Zinenko auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) 11495c0c51a9SNicolas Vasilache .dyn_cast_or_null<LLVM::LLVMType>(); 11505c0c51a9SNicolas Vasilache if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 11513145427dSRiver Riddle return failure(); 11525c0c51a9SNicolas Vasilache 11535c0c51a9SNicolas Vasilache // Only contiguous source tensors supported atm. 11542bf491c7SBenjamin Kramer SmallVector<int64_t, 4> strides; 11552bf491c7SBenjamin Kramer if (!isContiguous(sourceMemRefType, strides)) 11563145427dSRiver Riddle return failure(); 11575c0c51a9SNicolas Vasilache 11585446ec85SAlex Zinenko auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); 11595c0c51a9SNicolas Vasilache 11605c0c51a9SNicolas Vasilache // Create descriptor. 11615c0c51a9SNicolas Vasilache auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 11623a577f54SChristian Sigg Type llvmTargetElementTy = desc.getElementPtrType(); 11635c0c51a9SNicolas Vasilache // Set allocated ptr. 1164e62a6956SRiver Riddle Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); 11655c0c51a9SNicolas Vasilache allocated = 11665c0c51a9SNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 11675c0c51a9SNicolas Vasilache desc.setAllocatedPtr(rewriter, loc, allocated); 11685c0c51a9SNicolas Vasilache // Set aligned ptr. 1169e62a6956SRiver Riddle Value ptr = sourceMemRef.alignedPtr(rewriter, loc); 11705c0c51a9SNicolas Vasilache ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 11715c0c51a9SNicolas Vasilache desc.setAlignedPtr(rewriter, loc, ptr); 11725c0c51a9SNicolas Vasilache // Fill offset 0. 11735c0c51a9SNicolas Vasilache auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 11745c0c51a9SNicolas Vasilache auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 11755c0c51a9SNicolas Vasilache desc.setOffset(rewriter, loc, zero); 11765c0c51a9SNicolas Vasilache 11775c0c51a9SNicolas Vasilache // Fill size and stride descriptors in memref. 11785c0c51a9SNicolas Vasilache for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 11795c0c51a9SNicolas Vasilache int64_t index = indexedSize.index(); 11805c0c51a9SNicolas Vasilache auto sizeAttr = 11815c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 11825c0c51a9SNicolas Vasilache auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 11835c0c51a9SNicolas Vasilache desc.setSize(rewriter, loc, index, size); 11845c0c51a9SNicolas Vasilache auto strideAttr = 11855c0c51a9SNicolas Vasilache rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 11865c0c51a9SNicolas Vasilache auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 11875c0c51a9SNicolas Vasilache desc.setStride(rewriter, loc, index, stride); 11885c0c51a9SNicolas Vasilache } 11895c0c51a9SNicolas Vasilache 11905c0c51a9SNicolas Vasilache rewriter.replaceOp(op, {desc}); 11913145427dSRiver Riddle return success(); 11925c0c51a9SNicolas Vasilache } 11935c0c51a9SNicolas Vasilache }; 11945c0c51a9SNicolas Vasilache 11958345b86dSNicolas Vasilache /// Conversion pattern that converts a 1-D vector transfer read/write op in a 11968345b86dSNicolas Vasilache /// sequence of: 1197060c9dd1Saartbik /// 1. Get the source/dst address as an LLVM vector pointer. 1198060c9dd1Saartbik /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 1199060c9dd1Saartbik /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 1200060c9dd1Saartbik /// 4. Create a mask where offsetVector is compared against memref upper bound. 1201060c9dd1Saartbik /// 5. Rewrite op as a masked read or write. 12028345b86dSNicolas Vasilache template <typename ConcreteOp> 12038345b86dSNicolas Vasilache class VectorTransferConversion : public ConvertToLLVMPattern { 12048345b86dSNicolas Vasilache public: 12058345b86dSNicolas Vasilache explicit VectorTransferConversion(MLIRContext *context, 1206060c9dd1Saartbik LLVMTypeConverter &typeConv, 1207060c9dd1Saartbik bool enableIndexOpt) 1208060c9dd1Saartbik : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), 1209060c9dd1Saartbik enableIndexOptimizations(enableIndexOpt) {} 12108345b86dSNicolas Vasilache 12118345b86dSNicolas Vasilache LogicalResult 12128345b86dSNicolas Vasilache matchAndRewrite(Operation *op, ArrayRef<Value> operands, 12138345b86dSNicolas Vasilache ConversionPatternRewriter &rewriter) const override { 12148345b86dSNicolas Vasilache auto xferOp = cast<ConcreteOp>(op); 12158345b86dSNicolas Vasilache auto adaptor = getTransferOpAdapter(xferOp, operands); 1216b2c79c50SNicolas Vasilache 1217b2c79c50SNicolas Vasilache if (xferOp.getVectorType().getRank() > 1 || 1218b2c79c50SNicolas Vasilache llvm::size(xferOp.indices()) == 0) 12198345b86dSNicolas Vasilache return failure(); 12205f9e0466SNicolas Vasilache if (xferOp.permutation_map() != 12215f9e0466SNicolas Vasilache AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), 12225f9e0466SNicolas Vasilache xferOp.getVectorType().getRank(), 12235f9e0466SNicolas Vasilache op->getContext())) 12248345b86dSNicolas Vasilache return failure(); 12252bf491c7SBenjamin Kramer // Only contiguous source tensors supported atm. 12262bf491c7SBenjamin Kramer SmallVector<int64_t, 4> strides; 12272bf491c7SBenjamin Kramer if (!isContiguous(xferOp.getMemRefType(), strides)) 12282bf491c7SBenjamin Kramer return failure(); 12298345b86dSNicolas Vasilache 12308345b86dSNicolas Vasilache auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; 12318345b86dSNicolas Vasilache 12328345b86dSNicolas Vasilache Location loc = op->getLoc(); 12338345b86dSNicolas Vasilache MemRefType memRefType = xferOp.getMemRefType(); 12348345b86dSNicolas Vasilache 123568330ee0SThomas Raoux if (auto memrefVectorElementType = 123668330ee0SThomas Raoux memRefType.getElementType().dyn_cast<VectorType>()) { 123768330ee0SThomas Raoux // Memref has vector element type. 123868330ee0SThomas Raoux if (memrefVectorElementType.getElementType() != 123968330ee0SThomas Raoux xferOp.getVectorType().getElementType()) 124068330ee0SThomas Raoux return failure(); 12410de60b55SThomas Raoux #ifndef NDEBUG 124268330ee0SThomas Raoux // Check that memref vector type is a suffix of 'vectorType. 124368330ee0SThomas Raoux unsigned memrefVecEltRank = memrefVectorElementType.getRank(); 124468330ee0SThomas Raoux unsigned resultVecRank = xferOp.getVectorType().getRank(); 124568330ee0SThomas Raoux assert(memrefVecEltRank <= resultVecRank); 124668330ee0SThomas Raoux // TODO: Move this to isSuffix in Vector/Utils.h. 124768330ee0SThomas Raoux unsigned rankOffset = resultVecRank - memrefVecEltRank; 124868330ee0SThomas Raoux auto memrefVecEltShape = memrefVectorElementType.getShape(); 124968330ee0SThomas Raoux auto resultVecShape = xferOp.getVectorType().getShape(); 125068330ee0SThomas Raoux for (unsigned i = 0; i < memrefVecEltRank; ++i) 125168330ee0SThomas Raoux assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && 125268330ee0SThomas Raoux "memref vector element shape should match suffix of vector " 125368330ee0SThomas Raoux "result shape."); 12540de60b55SThomas Raoux #endif // ifndef NDEBUG 125568330ee0SThomas Raoux } 125668330ee0SThomas Raoux 12578345b86dSNicolas Vasilache // 1. Get the source/dst address as an LLVM vector pointer. 1258be16075bSWen-Heng (Jack) Chung // The vector pointer would always be on address space 0, therefore 1259be16075bSWen-Heng (Jack) Chung // addrspacecast shall be used when source/dst memrefs are not on 1260be16075bSWen-Heng (Jack) Chung // address space 0. 12618345b86dSNicolas Vasilache // TODO: support alignment when possible. 12628345b86dSNicolas Vasilache Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), 1263d3a98076SAlex Zinenko adaptor.indices(), rewriter); 12648345b86dSNicolas Vasilache auto vecTy = 12658345b86dSNicolas Vasilache toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); 1266be16075bSWen-Heng (Jack) Chung Value vectorDataPtr; 1267be16075bSWen-Heng (Jack) Chung if (memRefType.getMemorySpace() == 0) 1268be16075bSWen-Heng (Jack) Chung vectorDataPtr = 12698345b86dSNicolas Vasilache rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr); 1270be16075bSWen-Heng (Jack) Chung else 1271be16075bSWen-Heng (Jack) Chung vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>( 1272be16075bSWen-Heng (Jack) Chung loc, vecTy.getPointerTo(), dataPtr); 12738345b86dSNicolas Vasilache 12741870e787SNicolas Vasilache if (!xferOp.isMaskedDim(0)) 12751870e787SNicolas Vasilache return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, 12761870e787SNicolas Vasilache xferOp, operands, vectorDataPtr); 12771870e787SNicolas Vasilache 12788345b86dSNicolas Vasilache // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. 12798345b86dSNicolas Vasilache // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 12808345b86dSNicolas Vasilache // 4. Let dim the memref dimension, compute the vector comparison mask: 12818345b86dSNicolas Vasilache // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 1282060c9dd1Saartbik // 1283060c9dd1Saartbik // TODO: when the leaf transfer rank is k > 1, we need the last `k` 1284060c9dd1Saartbik // dimensions here. 1285060c9dd1Saartbik unsigned vecWidth = vecTy.getVectorNumElements(); 1286060c9dd1Saartbik unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 12870c2a4d3cSBenjamin Kramer Value off = xferOp.indices()[lastIndex]; 1288b2c79c50SNicolas Vasilache Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex); 1289060c9dd1Saartbik Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, 1290060c9dd1Saartbik vecWidth, dim, &off); 12918345b86dSNicolas Vasilache 12928345b86dSNicolas Vasilache // 5. Rewrite as a masked read / write. 12931870e787SNicolas Vasilache return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, 1294a99f62c4SAlex Zinenko operands, vectorDataPtr, mask); 12958345b86dSNicolas Vasilache } 1296060c9dd1Saartbik 1297060c9dd1Saartbik private: 1298060c9dd1Saartbik const bool enableIndexOptimizations; 12998345b86dSNicolas Vasilache }; 13008345b86dSNicolas Vasilache 1301870c1fd4SAlex Zinenko class VectorPrintOpConversion : public ConvertToLLVMPattern { 1302d9b500d3SAart Bik public: 1303d9b500d3SAart Bik explicit VectorPrintOpConversion(MLIRContext *context, 1304d9b500d3SAart Bik LLVMTypeConverter &typeConverter) 1305870c1fd4SAlex Zinenko : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, 1306d9b500d3SAart Bik typeConverter) {} 1307d9b500d3SAart Bik 1308d9b500d3SAart Bik // Proof-of-concept lowering implementation that relies on a small 1309d9b500d3SAart Bik // runtime support library, which only needs to provide a few 1310d9b500d3SAart Bik // printing methods (single value for all data types, opening/closing 1311d9b500d3SAart Bik // bracket, comma, newline). The lowering fully unrolls a vector 1312d9b500d3SAart Bik // in terms of these elementary printing operations. The advantage 1313d9b500d3SAart Bik // of this approach is that the library can remain unaware of all 1314d9b500d3SAart Bik // low-level implementation details of vectors while still supporting 1315d9b500d3SAart Bik // output of any shaped and dimensioned vector. Due to full unrolling, 1316d9b500d3SAart Bik // this approach is less suited for very large vectors though. 1317d9b500d3SAart Bik // 13189db53a18SRiver Riddle // TODO: rely solely on libc in future? something else? 1319d9b500d3SAart Bik // 13203145427dSRiver Riddle LogicalResult 1321e62a6956SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1322d9b500d3SAart Bik ConversionPatternRewriter &rewriter) const override { 1323d9b500d3SAart Bik auto printOp = cast<vector::PrintOp>(op); 13242d2c73c5SJacques Pienaar auto adaptor = vector::PrintOpAdaptor(operands); 1325d9b500d3SAart Bik Type printType = printOp.getPrintType(); 1326d9b500d3SAart Bik 13270f04384dSAlex Zinenko if (typeConverter.convertType(printType) == nullptr) 13283145427dSRiver Riddle return failure(); 1329d9b500d3SAart Bik 1330b8880f5fSAart Bik // Make sure element type has runtime support. 1331b8880f5fSAart Bik PrintConversion conversion = PrintConversion::None; 1332d9b500d3SAart Bik VectorType vectorType = printType.dyn_cast<VectorType>(); 1333d9b500d3SAart Bik Type eltType = vectorType ? vectorType.getElementType() : printType; 1334d9b500d3SAart Bik Operation *printer; 1335b8880f5fSAart Bik if (eltType.isF32()) { 1336d9b500d3SAart Bik printer = getPrintFloat(op); 1337b8880f5fSAart Bik } else if (eltType.isF64()) { 1338d9b500d3SAart Bik printer = getPrintDouble(op); 133954759cefSAart Bik } else if (eltType.isIndex()) { 134054759cefSAart Bik printer = getPrintU64(op); 1341b8880f5fSAart Bik } else if (auto intTy = eltType.dyn_cast<IntegerType>()) { 1342b8880f5fSAart Bik // Integers need a zero or sign extension on the operand 1343b8880f5fSAart Bik // (depending on the source type) as well as a signed or 1344b8880f5fSAart Bik // unsigned print method. Up to 64-bit is supported. 1345b8880f5fSAart Bik unsigned width = intTy.getWidth(); 1346b8880f5fSAart Bik if (intTy.isUnsigned()) { 134754759cefSAart Bik if (width <= 64) { 1348b8880f5fSAart Bik if (width < 64) 1349b8880f5fSAart Bik conversion = PrintConversion::ZeroExt64; 1350b8880f5fSAart Bik printer = getPrintU64(op); 1351b8880f5fSAart Bik } else { 13523145427dSRiver Riddle return failure(); 1353b8880f5fSAart Bik } 1354b8880f5fSAart Bik } else { 1355b8880f5fSAart Bik assert(intTy.isSignless() || intTy.isSigned()); 135654759cefSAart Bik if (width <= 64) { 1357b8880f5fSAart Bik // Note that we *always* zero extend booleans (1-bit integers), 1358b8880f5fSAart Bik // so that true/false is printed as 1/0 rather than -1/0. 1359b8880f5fSAart Bik if (width == 1) 136054759cefSAart Bik conversion = PrintConversion::ZeroExt64; 136154759cefSAart Bik else if (width < 64) 1362b8880f5fSAart Bik conversion = PrintConversion::SignExt64; 1363b8880f5fSAart Bik printer = getPrintI64(op); 1364b8880f5fSAart Bik } else { 1365b8880f5fSAart Bik return failure(); 1366b8880f5fSAart Bik } 1367b8880f5fSAart Bik } 1368b8880f5fSAart Bik } else { 1369b8880f5fSAart Bik return failure(); 1370b8880f5fSAart Bik } 1371d9b500d3SAart Bik 1372d9b500d3SAart Bik // Unroll vector into elementary print calls. 1373b8880f5fSAart Bik int64_t rank = vectorType ? vectorType.getRank() : 0; 1374b8880f5fSAart Bik emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, 1375b8880f5fSAart Bik conversion); 1376d9b500d3SAart Bik emitCall(rewriter, op->getLoc(), getPrintNewline(op)); 1377d9b500d3SAart Bik rewriter.eraseOp(op); 13783145427dSRiver Riddle return success(); 1379d9b500d3SAart Bik } 1380d9b500d3SAart Bik 1381d9b500d3SAart Bik private: 1382b8880f5fSAart Bik enum class PrintConversion { 1383b8880f5fSAart Bik None, 1384b8880f5fSAart Bik ZeroExt64, 1385b8880f5fSAart Bik SignExt64 1386b8880f5fSAart Bik }; 1387b8880f5fSAart Bik 1388d9b500d3SAart Bik void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, 1389e62a6956SRiver Riddle Value value, VectorType vectorType, Operation *printer, 1390b8880f5fSAart Bik int64_t rank, PrintConversion conversion) const { 1391d9b500d3SAart Bik Location loc = op->getLoc(); 1392d9b500d3SAart Bik if (rank == 0) { 1393b8880f5fSAart Bik switch (conversion) { 1394b8880f5fSAart Bik case PrintConversion::ZeroExt64: 1395b8880f5fSAart Bik value = rewriter.create<ZeroExtendIOp>( 1396b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1397b8880f5fSAart Bik break; 1398b8880f5fSAart Bik case PrintConversion::SignExt64: 1399b8880f5fSAart Bik value = rewriter.create<SignExtendIOp>( 1400b8880f5fSAart Bik loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); 1401b8880f5fSAart Bik break; 1402b8880f5fSAart Bik case PrintConversion::None: 1403b8880f5fSAart Bik break; 1404c9eeeb38Saartbik } 1405d9b500d3SAart Bik emitCall(rewriter, loc, printer, value); 1406d9b500d3SAart Bik return; 1407d9b500d3SAart Bik } 1408d9b500d3SAart Bik 1409d9b500d3SAart Bik emitCall(rewriter, loc, getPrintOpen(op)); 1410d9b500d3SAart Bik Operation *printComma = getPrintComma(op); 1411d9b500d3SAart Bik int64_t dim = vectorType.getDimSize(0); 1412d9b500d3SAart Bik for (int64_t d = 0; d < dim; ++d) { 1413d9b500d3SAart Bik auto reducedType = 1414d9b500d3SAart Bik rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; 14150f04384dSAlex Zinenko auto llvmType = typeConverter.convertType( 1416d9b500d3SAart Bik rank > 1 ? reducedType : vectorType.getElementType()); 1417e62a6956SRiver Riddle Value nestedVal = 14180f04384dSAlex Zinenko extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); 1419b8880f5fSAart Bik emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, 1420b8880f5fSAart Bik conversion); 1421d9b500d3SAart Bik if (d != dim - 1) 1422d9b500d3SAart Bik emitCall(rewriter, loc, printComma); 1423d9b500d3SAart Bik } 1424d9b500d3SAart Bik emitCall(rewriter, loc, getPrintClose(op)); 1425d9b500d3SAart Bik } 1426d9b500d3SAart Bik 1427d9b500d3SAart Bik // Helper to emit a call. 1428d9b500d3SAart Bik static void emitCall(ConversionPatternRewriter &rewriter, Location loc, 1429d9b500d3SAart Bik Operation *ref, ValueRange params = ValueRange()) { 143008e4f078SRahul Joshi rewriter.create<LLVM::CallOp>(loc, TypeRange(), 1431d9b500d3SAart Bik rewriter.getSymbolRefAttr(ref), params); 1432d9b500d3SAart Bik } 1433d9b500d3SAart Bik 1434d9b500d3SAart Bik // Helper for printer method declaration (first hit) and lookup. 14355446ec85SAlex Zinenko static Operation *getPrint(Operation *op, StringRef name, 14365446ec85SAlex Zinenko ArrayRef<LLVM::LLVMType> params) { 1437d9b500d3SAart Bik auto module = op->getParentOfType<ModuleOp>(); 1438d9b500d3SAart Bik auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); 1439d9b500d3SAart Bik if (func) 1440d9b500d3SAart Bik return func; 1441d9b500d3SAart Bik OpBuilder moduleBuilder(module.getBodyRegion()); 1442d9b500d3SAart Bik return moduleBuilder.create<LLVM::LLVMFuncOp>( 1443d9b500d3SAart Bik op->getLoc(), name, 14445446ec85SAlex Zinenko LLVM::LLVMType::getFunctionTy( 14455446ec85SAlex Zinenko LLVM::LLVMType::getVoidTy(op->getContext()), params, 14465446ec85SAlex Zinenko /*isVarArg=*/false)); 1447d9b500d3SAart Bik } 1448d9b500d3SAart Bik 1449d9b500d3SAart Bik // Helpers for method names. 1450e52414b1Saartbik Operation *getPrintI64(Operation *op) const { 145154759cefSAart Bik return getPrint(op, "printI64", 14525446ec85SAlex Zinenko LLVM::LLVMType::getInt64Ty(op->getContext())); 1453e52414b1Saartbik } 1454b8880f5fSAart Bik Operation *getPrintU64(Operation *op) const { 1455b8880f5fSAart Bik return getPrint(op, "printU64", 1456b8880f5fSAart Bik LLVM::LLVMType::getInt64Ty(op->getContext())); 1457b8880f5fSAart Bik } 1458d9b500d3SAart Bik Operation *getPrintFloat(Operation *op) const { 145954759cefSAart Bik return getPrint(op, "printF32", 14605446ec85SAlex Zinenko LLVM::LLVMType::getFloatTy(op->getContext())); 1461d9b500d3SAart Bik } 1462d9b500d3SAart Bik Operation *getPrintDouble(Operation *op) const { 146354759cefSAart Bik return getPrint(op, "printF64", 14645446ec85SAlex Zinenko LLVM::LLVMType::getDoubleTy(op->getContext())); 1465d9b500d3SAart Bik } 1466d9b500d3SAart Bik Operation *getPrintOpen(Operation *op) const { 146754759cefSAart Bik return getPrint(op, "printOpen", {}); 1468d9b500d3SAart Bik } 1469d9b500d3SAart Bik Operation *getPrintClose(Operation *op) const { 147054759cefSAart Bik return getPrint(op, "printClose", {}); 1471d9b500d3SAart Bik } 1472d9b500d3SAart Bik Operation *getPrintComma(Operation *op) const { 147354759cefSAart Bik return getPrint(op, "printComma", {}); 1474d9b500d3SAart Bik } 1475d9b500d3SAart Bik Operation *getPrintNewline(Operation *op) const { 147654759cefSAart Bik return getPrint(op, "printNewline", {}); 1477d9b500d3SAart Bik } 1478d9b500d3SAart Bik }; 1479d9b500d3SAart Bik 1480334a4159SReid Tatge /// Progressive lowering of ExtractStridedSliceOp to either: 1481c3c95b9cSaartbik /// 1. express single offset extract as a direct shuffle. 1482c3c95b9cSaartbik /// 2. extract + lower rank strided_slice + insert for the n-D case. 1483c3c95b9cSaartbik class VectorExtractStridedSliceOpConversion 1484334a4159SReid Tatge : public OpRewritePattern<ExtractStridedSliceOp> { 148565678d93SNicolas Vasilache public: 1486334a4159SReid Tatge using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 148765678d93SNicolas Vasilache 1488334a4159SReid Tatge LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 148965678d93SNicolas Vasilache PatternRewriter &rewriter) const override { 149065678d93SNicolas Vasilache auto dstType = op.getResult().getType().cast<VectorType>(); 149165678d93SNicolas Vasilache 149265678d93SNicolas Vasilache assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 149365678d93SNicolas Vasilache 149465678d93SNicolas Vasilache int64_t offset = 149565678d93SNicolas Vasilache op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 149665678d93SNicolas Vasilache int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 149765678d93SNicolas Vasilache int64_t stride = 149865678d93SNicolas Vasilache op.strides().getValue().front().cast<IntegerAttr>().getInt(); 149965678d93SNicolas Vasilache 150065678d93SNicolas Vasilache auto loc = op.getLoc(); 150165678d93SNicolas Vasilache auto elemType = dstType.getElementType(); 150235b68527SLei Zhang assert(elemType.isSignlessIntOrIndexOrFloat()); 1503c3c95b9cSaartbik 1504c3c95b9cSaartbik // Single offset can be more efficiently shuffled. 1505c3c95b9cSaartbik if (op.offsets().getValue().size() == 1) { 1506c3c95b9cSaartbik SmallVector<int64_t, 4> offsets; 1507c3c95b9cSaartbik offsets.reserve(size); 1508c3c95b9cSaartbik for (int64_t off = offset, e = offset + size * stride; off < e; 1509c3c95b9cSaartbik off += stride) 1510c3c95b9cSaartbik offsets.push_back(off); 1511c3c95b9cSaartbik rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 1512c3c95b9cSaartbik op.vector(), 1513c3c95b9cSaartbik rewriter.getI64ArrayAttr(offsets)); 1514c3c95b9cSaartbik return success(); 1515c3c95b9cSaartbik } 1516c3c95b9cSaartbik 1517c3c95b9cSaartbik // Extract/insert on a lower ranked extract strided slice op. 151865678d93SNicolas Vasilache Value zero = rewriter.create<ConstantOp>(loc, elemType, 151965678d93SNicolas Vasilache rewriter.getZeroAttr(elemType)); 152065678d93SNicolas Vasilache Value res = rewriter.create<SplatOp>(loc, dstType, zero); 152165678d93SNicolas Vasilache for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 152265678d93SNicolas Vasilache off += stride, ++idx) { 1523c3c95b9cSaartbik Value one = extractOne(rewriter, loc, op.vector(), off); 1524c3c95b9cSaartbik Value extracted = rewriter.create<ExtractStridedSliceOp>( 1525c3c95b9cSaartbik loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 152665678d93SNicolas Vasilache getI64SubArray(op.sizes(), /* dropFront=*/1), 152765678d93SNicolas Vasilache getI64SubArray(op.strides(), /* dropFront=*/1)); 152865678d93SNicolas Vasilache res = insertOne(rewriter, loc, extracted, res, idx); 152965678d93SNicolas Vasilache } 1530c3c95b9cSaartbik rewriter.replaceOp(op, res); 15313145427dSRiver Riddle return success(); 153265678d93SNicolas Vasilache } 1533334a4159SReid Tatge /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is 1534bd1ccfe6SRiver Riddle /// bounded as the rank is strictly decreasing. 1535bd1ccfe6SRiver Riddle bool hasBoundedRewriteRecursion() const final { return true; } 153665678d93SNicolas Vasilache }; 153765678d93SNicolas Vasilache 1538df186507SBenjamin Kramer } // namespace 1539df186507SBenjamin Kramer 15405c0c51a9SNicolas Vasilache /// Populate the given list with patterns that convert from Vector to LLVM. 15415c0c51a9SNicolas Vasilache void mlir::populateVectorToLLVMConversionPatterns( 1542ceb1b327Saartbik LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 1543060c9dd1Saartbik bool reassociateFPReductions, bool enableIndexOptimizations) { 154465678d93SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 15458345b86dSNicolas Vasilache // clang-format off 1546681f929fSNicolas Vasilache patterns.insert<VectorFMAOpNDRewritePattern, 1547681f929fSNicolas Vasilache VectorInsertStridedSliceOpDifferentRankRewritePattern, 15482d515e49SNicolas Vasilache VectorInsertStridedSliceOpSameRankRewritePattern, 1549c3c95b9cSaartbik VectorExtractStridedSliceOpConversion>(ctx); 1550ceb1b327Saartbik patterns.insert<VectorReductionOpConversion>( 1551ceb1b327Saartbik ctx, converter, reassociateFPReductions); 1552060c9dd1Saartbik patterns.insert<VectorCreateMaskOpConversion, 1553060c9dd1Saartbik VectorTransferConversion<TransferReadOp>, 1554060c9dd1Saartbik VectorTransferConversion<TransferWriteOp>>( 1555060c9dd1Saartbik ctx, converter, enableIndexOptimizations); 15568345b86dSNicolas Vasilache patterns 1557ceb1b327Saartbik .insert<VectorShuffleOpConversion, 15588345b86dSNicolas Vasilache VectorExtractElementOpConversion, 15598345b86dSNicolas Vasilache VectorExtractOpConversion, 15608345b86dSNicolas Vasilache VectorFMAOp1DConversion, 15618345b86dSNicolas Vasilache VectorInsertElementOpConversion, 15628345b86dSNicolas Vasilache VectorInsertOpConversion, 15638345b86dSNicolas Vasilache VectorPrintOpConversion, 156419dbb230Saartbik VectorTypeCastOpConversion, 156539379916Saartbik VectorMaskedLoadOpConversion, 156639379916Saartbik VectorMaskedStoreOpConversion, 156719dbb230Saartbik VectorGatherOpConversion, 1568e8dcf5f8Saartbik VectorScatterOpConversion, 1569e8dcf5f8Saartbik VectorExpandLoadOpConversion, 1570e8dcf5f8Saartbik VectorCompressStoreOpConversion>(ctx, converter); 15718345b86dSNicolas Vasilache // clang-format on 15725c0c51a9SNicolas Vasilache } 15735c0c51a9SNicolas Vasilache 157463b683a8SNicolas Vasilache void mlir::populateVectorToLLVMMatrixConversionPatterns( 157563b683a8SNicolas Vasilache LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 157663b683a8SNicolas Vasilache MLIRContext *ctx = converter.getDialect()->getContext(); 157763b683a8SNicolas Vasilache patterns.insert<VectorMatmulOpConversion>(ctx, converter); 1578c295a65dSaartbik patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter); 157963b683a8SNicolas Vasilache } 158063b683a8SNicolas Vasilache 15815c0c51a9SNicolas Vasilache namespace { 1582722f909fSRiver Riddle struct LowerVectorToLLVMPass 15831834ad4aSRiver Riddle : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 15841bfdf7c7Saartbik LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 15851bfdf7c7Saartbik this->reassociateFPReductions = options.reassociateFPReductions; 1586060c9dd1Saartbik this->enableIndexOptimizations = options.enableIndexOptimizations; 15871bfdf7c7Saartbik } 1588722f909fSRiver Riddle void runOnOperation() override; 15895c0c51a9SNicolas Vasilache }; 15905c0c51a9SNicolas Vasilache } // namespace 15915c0c51a9SNicolas Vasilache 1592722f909fSRiver Riddle void LowerVectorToLLVMPass::runOnOperation() { 1593078776a6Saartbik // Perform progressive lowering of operations on slices and 1594b21c7999Saartbik // all contraction operations. Also applies folding and DCE. 1595459cf6e5Saartbik { 15965c0c51a9SNicolas Vasilache OwningRewritePatternList patterns; 1597b1c688dbSaartbik populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 1598459cf6e5Saartbik populateVectorSlicesLoweringPatterns(patterns, &getContext()); 1599b21c7999Saartbik populateVectorContractLoweringPatterns(patterns, &getContext()); 1600a5b9316bSUday Bondhugula applyPatternsAndFoldGreedily(getOperation(), patterns); 1601459cf6e5Saartbik } 1602459cf6e5Saartbik 1603459cf6e5Saartbik // Convert to the LLVM IR dialect. 16045c0c51a9SNicolas Vasilache LLVMTypeConverter converter(&getContext()); 1605459cf6e5Saartbik OwningRewritePatternList patterns; 160663b683a8SNicolas Vasilache populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 1607060c9dd1Saartbik populateVectorToLLVMConversionPatterns( 1608060c9dd1Saartbik converter, patterns, reassociateFPReductions, enableIndexOptimizations); 1609bbf3ef85SNicolas Vasilache populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 16105c0c51a9SNicolas Vasilache populateStdToLLVMConversionPatterns(converter, patterns); 16115c0c51a9SNicolas Vasilache 16122a00ae39STim Shen LLVMConversionTarget target(getContext()); 1613060c9dd1Saartbik if (failed(applyPartialConversion(getOperation(), target, patterns))) 16145c0c51a9SNicolas Vasilache signalPassFailure(); 16155c0c51a9SNicolas Vasilache } 16165c0c51a9SNicolas Vasilache 16171bfdf7c7Saartbik std::unique_ptr<OperationPass<ModuleOp>> 16181bfdf7c7Saartbik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 16191bfdf7c7Saartbik return std::make_unique<LowerVectorToLLVMPass>(options); 16205c0c51a9SNicolas Vasilache } 1621