1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "vector-unrolling" 22 23 using namespace mlir; 24 using namespace mlir::vector; 25 26 /// During unrolling from `originalShape` to `targetShape` return the offset for 27 /// the slice `index`. 28 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 29 ArrayRef<int64_t> targetShape, 30 int64_t index) { 31 SmallVector<int64_t, 4> dstSliceStrides = 32 computeStrides(originalShape, targetShape); 33 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 34 SmallVector<int64_t, 4> elementOffsets = 35 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 36 return elementOffsets; 37 } 38 39 /// Compute the indices of the slice `index` for a tranfer op. 40 static SmallVector<Value> 41 sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape, 42 ArrayRef<int64_t> targetShape, ArrayRef<Value> indices, 43 AffineMap permutationMap, Location loc, 44 OpBuilder &builder) { 45 MLIRContext *ctx = builder.getContext(); 46 auto isBroadcast = [](AffineExpr expr) { 47 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 48 return constExpr.getValue() == 0; 49 return false; 50 }; 51 SmallVector<int64_t, 4> elementOffsets = 52 getVectorOffset(originalShape, targetShape, index); 53 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 54 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 55 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 56 if (isBroadcast(dim.value())) 57 continue; 58 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 59 auto expr = getAffineDimExpr(0, builder.getContext()) + 60 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 61 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 62 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 63 } 64 return slicedIndices; 65 } 66 67 // Clones `op` into a new operations that takes `operands` and returns 68 // `resultTypes`. 69 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 70 Operation *op, 71 ArrayRef<Value> operands, 72 ArrayRef<Type> resultTypes) { 73 return builder.create(loc, op->getName().getIdentifier(), operands, 74 resultTypes, op->getAttrs()); 75 } 76 77 /// Return the target shape for unrolling for the given `op`. Return llvm::None 78 /// if the op shouldn't be or cannot be unrolled. 79 static Optional<SmallVector<int64_t, 4>> 80 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 81 if (options.filterConstraint && failed(options.filterConstraint(op))) 82 return llvm::None; 83 assert(options.nativeShape && 84 "vector unrolling expects the native shape or native" 85 "shape call back function to be set"); 86 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 87 if (!unrollableVectorOp) 88 return llvm::None; 89 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 90 if (!maybeUnrollShape) 91 return llvm::None; 92 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 93 if (!targetShape) 94 return llvm::None; 95 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 96 if (!maybeShapeRatio || 97 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 98 return llvm::None; 99 return targetShape; 100 } 101 102 namespace { 103 104 struct UnrollTransferReadPattern 105 : public OpRewritePattern<vector::TransferReadOp> { 106 UnrollTransferReadPattern(MLIRContext *context, 107 const vector::UnrollVectorOptions &options) 108 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 109 options(options) {} 110 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 111 PatternRewriter &rewriter) const override { 112 // TODO: support 0-d corner case. 113 if (readOp.getTransferRank() == 0) 114 return failure(); 115 if (readOp.getMask()) 116 return failure(); 117 auto targetShape = getTargetShape(options, readOp); 118 if (!targetShape) 119 return failure(); 120 auto sourceVectorType = readOp.getVectorType(); 121 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 122 Location loc = readOp.getLoc(); 123 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 124 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 125 // Compute shape ratio of 'shape' and 'sizes'. 126 int64_t sliceCount = computeMaxLinearIndex(ratio); 127 // Prepare the result vector; 128 Value result = rewriter.create<arith::ConstantOp>( 129 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 130 auto targetType = 131 VectorType::get(*targetShape, sourceVectorType.getElementType()); 132 SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(), 133 readOp.getIndices().end()); 134 for (int64_t i = 0; i < sliceCount; i++) { 135 SmallVector<Value, 4> indices = 136 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 137 readOp.getPermutationMap(), loc, rewriter); 138 auto slicedRead = rewriter.create<vector::TransferReadOp>( 139 loc, targetType, readOp.getSource(), indices, 140 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 141 readOp.getInBoundsAttr()); 142 143 SmallVector<int64_t, 4> elementOffsets = 144 getVectorOffset(originalSize, *targetShape, i); 145 result = rewriter.create<vector::InsertStridedSliceOp>( 146 loc, slicedRead, result, elementOffsets, strides); 147 } 148 rewriter.replaceOp(readOp, result); 149 return success(); 150 } 151 152 private: 153 vector::UnrollVectorOptions options; 154 }; 155 156 struct UnrollTransferWritePattern 157 : public OpRewritePattern<vector::TransferWriteOp> { 158 UnrollTransferWritePattern(MLIRContext *context, 159 const vector::UnrollVectorOptions &options) 160 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 161 options(options) {} 162 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 163 PatternRewriter &rewriter) const override { 164 // TODO: support 0-d corner case. 165 if (writeOp.getTransferRank() == 0) 166 return failure(); 167 168 if (writeOp.getMask()) 169 return failure(); 170 auto targetShape = getTargetShape(options, writeOp); 171 if (!targetShape) 172 return failure(); 173 auto sourceVectorType = writeOp.getVectorType(); 174 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 175 Location loc = writeOp.getLoc(); 176 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 177 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 178 // Compute shape ratio of 'shape' and 'sizes'. 179 int64_t sliceCount = computeMaxLinearIndex(ratio); 180 SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(), 181 writeOp.getIndices().end()); 182 Value resultTensor; 183 for (int64_t i = 0; i < sliceCount; i++) { 184 SmallVector<int64_t, 4> elementOffsets = 185 getVectorOffset(originalSize, *targetShape, i); 186 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 187 loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 188 189 SmallVector<Value, 4> indices = 190 sliceTransferIndices(i, originalSize, *targetShape, originalIndices, 191 writeOp.getPermutationMap(), loc, rewriter); 192 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 193 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 194 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 195 // For the tensor case update the destination for the next transfer write. 196 if (!slicedWrite->getResults().empty()) 197 resultTensor = slicedWrite->getResult(0); 198 } 199 if (resultTensor) 200 rewriter.replaceOp(writeOp, resultTensor); 201 else 202 rewriter.eraseOp(writeOp); 203 return success(); 204 } 205 206 private: 207 vector::UnrollVectorOptions options; 208 }; 209 210 struct OffsetMapInfo { 211 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 212 213 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 214 215 static unsigned getHashValue(const SmallVector<int64_t> &v) { 216 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 217 } 218 219 static bool isEqual(const SmallVector<int64_t> &lhs, 220 const SmallVector<int64_t> &rhs) { 221 return lhs == rhs; 222 } 223 }; 224 225 struct UnrollContractionPattern 226 : public OpRewritePattern<vector::ContractionOp> { 227 UnrollContractionPattern(MLIRContext *context, 228 const vector::UnrollVectorOptions &options) 229 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 230 options(options) {} 231 232 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 233 PatternRewriter &rewriter) const override { 234 auto targetShape = getTargetShape(options, contractOp); 235 if (!targetShape) 236 return failure(); 237 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 238 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 239 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 240 241 // Compute shape ratio of 'shape' and 'sizes'. 242 int64_t sliceCount = computeMaxLinearIndex(ratio); 243 Location loc = contractOp.getLoc(); 244 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 245 AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; 246 llvm::MapVector< 247 SmallVector<int64_t>, Value, 248 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 249 accCache; 250 for (int64_t i = 0; i < sliceCount; i++) { 251 SmallVector<int64_t, 4> offsets = 252 getVectorOffset(originalSize, *targetShape, i); 253 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 254 255 // Helper to coompute the new shape of each operand and extract the slice. 256 auto extractOperand = [&](unsigned index, Value operand, 257 AffineMap permutationMap, 258 ArrayRef<int64_t> operandOffets) { 259 SmallVector<int64_t> operandShape = applyPermutationMap( 260 permutationMap, ArrayRef<int64_t>(*targetShape)); 261 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 262 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 263 loc, operand, operandOffets, operandShape, operandStrides); 264 }; 265 266 // Extract the new lhs operand. 267 AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; 268 SmallVector<int64_t> lhsOffets = 269 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 270 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 271 // If there is a mask associated to lhs, extract it as well. 272 if (slicesOperands.size() > 3) 273 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, 274 lhsOffets); 275 276 // Extract the new rhs operand. 277 AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; 278 SmallVector<int64_t> rhsOffets = 279 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 280 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 281 // If there is a mask associated to rhs, extract it as well. 282 if (slicesOperands.size() > 4) 283 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, 284 rhsOffets); 285 286 AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; 287 SmallVector<int64_t> accOffets = 288 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 289 // If a version of the accumulator has already been computed, use it 290 // otherwise extract the first version from the original operand. 291 auto accIt = accCache.find(accOffets); 292 if (accIt != accCache.end()) 293 slicesOperands[2] = accIt->second; 294 else 295 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 296 297 SmallVector<int64_t> dstShape = 298 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 299 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 300 Operation *newOp = cloneOpWithOperandsAndTypes( 301 rewriter, loc, contractOp, slicesOperands, targetType); 302 303 SmallVector<int64_t> dstOffets = 304 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 305 // Save the accumulated value untill all the loops are unrolled since 306 // reduction loop keep updating the accumulator. 307 accCache[dstOffets] = newOp->getResult(0); 308 } 309 // Assemble back the accumulator into a single vector. 310 Value result = rewriter.create<arith::ConstantOp>( 311 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 312 for (const auto &it : accCache) { 313 SmallVector<int64_t> dstStrides(it.first.size(), 1); 314 result = rewriter.create<vector::InsertStridedSliceOp>( 315 loc, it.second, result, it.first, dstStrides); 316 } 317 rewriter.replaceOp(contractOp, result); 318 return success(); 319 } 320 321 private: 322 vector::UnrollVectorOptions options; 323 }; 324 325 struct UnrollMultiReductionPattern 326 : public OpRewritePattern<vector::MultiDimReductionOp> { 327 UnrollMultiReductionPattern(MLIRContext *context, 328 const vector::UnrollVectorOptions &options) 329 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1), 330 options(options) {} 331 332 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 333 PatternRewriter &rewriter) const override { 334 Optional<SmallVector<int64_t, 4>> targetShape = 335 getTargetShape(options, reductionOp); 336 if (!targetShape) 337 return failure(); 338 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll(); 339 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 340 llvm::MapVector< 341 SmallVector<int64_t>, Value, 342 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 343 accCache; 344 // Compute shape ratio of 'shape' and 'sizes'. 345 int64_t sliceCount = computeMaxLinearIndex(ratio); 346 Location loc = reductionOp.getLoc(); 347 for (int64_t i = 0; i < sliceCount; i++) { 348 SmallVector<int64_t, 4> offsets = 349 getVectorOffset(originalSize, *targetShape, i); 350 351 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1); 352 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 353 loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); 354 355 SmallVector<int64_t> dstShape; 356 SmallVector<int64_t> destOffset; 357 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 358 if (!reductionOp.isReducedDim(i)) { 359 destOffset.push_back(offsets[i]); 360 dstShape.push_back((*targetShape)[i]); 361 } 362 } 363 auto targetType = VectorType::get( 364 dstShape, reductionOp.getSourceVectorType().getElementType()); 365 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 366 slicedOperand, targetType); 367 Value result = newOp->getResult(0); 368 // Save the accumulated value until all the loops are unrolled since 369 // reduction loop keeps updating the accumulator. 370 auto accIt = accCache.find(destOffset); 371 if (accIt != accCache.end()) 372 result = makeArithReduction(rewriter, loc, reductionOp.getKind(), 373 result, accIt->second); 374 accCache[destOffset] = result; 375 } 376 // Assemble back the accumulator into a single vector. 377 Value result = rewriter.create<arith::ConstantOp>( 378 loc, reductionOp.getDestType(), 379 rewriter.getZeroAttr(reductionOp.getDestType())); 380 for (const auto &it : accCache) { 381 SmallVector<int64_t> dstStrides(it.first.size(), 1); 382 result = rewriter.create<vector::InsertStridedSliceOp>( 383 loc, it.second, result, it.first, dstStrides); 384 } 385 rewriter.replaceOp(reductionOp, result); 386 return success(); 387 } 388 389 private: 390 vector::UnrollVectorOptions options; 391 }; 392 393 struct UnrollElementwisePattern : public RewritePattern { 394 UnrollElementwisePattern(MLIRContext *context, 395 const vector::UnrollVectorOptions &options) 396 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 397 options(options) {} 398 LogicalResult matchAndRewrite(Operation *op, 399 PatternRewriter &rewriter) const override { 400 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 401 return failure(); 402 auto targetShape = getTargetShape(options, op); 403 if (!targetShape) 404 return failure(); 405 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 406 SmallVector<int64_t, 4> originalSize = 407 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 408 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 409 int64_t sliceCount = computeMaxLinearIndex(ratio); 410 Location loc = op->getLoc(); 411 // Prepare the result vector. 412 Value result = rewriter.create<arith::ConstantOp>( 413 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 414 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 415 VectorType newVecType = 416 VectorType::get(*targetShape, dstVecType.getElementType()); 417 for (int64_t i = 0; i < sliceCount; i++) { 418 SmallVector<int64_t, 4> offsets = 419 getVectorOffset(originalSize, *targetShape, i); 420 SmallVector<Value, 4> extractOperands; 421 for (OpOperand &operand : op->getOpOperands()) { 422 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 423 if (!vecType) { 424 extractOperands.push_back(operand.get()); 425 continue; 426 } 427 extractOperands.push_back( 428 rewriter.create<vector::ExtractStridedSliceOp>( 429 loc, operand.get(), offsets, *targetShape, strides)); 430 } 431 Operation *newOp = cloneOpWithOperandsAndTypes( 432 rewriter, loc, op, extractOperands, newVecType); 433 result = rewriter.create<vector::InsertStridedSliceOp>( 434 loc, newOp->getResult(0), result, offsets, strides); 435 } 436 rewriter.replaceOp(op, result); 437 return success(); 438 } 439 440 private: 441 vector::UnrollVectorOptions options; 442 }; 443 444 /// Canonicalize an extract_map using the result of a pointwise operation. 445 /// Transforms: 446 /// %v = arith.addf %a, %b : vector32xf32> 447 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 448 /// to: 449 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 450 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 451 /// %dv = arith.addf %da, %db : vector<1xf32> 452 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 453 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 454 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 455 PatternRewriter &rewriter) const override { 456 Operation *definedOp = extract.getVector().getDefiningOp(); 457 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 458 definedOp->getNumResults() != 1) 459 return failure(); 460 Location loc = extract.getLoc(); 461 SmallVector<Value, 4> extractOperands; 462 for (OpOperand &operand : definedOp->getOpOperands()) { 463 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 464 if (!vecType) { 465 extractOperands.push_back(operand.get()); 466 continue; 467 } 468 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 469 loc, 470 VectorType::get(extract.getResultType().getShape(), 471 vecType.getElementType()), 472 operand.get(), extract.getIds())); 473 } 474 Operation *newOp = cloneOpWithOperandsAndTypes( 475 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 476 rewriter.replaceOp(extract, newOp->getResult(0)); 477 return success(); 478 } 479 }; 480 481 /// Canonicalize an extract_map using the result of a contract operation. 482 /// This propagate the extract_map to operands. 483 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 484 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 485 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 486 PatternRewriter &rewriter) const override { 487 Operation *definedOp = extract.getVector().getDefiningOp(); 488 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 489 if (!contract) 490 return failure(); 491 Location loc = contract.getLoc(); 492 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 493 AffineMap affineMap = contract.getIndexingMaps()[accIndex]; 494 // Create a map of the dimensions distributed based on the acc affine map. 495 // Only parallel dimensions are being distributed, reduction dimensions are 496 // untouched. 497 DenseMap<int64_t, int64_t> map; 498 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 499 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 500 SmallVector<Value, 4> extractOperands; 501 for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) { 502 // For each operands calculate the new vector type after distribution. 503 Value operand = contract->getOperand(it.index()); 504 auto vecType = operand.getType().cast<VectorType>(); 505 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 506 vecType.getShape().end()); 507 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 508 unsigned dim = it.value().getDimPosition(i); 509 auto distributedDim = map.find(dim); 510 // If the dimension is not in the map it means it is a reduction and 511 // doesn't get distributed. 512 if (distributedDim == map.end()) 513 continue; 514 operandShape[i] = distributedDim->second; 515 } 516 VectorType newVecType = 517 VectorType::get(operandShape, vecType.getElementType()); 518 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 519 loc, newVecType, operand, extract.getIds())); 520 } 521 Operation *newOp = 522 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 523 extract.getResult().getType()); 524 rewriter.replaceOp(extract, newOp->getResult(0)); 525 return success(); 526 } 527 }; 528 529 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 530 /// TransferRead. 531 /// Example: 532 /// ``` 533 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 534 /// memref<64x64x64xf32>, vector<64x4x32xf32> 535 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 536 /// ``` 537 /// to: 538 /// ``` 539 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 540 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 541 /// memref<64x64x64xf32>, vector<2x4x1xf32> 542 /// ``` 543 struct TransferReadExtractPattern 544 : public OpRewritePattern<vector::TransferReadOp> { 545 TransferReadExtractPattern(MLIRContext *context) 546 : OpRewritePattern<vector::TransferReadOp>(context) {} 547 LogicalResult matchAndRewrite(vector::TransferReadOp read, 548 PatternRewriter &rewriter) const override { 549 // TODO: support 0-d corner case. 550 if (read.getTransferRank() == 0) 551 return failure(); 552 553 if (!read.getResult().hasOneUse()) 554 return failure(); 555 auto extract = 556 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 557 if (!extract) 558 return failure(); 559 if (read.getMask()) 560 return failure(); 561 562 SmallVector<Value, 4> indices(read.getIndices().begin(), 563 read.getIndices().end()); 564 AffineMap indexMap = extract.map().compose(read.getPermutationMap()); 565 unsigned idCount = 0; 566 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 567 for (auto it : 568 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 569 AffineExpr d0, d1; 570 bindDims(read.getContext(), d0, d1); 571 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 572 if (!indexExpr) 573 continue; 574 unsigned indexPos = indexExpr.getPosition(); 575 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 576 auto scale = getAffineConstantExpr( 577 extract.getResultType().getDimSize(vectorPos), read.getContext()); 578 indices[indexPos] = makeComposedAffineApply( 579 rewriter, read.getLoc(), d0 + scale * d1, 580 {indices[indexPos], extract.getIds()[idCount++]}); 581 } 582 Value newRead = lb.create<vector::TransferReadOp>( 583 extract.getType(), read.getSource(), indices, 584 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 585 read.getInBoundsAttr()); 586 Value dest = lb.create<arith::ConstantOp>( 587 read.getType(), rewriter.getZeroAttr(read.getType())); 588 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds()); 589 rewriter.replaceOp(read, newRead); 590 return success(); 591 } 592 }; 593 594 struct TransferWriteInsertPattern 595 : public OpRewritePattern<vector::TransferWriteOp> { 596 TransferWriteInsertPattern(MLIRContext *context) 597 : OpRewritePattern<vector::TransferWriteOp>(context) {} 598 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 599 PatternRewriter &rewriter) const override { 600 // TODO: support 0-d corner case. 601 if (write.getTransferRank() == 0) 602 return failure(); 603 604 auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>(); 605 if (!insert) 606 return failure(); 607 if (write.getMask()) 608 return failure(); 609 SmallVector<Value, 4> indices(write.getIndices().begin(), 610 write.getIndices().end()); 611 AffineMap indexMap = insert.map().compose(write.getPermutationMap()); 612 unsigned idCount = 0; 613 Location loc = write.getLoc(); 614 for (auto it : 615 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 616 AffineExpr d0, d1; 617 bindDims(write.getContext(), d0, d1); 618 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 619 if (!indexExpr) 620 continue; 621 unsigned indexPos = indexExpr.getPosition(); 622 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 623 auto scale = getAffineConstantExpr( 624 insert.getSourceVectorType().getDimSize(vectorPos), 625 write.getContext()); 626 indices[indexPos] = makeComposedAffineApply( 627 rewriter, loc, d0 + scale * d1, 628 {indices[indexPos], insert.getIds()[idCount++]}); 629 } 630 rewriter.create<vector::TransferWriteOp>( 631 loc, insert.getVector(), write.getSource(), indices, 632 write.getPermutationMapAttr(), write.getInBoundsAttr()); 633 rewriter.eraseOp(write); 634 return success(); 635 } 636 }; 637 638 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 639 UnrollReductionPattern(MLIRContext *context, 640 const vector::UnrollVectorOptions &options) 641 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1), 642 options(options) {} 643 644 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 645 PatternRewriter &rewriter) const override { 646 Optional<SmallVector<int64_t, 4>> targetShape = 647 getTargetShape(options, reductionOp); 648 if (!targetShape) 649 return failure(); 650 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 651 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; 652 653 // Create unrolled vector reduction. 654 Location loc = reductionOp.getLoc(); 655 Value accumulator = nullptr; 656 for (int64_t i = 0; i < ratio; ++i) { 657 SmallVector<int64_t> offsets = 658 getVectorOffset(originalSize, *targetShape, i); 659 SmallVector<int64_t> strides(offsets.size(), 1); 660 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 661 loc, reductionOp.getVector(), offsets, *targetShape, strides); 662 Operation *newOp = cloneOpWithOperandsAndTypes( 663 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 664 Value result = newOp->getResult(0); 665 666 if (!accumulator) { 667 // This is the first reduction. 668 accumulator = result; 669 } else { 670 // On subsequent reduction, combine with the accumulator. 671 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 672 accumulator, result); 673 } 674 } 675 676 rewriter.replaceOp(reductionOp, accumulator); 677 return success(); 678 } 679 680 private: 681 const vector::UnrollVectorOptions options; 682 }; 683 684 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> { 685 UnrollTranposePattern(MLIRContext *context, 686 const vector::UnrollVectorOptions &options) 687 : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1), 688 options(options) {} 689 LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, 690 PatternRewriter &rewriter) const override { 691 if (tranposeOp.getResultType().getRank() == 0) 692 return failure(); 693 auto targetShape = getTargetShape(options, tranposeOp); 694 if (!targetShape) 695 return failure(); 696 auto originalVectorType = tranposeOp.getResultType(); 697 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 698 Location loc = tranposeOp.getLoc(); 699 ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 700 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 701 int64_t sliceCount = computeMaxLinearIndex(ratio); 702 // Prepare the result vector; 703 Value result = rewriter.create<arith::ConstantOp>( 704 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 705 SmallVector<int64_t> permutation; 706 tranposeOp.getTransp(permutation); 707 for (int64_t i = 0; i < sliceCount; i++) { 708 SmallVector<int64_t, 4> elementOffsets = 709 getVectorOffset(originalSize, *targetShape, i); 710 SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size()); 711 SmallVector<int64_t, 4> permutedShape(elementOffsets.size()); 712 // Compute the source offsets and shape. 713 for (auto &indices : llvm::enumerate(permutation)) { 714 permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 715 permutedShape[indices.value()] = (*targetShape)[indices.index()]; 716 } 717 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 718 loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); 719 Value tranposedSlice = 720 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 721 result = rewriter.create<vector::InsertStridedSliceOp>( 722 loc, tranposedSlice, result, elementOffsets, strides); 723 } 724 rewriter.replaceOp(tranposeOp, result); 725 return success(); 726 } 727 728 private: 729 vector::UnrollVectorOptions options; 730 }; 731 732 } // namespace 733 734 void mlir::vector::populateVectorUnrollPatterns( 735 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 736 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 737 UnrollContractionPattern, UnrollElementwisePattern, 738 UnrollReductionPattern, UnrollMultiReductionPattern, 739 UnrollTranposePattern>(patterns.getContext(), options); 740 } 741 742 void mlir::vector::populatePropagateVectorDistributionPatterns( 743 RewritePatternSet &patterns) { 744 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 745 TransferReadExtractPattern, TransferWriteInsertPattern>( 746 patterns.getContext()); 747 } 748