1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/Utils/IndexingUtils.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 17 using namespace mlir; 18 using namespace mlir::vector; 19 20 // Helper that picks the proper sequence for inserting. 21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 22 Value into, int64_t offset) { 23 auto vectorType = into.getType().cast<VectorType>(); 24 if (vectorType.getRank() > 1) 25 return rewriter.create<InsertOp>(loc, from, into, offset); 26 return rewriter.create<vector::InsertElementOp>( 27 loc, vectorType, from, into, 28 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 29 } 30 31 // Helper that picks the proper sequence for extracting. 32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 33 int64_t offset) { 34 auto vectorType = vector.getType().cast<VectorType>(); 35 if (vectorType.getRank() > 1) 36 return rewriter.create<ExtractOp>(loc, vector, offset); 37 return rewriter.create<vector::ExtractElementOp>( 38 loc, vectorType.getElementType(), vector, 39 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 40 } 41 42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 43 /// have different ranks. 44 /// 45 /// When ranks are different, InsertStridedSlice needs to extract a properly 46 /// ranked vector from the destination vector into which to insert. This pattern 47 /// only takes care of this extraction part and forwards the rest to 48 /// [VectorInsertStridedSliceOpSameRankRewritePattern]. 49 /// 50 /// For a k-D source and n-D destination vector (k < n), we emit: 51 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 52 /// insert the k-D source. 53 /// 2. k-D -> (n-1)-D InsertStridedSlice op 54 /// 3. InsertOp that is the reverse of 1. 55 class VectorInsertStridedSliceOpDifferentRankRewritePattern 56 : public OpRewritePattern<InsertStridedSliceOp> { 57 public: 58 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 61 PatternRewriter &rewriter) const override { 62 auto srcType = op.getSourceVectorType(); 63 auto dstType = op.getDestVectorType(); 64 65 if (op.offsets().getValue().empty()) 66 return failure(); 67 68 auto loc = op.getLoc(); 69 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 70 assert(rankDiff >= 0); 71 if (rankDiff == 0) 72 return failure(); 73 74 int64_t rankRest = dstType.getRank() - rankDiff; 75 // Extract / insert the subvector of matching rank and InsertStridedSlice 76 // on it. 77 Value extracted = 78 rewriter.create<ExtractOp>(loc, op.dest(), 79 getI64SubArray(op.offsets(), /*dropFront=*/0, 80 /*dropBack=*/rankRest)); 81 82 // A different pattern will kick in for InsertStridedSlice with matching 83 // ranks. 84 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 85 loc, op.source(), extracted, 86 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 87 getI64SubArray(op.strides(), /*dropFront=*/0)); 88 89 rewriter.replaceOpWithNewOp<InsertOp>( 90 op, stridedSliceInnerOp.getResult(), op.dest(), 91 getI64SubArray(op.offsets(), /*dropFront=*/0, 92 /*dropBack=*/rankRest)); 93 return success(); 94 } 95 }; 96 97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 98 /// have the same rank. For each outermost index in the slice: 99 /// begin end stride 100 /// [offset : offset+size*stride : stride] 101 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 102 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 103 /// 3. the destination subvector is inserted back in the proper place 104 /// 3. InsertOp that is the reverse of 1. 105 class VectorInsertStridedSliceOpSameRankRewritePattern 106 : public OpRewritePattern<InsertStridedSliceOp> { 107 public: 108 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 109 110 void initialize() { 111 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 112 // bounded as the rank is strictly decreasing. 113 setHasBoundedRewriteRecursion(); 114 } 115 116 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 117 PatternRewriter &rewriter) const override { 118 auto srcType = op.getSourceVectorType(); 119 auto dstType = op.getDestVectorType(); 120 121 if (op.offsets().getValue().empty()) 122 return failure(); 123 124 int64_t srcRank = srcType.getRank(); 125 int64_t dstRank = dstType.getRank(); 126 assert(dstRank >= srcRank); 127 if (dstRank != srcRank) 128 return failure(); 129 130 if (srcType == dstType) { 131 rewriter.replaceOp(op, op.source()); 132 return success(); 133 } 134 135 int64_t offset = 136 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 137 int64_t size = srcType.getShape().front(); 138 int64_t stride = 139 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 140 141 auto loc = op.getLoc(); 142 Value res = op.dest(); 143 144 if (srcRank == 1) { 145 int nSrc = srcType.getShape().front(); 146 int nDest = dstType.getShape().front(); 147 // 1. Scale source to destType so we can shufflevector them together. 148 SmallVector<int64_t> offsets(nDest, 0); 149 for (int64_t i = 0; i < nSrc; ++i) 150 offsets[i] = i; 151 Value scaledSource = 152 rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets); 153 154 // 2. Create a mask where we take the value from scaledSource of dest 155 // depending on the offset. 156 offsets.clear(); 157 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { 158 if (i < offset || i >= e || (i - offset) % stride != 0) 159 offsets.push_back(nDest + i); 160 else 161 offsets.push_back((i - offset) / stride); 162 } 163 164 // 3. Replace with a ShuffleOp. 165 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(), 166 offsets); 167 168 return success(); 169 } 170 171 // For each slice of the source vector along the most major dimension. 172 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 173 off += stride, ++idx) { 174 // 1. extract the proper subvector (or element) from source 175 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 176 if (extractedSource.getType().isa<VectorType>()) { 177 // 2. If we have a vector, extract the proper subvector from destination 178 // Otherwise we are at the element level and no need to recurse. 179 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 180 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 181 // smaller rank. 182 extractedSource = rewriter.create<InsertStridedSliceOp>( 183 loc, extractedSource, extractedDest, 184 getI64SubArray(op.offsets(), /* dropFront=*/1), 185 getI64SubArray(op.strides(), /* dropFront=*/1)); 186 } 187 // 4. Insert the extractedSource into the res vector. 188 res = insertOne(rewriter, loc, extractedSource, res, off); 189 } 190 191 rewriter.replaceOp(op, res); 192 return success(); 193 } 194 }; 195 196 /// Progressive lowering of ExtractStridedSliceOp to either: 197 /// 1. single offset extract as a direct vector::ShuffleOp. 198 /// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + 199 /// InsertOp/InsertElementOp for the n-D case. 200 class VectorExtractStridedSliceOpRewritePattern 201 : public OpRewritePattern<ExtractStridedSliceOp> { 202 public: 203 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 204 205 void initialize() { 206 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 207 // is bounded as the rank is strictly decreasing. 208 setHasBoundedRewriteRecursion(); 209 } 210 211 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 212 PatternRewriter &rewriter) const override { 213 auto dstType = op.getType(); 214 215 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 216 217 int64_t offset = 218 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 219 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 220 int64_t stride = 221 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 222 223 auto loc = op.getLoc(); 224 auto elemType = dstType.getElementType(); 225 assert(elemType.isSignlessIntOrIndexOrFloat()); 226 227 // Single offset can be more efficiently shuffled. 228 if (op.offsets().getValue().size() == 1) { 229 SmallVector<int64_t, 4> offsets; 230 offsets.reserve(size); 231 for (int64_t off = offset, e = offset + size * stride; off < e; 232 off += stride) 233 offsets.push_back(off); 234 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 235 op.vector(), 236 rewriter.getI64ArrayAttr(offsets)); 237 return success(); 238 } 239 240 // Extract/insert on a lower ranked extract strided slice op. 241 Value zero = rewriter.create<arith::ConstantOp>( 242 loc, elemType, rewriter.getZeroAttr(elemType)); 243 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 244 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 245 off += stride, ++idx) { 246 Value one = extractOne(rewriter, loc, op.vector(), off); 247 Value extracted = rewriter.create<ExtractStridedSliceOp>( 248 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 249 getI64SubArray(op.sizes(), /* dropFront=*/1), 250 getI64SubArray(op.strides(), /* dropFront=*/1)); 251 res = insertOne(rewriter, loc, extracted, res, idx); 252 } 253 rewriter.replaceOp(op, res); 254 return success(); 255 } 256 }; 257 258 /// Populate the given list with patterns that convert from Vector to LLVM. 259 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( 260 RewritePatternSet &patterns) { 261 patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern, 262 VectorInsertStridedSliceOpSameRankRewritePattern, 263 VectorExtractStridedSliceOpRewritePattern>( 264 patterns.getContext()); 265 } 266