1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 10 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/ImplicitLocOpBuilder.h" 13 #include "mlir/IR/TypeUtilities.h" 14 15 #define DEBUG_TYPE "vector-drop-unit-dim" 16 17 using namespace mlir; 18 using namespace mlir::vector; 19 20 // Trims leading one dimensions from `oldType` and returns the result type. 21 // Returns `vector<1xT>` if `oldType` only has one element. 22 static VectorType trimLeadingOneDims(VectorType oldType) { 23 ArrayRef<int64_t> oldShape = oldType.getShape(); 24 ArrayRef<int64_t> newShape = 25 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 26 // Make sure we have at least 1 dimension per vector type requirements. 27 if (newShape.empty()) 28 newShape = oldShape.take_back(); 29 return VectorType::get(newShape, oldType.getElementType()); 30 } 31 32 /// Return a smallVector of size `rank` containing all zeros. 33 static SmallVector<int64_t> splatZero(int64_t rank) { 34 return SmallVector<int64_t>(rank, 0); 35 } 36 namespace { 37 38 // Casts away leading one dimensions in vector.extract_strided_slice's vector 39 // input by inserting vector.shape_cast. 40 struct CastAwayExtractStridedSliceLeadingOneDim 41 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 42 using OpRewritePattern::OpRewritePattern; 43 44 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 45 PatternRewriter &rewriter) const override { 46 // vector.extract_strided_slice requires the input and output vector to have 47 // the same rank. Here we drop leading one dimensions from the input vector 48 // type to make sure we don't cause mismatch. 49 VectorType oldSrcType = extractOp.getVectorType(); 50 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 51 52 if (newSrcType.getRank() == oldSrcType.getRank()) 53 return failure(); 54 55 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 56 57 VectorType oldDstType = extractOp.getType(); 58 VectorType newDstType = 59 VectorType::get(oldDstType.getShape().drop_front(dropCount), 60 oldDstType.getElementType()); 61 62 Location loc = extractOp.getLoc(); 63 64 Value newSrcVector = rewriter.create<vector::ExtractOp>( 65 loc, extractOp.vector(), splatZero(dropCount)); 66 67 // The offsets/sizes/strides attribute can have a less number of elements 68 // than the input vector's rank: it is meant for the leading dimensions. 69 auto newOffsets = rewriter.getArrayAttr( 70 extractOp.offsets().getValue().drop_front(dropCount)); 71 auto newSizes = rewriter.getArrayAttr( 72 extractOp.sizes().getValue().drop_front(dropCount)); 73 auto newStrides = rewriter.getArrayAttr( 74 extractOp.strides().getValue().drop_front(dropCount)); 75 76 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 77 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 78 79 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 80 newExtractOp); 81 82 return success(); 83 } 84 }; 85 86 // Casts away leading one dimensions in vector.extract_strided_slice's vector 87 // inputs by inserting vector.shape_cast. 88 struct CastAwayInsertStridedSliceLeadingOneDim 89 : public OpRewritePattern<vector::InsertStridedSliceOp> { 90 using OpRewritePattern::OpRewritePattern; 91 92 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 93 PatternRewriter &rewriter) const override { 94 VectorType oldSrcType = insertOp.getSourceVectorType(); 95 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 96 VectorType oldDstType = insertOp.getDestVectorType(); 97 VectorType newDstType = trimLeadingOneDims(oldDstType); 98 99 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 100 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 101 if (srcDropCount == 0 && dstDropCount == 0) 102 return failure(); 103 104 // Trim leading one dimensions from both operands. 105 Location loc = insertOp.getLoc(); 106 107 Value newSrcVector = rewriter.create<vector::ExtractOp>( 108 loc, insertOp.source(), splatZero(srcDropCount)); 109 Value newDstVector = rewriter.create<vector::ExtractOp>( 110 loc, insertOp.dest(), splatZero(dstDropCount)); 111 112 auto newOffsets = rewriter.getArrayAttr( 113 insertOp.offsets().getValue().take_back(newDstType.getRank())); 114 auto newStrides = rewriter.getArrayAttr( 115 insertOp.strides().getValue().take_back(newSrcType.getRank())); 116 117 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 118 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 119 120 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 121 newInsertOp); 122 123 return success(); 124 } 125 }; 126 127 // Turns vector.transfer_read on vector with leading 1 dimensions into 128 // vector.shape_cast followed by vector.transfer_read on vector without leading 129 // 1 dimensions. 130 struct CastAwayTransferReadLeadingOneDim 131 : public OpRewritePattern<vector::TransferReadOp> { 132 using OpRewritePattern::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(vector::TransferReadOp read, 135 PatternRewriter &rewriter) const override { 136 // TODO: support 0-d corner case. 137 if (read.getTransferRank() == 0) 138 return failure(); 139 140 if (read.mask()) 141 return failure(); 142 143 auto shapedType = read.source().getType().cast<ShapedType>(); 144 if (shapedType.getElementType() != read.getVectorType().getElementType()) 145 return failure(); 146 147 VectorType oldType = read.getVectorType(); 148 VectorType newType = trimLeadingOneDims(oldType); 149 150 if (newType == oldType) 151 return failure(); 152 153 AffineMap oldMap = read.permutation_map(); 154 ArrayRef<AffineExpr> newResults = 155 oldMap.getResults().take_back(newType.getRank()); 156 AffineMap newMap = 157 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 158 rewriter.getContext()); 159 160 ArrayAttr inBoundsAttr; 161 if (read.in_bounds()) 162 inBoundsAttr = rewriter.getArrayAttr( 163 read.in_boundsAttr().getValue().take_back(newType.getRank())); 164 165 auto newRead = rewriter.create<vector::TransferReadOp>( 166 read.getLoc(), newType, read.source(), read.indices(), 167 AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(), 168 inBoundsAttr); 169 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 170 171 return success(); 172 } 173 }; 174 175 // Turns vector.transfer_write on vector with leading 1 dimensions into 176 // vector.shape_cast followed by vector.transfer_write on vector without leading 177 // 1 dimensions. 178 struct CastAwayTransferWriteLeadingOneDim 179 : public OpRewritePattern<vector::TransferWriteOp> { 180 using OpRewritePattern::OpRewritePattern; 181 182 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 183 PatternRewriter &rewriter) const override { 184 // TODO: support 0-d corner case. 185 if (write.getTransferRank() == 0) 186 return failure(); 187 188 if (write.mask()) 189 return failure(); 190 191 auto shapedType = write.source().getType().dyn_cast<ShapedType>(); 192 if (shapedType.getElementType() != write.getVectorType().getElementType()) 193 return failure(); 194 195 VectorType oldType = write.getVectorType(); 196 VectorType newType = trimLeadingOneDims(oldType); 197 if (newType == oldType) 198 return failure(); 199 int64_t dropDim = oldType.getRank() - newType.getRank(); 200 201 AffineMap oldMap = write.permutation_map(); 202 ArrayRef<AffineExpr> newResults = 203 oldMap.getResults().take_back(newType.getRank()); 204 AffineMap newMap = 205 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 206 rewriter.getContext()); 207 208 ArrayAttr inBoundsAttr; 209 if (write.in_bounds()) 210 inBoundsAttr = rewriter.getArrayAttr( 211 write.in_boundsAttr().getValue().take_back(newType.getRank())); 212 213 auto newVector = rewriter.create<vector::ExtractOp>( 214 write.getLoc(), write.vector(), splatZero(dropDim)); 215 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 216 write, newVector, write.source(), write.indices(), 217 AffineMapAttr::get(newMap), inBoundsAttr); 218 219 return success(); 220 } 221 }; 222 223 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 224 public: 225 CastAwayElementwiseLeadingOneDim(MLIRContext *context) 226 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 227 228 LogicalResult matchAndRewrite(Operation *op, 229 PatternRewriter &rewriter) const override { 230 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 231 return failure(); 232 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); 233 if (!vecType) 234 return failure(); 235 VectorType newVecType = trimLeadingOneDims(vecType); 236 if (newVecType == vecType) 237 return failure(); 238 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 239 SmallVector<Value, 4> newOperands; 240 for (Value operand : op->getOperands()) { 241 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { 242 newOperands.push_back(rewriter.create<vector::ExtractOp>( 243 op->getLoc(), operand, splatZero(dropDim))); 244 } else { 245 newOperands.push_back(operand); 246 } 247 } 248 OperationState state(op->getLoc(), op->getName()); 249 state.addAttributes(op->getAttrs()); 250 state.addOperands(newOperands); 251 state.addTypes(newVecType); 252 Operation *newOp = rewriter.createOperation(state); 253 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 254 newOp->getResult(0)); 255 return success(); 256 } 257 }; 258 259 } // namespace 260 261 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 262 RewritePatternSet &patterns) { 263 patterns.add<CastAwayExtractStridedSliceLeadingOneDim, 264 CastAwayInsertStridedSliceLeadingOneDim, 265 CastAwayTransferReadLeadingOneDim, 266 CastAwayTransferWriteLeadingOneDim, 267 CastAwayElementwiseLeadingOneDim>(patterns.getContext()); 268 populateShapeCastFoldingPatterns(patterns); 269 } 270