1 //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 10 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 11 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/ImplicitLocOpBuilder.h" 14 #include "mlir/IR/TypeUtilities.h" 15 16 #define DEBUG_TYPE "vector-drop-unit-dim" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 // Trims leading one dimensions from `oldType` and returns the result type. 22 // Returns `vector<1xT>` if `oldType` only has one element. 23 static VectorType trimLeadingOneDims(VectorType oldType) { 24 ArrayRef<int64_t> oldShape = oldType.getShape(); 25 ArrayRef<int64_t> newShape = 26 oldShape.drop_while([](int64_t dim) { return dim == 1; }); 27 // Make sure we have at least 1 dimension per vector type requirements. 28 if (newShape.empty()) 29 newShape = oldShape.take_back(); 30 return VectorType::get(newShape, oldType.getElementType()); 31 } 32 33 /// Return a smallVector of size `rank` containing all zeros. 34 static SmallVector<int64_t> splatZero(int64_t rank) { 35 return SmallVector<int64_t>(rank, 0); 36 } 37 namespace { 38 39 // Casts away leading one dimensions in vector.extract_strided_slice's vector 40 // input by inserting vector.shape_cast. 41 struct CastAwayExtractStridedSliceLeadingOneDim 42 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 43 using OpRewritePattern::OpRewritePattern; 44 45 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 46 PatternRewriter &rewriter) const override { 47 // vector.extract_strided_slice requires the input and output vector to have 48 // the same rank. Here we drop leading one dimensions from the input vector 49 // type to make sure we don't cause mismatch. 50 VectorType oldSrcType = extractOp.getVectorType(); 51 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 52 53 if (newSrcType.getRank() == oldSrcType.getRank()) 54 return failure(); 55 56 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); 57 58 VectorType oldDstType = extractOp.getType(); 59 VectorType newDstType = 60 VectorType::get(oldDstType.getShape().drop_front(dropCount), 61 oldDstType.getElementType()); 62 63 Location loc = extractOp.getLoc(); 64 65 Value newSrcVector = rewriter.create<vector::ExtractOp>( 66 loc, extractOp.vector(), splatZero(dropCount)); 67 68 // The offsets/sizes/strides attribute can have a less number of elements 69 // than the input vector's rank: it is meant for the leading dimensions. 70 auto newOffsets = rewriter.getArrayAttr( 71 extractOp.offsets().getValue().drop_front(dropCount)); 72 auto newSizes = rewriter.getArrayAttr( 73 extractOp.sizes().getValue().drop_front(dropCount)); 74 auto newStrides = rewriter.getArrayAttr( 75 extractOp.strides().getValue().drop_front(dropCount)); 76 77 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 78 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); 79 80 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, 81 newExtractOp); 82 83 return success(); 84 } 85 }; 86 87 // Casts away leading one dimensions in vector.extract_strided_slice's vector 88 // inputs by inserting vector.shape_cast. 89 struct CastAwayInsertStridedSliceLeadingOneDim 90 : public OpRewritePattern<vector::InsertStridedSliceOp> { 91 using OpRewritePattern::OpRewritePattern; 92 93 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, 94 PatternRewriter &rewriter) const override { 95 VectorType oldSrcType = insertOp.getSourceVectorType(); 96 VectorType newSrcType = trimLeadingOneDims(oldSrcType); 97 VectorType oldDstType = insertOp.getDestVectorType(); 98 VectorType newDstType = trimLeadingOneDims(oldDstType); 99 100 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); 101 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); 102 if (srcDropCount == 0 && dstDropCount == 0) 103 return failure(); 104 105 // Trim leading one dimensions from both operands. 106 Location loc = insertOp.getLoc(); 107 108 Value newSrcVector = rewriter.create<vector::ExtractOp>( 109 loc, insertOp.source(), splatZero(srcDropCount)); 110 Value newDstVector = rewriter.create<vector::ExtractOp>( 111 loc, insertOp.dest(), splatZero(dstDropCount)); 112 113 auto newOffsets = rewriter.getArrayAttr( 114 insertOp.offsets().getValue().take_back(newDstType.getRank())); 115 auto newStrides = rewriter.getArrayAttr( 116 insertOp.strides().getValue().take_back(newSrcType.getRank())); 117 118 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( 119 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); 120 121 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, 122 newInsertOp); 123 124 return success(); 125 } 126 }; 127 128 // Turns vector.transfer_read on vector with leading 1 dimensions into 129 // vector.shape_cast followed by vector.transfer_read on vector without leading 130 // 1 dimensions. 131 struct CastAwayTransferReadLeadingOneDim 132 : public OpRewritePattern<vector::TransferReadOp> { 133 using OpRewritePattern::OpRewritePattern; 134 135 LogicalResult matchAndRewrite(vector::TransferReadOp read, 136 PatternRewriter &rewriter) const override { 137 // TODO: support 0-d corner case. 138 if (read.getTransferRank() == 0) 139 return failure(); 140 141 if (read.mask()) 142 return failure(); 143 144 auto shapedType = read.source().getType().cast<ShapedType>(); 145 if (shapedType.getElementType() != read.getVectorType().getElementType()) 146 return failure(); 147 148 VectorType oldType = read.getVectorType(); 149 VectorType newType = trimLeadingOneDims(oldType); 150 151 if (newType == oldType) 152 return failure(); 153 154 AffineMap oldMap = read.permutation_map(); 155 ArrayRef<AffineExpr> newResults = 156 oldMap.getResults().take_back(newType.getRank()); 157 AffineMap newMap = 158 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 159 rewriter.getContext()); 160 161 ArrayAttr inBoundsAttr; 162 if (read.in_bounds()) 163 inBoundsAttr = rewriter.getArrayAttr( 164 read.in_boundsAttr().getValue().take_back(newType.getRank())); 165 166 auto newRead = rewriter.create<vector::TransferReadOp>( 167 read.getLoc(), newType, read.source(), read.indices(), 168 AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(), 169 inBoundsAttr); 170 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); 171 172 return success(); 173 } 174 }; 175 176 // Turns vector.transfer_write on vector with leading 1 dimensions into 177 // vector.shape_cast followed by vector.transfer_write on vector without leading 178 // 1 dimensions. 179 struct CastAwayTransferWriteLeadingOneDim 180 : public OpRewritePattern<vector::TransferWriteOp> { 181 using OpRewritePattern::OpRewritePattern; 182 183 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 184 PatternRewriter &rewriter) const override { 185 // TODO: support 0-d corner case. 186 if (write.getTransferRank() == 0) 187 return failure(); 188 189 if (write.mask()) 190 return failure(); 191 192 auto shapedType = write.source().getType().dyn_cast<ShapedType>(); 193 if (shapedType.getElementType() != write.getVectorType().getElementType()) 194 return failure(); 195 196 VectorType oldType = write.getVectorType(); 197 VectorType newType = trimLeadingOneDims(oldType); 198 if (newType == oldType) 199 return failure(); 200 int64_t dropDim = oldType.getRank() - newType.getRank(); 201 202 AffineMap oldMap = write.permutation_map(); 203 ArrayRef<AffineExpr> newResults = 204 oldMap.getResults().take_back(newType.getRank()); 205 AffineMap newMap = 206 AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, 207 rewriter.getContext()); 208 209 ArrayAttr inBoundsAttr; 210 if (write.in_bounds()) 211 inBoundsAttr = rewriter.getArrayAttr( 212 write.in_boundsAttr().getValue().take_back(newType.getRank())); 213 214 auto newVector = rewriter.create<vector::ExtractOp>( 215 write.getLoc(), write.vector(), splatZero(dropDim)); 216 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 217 write, newVector, write.source(), write.indices(), 218 AffineMapAttr::get(newMap), inBoundsAttr); 219 220 return success(); 221 } 222 }; 223 224 /// Turns vector.contract on vector with leading 1 dimensions into 225 /// vector.extract followed by vector.contract on vector without leading 226 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required 227 /// prior to extract. 228 struct CastAwayContractionLeadingOneDim 229 : public OpRewritePattern<vector::ContractionOp> { 230 using OpRewritePattern::OpRewritePattern; 231 232 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 233 PatternRewriter &rewriter) const override { 234 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); 235 if (oldAccType == nullptr) 236 return failure(); 237 if (oldAccType.getRank() < 2) 238 return failure(); 239 // TODO: implement masks. 240 if (llvm::size(contractOp.masks()) != 0) 241 return failure(); 242 if (oldAccType.getShape()[0] != 1) 243 return failure(); 244 // currently we support only dropping one dim but the pattern can be applied 245 // greedily to drop more. 246 int64_t dropDim = 1; 247 248 auto oldIndexingMaps = contractOp.getIndexingMaps(); 249 SmallVector<AffineMap> newIndexingMaps; 250 251 auto oldIteratorTypes = contractOp.iterator_types(); 252 SmallVector<Attribute> newIteratorTypes; 253 254 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); 255 256 if (!isParallelIterator(oldIteratorTypes[dimToDrop])) 257 // only parallel type iterators can be dropped. 258 return failure(); 259 260 for (const auto &it : llvm::enumerate(oldIteratorTypes)) { 261 int64_t currDim = it.index(); 262 if (currDim == dimToDrop) 263 continue; 264 newIteratorTypes.push_back(it.value()); 265 } 266 267 SmallVector<Value> operands = {contractOp.lhs(), contractOp.rhs(), 268 contractOp.acc()}; 269 SmallVector<Value> newOperands; 270 271 for (const auto &it : llvm::enumerate(oldIndexingMaps)) { 272 // Check if the dim to be dropped exists as a leading dim in the operand 273 // if it does then we use vector.extract to drop it. 274 bool validExtract = false; 275 SmallVector<AffineExpr> results; 276 auto map = it.value(); 277 int64_t orginalZeroDim = it.value().getDimPosition(0); 278 if (orginalZeroDim != dimToDrop) { 279 // There are two reasons to be in this path, 1. We need to 280 // tranpose the operand to make the dim to be dropped 281 // leading. 2. The dim to be dropped does not exist and in 282 // that case we dont want to add a unit tranpose but we must 283 // check all the indices to make sure this is the case. 284 bool tranposeNeeded = false; 285 SmallVector<int64_t> perm; 286 SmallVector<AffineExpr> transposeResults; 287 288 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 289 int64_t currDim = map.getDimPosition(i); 290 if (currDim == dimToDrop) { 291 tranposeNeeded = true; 292 perm.insert(perm.begin(), i); 293 auto targetExpr = rewriter.getAffineDimExpr(currDim); 294 transposeResults.insert(transposeResults.begin(), targetExpr); 295 } else { 296 perm.push_back(i); 297 auto targetExpr = rewriter.getAffineDimExpr(currDim); 298 transposeResults.push_back(targetExpr); 299 } 300 } 301 // Do the tranpose now if needed so that we can drop the 302 // correct dim using extract later. 303 if (tranposeNeeded) { 304 map = AffineMap::get(map.getNumDims(), 0, transposeResults, 305 contractOp.getContext()); 306 operands[it.index()] = rewriter.create<vector::TransposeOp>( 307 contractOp.getLoc(), operands[it.index()], perm); 308 } 309 } 310 // We have taken care to have the dim to be dropped be 311 // the leading dim. If its still not leading that means it 312 // does not exist in this operand and hence we do not need 313 // an extract. 314 if (map.getDimPosition(0) == dimToDrop) 315 validExtract = true; 316 317 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 318 int64_t currDim = map.getDimPosition(i); 319 if (currDim == dimToDrop) 320 // This is the dim we are dropping. 321 continue; 322 auto targetExpr = rewriter.getAffineDimExpr( 323 currDim < dimToDrop ? currDim : currDim - 1); 324 results.push_back(targetExpr); 325 } 326 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, 327 contractOp.getContext())); 328 // Extract if its a valid extraction, otherwise use the operand 329 // without extraction. 330 newOperands.push_back(validExtract 331 ? rewriter.create<vector::ExtractOp>( 332 contractOp.getLoc(), operands[it.index()], 333 splatZero(dropDim)) 334 : operands[it.index()]); 335 } 336 auto newContractOp = rewriter.create<vector::ContractionOp>( 337 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], 338 rewriter.getAffineMapArrayAttr(newIndexingMaps), 339 rewriter.getArrayAttr(newIteratorTypes), contractOp.kind()); 340 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 341 contractOp, contractOp->getResultTypes()[0], newContractOp); 342 return success(); 343 } 344 }; 345 346 class CastAwayElementwiseLeadingOneDim : public RewritePattern { 347 public: 348 CastAwayElementwiseLeadingOneDim(MLIRContext *context) 349 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 350 351 LogicalResult matchAndRewrite(Operation *op, 352 PatternRewriter &rewriter) const override { 353 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 354 return failure(); 355 auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); 356 if (!vecType) 357 return failure(); 358 VectorType newVecType = trimLeadingOneDims(vecType); 359 if (newVecType == vecType) 360 return failure(); 361 int64_t dropDim = vecType.getRank() - newVecType.getRank(); 362 SmallVector<Value, 4> newOperands; 363 for (Value operand : op->getOperands()) { 364 if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { 365 newOperands.push_back(rewriter.create<vector::ExtractOp>( 366 op->getLoc(), operand, splatZero(dropDim))); 367 } else { 368 newOperands.push_back(operand); 369 } 370 } 371 OperationState state(op->getLoc(), op->getName()); 372 state.addAttributes(op->getAttrs()); 373 state.addOperands(newOperands); 374 state.addTypes(newVecType); 375 Operation *newOp = rewriter.createOperation(state); 376 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, 377 newOp->getResult(0)); 378 return success(); 379 } 380 }; 381 382 } // namespace 383 384 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( 385 RewritePatternSet &patterns) { 386 patterns 387 .add<CastAwayExtractStridedSliceLeadingOneDim, 388 CastAwayInsertStridedSliceLeadingOneDim, 389 CastAwayTransferReadLeadingOneDim, 390 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, 391 CastAwayContractionLeadingOneDim>(patterns.getContext()); 392 populateShapeCastFoldingPatterns(patterns); 393 } 394