1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "vector-unrolling" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 /// During unrolling from `originalShape` to `targetShape` return the offset for 27 /// the slice `index`. 28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 29 ArrayRef<int64_t> targetShape, 30 int64_t index) { 31 SmallVector<int64_t, 4> dstSliceStrides = 32 computeStrides(originalShape, targetShape); 33 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 34 SmallVector<int64_t, 4> elementOffsets = 35 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 36 return elementOffsets; 37 } 38 39 /// Compute the indices of the slice `index` for a tranfer op. 40 static SmallVector<Value> 41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape, 42 ArrayRef<int64_t> targetShape, ArrayRef<Value> indices, 43 AffineMap permutationMap, Location loc, 44 OpBuilder &builder) { 45 MLIRContext *ctx = builder.getContext(); 46 auto isBroadcast = [](AffineExpr expr) { 47 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 48 return constExpr.getValue() == 0; 49 return false; 50 }; 51 SmallVector<int64_t, 4> elementOffsets = 52 getVectorOffset(originalShape, targetShape, index); 53 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 54 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 55 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 56 if (isBroadcast(dim.value())) 57 continue; 58 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 59 auto expr = getAffineDimExpr(0, builder.getContext()) + 60 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 61 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 62 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 63 } 64 return slicedIndices; 65 } 66 67 // Clones `op` into a new operations that takes `operands` and returns 68 // `resultTypes`. 69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 70 Operation *op, 71 ArrayRef<Value> operands, 72 ArrayRef<Type> resultTypes) { 73 OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs()); 74 return builder.createOperation(res); 75 } 76 77 /// Return the target shape for unrolling for the given `op`. Return llvm::None 78 /// if the op shouldn't be or cannot be unrolled. 79 static Optional<SmallVector<int64_t, 4>> 80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 81 if (options.filterConstraint && failed(options.filterConstraint(op))) 82 return llvm::None; 83 assert(options.nativeShape && 84 "vector unrolling expects the native shape or native" 85 "shape call back function to be set"); 86 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 87 if (!unrollableVectorOp) 88 return llvm::None; 89 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 90 if (!maybeUnrollShape) 91 return llvm::None; 92 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 93 if (!targetShape) 94 return llvm::None; 95 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 96 if (!maybeShapeRatio || 97 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 98 return llvm::None; 99 return targetShape; 100 } 101 102 namespace { 103 104 struct UnrollTransferReadPattern 105 : public OpRewritePattern<vector::TransferReadOp> { 106 UnrollTransferReadPattern(MLIRContext *context, 107 const vector::UnrollVectorOptions &options) 108 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 109 options(options) {} 110 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 111 PatternRewriter &rewriter) const override { 112 // TODO: support 0-d corner case. 113 if (readOp.getTransferRank() == 0) 114 return failure(); 115 if (readOp.mask()) 116 return failure(); 117 auto targetShape = getTargetShape(options, readOp); 118 if (!targetShape) 119 return failure(); 120 auto sourceVectorType = readOp.getVectorType(); 121 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 122 Location loc = readOp.getLoc(); 123 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 124 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 125 // Compute shape ratio of 'shape' and 'sizes'. 126 int64_t sliceCount = computeMaxLinearIndex(ratio); 127 // Prepare the result vector; 128 Value result = rewriter.create<arith::ConstantOp>( 129 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 130 auto targetType = 131 VectorType::get(*targetShape, sourceVectorType.getElementType()); 132 SmallVector<Value, 4> originalIndices(readOp.indices().begin(), 133 readOp.indices().end()); 134 for (int64_t i = 0; i < sliceCount; i++) { 135 SmallVector<Value, 4> indices = 136 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 137 readOp.permutation_map(), loc, rewriter); 138 auto slicedRead = rewriter.create<vector::TransferReadOp>( 139 loc, targetType, readOp.source(), indices, 140 readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(), 141 readOp.in_boundsAttr()); 142 143 SmallVector<int64_t, 4> elementOffsets = 144 getVectorOffset(originalSize, *targetShape, i); 145 result = rewriter.create<vector::InsertStridedSliceOp>( 146 loc, slicedRead, result, elementOffsets, strides); 147 } 148 rewriter.replaceOp(readOp, result); 149 return success(); 150 } 151 152 private: 153 vector::UnrollVectorOptions options; 154 }; 155 156 struct UnrollTransferWritePattern 157 : public OpRewritePattern<vector::TransferWriteOp> { 158 UnrollTransferWritePattern(MLIRContext *context, 159 const vector::UnrollVectorOptions &options) 160 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 161 options(options) {} 162 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 163 PatternRewriter &rewriter) const override { 164 // TODO: support 0-d corner case. 165 if (writeOp.getTransferRank() == 0) 166 return failure(); 167 168 if (writeOp.mask()) 169 return failure(); 170 auto targetShape = getTargetShape(options, writeOp); 171 if (!targetShape) 172 return failure(); 173 auto sourceVectorType = writeOp.getVectorType(); 174 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 175 Location loc = writeOp.getLoc(); 176 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 177 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 178 // Compute shape ratio of 'shape' and 'sizes'. 179 int64_t sliceCount = computeMaxLinearIndex(ratio); 180 SmallVector<Value, 4> originalIndices(writeOp.indices().begin(), 181 writeOp.indices().end()); 182 Value resultTensor; 183 for (int64_t i = 0; i < sliceCount; i++) { 184 SmallVector<int64_t, 4> elementOffsets = 185 getVectorOffset(originalSize, *targetShape, i); 186 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 187 loc, writeOp.vector(), elementOffsets, *targetShape, strides); 188 189 SmallVector<Value, 4> indices = 190 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 191 writeOp.permutation_map(), loc, rewriter); 192 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 193 loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), 194 indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); 195 // For the tensor case update the destination for the next transfer write. 196 if (!slicedWrite->getResults().empty()) 197 resultTensor = slicedWrite->getResult(0); 198 } 199 if (resultTensor) 200 rewriter.replaceOp(writeOp, resultTensor); 201 else 202 rewriter.eraseOp(writeOp); 203 return success(); 204 } 205 206 private: 207 vector::UnrollVectorOptions options; 208 }; 209 210 struct OffsetMapInfo { 211 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 212 213 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 214 215 static unsigned getHashValue(const SmallVector<int64_t> &v) { 216 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 217 } 218 219 static bool isEqual(const SmallVector<int64_t> &lhs, 220 const SmallVector<int64_t> &rhs) { 221 return lhs == rhs; 222 } 223 }; 224 225 struct UnrollContractionPattern 226 : public OpRewritePattern<vector::ContractionOp> { 227 UnrollContractionPattern(MLIRContext *context, 228 const vector::UnrollVectorOptions &options) 229 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 230 options(options) {} 231 232 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 233 PatternRewriter &rewriter) const override { 234 auto targetShape = getTargetShape(options, contractOp); 235 if (!targetShape) 236 return failure(); 237 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 238 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 239 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 240 241 // Compute shape ratio of 'shape' and 'sizes'. 242 int64_t sliceCount = computeMaxLinearIndex(ratio); 243 Location loc = contractOp.getLoc(); 244 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 245 AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; 246 llvm::MapVector< 247 SmallVector<int64_t>, Value, 248 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 249 accCache; 250 for (int64_t i = 0; i < sliceCount; i++) { 251 SmallVector<int64_t, 4> offsets = 252 getVectorOffset(originalSize, *targetShape, i); 253 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 254 255 // Helper to coompute the new shape of each operand and extract the slice. 256 auto extractOperand = [&](unsigned index, Value operand, 257 AffineMap permutationMap, 258 ArrayRef<int64_t> operandOffets) { 259 SmallVector<int64_t> operandShape = applyPermutationMap( 260 permutationMap, ArrayRef<int64_t>(*targetShape)); 261 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 262 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 263 loc, operand, operandOffets, operandShape, operandStrides); 264 }; 265 266 // Extract the new lhs operand. 267 AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; 268 SmallVector<int64_t> lhsOffets = 269 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 270 extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets); 271 // If there is a mask associated to lhs, extract it as well. 272 if (slicesOperands.size() > 3) 273 extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets); 274 275 // Extract the new rhs operand. 276 AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; 277 SmallVector<int64_t> rhsOffets = 278 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 279 extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets); 280 // If there is a mask associated to rhs, extract it as well. 281 if (slicesOperands.size() > 4) 282 extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets); 283 284 AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; 285 SmallVector<int64_t> accOffets = 286 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 287 // If a version of the accumulator has already been computed, use it 288 // otherwise extract the first version from the original operand. 289 auto accIt = accCache.find(accOffets); 290 if (accIt != accCache.end()) 291 slicesOperands[2] = accIt->second; 292 else 293 extractOperand(2, contractOp.acc(), accPermutationMap, accOffets); 294 295 SmallVector<int64_t> dstShape = 296 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 297 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 298 Operation *newOp = cloneOpWithOperandsAndTypes( 299 rewriter, loc, contractOp, slicesOperands, targetType); 300 301 SmallVector<int64_t> dstOffets = 302 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 303 // Save the accumulated value untill all the loops are unrolled since 304 // reduction loop keep updating the accumulator. 305 accCache[dstOffets] = newOp->getResult(0); 306 } 307 // Assemble back the accumulator into a single vector. 308 Value result = rewriter.create<arith::ConstantOp>( 309 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 310 for (const auto &it : accCache) { 311 SmallVector<int64_t> dstStrides(it.first.size(), 1); 312 result = rewriter.create<vector::InsertStridedSliceOp>( 313 loc, it.second, result, it.first, dstStrides); 314 } 315 rewriter.replaceOp(contractOp, result); 316 return success(); 317 } 318 319 private: 320 vector::UnrollVectorOptions options; 321 }; 322 323 struct UnrollMultiReductionPattern 324 : public OpRewritePattern<vector::MultiDimReductionOp> { 325 UnrollMultiReductionPattern(MLIRContext *context, 326 const vector::UnrollVectorOptions &options) 327 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1), 328 options(options) {} 329 330 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 331 PatternRewriter &rewriter) const override { 332 Optional<SmallVector<int64_t, 4>> targetShape = 333 getTargetShape(options, reductionOp); 334 if (!targetShape) 335 return failure(); 336 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll(); 337 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 338 llvm::MapVector< 339 SmallVector<int64_t>, Value, 340 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 341 accCache; 342 // Compute shape ratio of 'shape' and 'sizes'. 343 int64_t sliceCount = computeMaxLinearIndex(ratio); 344 Location loc = reductionOp.getLoc(); 345 for (int64_t i = 0; i < sliceCount; i++) { 346 SmallVector<int64_t, 4> offsets = 347 getVectorOffset(originalSize, *targetShape, i); 348 349 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1); 350 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 351 loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); 352 353 SmallVector<int64_t> dstShape; 354 SmallVector<int64_t> destOffset; 355 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 356 if (!reductionOp.isReducedDim(i)) { 357 destOffset.push_back(offsets[i]); 358 dstShape.push_back((*targetShape)[i]); 359 } 360 } 361 auto targetType = VectorType::get( 362 dstShape, reductionOp.getSourceVectorType().getElementType()); 363 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 364 slicedOperand, targetType); 365 Value result = newOp->getResult(0); 366 // Save the accumulated value until all the loops are unrolled since 367 // reduction loop keeps updating the accumulator. 368 auto accIt = accCache.find(destOffset); 369 if (accIt != accCache.end()) 370 result = makeArithReduction(rewriter, loc, reductionOp.kind(), result, 371 accIt->second); 372 accCache[destOffset] = result; 373 } 374 // Assemble back the accumulator into a single vector. 375 Value result = rewriter.create<arith::ConstantOp>( 376 loc, reductionOp.getDestType(), 377 rewriter.getZeroAttr(reductionOp.getDestType())); 378 for (const auto &it : accCache) { 379 SmallVector<int64_t> dstStrides(it.first.size(), 1); 380 result = rewriter.create<vector::InsertStridedSliceOp>( 381 loc, it.second, result, it.first, dstStrides); 382 } 383 rewriter.replaceOp(reductionOp, result); 384 return success(); 385 } 386 387 private: 388 vector::UnrollVectorOptions options; 389 }; 390 391 struct UnrollElementwisePattern : public RewritePattern { 392 UnrollElementwisePattern(MLIRContext *context, 393 const vector::UnrollVectorOptions &options) 394 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 395 options(options) {} 396 LogicalResult matchAndRewrite(Operation *op, 397 PatternRewriter &rewriter) const override { 398 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 399 return failure(); 400 auto targetShape = getTargetShape(options, op); 401 if (!targetShape) 402 return failure(); 403 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 404 SmallVector<int64_t, 4> originalSize = 405 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 406 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 407 int64_t sliceCount = computeMaxLinearIndex(ratio); 408 Location loc = op->getLoc(); 409 // Prepare the result vector. 410 Value result = rewriter.create<arith::ConstantOp>( 411 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 412 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 413 VectorType newVecType = 414 VectorType::get(*targetShape, dstVecType.getElementType()); 415 for (int64_t i = 0; i < sliceCount; i++) { 416 SmallVector<int64_t, 4> offsets = 417 getVectorOffset(originalSize, *targetShape, i); 418 SmallVector<Value, 4> extractOperands; 419 for (OpOperand &operand : op->getOpOperands()) { 420 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 421 if (!vecType) { 422 extractOperands.push_back(operand.get()); 423 continue; 424 } 425 extractOperands.push_back( 426 rewriter.create<vector::ExtractStridedSliceOp>( 427 loc, operand.get(), offsets, *targetShape, strides)); 428 } 429 Operation *newOp = cloneOpWithOperandsAndTypes( 430 rewriter, loc, op, extractOperands, newVecType); 431 result = rewriter.create<vector::InsertStridedSliceOp>( 432 loc, newOp->getResult(0), result, offsets, strides); 433 } 434 rewriter.replaceOp(op, result); 435 return success(); 436 } 437 438 private: 439 vector::UnrollVectorOptions options; 440 }; 441 442 /// Canonicalize an extract_map using the result of a pointwise operation. 443 /// Transforms: 444 /// %v = arith.addf %a, %b : vector32xf32> 445 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 446 /// to: 447 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 448 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 449 /// %dv = arith.addf %da, %db : vector<1xf32> 450 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 451 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 452 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 453 PatternRewriter &rewriter) const override { 454 Operation *definedOp = extract.vector().getDefiningOp(); 455 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 456 definedOp->getNumResults() != 1) 457 return failure(); 458 Location loc = extract.getLoc(); 459 SmallVector<Value, 4> extractOperands; 460 for (OpOperand &operand : definedOp->getOpOperands()) { 461 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 462 if (!vecType) { 463 extractOperands.push_back(operand.get()); 464 continue; 465 } 466 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 467 loc, 468 VectorType::get(extract.getResultType().getShape(), 469 vecType.getElementType()), 470 operand.get(), extract.ids())); 471 } 472 Operation *newOp = cloneOpWithOperandsAndTypes( 473 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 474 rewriter.replaceOp(extract, newOp->getResult(0)); 475 return success(); 476 } 477 }; 478 479 /// Canonicalize an extract_map using the result of a contract operation. 480 /// This propagate the extract_map to operands. 481 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 482 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 483 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 484 PatternRewriter &rewriter) const override { 485 Operation *definedOp = extract.vector().getDefiningOp(); 486 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 487 if (!contract) 488 return failure(); 489 Location loc = contract.getLoc(); 490 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 491 AffineMap affineMap = contract.getIndexingMaps()[accIndex]; 492 // Create a map of the dimensions distributed based on the acc affine map. 493 // Only parallel dimensions are being distributed, reduction dimensions are 494 // untouched. 495 DenseMap<int64_t, int64_t> map; 496 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 497 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 498 SmallVector<Value, 4> extractOperands; 499 for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) { 500 // For each operands calculate the new vector type after distribution. 501 Value operand = contract->getOperand(it.index()); 502 auto vecType = operand.getType().cast<VectorType>(); 503 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 504 vecType.getShape().end()); 505 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 506 unsigned dim = it.value().getDimPosition(i); 507 auto distributedDim = map.find(dim); 508 // If the dimension is not in the map it means it is a reduction and 509 // doesn't get distributed. 510 if (distributedDim == map.end()) 511 continue; 512 operandShape[i] = distributedDim->second; 513 } 514 VectorType newVecType = 515 VectorType::get(operandShape, vecType.getElementType()); 516 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 517 loc, newVecType, operand, extract.ids())); 518 } 519 Operation *newOp = 520 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 521 extract.getResult().getType()); 522 rewriter.replaceOp(extract, newOp->getResult(0)); 523 return success(); 524 } 525 }; 526 527 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 528 /// TransferRead. 529 /// Example: 530 /// ``` 531 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 532 /// memref<64x64x64xf32>, vector<64x4x32xf32> 533 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 534 /// ``` 535 /// to: 536 /// ``` 537 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 538 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 539 /// memref<64x64x64xf32>, vector<2x4x1xf32> 540 /// ``` 541 struct TransferReadExtractPattern 542 : public OpRewritePattern<vector::TransferReadOp> { 543 TransferReadExtractPattern(MLIRContext *context) 544 : OpRewritePattern<vector::TransferReadOp>(context) {} 545 LogicalResult matchAndRewrite(vector::TransferReadOp read, 546 PatternRewriter &rewriter) const override { 547 // TODO: support 0-d corner case. 548 if (read.getTransferRank() == 0) 549 return failure(); 550 551 if (!read.getResult().hasOneUse()) 552 return failure(); 553 auto extract = 554 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 555 if (!extract) 556 return failure(); 557 if (read.mask()) 558 return failure(); 559 560 SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end()); 561 AffineMap indexMap = extract.map().compose(read.permutation_map()); 562 unsigned idCount = 0; 563 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 564 for (auto it : 565 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 566 AffineExpr d0, d1; 567 bindDims(read.getContext(), d0, d1); 568 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 569 if (!indexExpr) 570 continue; 571 unsigned indexPos = indexExpr.getPosition(); 572 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 573 auto scale = getAffineConstantExpr( 574 extract.getResultType().getDimSize(vectorPos), read.getContext()); 575 indices[indexPos] = makeComposedAffineApply( 576 rewriter, read.getLoc(), d0 + scale * d1, 577 {indices[indexPos], extract.ids()[idCount++]}); 578 } 579 Value newRead = lb.create<vector::TransferReadOp>( 580 extract.getType(), read.source(), indices, read.permutation_mapAttr(), 581 read.padding(), read.mask(), read.in_boundsAttr()); 582 Value dest = lb.create<arith::ConstantOp>( 583 read.getType(), rewriter.getZeroAttr(read.getType())); 584 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids()); 585 rewriter.replaceOp(read, newRead); 586 return success(); 587 } 588 }; 589 590 struct TransferWriteInsertPattern 591 : public OpRewritePattern<vector::TransferWriteOp> { 592 TransferWriteInsertPattern(MLIRContext *context) 593 : OpRewritePattern<vector::TransferWriteOp>(context) {} 594 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 595 PatternRewriter &rewriter) const override { 596 // TODO: support 0-d corner case. 597 if (write.getTransferRank() == 0) 598 return failure(); 599 600 auto insert = write.vector().getDefiningOp<vector::InsertMapOp>(); 601 if (!insert) 602 return failure(); 603 if (write.mask()) 604 return failure(); 605 SmallVector<Value, 4> indices(write.indices().begin(), 606 write.indices().end()); 607 AffineMap indexMap = insert.map().compose(write.permutation_map()); 608 unsigned idCount = 0; 609 Location loc = write.getLoc(); 610 for (auto it : 611 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 612 AffineExpr d0, d1; 613 bindDims(write.getContext(), d0, d1); 614 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 615 if (!indexExpr) 616 continue; 617 unsigned indexPos = indexExpr.getPosition(); 618 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 619 auto scale = getAffineConstantExpr( 620 insert.getSourceVectorType().getDimSize(vectorPos), 621 write.getContext()); 622 indices[indexPos] = 623 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 624 {indices[indexPos], insert.ids()[idCount++]}); 625 } 626 rewriter.create<vector::TransferWriteOp>( 627 loc, insert.vector(), write.source(), indices, 628 write.permutation_mapAttr(), write.in_boundsAttr()); 629 rewriter.eraseOp(write); 630 return success(); 631 } 632 }; 633 634 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 635 UnrollReductionPattern(MLIRContext *context, 636 const vector::UnrollVectorOptions &options) 637 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1), 638 options(options) {} 639 640 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 641 PatternRewriter &rewriter) const override { 642 Optional<SmallVector<int64_t, 4>> targetShape = 643 getTargetShape(options, reductionOp); 644 if (!targetShape) 645 return failure(); 646 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 647 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; 648 649 // Create unrolled vector reduction. 650 Location loc = reductionOp.getLoc(); 651 Value accumulator = nullptr; 652 for (int64_t i = 0; i < ratio; ++i) { 653 SmallVector<int64_t> offsets = 654 getVectorOffset(originalSize, *targetShape, i); 655 SmallVector<int64_t> strides(offsets.size(), 1); 656 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 657 loc, reductionOp.vector(), offsets, *targetShape, strides); 658 Operation *newOp = cloneOpWithOperandsAndTypes( 659 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 660 Value result = newOp->getResult(0); 661 662 if (!accumulator) { 663 // This is the first reduction. 664 accumulator = result; 665 } else { 666 // On subsequent reduction, combine with the accumulator. 667 accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(), 668 accumulator, result); 669 } 670 } 671 672 rewriter.replaceOp(reductionOp, accumulator); 673 return success(); 674 } 675 676 private: 677 const vector::UnrollVectorOptions options; 678 }; 679 680 } // namespace 681 682 void mlir::vector::populateVectorUnrollPatterns( 683 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 684 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 685 UnrollContractionPattern, UnrollElementwisePattern, 686 UnrollReductionPattern, UnrollMultiReductionPattern>( 687 patterns.getContext(), options); 688 } 689 690 void mlir::vector::populatePropagateVectorDistributionPatterns( 691 RewritePatternSet &patterns) { 692 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 693 TransferReadExtractPattern, TransferWriteInsertPattern>( 694 patterns.getContext()); 695 } 696