1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "vector-unrolling" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 /// During unrolling from `originalShape` to `targetShape` return the offset for 27 /// the slice `index`. 28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 29 ArrayRef<int64_t> targetShape, 30 int64_t index) { 31 SmallVector<int64_t, 4> dstSliceStrides = 32 computeStrides(originalShape, targetShape); 33 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 34 SmallVector<int64_t, 4> elementOffsets = 35 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 36 return elementOffsets; 37 } 38 39 /// Compute the indices of the slice `index` for a tranfer op. 40 static SmallVector<Value> 41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape, 42 ArrayRef<int64_t> targetShape, ArrayRef<Value> indices, 43 AffineMap permutationMap, Location loc, 44 OpBuilder &builder) { 45 MLIRContext *ctx = builder.getContext(); 46 auto isBroadcast = [](AffineExpr expr) { 47 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 48 return constExpr.getValue() == 0; 49 return false; 50 }; 51 SmallVector<int64_t, 4> elementOffsets = 52 getVectorOffset(originalShape, targetShape, index); 53 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 54 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 55 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 56 if (isBroadcast(dim.value())) 57 continue; 58 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 59 auto expr = getAffineDimExpr(0, builder.getContext()) + 60 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 61 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 62 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 63 } 64 return slicedIndices; 65 } 66 67 // Clones `op` into a new operations that takes `operands` and returns 68 // `resultTypes`. 69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 70 Operation *op, 71 ArrayRef<Value> operands, 72 ArrayRef<Type> resultTypes) { 73 OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs()); 74 return builder.createOperation(res); 75 } 76 77 /// Return the target shape for unrolling for the given `op`. Return llvm::None 78 /// if the op shouldn't be or cannot be unrolled. 79 static Optional<SmallVector<int64_t, 4>> 80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 81 if (options.filterConstraint && failed(options.filterConstraint(op))) 82 return llvm::None; 83 assert(options.nativeShape && 84 "vector unrolling expects the native shape or native" 85 "shape call back function to be set"); 86 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 87 if (!unrollableVectorOp) 88 return llvm::None; 89 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 90 if (!maybeUnrollShape) 91 return llvm::None; 92 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 93 if (!targetShape) 94 return llvm::None; 95 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 96 if (!maybeShapeRatio || 97 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 98 return llvm::None; 99 return targetShape; 100 } 101 102 namespace { 103 104 struct UnrollTransferReadPattern 105 : public OpRewritePattern<vector::TransferReadOp> { 106 UnrollTransferReadPattern(MLIRContext *context, 107 const vector::UnrollVectorOptions &options) 108 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 109 options(options) {} 110 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 111 PatternRewriter &rewriter) const override { 112 // TODO: support 0-d corner case. 113 if (readOp.getTransferRank() == 0) 114 return failure(); 115 if (readOp.mask()) 116 return failure(); 117 auto targetShape = getTargetShape(options, readOp); 118 if (!targetShape) 119 return failure(); 120 auto sourceVectorType = readOp.getVectorType(); 121 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 122 Location loc = readOp.getLoc(); 123 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 124 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 125 // Compute shape ratio of 'shape' and 'sizes'. 126 int64_t sliceCount = computeMaxLinearIndex(ratio); 127 // Prepare the result vector; 128 Value result = rewriter.create<arith::ConstantOp>( 129 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 130 auto targetType = 131 VectorType::get(*targetShape, sourceVectorType.getElementType()); 132 SmallVector<Value, 4> originalIndices(readOp.indices().begin(), 133 readOp.indices().end()); 134 for (int64_t i = 0; i < sliceCount; i++) { 135 SmallVector<Value, 4> indices = 136 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 137 readOp.permutation_map(), loc, rewriter); 138 auto slicedRead = rewriter.create<vector::TransferReadOp>( 139 loc, targetType, readOp.source(), indices, 140 readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(), 141 readOp.in_boundsAttr()); 142 143 SmallVector<int64_t, 4> elementOffsets = 144 getVectorOffset(originalSize, *targetShape, i); 145 result = rewriter.create<vector::InsertStridedSliceOp>( 146 loc, slicedRead, result, elementOffsets, strides); 147 } 148 rewriter.replaceOp(readOp, result); 149 return success(); 150 } 151 152 private: 153 vector::UnrollVectorOptions options; 154 }; 155 156 struct UnrollTransferWritePattern 157 : public OpRewritePattern<vector::TransferWriteOp> { 158 UnrollTransferWritePattern(MLIRContext *context, 159 const vector::UnrollVectorOptions &options) 160 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 161 options(options) {} 162 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 163 PatternRewriter &rewriter) const override { 164 // TODO: support 0-d corner case. 165 if (writeOp.getTransferRank() == 0) 166 return failure(); 167 168 if (writeOp.mask()) 169 return failure(); 170 auto targetShape = getTargetShape(options, writeOp); 171 if (!targetShape) 172 return failure(); 173 auto sourceVectorType = writeOp.getVectorType(); 174 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 175 Location loc = writeOp.getLoc(); 176 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 177 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 178 // Compute shape ratio of 'shape' and 'sizes'. 179 int64_t sliceCount = computeMaxLinearIndex(ratio); 180 SmallVector<Value, 4> originalIndices(writeOp.indices().begin(), 181 writeOp.indices().end()); 182 Value resultTensor; 183 for (int64_t i = 0; i < sliceCount; i++) { 184 SmallVector<int64_t, 4> elementOffsets = 185 getVectorOffset(originalSize, *targetShape, i); 186 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 187 loc, writeOp.vector(), elementOffsets, *targetShape, strides); 188 189 SmallVector<Value, 4> indices = 190 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 191 writeOp.permutation_map(), loc, rewriter); 192 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 193 loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), 194 indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 195 // For the tensor case update the destination for the next transfer write. 196 if (!slicedWrite->getResults().empty()) 197 resultTensor = slicedWrite->getResult(0); 198 } 199 if (resultTensor) 200 rewriter.replaceOp(writeOp, resultTensor); 201 else 202 rewriter.eraseOp(writeOp); 203 return success(); 204 } 205 206 private: 207 vector::UnrollVectorOptions options; 208 }; 209 210 struct UnrollContractionPattern 211 : public OpRewritePattern<vector::ContractionOp> { 212 struct OffsetMapInfo { 213 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 214 215 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 216 217 static unsigned getHashValue(const SmallVector<int64_t> &v) { 218 return static_cast<unsigned>( 219 llvm::hash_combine_range(v.begin(), v.end())); 220 } 221 222 static bool isEqual(const SmallVector<int64_t> &lhs, 223 const SmallVector<int64_t> &rhs) { 224 return lhs == rhs; 225 } 226 }; 227 UnrollContractionPattern(MLIRContext *context, 228 const vector::UnrollVectorOptions &options) 229 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 230 options(options) {} 231 232 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 233 PatternRewriter &rewriter) const override { 234 auto targetShape = getTargetShape(options, contractOp); 235 if (!targetShape) 236 return failure(); 237 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 238 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 239 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 240 241 // Compute shape ratio of 'shape' and 'sizes'. 242 int64_t sliceCount = computeMaxLinearIndex(ratio); 243 Location loc = contractOp.getLoc(); 244 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 245 AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; 246 llvm::MapVector< 247 SmallVector<int64_t>, Value, 248 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 249 accCache; 250 for (int64_t i = 0; i < sliceCount; i++) { 251 SmallVector<int64_t, 4> offsets = 252 getVectorOffset(originalSize, *targetShape, i); 253 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 254 255 // Helper to coompute the new shape of each operand and extract the slice. 256 auto extractOperand = [&](unsigned index, Value operand, 257 AffineMap permutationMap, 258 ArrayRef<int64_t> operandOffets) { 259 SmallVector<int64_t> operandShape = applyPermutationMap( 260 permutationMap, ArrayRef<int64_t>(*targetShape)); 261 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 262 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 263 loc, operand, operandOffets, operandShape, operandStrides); 264 }; 265 266 // Extract the new lhs operand. 267 AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; 268 SmallVector<int64_t> lhsOffets = 269 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 270 extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets); 271 // If there is a mask associated to lhs, extract it as well. 272 if (slicesOperands.size() > 3) 273 extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets); 274 275 // Extract the new rhs operand. 276 AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; 277 SmallVector<int64_t> rhsOffets = 278 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 279 extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets); 280 // If there is a mask associated to rhs, extract it as well. 281 if (slicesOperands.size() > 4) 282 extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets); 283 284 AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; 285 SmallVector<int64_t> accOffets = 286 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 287 // If a version of the accumulator has already been computed, use it 288 // otherwise extract the first version from the original operand. 289 auto accIt = accCache.find(accOffets); 290 if (accIt != accCache.end()) 291 slicesOperands[2] = accIt->second; 292 else 293 extractOperand(2, contractOp.acc(), accPermutationMap, accOffets); 294 295 SmallVector<int64_t> dstShape = 296 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 297 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 298 Operation *newOp = cloneOpWithOperandsAndTypes( 299 rewriter, loc, contractOp, slicesOperands, targetType); 300 301 SmallVector<int64_t> dstOffets = 302 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 303 // Save the accumulated value untill all the loops are unrolled since 304 // reduction loop keep updating the accumulator. 305 accCache[dstOffets] = newOp->getResult(0); 306 } 307 // Assemble back the accumulator into a single vector. 308 Value result = rewriter.create<arith::ConstantOp>( 309 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 310 for (const auto &it : accCache) { 311 SmallVector<int64_t> dstStrides(it.first.size(), 1); 312 result = rewriter.create<vector::InsertStridedSliceOp>( 313 loc, it.second, result, it.first, dstStrides); 314 } 315 rewriter.replaceOp(contractOp, result); 316 return success(); 317 } 318 319 private: 320 vector::UnrollVectorOptions options; 321 }; 322 323 struct UnrollElementwisePattern : public RewritePattern { 324 UnrollElementwisePattern(MLIRContext *context, 325 const vector::UnrollVectorOptions &options) 326 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 327 options(options) {} 328 LogicalResult matchAndRewrite(Operation *op, 329 PatternRewriter &rewriter) const override { 330 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 331 return failure(); 332 auto targetShape = getTargetShape(options, op); 333 if (!targetShape) 334 return failure(); 335 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 336 SmallVector<int64_t, 4> originalSize = 337 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 338 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 339 int64_t sliceCount = computeMaxLinearIndex(ratio); 340 Location loc = op->getLoc(); 341 // Prepare the result vector. 342 Value result = rewriter.create<arith::ConstantOp>( 343 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 344 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 345 VectorType newVecType = 346 VectorType::get(*targetShape, dstVecType.getElementType()); 347 for (int64_t i = 0; i < sliceCount; i++) { 348 SmallVector<int64_t, 4> offsets = 349 getVectorOffset(originalSize, *targetShape, i); 350 SmallVector<Value, 4> extractOperands; 351 for (OpOperand &operand : op->getOpOperands()) { 352 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 353 if (!vecType) { 354 extractOperands.push_back(operand.get()); 355 continue; 356 } 357 extractOperands.push_back( 358 rewriter.create<vector::ExtractStridedSliceOp>( 359 loc, operand.get(), offsets, *targetShape, strides)); 360 } 361 Operation *newOp = cloneOpWithOperandsAndTypes( 362 rewriter, loc, op, extractOperands, newVecType); 363 result = rewriter.create<vector::InsertStridedSliceOp>( 364 loc, newOp->getResult(0), result, offsets, strides); 365 } 366 rewriter.replaceOp(op, result); 367 return success(); 368 } 369 370 private: 371 vector::UnrollVectorOptions options; 372 }; 373 374 /// Canonicalize an extract_map using the result of a pointwise operation. 375 /// Transforms: 376 /// %v = arith.addf %a, %b : vector32xf32> 377 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 378 /// to: 379 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 380 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 381 /// %dv = arith.addf %da, %db : vector<1xf32> 382 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 383 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 384 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 385 PatternRewriter &rewriter) const override { 386 Operation *definedOp = extract.vector().getDefiningOp(); 387 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 388 definedOp->getNumResults() != 1) 389 return failure(); 390 Location loc = extract.getLoc(); 391 SmallVector<Value, 4> extractOperands; 392 for (OpOperand &operand : definedOp->getOpOperands()) { 393 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 394 if (!vecType) { 395 extractOperands.push_back(operand.get()); 396 continue; 397 } 398 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 399 loc, 400 VectorType::get(extract.getResultType().getShape(), 401 vecType.getElementType()), 402 operand.get(), extract.ids())); 403 } 404 Operation *newOp = cloneOpWithOperandsAndTypes( 405 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 406 rewriter.replaceOp(extract, newOp->getResult(0)); 407 return success(); 408 } 409 }; 410 411 /// Canonicalize an extract_map using the result of a contract operation. 412 /// This propagate the extract_map to operands. 413 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 414 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 415 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 416 PatternRewriter &rewriter) const override { 417 Operation *definedOp = extract.vector().getDefiningOp(); 418 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 419 if (!contract) 420 return failure(); 421 Location loc = contract.getLoc(); 422 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 423 AffineMap affineMap = contract.getIndexingMaps()[accIndex]; 424 // Create a map of the dimensions distributed based on the acc affine map. 425 // Only parallel dimensions are being distributed, reduction dimensions are 426 // untouched. 427 DenseMap<int64_t, int64_t> map; 428 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 429 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 430 SmallVector<Value, 4> extractOperands; 431 for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) { 432 // For each operands calculate the new vector type after distribution. 433 Value operand = contract->getOperand(it.index()); 434 auto vecType = operand.getType().cast<VectorType>(); 435 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 436 vecType.getShape().end()); 437 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 438 unsigned dim = it.value().getDimPosition(i); 439 auto distributedDim = map.find(dim); 440 // If the dimension is not in the map it means it is a reduction and 441 // doesn't get distributed. 442 if (distributedDim == map.end()) 443 continue; 444 operandShape[i] = distributedDim->second; 445 } 446 VectorType newVecType = 447 VectorType::get(operandShape, vecType.getElementType()); 448 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 449 loc, newVecType, operand, extract.ids())); 450 } 451 Operation *newOp = 452 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 453 extract.getResult().getType()); 454 rewriter.replaceOp(extract, newOp->getResult(0)); 455 return success(); 456 } 457 }; 458 459 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 460 /// TransferRead. 461 /// Example: 462 /// ``` 463 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 464 /// memref<64x64x64xf32>, vector<64x4x32xf32> 465 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 466 /// ``` 467 /// to: 468 /// ``` 469 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 470 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 471 /// memref<64x64x64xf32>, vector<2x4x1xf32> 472 /// ``` 473 struct TransferReadExtractPattern 474 : public OpRewritePattern<vector::TransferReadOp> { 475 TransferReadExtractPattern(MLIRContext *context) 476 : OpRewritePattern<vector::TransferReadOp>(context) {} 477 LogicalResult matchAndRewrite(vector::TransferReadOp read, 478 PatternRewriter &rewriter) const override { 479 // TODO: support 0-d corner case. 480 if (read.getTransferRank() == 0) 481 return failure(); 482 483 if (!read.getResult().hasOneUse()) 484 return failure(); 485 auto extract = 486 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 487 if (!extract) 488 return failure(); 489 if (read.mask()) 490 return failure(); 491 492 SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end()); 493 AffineMap indexMap = extract.map().compose(read.permutation_map()); 494 unsigned idCount = 0; 495 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 496 for (auto it : 497 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 498 AffineExpr d0, d1; 499 bindDims(read.getContext(), d0, d1); 500 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 501 if (!indexExpr) 502 continue; 503 unsigned indexPos = indexExpr.getPosition(); 504 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 505 auto scale = getAffineConstantExpr( 506 extract.getResultType().getDimSize(vectorPos), read.getContext()); 507 indices[indexPos] = makeComposedAffineApply( 508 rewriter, read.getLoc(), d0 + scale * d1, 509 {indices[indexPos], extract.ids()[idCount++]}); 510 } 511 Value newRead = lb.create<vector::TransferReadOp>( 512 extract.getType(), read.source(), indices, read.permutation_mapAttr(), 513 read.padding(), read.mask(), read.in_boundsAttr()); 514 Value dest = lb.create<arith::ConstantOp>( 515 read.getType(), rewriter.getZeroAttr(read.getType())); 516 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids()); 517 rewriter.replaceOp(read, newRead); 518 return success(); 519 } 520 }; 521 522 struct TransferWriteInsertPattern 523 : public OpRewritePattern<vector::TransferWriteOp> { 524 TransferWriteInsertPattern(MLIRContext *context) 525 : OpRewritePattern<vector::TransferWriteOp>(context) {} 526 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 527 PatternRewriter &rewriter) const override { 528 // TODO: support 0-d corner case. 529 if (write.getTransferRank() == 0) 530 return failure(); 531 532 auto insert = write.vector().getDefiningOp<vector::InsertMapOp>(); 533 if (!insert) 534 return failure(); 535 if (write.mask()) 536 return failure(); 537 SmallVector<Value, 4> indices(write.indices().begin(), 538 write.indices().end()); 539 AffineMap indexMap = insert.map().compose(write.permutation_map()); 540 unsigned idCount = 0; 541 Location loc = write.getLoc(); 542 for (auto it : 543 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 544 AffineExpr d0, d1; 545 bindDims(write.getContext(), d0, d1); 546 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 547 if (!indexExpr) 548 continue; 549 unsigned indexPos = indexExpr.getPosition(); 550 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 551 auto scale = getAffineConstantExpr( 552 insert.getSourceVectorType().getDimSize(vectorPos), 553 write.getContext()); 554 indices[indexPos] = 555 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 556 {indices[indexPos], insert.ids()[idCount++]}); 557 } 558 rewriter.create<vector::TransferWriteOp>( 559 loc, insert.vector(), write.source(), indices, 560 write.permutation_mapAttr(), write.in_boundsAttr()); 561 rewriter.eraseOp(write); 562 return success(); 563 } 564 }; 565 566 } // namespace 567 568 void mlir::vector::populateVectorUnrollPatterns( 569 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 570 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 571 UnrollContractionPattern, UnrollElementwisePattern>( 572 patterns.getContext(), options); 573 } 574 575 void mlir::vector::populatePropagateVectorDistributionPatterns( 576 RewritePatternSet &patterns) { 577 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 578 TransferReadExtractPattern, TransferWriteInsertPattern>( 579 patterns.getContext()); 580 } 581