1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "mlir/Support/MathExtras.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include <numeric> 22 23 #define DEBUG_TYPE "vector-unrolling" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 28 /// During unrolling from `originalShape` to `targetShape` return the offset for 29 /// the slice `index`. 30 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 31 ArrayRef<int64_t> targetShape, 32 int64_t index) { 33 SmallVector<int64_t, 4> dstSliceStrides = 34 computeStrides(originalShape, targetShape); 35 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 36 SmallVector<int64_t, 4> elementOffsets = 37 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 38 return elementOffsets; 39 } 40 41 /// A functor that accomplishes the same thing as `getVectorOffset` but allows 42 /// for reordering the traversal of the dimensions. The order of traversal is 43 /// given in "for loop order" (outer to inner). 44 namespace { 45 class DecomposeShapeIterator { 46 private: 47 SmallVector<int64_t, 4> vectorShape; 48 SmallVector<int64_t> loopOrder; 49 SmallVector<int64_t> sliceStrides; 50 int64_t maxIndexVal{1}; 51 52 public: 53 DecomposeShapeIterator(ArrayRef<int64_t> originalShape, 54 ArrayRef<int64_t> targetShape, 55 ArrayRef<int64_t> loopOrder) 56 : vectorShape(targetShape.begin(), targetShape.end()), 57 loopOrder(loopOrder.begin(), loopOrder.end()), 58 sliceStrides(originalShape.size()) { 59 assert(originalShape.size() == targetShape.size()); 60 assert(loopOrder.size() == targetShape.size()); 61 62 // Compute the count for each dimension. 63 SmallVector<int64_t> sliceDimCounts(originalShape.size()); 64 for (unsigned r = 0; r < originalShape.size(); ++r) { 65 sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); 66 maxIndexVal *= sliceDimCounts[r]; 67 } 68 69 // Reversing "loop order" gives dimensions from fastest varying to slowest 70 // varying (smallest stride to largest stride). 71 int64_t accum = 1; 72 for (auto idx : llvm::reverse(loopOrder)) { 73 sliceStrides[idx] = accum; 74 accum *= sliceDimCounts[idx]; 75 } 76 } 77 78 // Turn the linear index into a d-tuple based on units of vectors of size 79 // `vectorShape`. The linear index is assumed to represent traversal of the 80 // dimensions based on `order`. 81 SmallVector<int64_t> delinearize(int64_t index) const { 82 // Traverse in for loop order (largest stride to smallest stride). 83 SmallVector<int64_t> vectorOffsets(sliceStrides.size()); 84 for (auto idx : loopOrder) { 85 vectorOffsets[idx] = index / sliceStrides[idx]; 86 index %= sliceStrides[idx]; 87 } 88 return vectorOffsets; 89 } 90 91 int64_t maxIndex() const { return maxIndexVal; } 92 93 /// Return the offset within d-tuple based on the ordering given by 94 /// `loopOrder`. 95 SmallVector<int64_t> getVectorOffset(int64_t index) const { 96 SmallVector<int64_t> vectorOffsets = delinearize(index); 97 SmallVector<int64_t> elementOffsets = 98 computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); 99 return elementOffsets; 100 } 101 }; 102 } // namespace 103 104 /// Compute the indices of the slice `index` for a tranfer op. 105 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, 106 ArrayRef<Value> indices, 107 AffineMap permutationMap, 108 Location loc, 109 OpBuilder &builder) { 110 MLIRContext *ctx = builder.getContext(); 111 auto isBroadcast = [](AffineExpr expr) { 112 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 113 return constExpr.getValue() == 0; 114 return false; 115 }; 116 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 117 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 118 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 119 if (isBroadcast(dim.value())) 120 continue; 121 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 122 auto expr = getAffineDimExpr(0, builder.getContext()) + 123 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 124 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 125 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 126 } 127 return slicedIndices; 128 } 129 130 // Clones `op` into a new operations that takes `operands` and returns 131 // `resultTypes`. 132 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 133 Operation *op, 134 ArrayRef<Value> operands, 135 ArrayRef<Type> resultTypes) { 136 return builder.create(loc, op->getName().getIdentifier(), operands, 137 resultTypes, op->getAttrs()); 138 } 139 140 /// Return the target shape for unrolling for the given `op`. Return llvm::None 141 /// if the op shouldn't be or cannot be unrolled. 142 static Optional<SmallVector<int64_t, 4>> 143 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 144 if (options.filterConstraint && failed(options.filterConstraint(op))) 145 return llvm::None; 146 assert(options.nativeShape && 147 "vector unrolling expects the native shape or native" 148 "shape call back function to be set"); 149 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 150 if (!unrollableVectorOp) 151 return llvm::None; 152 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 153 if (!maybeUnrollShape) 154 return llvm::None; 155 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 156 if (!targetShape) 157 return llvm::None; 158 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 159 if (!maybeShapeRatio || 160 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 161 return llvm::None; 162 return targetShape; 163 } 164 165 static SmallVector<int64_t> 166 getUnrollOrder(unsigned numLoops, Operation *op, 167 const vector::UnrollVectorOptions &options) { 168 SmallVector<int64_t> loopOrder = 169 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); 170 if (options.traversalOrderCallback != nullptr) { 171 Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op); 172 if (order) { 173 loopOrder = std::move(*order); 174 } 175 } 176 return loopOrder; 177 } 178 179 namespace { 180 181 struct UnrollTransferReadPattern 182 : public OpRewritePattern<vector::TransferReadOp> { 183 UnrollTransferReadPattern(MLIRContext *context, 184 const vector::UnrollVectorOptions &options) 185 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 186 options(options) {} 187 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 188 PatternRewriter &rewriter) const override { 189 // TODO: support 0-d corner case. 190 if (readOp.getTransferRank() == 0) 191 return failure(); 192 if (readOp.getMask()) 193 return failure(); 194 auto targetShape = getTargetShape(options, readOp); 195 if (!targetShape) 196 return failure(); 197 auto sourceVectorType = readOp.getVectorType(); 198 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 199 Location loc = readOp.getLoc(); 200 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 201 202 // Prepare the result vector; 203 Value result = rewriter.create<arith::ConstantOp>( 204 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 205 auto targetType = 206 VectorType::get(*targetShape, sourceVectorType.getElementType()); 207 SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(), 208 readOp.getIndices().end()); 209 210 SmallVector<int64_t> loopOrder = 211 getUnrollOrder(originalSize.size(), readOp, options); 212 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 213 loopOrder); 214 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 215 SmallVector<int64_t, 4> elementOffsets = 216 indexToOffsets.getVectorOffset(i); 217 SmallVector<Value, 4> indices = 218 sliceTransferIndices(elementOffsets, originalIndices, 219 readOp.getPermutationMap(), loc, rewriter); 220 auto slicedRead = rewriter.create<vector::TransferReadOp>( 221 loc, targetType, readOp.getSource(), indices, 222 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 223 readOp.getInBoundsAttr()); 224 225 result = rewriter.create<vector::InsertStridedSliceOp>( 226 loc, slicedRead, result, elementOffsets, strides); 227 } 228 rewriter.replaceOp(readOp, result); 229 return success(); 230 } 231 232 private: 233 vector::UnrollVectorOptions options; 234 }; 235 236 struct UnrollTransferWritePattern 237 : public OpRewritePattern<vector::TransferWriteOp> { 238 UnrollTransferWritePattern(MLIRContext *context, 239 const vector::UnrollVectorOptions &options) 240 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 241 options(options) {} 242 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 243 PatternRewriter &rewriter) const override { 244 // TODO: support 0-d corner case. 245 if (writeOp.getTransferRank() == 0) 246 return failure(); 247 248 if (writeOp.getMask()) 249 return failure(); 250 auto targetShape = getTargetShape(options, writeOp); 251 if (!targetShape) 252 return failure(); 253 auto sourceVectorType = writeOp.getVectorType(); 254 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 255 Location loc = writeOp.getLoc(); 256 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 257 SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(), 258 writeOp.getIndices().end()); 259 260 SmallVector<int64_t> loopOrder = 261 getUnrollOrder(originalSize.size(), writeOp, options); 262 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 263 loopOrder); 264 Value resultTensor; 265 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 266 SmallVector<int64_t, 4> elementOffsets = 267 indexToOffsets.getVectorOffset(i); 268 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 269 loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 270 SmallVector<Value, 4> indices = 271 sliceTransferIndices(elementOffsets, originalIndices, 272 writeOp.getPermutationMap(), loc, rewriter); 273 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 274 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 275 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 276 // For the tensor case update the destination for the next transfer write. 277 if (!slicedWrite->getResults().empty()) 278 resultTensor = slicedWrite->getResult(0); 279 } 280 if (resultTensor) 281 rewriter.replaceOp(writeOp, resultTensor); 282 else 283 rewriter.eraseOp(writeOp); 284 return success(); 285 } 286 287 private: 288 vector::UnrollVectorOptions options; 289 }; 290 291 struct OffsetMapInfo { 292 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 293 294 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 295 296 static unsigned getHashValue(const SmallVector<int64_t> &v) { 297 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 298 } 299 300 static bool isEqual(const SmallVector<int64_t> &lhs, 301 const SmallVector<int64_t> &rhs) { 302 return lhs == rhs; 303 } 304 }; 305 306 struct UnrollContractionPattern 307 : public OpRewritePattern<vector::ContractionOp> { 308 UnrollContractionPattern(MLIRContext *context, 309 const vector::UnrollVectorOptions &options) 310 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 311 options(options) {} 312 313 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 314 PatternRewriter &rewriter) const override { 315 auto targetShape = getTargetShape(options, contractOp); 316 if (!targetShape) 317 return failure(); 318 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 319 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 320 321 Location loc = contractOp.getLoc(); 322 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 323 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; 324 llvm::MapVector< 325 SmallVector<int64_t>, Value, 326 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 327 accCache; 328 329 SmallVector<int64_t> loopOrder = getUnrollOrder( 330 contractOp.getIteratorTypes().size(), contractOp, options); 331 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 332 loopOrder); 333 const int64_t sliceCount = indexToOffsets.maxIndex(); 334 for (int64_t i = 0; i < sliceCount; i++) { 335 SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i); 336 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 337 338 // Helper to coompute the new shape of each operand and extract the slice. 339 auto extractOperand = [&](unsigned index, Value operand, 340 AffineMap permutationMap, 341 ArrayRef<int64_t> operandOffets) { 342 SmallVector<int64_t> operandShape = applyPermutationMap( 343 permutationMap, ArrayRef<int64_t>(*targetShape)); 344 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 345 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 346 loc, operand, operandOffets, operandShape, operandStrides); 347 }; 348 349 // Extract the new lhs operand. 350 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; 351 SmallVector<int64_t> lhsOffets = 352 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 353 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 354 // If there is a mask associated to lhs, extract it as well. 355 if (slicesOperands.size() > 3) 356 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, 357 lhsOffets); 358 359 // Extract the new rhs operand. 360 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; 361 SmallVector<int64_t> rhsOffets = 362 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 363 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 364 // If there is a mask associated to rhs, extract it as well. 365 if (slicesOperands.size() > 4) 366 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, 367 rhsOffets); 368 369 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; 370 SmallVector<int64_t> accOffets = 371 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 372 // If a version of the accumulator has already been computed, use it 373 // otherwise extract the first version from the original operand. 374 auto accIt = accCache.find(accOffets); 375 if (accIt != accCache.end()) 376 slicesOperands[2] = accIt->second; 377 else 378 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 379 380 SmallVector<int64_t> dstShape = 381 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 382 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 383 Operation *newOp = cloneOpWithOperandsAndTypes( 384 rewriter, loc, contractOp, slicesOperands, targetType); 385 386 SmallVector<int64_t> dstOffets = 387 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 388 // Save the accumulated value untill all the loops are unrolled since 389 // reduction loop keep updating the accumulator. 390 accCache[dstOffets] = newOp->getResult(0); 391 } 392 // Assemble back the accumulator into a single vector. 393 Value result = rewriter.create<arith::ConstantOp>( 394 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 395 for (const auto &it : accCache) { 396 SmallVector<int64_t> dstStrides(it.first.size(), 1); 397 result = rewriter.create<vector::InsertStridedSliceOp>( 398 loc, it.second, result, it.first, dstStrides); 399 } 400 rewriter.replaceOp(contractOp, result); 401 return success(); 402 } 403 404 private: 405 vector::UnrollVectorOptions options; 406 }; 407 408 struct UnrollMultiReductionPattern 409 : public OpRewritePattern<vector::MultiDimReductionOp> { 410 UnrollMultiReductionPattern(MLIRContext *context, 411 const vector::UnrollVectorOptions &options) 412 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1), 413 options(options) {} 414 415 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 416 PatternRewriter &rewriter) const override { 417 Optional<SmallVector<int64_t, 4>> targetShape = 418 getTargetShape(options, reductionOp); 419 if (!targetShape) 420 return failure(); 421 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll(); 422 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 423 llvm::MapVector< 424 SmallVector<int64_t>, Value, 425 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 426 accCache; 427 // Compute shape ratio of 'shape' and 'sizes'. 428 int64_t sliceCount = computeMaxLinearIndex(ratio); 429 Location loc = reductionOp.getLoc(); 430 for (int64_t i = 0; i < sliceCount; i++) { 431 SmallVector<int64_t, 4> offsets = 432 getVectorOffset(originalSize, *targetShape, i); 433 434 SmallVector<Value> operands; 435 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1); 436 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 437 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); 438 operands.push_back(slicedOperand); 439 SmallVector<int64_t> dstShape; 440 SmallVector<int64_t> destOffset; 441 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 442 if (!reductionOp.isReducedDim(i)) { 443 destOffset.push_back(offsets[i]); 444 dstShape.push_back((*targetShape)[i]); 445 } 446 } 447 Value acc; 448 SmallVector<int64_t, 4> accStrides(destOffset.size(), 1); 449 // If a version of the accumulator has already been computed, use it 450 // otherwise extract the first version from the original operand. 451 auto accIt = accCache.find(destOffset); 452 if (accIt != accCache.end()) 453 acc = accIt->second; 454 else 455 acc = rewriter.create<vector::ExtractStridedSliceOp>( 456 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); 457 operands.push_back(acc); 458 auto targetType = VectorType::get( 459 dstShape, reductionOp.getSourceVectorType().getElementType()); 460 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 461 operands, targetType); 462 Value result = newOp->getResult(0); 463 accCache[destOffset] = result; 464 } 465 // Assemble back the accumulator into a single vector. 466 Value result = rewriter.create<arith::ConstantOp>( 467 loc, reductionOp.getDestType(), 468 rewriter.getZeroAttr(reductionOp.getDestType())); 469 for (const auto &it : accCache) { 470 SmallVector<int64_t> dstStrides(it.first.size(), 1); 471 result = rewriter.create<vector::InsertStridedSliceOp>( 472 loc, it.second, result, it.first, dstStrides); 473 } 474 rewriter.replaceOp(reductionOp, result); 475 return success(); 476 } 477 478 private: 479 vector::UnrollVectorOptions options; 480 }; 481 482 struct UnrollElementwisePattern : public RewritePattern { 483 UnrollElementwisePattern(MLIRContext *context, 484 const vector::UnrollVectorOptions &options) 485 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 486 options(options) {} 487 LogicalResult matchAndRewrite(Operation *op, 488 PatternRewriter &rewriter) const override { 489 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 490 return failure(); 491 auto targetShape = getTargetShape(options, op); 492 if (!targetShape) 493 return failure(); 494 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 495 SmallVector<int64_t, 4> originalSize = 496 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 497 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 498 int64_t sliceCount = computeMaxLinearIndex(ratio); 499 Location loc = op->getLoc(); 500 // Prepare the result vector. 501 Value result = rewriter.create<arith::ConstantOp>( 502 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 503 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 504 VectorType newVecType = 505 VectorType::get(*targetShape, dstVecType.getElementType()); 506 for (int64_t i = 0; i < sliceCount; i++) { 507 SmallVector<int64_t, 4> offsets = 508 getVectorOffset(originalSize, *targetShape, i); 509 SmallVector<Value, 4> extractOperands; 510 for (OpOperand &operand : op->getOpOperands()) { 511 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 512 if (!vecType) { 513 extractOperands.push_back(operand.get()); 514 continue; 515 } 516 extractOperands.push_back( 517 rewriter.create<vector::ExtractStridedSliceOp>( 518 loc, operand.get(), offsets, *targetShape, strides)); 519 } 520 Operation *newOp = cloneOpWithOperandsAndTypes( 521 rewriter, loc, op, extractOperands, newVecType); 522 result = rewriter.create<vector::InsertStridedSliceOp>( 523 loc, newOp->getResult(0), result, offsets, strides); 524 } 525 rewriter.replaceOp(op, result); 526 return success(); 527 } 528 529 private: 530 vector::UnrollVectorOptions options; 531 }; 532 533 /// Canonicalize an extract_map using the result of a pointwise operation. 534 /// Transforms: 535 /// %v = arith.addf %a, %b : vector32xf32> 536 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 537 /// to: 538 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 539 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 540 /// %dv = arith.addf %da, %db : vector<1xf32> 541 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 542 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 543 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 544 PatternRewriter &rewriter) const override { 545 Operation *definedOp = extract.getVector().getDefiningOp(); 546 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 547 definedOp->getNumResults() != 1) 548 return failure(); 549 Location loc = extract.getLoc(); 550 SmallVector<Value, 4> extractOperands; 551 for (OpOperand &operand : definedOp->getOpOperands()) { 552 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 553 if (!vecType) { 554 extractOperands.push_back(operand.get()); 555 continue; 556 } 557 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 558 loc, 559 VectorType::get(extract.getResultType().getShape(), 560 vecType.getElementType()), 561 operand.get(), extract.getIds())); 562 } 563 Operation *newOp = cloneOpWithOperandsAndTypes( 564 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 565 rewriter.replaceOp(extract, newOp->getResult(0)); 566 return success(); 567 } 568 }; 569 570 /// Canonicalize an extract_map using the result of a contract operation. 571 /// This propagate the extract_map to operands. 572 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 573 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 574 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 575 PatternRewriter &rewriter) const override { 576 Operation *definedOp = extract.getVector().getDefiningOp(); 577 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 578 if (!contract) 579 return failure(); 580 Location loc = contract.getLoc(); 581 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 582 AffineMap affineMap = contract.getIndexingMapsArray()[accIndex]; 583 // Create a map of the dimensions distributed based on the acc affine map. 584 // Only parallel dimensions are being distributed, reduction dimensions are 585 // untouched. 586 DenseMap<int64_t, int64_t> map; 587 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 588 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 589 SmallVector<Value, 4> extractOperands; 590 for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) { 591 // For each operands calculate the new vector type after distribution. 592 Value operand = contract->getOperand(it.index()); 593 auto vecType = operand.getType().cast<VectorType>(); 594 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 595 vecType.getShape().end()); 596 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 597 unsigned dim = it.value().getDimPosition(i); 598 auto distributedDim = map.find(dim); 599 // If the dimension is not in the map it means it is a reduction and 600 // doesn't get distributed. 601 if (distributedDim == map.end()) 602 continue; 603 operandShape[i] = distributedDim->second; 604 } 605 VectorType newVecType = 606 VectorType::get(operandShape, vecType.getElementType()); 607 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 608 loc, newVecType, operand, extract.getIds())); 609 } 610 Operation *newOp = 611 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 612 extract.getResult().getType()); 613 rewriter.replaceOp(extract, newOp->getResult(0)); 614 return success(); 615 } 616 }; 617 618 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 619 /// TransferRead. 620 /// Example: 621 /// ``` 622 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 623 /// memref<64x64x64xf32>, vector<64x4x32xf32> 624 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 625 /// ``` 626 /// to: 627 /// ``` 628 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 629 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 630 /// memref<64x64x64xf32>, vector<2x4x1xf32> 631 /// ``` 632 struct TransferReadExtractPattern 633 : public OpRewritePattern<vector::TransferReadOp> { 634 TransferReadExtractPattern(MLIRContext *context) 635 : OpRewritePattern<vector::TransferReadOp>(context) {} 636 LogicalResult matchAndRewrite(vector::TransferReadOp read, 637 PatternRewriter &rewriter) const override { 638 // TODO: support 0-d corner case. 639 if (read.getTransferRank() == 0) 640 return failure(); 641 642 if (!read.getResult().hasOneUse()) 643 return failure(); 644 auto extract = 645 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 646 if (!extract) 647 return failure(); 648 if (read.getMask()) 649 return failure(); 650 651 SmallVector<Value, 4> indices(read.getIndices().begin(), 652 read.getIndices().end()); 653 AffineMap indexMap = extract.map().compose(read.getPermutationMap()); 654 unsigned idCount = 0; 655 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 656 for (auto it : 657 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 658 AffineExpr d0, d1; 659 bindDims(read.getContext(), d0, d1); 660 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 661 if (!indexExpr) 662 continue; 663 unsigned indexPos = indexExpr.getPosition(); 664 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 665 auto scale = getAffineConstantExpr( 666 extract.getResultType().getDimSize(vectorPos), read.getContext()); 667 indices[indexPos] = makeComposedAffineApply( 668 rewriter, read.getLoc(), d0 + scale * d1, 669 {indices[indexPos], extract.getIds()[idCount++]}); 670 } 671 Value newRead = lb.create<vector::TransferReadOp>( 672 extract.getType(), read.getSource(), indices, 673 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 674 read.getInBoundsAttr()); 675 Value dest = lb.create<arith::ConstantOp>( 676 read.getType(), rewriter.getZeroAttr(read.getType())); 677 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds()); 678 rewriter.replaceOp(read, newRead); 679 return success(); 680 } 681 }; 682 683 struct TransferWriteInsertPattern 684 : public OpRewritePattern<vector::TransferWriteOp> { 685 TransferWriteInsertPattern(MLIRContext *context) 686 : OpRewritePattern<vector::TransferWriteOp>(context) {} 687 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 688 PatternRewriter &rewriter) const override { 689 // TODO: support 0-d corner case. 690 if (write.getTransferRank() == 0) 691 return failure(); 692 693 auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>(); 694 if (!insert) 695 return failure(); 696 if (write.getMask()) 697 return failure(); 698 SmallVector<Value, 4> indices(write.getIndices().begin(), 699 write.getIndices().end()); 700 AffineMap indexMap = insert.map().compose(write.getPermutationMap()); 701 unsigned idCount = 0; 702 Location loc = write.getLoc(); 703 for (auto it : 704 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 705 AffineExpr d0, d1; 706 bindDims(write.getContext(), d0, d1); 707 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 708 if (!indexExpr) 709 continue; 710 unsigned indexPos = indexExpr.getPosition(); 711 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 712 auto scale = getAffineConstantExpr( 713 insert.getSourceVectorType().getDimSize(vectorPos), 714 write.getContext()); 715 indices[indexPos] = makeComposedAffineApply( 716 rewriter, loc, d0 + scale * d1, 717 {indices[indexPos], insert.getIds()[idCount++]}); 718 } 719 rewriter.create<vector::TransferWriteOp>( 720 loc, insert.getVector(), write.getSource(), indices, 721 write.getPermutationMapAttr(), write.getInBoundsAttr()); 722 rewriter.eraseOp(write); 723 return success(); 724 } 725 }; 726 727 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 728 UnrollReductionPattern(MLIRContext *context, 729 const vector::UnrollVectorOptions &options) 730 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1), 731 options(options) {} 732 733 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 734 PatternRewriter &rewriter) const override { 735 Optional<SmallVector<int64_t, 4>> targetShape = 736 getTargetShape(options, reductionOp); 737 if (!targetShape) 738 return failure(); 739 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 740 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; 741 742 // Create unrolled vector reduction. 743 Location loc = reductionOp.getLoc(); 744 Value accumulator = nullptr; 745 for (int64_t i = 0; i < ratio; ++i) { 746 SmallVector<int64_t> offsets = 747 getVectorOffset(originalSize, *targetShape, i); 748 SmallVector<int64_t> strides(offsets.size(), 1); 749 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 750 loc, reductionOp.getVector(), offsets, *targetShape, strides); 751 Operation *newOp = cloneOpWithOperandsAndTypes( 752 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 753 Value result = newOp->getResult(0); 754 755 if (!accumulator) { 756 // This is the first reduction. 757 accumulator = result; 758 } else { 759 // On subsequent reduction, combine with the accumulator. 760 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 761 accumulator, result); 762 } 763 } 764 765 rewriter.replaceOp(reductionOp, accumulator); 766 return success(); 767 } 768 769 private: 770 const vector::UnrollVectorOptions options; 771 }; 772 773 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> { 774 UnrollTranposePattern(MLIRContext *context, 775 const vector::UnrollVectorOptions &options) 776 : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1), 777 options(options) {} 778 LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, 779 PatternRewriter &rewriter) const override { 780 if (tranposeOp.getResultType().getRank() == 0) 781 return failure(); 782 auto targetShape = getTargetShape(options, tranposeOp); 783 if (!targetShape) 784 return failure(); 785 auto originalVectorType = tranposeOp.getResultType(); 786 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 787 Location loc = tranposeOp.getLoc(); 788 ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 789 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 790 int64_t sliceCount = computeMaxLinearIndex(ratio); 791 // Prepare the result vector; 792 Value result = rewriter.create<arith::ConstantOp>( 793 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 794 SmallVector<int64_t> permutation; 795 tranposeOp.getTransp(permutation); 796 for (int64_t i = 0; i < sliceCount; i++) { 797 SmallVector<int64_t, 4> elementOffsets = 798 getVectorOffset(originalSize, *targetShape, i); 799 SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size()); 800 SmallVector<int64_t, 4> permutedShape(elementOffsets.size()); 801 // Compute the source offsets and shape. 802 for (auto &indices : llvm::enumerate(permutation)) { 803 permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 804 permutedShape[indices.value()] = (*targetShape)[indices.index()]; 805 } 806 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 807 loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); 808 Value tranposedSlice = 809 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 810 result = rewriter.create<vector::InsertStridedSliceOp>( 811 loc, tranposedSlice, result, elementOffsets, strides); 812 } 813 rewriter.replaceOp(tranposeOp, result); 814 return success(); 815 } 816 817 private: 818 vector::UnrollVectorOptions options; 819 }; 820 821 } // namespace 822 823 void mlir::vector::populateVectorUnrollPatterns( 824 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 825 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 826 UnrollContractionPattern, UnrollElementwisePattern, 827 UnrollReductionPattern, UnrollMultiReductionPattern, 828 UnrollTranposePattern>(patterns.getContext(), options); 829 } 830 831 void mlir::vector::populatePropagateVectorDistributionPatterns( 832 RewritePatternSet &patterns) { 833 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 834 TransferReadExtractPattern, TransferWriteInsertPattern>( 835 patterns.getContext()); 836 } 837