1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "mlir/Support/MathExtras.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/Support/Debug.h" 22 #include <numeric> 23 24 #define DEBUG_TYPE "vector-unrolling" 25 26 using namespace mlir; 27 using namespace mlir::vector; 28 29 /// During unrolling from `originalShape` to `targetShape` return the offset for 30 /// the slice `index`. 31 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 32 ArrayRef<int64_t> targetShape, 33 int64_t index) { 34 SmallVector<int64_t, 4> dstSliceStrides = 35 computeStrides(originalShape, targetShape); 36 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 37 SmallVector<int64_t, 4> elementOffsets = 38 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 39 return elementOffsets; 40 } 41 42 /// A functor that accomplishes the same thing as `getVectorOffset` but allows 43 /// for reordering the traversal of the dimensions. The order of traversal is 44 /// given in "for loop order" (outer to inner). 45 namespace { 46 class DecomposeShapeIterator { 47 private: 48 SmallVector<int64_t, 4> vectorShape; 49 SmallVector<int64_t> loopOrder; 50 SmallVector<int64_t> sliceStrides; 51 int64_t maxIndexVal{1}; 52 53 public: 54 DecomposeShapeIterator(ArrayRef<int64_t> originalShape, 55 ArrayRef<int64_t> targetShape, 56 ArrayRef<int64_t> loopOrder) 57 : vectorShape(targetShape.begin(), targetShape.end()), 58 loopOrder(loopOrder.begin(), loopOrder.end()), 59 sliceStrides(originalShape.size()) { 60 // Compute the count for each dimension. 61 SmallVector<int64_t> sliceDimCounts(originalShape.size()); 62 for (unsigned r = 0; r < originalShape.size(); ++r) { 63 sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); 64 maxIndexVal *= sliceDimCounts[r]; 65 } 66 67 // Reversing "loop order" gives dimensions from fastest varying to slowest 68 // varying (smallest stride to largest stride). 69 int64_t accum = 1; 70 for (auto idx : llvm::reverse(loopOrder)) { 71 sliceStrides[idx] = accum; 72 accum *= sliceDimCounts[idx]; 73 } 74 } 75 76 // Turn the linear index into a d-tuple based on units of vectors of size 77 // `vectorShape`. The linear index is assumed to represent traversal of the 78 // dimensions based on `order`. 79 SmallVector<int64_t> delinearize(int64_t index) const { 80 // Traverse in for loop order (largest stride to smallest stride). 81 SmallVector<int64_t, 4> vectorOffsets(sliceStrides.size()); 82 for (auto idx : loopOrder) { 83 vectorOffsets[idx] = index / sliceStrides[idx]; 84 index %= sliceStrides[idx]; 85 } 86 return vectorOffsets; 87 } 88 89 int64_t maxIndex() const { return maxIndexVal; } 90 91 /// Return the offset within d-tuple based on the ordering given by 92 /// `loopOrder`. 93 SmallVector<int64_t> getVectorOffset(int64_t index) const { 94 SmallVector<int64_t> vectorOffsets = delinearize(index); 95 SmallVector<int64_t> elementOffsets = 96 computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); 97 return elementOffsets; 98 } 99 }; 100 } // namespace 101 102 /// Compute the indices of the slice `index` for a tranfer op. 103 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, 104 ArrayRef<Value> indices, 105 AffineMap permutationMap, 106 Location loc, 107 OpBuilder &builder) { 108 MLIRContext *ctx = builder.getContext(); 109 auto isBroadcast = [](AffineExpr expr) { 110 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 111 return constExpr.getValue() == 0; 112 return false; 113 }; 114 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 115 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 116 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 117 if (isBroadcast(dim.value())) 118 continue; 119 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 120 auto expr = getAffineDimExpr(0, builder.getContext()) + 121 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 122 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 123 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 124 } 125 return slicedIndices; 126 } 127 128 // Clones `op` into a new operations that takes `operands` and returns 129 // `resultTypes`. 130 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 131 Operation *op, 132 ArrayRef<Value> operands, 133 ArrayRef<Type> resultTypes) { 134 return builder.create(loc, op->getName().getIdentifier(), operands, 135 resultTypes, op->getAttrs()); 136 } 137 138 /// Return the target shape for unrolling for the given `op`. Return llvm::None 139 /// if the op shouldn't be or cannot be unrolled. 140 static Optional<SmallVector<int64_t, 4>> 141 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 142 if (options.filterConstraint && failed(options.filterConstraint(op))) 143 return llvm::None; 144 assert(options.nativeShape && 145 "vector unrolling expects the native shape or native" 146 "shape call back function to be set"); 147 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 148 if (!unrollableVectorOp) 149 return llvm::None; 150 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 151 if (!maybeUnrollShape) 152 return llvm::None; 153 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 154 if (!targetShape) 155 return llvm::None; 156 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 157 if (!maybeShapeRatio || 158 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 159 return llvm::None; 160 return targetShape; 161 } 162 163 static SmallVector<int64_t> 164 getUnrollOrder(unsigned numLoops, Operation *op, 165 const vector::UnrollVectorOptions &options) { 166 SmallVector<int64_t> loopOrder = 167 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); 168 if (options.traversalOrderCallback != nullptr) { 169 Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op); 170 if (order.hasValue()) { 171 loopOrder = std::move(*order); 172 } 173 } 174 return loopOrder; 175 } 176 177 namespace { 178 179 struct UnrollTransferReadPattern 180 : public OpRewritePattern<vector::TransferReadOp> { 181 UnrollTransferReadPattern(MLIRContext *context, 182 const vector::UnrollVectorOptions &options) 183 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 184 options(options) {} 185 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 186 PatternRewriter &rewriter) const override { 187 // TODO: support 0-d corner case. 188 if (readOp.getTransferRank() == 0) 189 return failure(); 190 if (readOp.getMask()) 191 return failure(); 192 auto targetShape = getTargetShape(options, readOp); 193 if (!targetShape) 194 return failure(); 195 auto sourceVectorType = readOp.getVectorType(); 196 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 197 Location loc = readOp.getLoc(); 198 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 199 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 200 201 // Prepare the result vector; 202 Value result = rewriter.create<arith::ConstantOp>( 203 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 204 auto targetType = 205 VectorType::get(*targetShape, sourceVectorType.getElementType()); 206 SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(), 207 readOp.getIndices().end()); 208 209 SmallVector<int64_t> loopOrder = 210 getUnrollOrder(ratio.size(), readOp, options); 211 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 212 loopOrder); 213 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 214 SmallVector<int64_t, 4> elementOffsets = 215 indexToOffsets.getVectorOffset(i); 216 SmallVector<Value, 4> indices = 217 sliceTransferIndices(elementOffsets, originalIndices, 218 readOp.getPermutationMap(), loc, rewriter); 219 auto slicedRead = rewriter.create<vector::TransferReadOp>( 220 loc, targetType, readOp.getSource(), indices, 221 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 222 readOp.getInBoundsAttr()); 223 224 result = rewriter.create<vector::InsertStridedSliceOp>( 225 loc, slicedRead, result, elementOffsets, strides); 226 } 227 rewriter.replaceOp(readOp, result); 228 return success(); 229 } 230 231 private: 232 vector::UnrollVectorOptions options; 233 }; 234 235 struct UnrollTransferWritePattern 236 : public OpRewritePattern<vector::TransferWriteOp> { 237 UnrollTransferWritePattern(MLIRContext *context, 238 const vector::UnrollVectorOptions &options) 239 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 240 options(options) {} 241 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 242 PatternRewriter &rewriter) const override { 243 // TODO: support 0-d corner case. 244 if (writeOp.getTransferRank() == 0) 245 return failure(); 246 247 if (writeOp.getMask()) 248 return failure(); 249 auto targetShape = getTargetShape(options, writeOp); 250 if (!targetShape) 251 return failure(); 252 auto sourceVectorType = writeOp.getVectorType(); 253 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 254 Location loc = writeOp.getLoc(); 255 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 256 SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(), 257 writeOp.getIndices().end()); 258 259 SmallVector<int64_t> loopOrder = 260 getUnrollOrder(originalIndices.size(), writeOp, options); 261 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 262 loopOrder); 263 Value resultTensor; 264 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 265 SmallVector<int64_t, 4> elementOffsets = 266 indexToOffsets.getVectorOffset(i); 267 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 268 loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 269 SmallVector<Value, 4> indices = 270 sliceTransferIndices(elementOffsets, originalIndices, 271 writeOp.getPermutationMap(), loc, rewriter); 272 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 273 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 274 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 275 // For the tensor case update the destination for the next transfer write. 276 if (!slicedWrite->getResults().empty()) 277 resultTensor = slicedWrite->getResult(0); 278 } 279 if (resultTensor) 280 rewriter.replaceOp(writeOp, resultTensor); 281 else 282 rewriter.eraseOp(writeOp); 283 return success(); 284 } 285 286 private: 287 vector::UnrollVectorOptions options; 288 }; 289 290 struct OffsetMapInfo { 291 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 292 293 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 294 295 static unsigned getHashValue(const SmallVector<int64_t> &v) { 296 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 297 } 298 299 static bool isEqual(const SmallVector<int64_t> &lhs, 300 const SmallVector<int64_t> &rhs) { 301 return lhs == rhs; 302 } 303 }; 304 305 struct UnrollContractionPattern 306 : public OpRewritePattern<vector::ContractionOp> { 307 UnrollContractionPattern(MLIRContext *context, 308 const vector::UnrollVectorOptions &options) 309 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 310 options(options) {} 311 312 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 313 PatternRewriter &rewriter) const override { 314 auto targetShape = getTargetShape(options, contractOp); 315 if (!targetShape) 316 return failure(); 317 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 318 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 319 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 320 321 Location loc = contractOp.getLoc(); 322 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 323 AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; 324 llvm::MapVector< 325 SmallVector<int64_t>, Value, 326 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 327 accCache; 328 329 SmallVector<int64_t> loopOrder = getUnrollOrder( 330 contractOp.getIndexingMaps().size(), contractOp, options); 331 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 332 loopOrder); 333 const int64_t sliceCount = indexToOffsets.maxIndex(); 334 for (int64_t i = 0; i < sliceCount; i++) { 335 SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i); 336 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 337 338 // Helper to coompute the new shape of each operand and extract the slice. 339 auto extractOperand = [&](unsigned index, Value operand, 340 AffineMap permutationMap, 341 ArrayRef<int64_t> operandOffets) { 342 SmallVector<int64_t> operandShape = applyPermutationMap( 343 permutationMap, ArrayRef<int64_t>(*targetShape)); 344 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 345 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 346 loc, operand, operandOffets, operandShape, operandStrides); 347 }; 348 349 // Extract the new lhs operand. 350 AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; 351 SmallVector<int64_t> lhsOffets = 352 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 353 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 354 // If there is a mask associated to lhs, extract it as well. 355 if (slicesOperands.size() > 3) 356 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, 357 lhsOffets); 358 359 // Extract the new rhs operand. 360 AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; 361 SmallVector<int64_t> rhsOffets = 362 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 363 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 364 // If there is a mask associated to rhs, extract it as well. 365 if (slicesOperands.size() > 4) 366 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, 367 rhsOffets); 368 369 AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; 370 SmallVector<int64_t> accOffets = 371 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 372 // If a version of the accumulator has already been computed, use it 373 // otherwise extract the first version from the original operand. 374 auto accIt = accCache.find(accOffets); 375 if (accIt != accCache.end()) 376 slicesOperands[2] = accIt->second; 377 else 378 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 379 380 SmallVector<int64_t> dstShape = 381 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 382 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 383 Operation *newOp = cloneOpWithOperandsAndTypes( 384 rewriter, loc, contractOp, slicesOperands, targetType); 385 386 SmallVector<int64_t> dstOffets = 387 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 388 // Save the accumulated value untill all the loops are unrolled since 389 // reduction loop keep updating the accumulator. 390 accCache[dstOffets] = newOp->getResult(0); 391 } 392 // Assemble back the accumulator into a single vector. 393 Value result = rewriter.create<arith::ConstantOp>( 394 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 395 for (const auto &it : accCache) { 396 SmallVector<int64_t> dstStrides(it.first.size(), 1); 397 result = rewriter.create<vector::InsertStridedSliceOp>( 398 loc, it.second, result, it.first, dstStrides); 399 } 400 rewriter.replaceOp(contractOp, result); 401 return success(); 402 } 403 404 private: 405 vector::UnrollVectorOptions options; 406 }; 407 408 struct UnrollMultiReductionPattern 409 : public OpRewritePattern<vector::MultiDimReductionOp> { 410 UnrollMultiReductionPattern(MLIRContext *context, 411 const vector::UnrollVectorOptions &options) 412 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1), 413 options(options) {} 414 415 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 416 PatternRewriter &rewriter) const override { 417 Optional<SmallVector<int64_t, 4>> targetShape = 418 getTargetShape(options, reductionOp); 419 if (!targetShape) 420 return failure(); 421 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll(); 422 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 423 llvm::MapVector< 424 SmallVector<int64_t>, Value, 425 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 426 accCache; 427 // Compute shape ratio of 'shape' and 'sizes'. 428 int64_t sliceCount = computeMaxLinearIndex(ratio); 429 Location loc = reductionOp.getLoc(); 430 for (int64_t i = 0; i < sliceCount; i++) { 431 SmallVector<int64_t, 4> offsets = 432 getVectorOffset(originalSize, *targetShape, i); 433 434 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1); 435 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 436 loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); 437 438 SmallVector<int64_t> dstShape; 439 SmallVector<int64_t> destOffset; 440 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 441 if (!reductionOp.isReducedDim(i)) { 442 destOffset.push_back(offsets[i]); 443 dstShape.push_back((*targetShape)[i]); 444 } 445 } 446 auto targetType = VectorType::get( 447 dstShape, reductionOp.getSourceVectorType().getElementType()); 448 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 449 slicedOperand, targetType); 450 Value result = newOp->getResult(0); 451 // Save the accumulated value until all the loops are unrolled since 452 // reduction loop keeps updating the accumulator. 453 auto accIt = accCache.find(destOffset); 454 if (accIt != accCache.end()) 455 result = makeArithReduction(rewriter, loc, reductionOp.getKind(), 456 result, accIt->second); 457 accCache[destOffset] = result; 458 } 459 // Assemble back the accumulator into a single vector. 460 Value result = rewriter.create<arith::ConstantOp>( 461 loc, reductionOp.getDestType(), 462 rewriter.getZeroAttr(reductionOp.getDestType())); 463 for (const auto &it : accCache) { 464 SmallVector<int64_t> dstStrides(it.first.size(), 1); 465 result = rewriter.create<vector::InsertStridedSliceOp>( 466 loc, it.second, result, it.first, dstStrides); 467 } 468 rewriter.replaceOp(reductionOp, result); 469 return success(); 470 } 471 472 private: 473 vector::UnrollVectorOptions options; 474 }; 475 476 struct UnrollElementwisePattern : public RewritePattern { 477 UnrollElementwisePattern(MLIRContext *context, 478 const vector::UnrollVectorOptions &options) 479 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 480 options(options) {} 481 LogicalResult matchAndRewrite(Operation *op, 482 PatternRewriter &rewriter) const override { 483 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 484 return failure(); 485 auto targetShape = getTargetShape(options, op); 486 if (!targetShape) 487 return failure(); 488 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 489 SmallVector<int64_t, 4> originalSize = 490 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 491 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 492 int64_t sliceCount = computeMaxLinearIndex(ratio); 493 Location loc = op->getLoc(); 494 // Prepare the result vector. 495 Value result = rewriter.create<arith::ConstantOp>( 496 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 497 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 498 VectorType newVecType = 499 VectorType::get(*targetShape, dstVecType.getElementType()); 500 for (int64_t i = 0; i < sliceCount; i++) { 501 SmallVector<int64_t, 4> offsets = 502 getVectorOffset(originalSize, *targetShape, i); 503 SmallVector<Value, 4> extractOperands; 504 for (OpOperand &operand : op->getOpOperands()) { 505 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 506 if (!vecType) { 507 extractOperands.push_back(operand.get()); 508 continue; 509 } 510 extractOperands.push_back( 511 rewriter.create<vector::ExtractStridedSliceOp>( 512 loc, operand.get(), offsets, *targetShape, strides)); 513 } 514 Operation *newOp = cloneOpWithOperandsAndTypes( 515 rewriter, loc, op, extractOperands, newVecType); 516 result = rewriter.create<vector::InsertStridedSliceOp>( 517 loc, newOp->getResult(0), result, offsets, strides); 518 } 519 rewriter.replaceOp(op, result); 520 return success(); 521 } 522 523 private: 524 vector::UnrollVectorOptions options; 525 }; 526 527 /// Canonicalize an extract_map using the result of a pointwise operation. 528 /// Transforms: 529 /// %v = arith.addf %a, %b : vector32xf32> 530 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 531 /// to: 532 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 533 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 534 /// %dv = arith.addf %da, %db : vector<1xf32> 535 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 536 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 537 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 538 PatternRewriter &rewriter) const override { 539 Operation *definedOp = extract.getVector().getDefiningOp(); 540 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 541 definedOp->getNumResults() != 1) 542 return failure(); 543 Location loc = extract.getLoc(); 544 SmallVector<Value, 4> extractOperands; 545 for (OpOperand &operand : definedOp->getOpOperands()) { 546 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 547 if (!vecType) { 548 extractOperands.push_back(operand.get()); 549 continue; 550 } 551 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 552 loc, 553 VectorType::get(extract.getResultType().getShape(), 554 vecType.getElementType()), 555 operand.get(), extract.getIds())); 556 } 557 Operation *newOp = cloneOpWithOperandsAndTypes( 558 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 559 rewriter.replaceOp(extract, newOp->getResult(0)); 560 return success(); 561 } 562 }; 563 564 /// Canonicalize an extract_map using the result of a contract operation. 565 /// This propagate the extract_map to operands. 566 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 567 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 568 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 569 PatternRewriter &rewriter) const override { 570 Operation *definedOp = extract.getVector().getDefiningOp(); 571 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 572 if (!contract) 573 return failure(); 574 Location loc = contract.getLoc(); 575 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 576 AffineMap affineMap = contract.getIndexingMaps()[accIndex]; 577 // Create a map of the dimensions distributed based on the acc affine map. 578 // Only parallel dimensions are being distributed, reduction dimensions are 579 // untouched. 580 DenseMap<int64_t, int64_t> map; 581 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 582 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 583 SmallVector<Value, 4> extractOperands; 584 for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) { 585 // For each operands calculate the new vector type after distribution. 586 Value operand = contract->getOperand(it.index()); 587 auto vecType = operand.getType().cast<VectorType>(); 588 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 589 vecType.getShape().end()); 590 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 591 unsigned dim = it.value().getDimPosition(i); 592 auto distributedDim = map.find(dim); 593 // If the dimension is not in the map it means it is a reduction and 594 // doesn't get distributed. 595 if (distributedDim == map.end()) 596 continue; 597 operandShape[i] = distributedDim->second; 598 } 599 VectorType newVecType = 600 VectorType::get(operandShape, vecType.getElementType()); 601 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 602 loc, newVecType, operand, extract.getIds())); 603 } 604 Operation *newOp = 605 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 606 extract.getResult().getType()); 607 rewriter.replaceOp(extract, newOp->getResult(0)); 608 return success(); 609 } 610 }; 611 612 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 613 /// TransferRead. 614 /// Example: 615 /// ``` 616 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 617 /// memref<64x64x64xf32>, vector<64x4x32xf32> 618 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 619 /// ``` 620 /// to: 621 /// ``` 622 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 623 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 624 /// memref<64x64x64xf32>, vector<2x4x1xf32> 625 /// ``` 626 struct TransferReadExtractPattern 627 : public OpRewritePattern<vector::TransferReadOp> { 628 TransferReadExtractPattern(MLIRContext *context) 629 : OpRewritePattern<vector::TransferReadOp>(context) {} 630 LogicalResult matchAndRewrite(vector::TransferReadOp read, 631 PatternRewriter &rewriter) const override { 632 // TODO: support 0-d corner case. 633 if (read.getTransferRank() == 0) 634 return failure(); 635 636 if (!read.getResult().hasOneUse()) 637 return failure(); 638 auto extract = 639 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 640 if (!extract) 641 return failure(); 642 if (read.getMask()) 643 return failure(); 644 645 SmallVector<Value, 4> indices(read.getIndices().begin(), 646 read.getIndices().end()); 647 AffineMap indexMap = extract.map().compose(read.getPermutationMap()); 648 unsigned idCount = 0; 649 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 650 for (auto it : 651 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 652 AffineExpr d0, d1; 653 bindDims(read.getContext(), d0, d1); 654 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 655 if (!indexExpr) 656 continue; 657 unsigned indexPos = indexExpr.getPosition(); 658 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 659 auto scale = getAffineConstantExpr( 660 extract.getResultType().getDimSize(vectorPos), read.getContext()); 661 indices[indexPos] = makeComposedAffineApply( 662 rewriter, read.getLoc(), d0 + scale * d1, 663 {indices[indexPos], extract.getIds()[idCount++]}); 664 } 665 Value newRead = lb.create<vector::TransferReadOp>( 666 extract.getType(), read.getSource(), indices, 667 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 668 read.getInBoundsAttr()); 669 Value dest = lb.create<arith::ConstantOp>( 670 read.getType(), rewriter.getZeroAttr(read.getType())); 671 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds()); 672 rewriter.replaceOp(read, newRead); 673 return success(); 674 } 675 }; 676 677 struct TransferWriteInsertPattern 678 : public OpRewritePattern<vector::TransferWriteOp> { 679 TransferWriteInsertPattern(MLIRContext *context) 680 : OpRewritePattern<vector::TransferWriteOp>(context) {} 681 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 682 PatternRewriter &rewriter) const override { 683 // TODO: support 0-d corner case. 684 if (write.getTransferRank() == 0) 685 return failure(); 686 687 auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>(); 688 if (!insert) 689 return failure(); 690 if (write.getMask()) 691 return failure(); 692 SmallVector<Value, 4> indices(write.getIndices().begin(), 693 write.getIndices().end()); 694 AffineMap indexMap = insert.map().compose(write.getPermutationMap()); 695 unsigned idCount = 0; 696 Location loc = write.getLoc(); 697 for (auto it : 698 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 699 AffineExpr d0, d1; 700 bindDims(write.getContext(), d0, d1); 701 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 702 if (!indexExpr) 703 continue; 704 unsigned indexPos = indexExpr.getPosition(); 705 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 706 auto scale = getAffineConstantExpr( 707 insert.getSourceVectorType().getDimSize(vectorPos), 708 write.getContext()); 709 indices[indexPos] = makeComposedAffineApply( 710 rewriter, loc, d0 + scale * d1, 711 {indices[indexPos], insert.getIds()[idCount++]}); 712 } 713 rewriter.create<vector::TransferWriteOp>( 714 loc, insert.getVector(), write.getSource(), indices, 715 write.getPermutationMapAttr(), write.getInBoundsAttr()); 716 rewriter.eraseOp(write); 717 return success(); 718 } 719 }; 720 721 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 722 UnrollReductionPattern(MLIRContext *context, 723 const vector::UnrollVectorOptions &options) 724 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1), 725 options(options) {} 726 727 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 728 PatternRewriter &rewriter) const override { 729 Optional<SmallVector<int64_t, 4>> targetShape = 730 getTargetShape(options, reductionOp); 731 if (!targetShape) 732 return failure(); 733 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 734 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; 735 736 // Create unrolled vector reduction. 737 Location loc = reductionOp.getLoc(); 738 Value accumulator = nullptr; 739 for (int64_t i = 0; i < ratio; ++i) { 740 SmallVector<int64_t> offsets = 741 getVectorOffset(originalSize, *targetShape, i); 742 SmallVector<int64_t> strides(offsets.size(), 1); 743 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 744 loc, reductionOp.getVector(), offsets, *targetShape, strides); 745 Operation *newOp = cloneOpWithOperandsAndTypes( 746 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 747 Value result = newOp->getResult(0); 748 749 if (!accumulator) { 750 // This is the first reduction. 751 accumulator = result; 752 } else { 753 // On subsequent reduction, combine with the accumulator. 754 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 755 accumulator, result); 756 } 757 } 758 759 rewriter.replaceOp(reductionOp, accumulator); 760 return success(); 761 } 762 763 private: 764 const vector::UnrollVectorOptions options; 765 }; 766 767 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> { 768 UnrollTranposePattern(MLIRContext *context, 769 const vector::UnrollVectorOptions &options) 770 : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1), 771 options(options) {} 772 LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, 773 PatternRewriter &rewriter) const override { 774 if (tranposeOp.getResultType().getRank() == 0) 775 return failure(); 776 auto targetShape = getTargetShape(options, tranposeOp); 777 if (!targetShape) 778 return failure(); 779 auto originalVectorType = tranposeOp.getResultType(); 780 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 781 Location loc = tranposeOp.getLoc(); 782 ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 783 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 784 int64_t sliceCount = computeMaxLinearIndex(ratio); 785 // Prepare the result vector; 786 Value result = rewriter.create<arith::ConstantOp>( 787 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 788 SmallVector<int64_t> permutation; 789 tranposeOp.getTransp(permutation); 790 for (int64_t i = 0; i < sliceCount; i++) { 791 SmallVector<int64_t, 4> elementOffsets = 792 getVectorOffset(originalSize, *targetShape, i); 793 SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size()); 794 SmallVector<int64_t, 4> permutedShape(elementOffsets.size()); 795 // Compute the source offsets and shape. 796 for (auto &indices : llvm::enumerate(permutation)) { 797 permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 798 permutedShape[indices.value()] = (*targetShape)[indices.index()]; 799 } 800 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 801 loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); 802 Value tranposedSlice = 803 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 804 result = rewriter.create<vector::InsertStridedSliceOp>( 805 loc, tranposedSlice, result, elementOffsets, strides); 806 } 807 rewriter.replaceOp(tranposeOp, result); 808 return success(); 809 } 810 811 private: 812 vector::UnrollVectorOptions options; 813 }; 814 815 } // namespace 816 817 void mlir::vector::populateVectorUnrollPatterns( 818 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 819 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 820 UnrollContractionPattern, UnrollElementwisePattern, 821 UnrollReductionPattern, UnrollMultiReductionPattern, 822 UnrollTranposePattern>(patterns.getContext(), options); 823 } 824 825 void mlir::vector::populatePropagateVectorDistributionPatterns( 826 RewritePatternSet &patterns) { 827 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 828 TransferReadExtractPattern, TransferWriteInsertPattern>( 829 patterns.getContext()); 830 } 831