1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "mlir/Support/MathExtras.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/STLExtras.h" 21 #include <numeric> 22 23 #define DEBUG_TYPE "vector-unrolling" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 28 /// During unrolling from `originalShape` to `targetShape` return the offset for 29 /// the slice `index`. 30 static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape, 31 ArrayRef<int64_t> targetShape, 32 int64_t index) { 33 SmallVector<int64_t, 4> dstSliceStrides = 34 computeStrides(originalShape, targetShape); 35 SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index); 36 SmallVector<int64_t, 4> elementOffsets = 37 computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); 38 return elementOffsets; 39 } 40 41 /// A functor that accomplishes the same thing as `getVectorOffset` but allows 42 /// for reordering the traversal of the dimensions. The order of traversal is 43 /// given in "for loop order" (outer to inner). 44 namespace { 45 class DecomposeShapeIterator { 46 private: 47 SmallVector<int64_t, 4> vectorShape; 48 SmallVector<int64_t> loopOrder; 49 SmallVector<int64_t> sliceStrides; 50 int64_t maxIndexVal{1}; 51 52 public: 53 DecomposeShapeIterator(ArrayRef<int64_t> originalShape, 54 ArrayRef<int64_t> targetShape, 55 ArrayRef<int64_t> loopOrder) 56 : vectorShape(targetShape.begin(), targetShape.end()), 57 loopOrder(loopOrder.begin(), loopOrder.end()), 58 sliceStrides(originalShape.size()) { 59 assert(originalShape.size() == targetShape.size()); 60 assert(loopOrder.size() == targetShape.size()); 61 62 // Compute the count for each dimension. 63 SmallVector<int64_t> sliceDimCounts(originalShape.size()); 64 for (unsigned r = 0; r < originalShape.size(); ++r) { 65 sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); 66 maxIndexVal *= sliceDimCounts[r]; 67 } 68 69 // Reversing "loop order" gives dimensions from fastest varying to slowest 70 // varying (smallest stride to largest stride). 71 int64_t accum = 1; 72 for (auto idx : llvm::reverse(loopOrder)) { 73 sliceStrides[idx] = accum; 74 accum *= sliceDimCounts[idx]; 75 } 76 } 77 78 // Turn the linear index into a d-tuple based on units of vectors of size 79 // `vectorShape`. The linear index is assumed to represent traversal of the 80 // dimensions based on `order`. 81 SmallVector<int64_t> delinearize(int64_t index) const { 82 // Traverse in for loop order (largest stride to smallest stride). 83 SmallVector<int64_t> vectorOffsets(sliceStrides.size()); 84 for (auto idx : loopOrder) { 85 vectorOffsets[idx] = index / sliceStrides[idx]; 86 index %= sliceStrides[idx]; 87 } 88 return vectorOffsets; 89 } 90 91 int64_t maxIndex() const { return maxIndexVal; } 92 93 /// Return the offset within d-tuple based on the ordering given by 94 /// `loopOrder`. 95 SmallVector<int64_t> getVectorOffset(int64_t index) const { 96 SmallVector<int64_t> vectorOffsets = delinearize(index); 97 SmallVector<int64_t> elementOffsets = 98 computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); 99 return elementOffsets; 100 } 101 }; 102 } // namespace 103 104 /// Compute the indices of the slice `index` for a tranfer op. 105 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, 106 ArrayRef<Value> indices, 107 AffineMap permutationMap, 108 Location loc, 109 OpBuilder &builder) { 110 MLIRContext *ctx = builder.getContext(); 111 auto isBroadcast = [](AffineExpr expr) { 112 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) 113 return constExpr.getValue() == 0; 114 return false; 115 }; 116 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 117 SmallVector<Value> slicedIndices(indices.begin(), indices.end()); 118 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 119 if (isBroadcast(dim.value())) 120 continue; 121 unsigned pos = dim.value().cast<AffineDimExpr>().getPosition(); 122 auto expr = getAffineDimExpr(0, builder.getContext()) + 123 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 124 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 125 slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]); 126 } 127 return slicedIndices; 128 } 129 130 // Clones `op` into a new operations that takes `operands` and returns 131 // `resultTypes`. 132 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 133 Operation *op, 134 ArrayRef<Value> operands, 135 ArrayRef<Type> resultTypes) { 136 return builder.create(loc, op->getName().getIdentifier(), operands, 137 resultTypes, op->getAttrs()); 138 } 139 140 /// Return the target shape for unrolling for the given `op`. Return llvm::None 141 /// if the op shouldn't be or cannot be unrolled. 142 static Optional<SmallVector<int64_t, 4>> 143 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 144 if (options.filterConstraint && failed(options.filterConstraint(op))) 145 return llvm::None; 146 assert(options.nativeShape && 147 "vector unrolling expects the native shape or native" 148 "shape call back function to be set"); 149 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 150 if (!unrollableVectorOp) 151 return llvm::None; 152 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 153 if (!maybeUnrollShape) 154 return llvm::None; 155 Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op); 156 if (!targetShape) 157 return llvm::None; 158 auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); 159 if (!maybeShapeRatio || 160 llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) 161 return llvm::None; 162 return targetShape; 163 } 164 165 static SmallVector<int64_t> 166 getUnrollOrder(unsigned numLoops, Operation *op, 167 const vector::UnrollVectorOptions &options) { 168 SmallVector<int64_t> loopOrder = 169 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); 170 if (options.traversalOrderCallback != nullptr) { 171 Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op); 172 if (order) { 173 loopOrder = std::move(*order); 174 } 175 } 176 return loopOrder; 177 } 178 179 namespace { 180 181 struct UnrollTransferReadPattern 182 : public OpRewritePattern<vector::TransferReadOp> { 183 UnrollTransferReadPattern(MLIRContext *context, 184 const vector::UnrollVectorOptions &options) 185 : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1), 186 options(options) {} 187 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 188 PatternRewriter &rewriter) const override { 189 // TODO: support 0-d corner case. 190 if (readOp.getTransferRank() == 0) 191 return failure(); 192 if (readOp.getMask()) 193 return failure(); 194 auto targetShape = getTargetShape(options, readOp); 195 if (!targetShape) 196 return failure(); 197 auto sourceVectorType = readOp.getVectorType(); 198 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 199 Location loc = readOp.getLoc(); 200 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 201 202 // Prepare the result vector; 203 Value result = rewriter.create<arith::ConstantOp>( 204 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 205 auto targetType = 206 VectorType::get(*targetShape, sourceVectorType.getElementType()); 207 SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(), 208 readOp.getIndices().end()); 209 210 SmallVector<int64_t> loopOrder = 211 getUnrollOrder(originalSize.size(), readOp, options); 212 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 213 loopOrder); 214 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 215 SmallVector<int64_t, 4> elementOffsets = 216 indexToOffsets.getVectorOffset(i); 217 SmallVector<Value, 4> indices = 218 sliceTransferIndices(elementOffsets, originalIndices, 219 readOp.getPermutationMap(), loc, rewriter); 220 auto slicedRead = rewriter.create<vector::TransferReadOp>( 221 loc, targetType, readOp.getSource(), indices, 222 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 223 readOp.getInBoundsAttr()); 224 225 result = rewriter.create<vector::InsertStridedSliceOp>( 226 loc, slicedRead, result, elementOffsets, strides); 227 } 228 rewriter.replaceOp(readOp, result); 229 return success(); 230 } 231 232 private: 233 vector::UnrollVectorOptions options; 234 }; 235 236 struct UnrollTransferWritePattern 237 : public OpRewritePattern<vector::TransferWriteOp> { 238 UnrollTransferWritePattern(MLIRContext *context, 239 const vector::UnrollVectorOptions &options) 240 : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1), 241 options(options) {} 242 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 243 PatternRewriter &rewriter) const override { 244 // TODO: support 0-d corner case. 245 if (writeOp.getTransferRank() == 0) 246 return failure(); 247 248 if (writeOp.getMask()) 249 return failure(); 250 auto targetShape = getTargetShape(options, writeOp); 251 if (!targetShape) 252 return failure(); 253 auto sourceVectorType = writeOp.getVectorType(); 254 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 255 Location loc = writeOp.getLoc(); 256 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 257 SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(), 258 writeOp.getIndices().end()); 259 260 SmallVector<int64_t> loopOrder = 261 getUnrollOrder(originalSize.size(), writeOp, options); 262 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 263 loopOrder); 264 Value resultTensor; 265 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { 266 SmallVector<int64_t, 4> elementOffsets = 267 indexToOffsets.getVectorOffset(i); 268 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 269 loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 270 SmallVector<Value, 4> indices = 271 sliceTransferIndices(elementOffsets, originalIndices, 272 writeOp.getPermutationMap(), loc, rewriter); 273 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 274 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 275 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 276 // For the tensor case update the destination for the next transfer write. 277 if (!slicedWrite->getResults().empty()) 278 resultTensor = slicedWrite->getResult(0); 279 } 280 if (resultTensor) 281 rewriter.replaceOp(writeOp, resultTensor); 282 else 283 rewriter.eraseOp(writeOp); 284 return success(); 285 } 286 287 private: 288 vector::UnrollVectorOptions options; 289 }; 290 291 struct OffsetMapInfo { 292 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 293 294 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 295 296 static unsigned getHashValue(const SmallVector<int64_t> &v) { 297 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 298 } 299 300 static bool isEqual(const SmallVector<int64_t> &lhs, 301 const SmallVector<int64_t> &rhs) { 302 return lhs == rhs; 303 } 304 }; 305 306 struct UnrollContractionPattern 307 : public OpRewritePattern<vector::ContractionOp> { 308 UnrollContractionPattern(MLIRContext *context, 309 const vector::UnrollVectorOptions &options) 310 : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1), 311 options(options) {} 312 313 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 314 PatternRewriter &rewriter) const override { 315 auto targetShape = getTargetShape(options, contractOp); 316 if (!targetShape) 317 return failure(); 318 auto dstVecType = contractOp.getResultType().cast<VectorType>(); 319 SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll(); 320 321 Location loc = contractOp.getLoc(); 322 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 323 AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; 324 llvm::MapVector< 325 SmallVector<int64_t>, Value, 326 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 327 accCache; 328 329 SmallVector<int64_t> loopOrder = getUnrollOrder( 330 contractOp.getIteratorTypes().size(), contractOp, options); 331 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, 332 loopOrder); 333 const int64_t sliceCount = indexToOffsets.maxIndex(); 334 for (int64_t i = 0; i < sliceCount; i++) { 335 SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i); 336 SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands()); 337 338 // Helper to coompute the new shape of each operand and extract the slice. 339 auto extractOperand = [&](unsigned index, Value operand, 340 AffineMap permutationMap, 341 ArrayRef<int64_t> operandOffets) { 342 SmallVector<int64_t> operandShape = applyPermutationMap( 343 permutationMap, ArrayRef<int64_t>(*targetShape)); 344 SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1); 345 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 346 loc, operand, operandOffets, operandShape, operandStrides); 347 }; 348 349 // Extract the new lhs operand. 350 AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; 351 SmallVector<int64_t> lhsOffets = 352 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 353 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 354 // If there is a mask associated to lhs, extract it as well. 355 if (slicesOperands.size() > 3) 356 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, 357 lhsOffets); 358 359 // Extract the new rhs operand. 360 AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; 361 SmallVector<int64_t> rhsOffets = 362 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 363 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 364 // If there is a mask associated to rhs, extract it as well. 365 if (slicesOperands.size() > 4) 366 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, 367 rhsOffets); 368 369 AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; 370 SmallVector<int64_t> accOffets = 371 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 372 // If a version of the accumulator has already been computed, use it 373 // otherwise extract the first version from the original operand. 374 auto accIt = accCache.find(accOffets); 375 if (accIt != accCache.end()) 376 slicesOperands[2] = accIt->second; 377 else 378 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 379 380 SmallVector<int64_t> dstShape = 381 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 382 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 383 Operation *newOp = cloneOpWithOperandsAndTypes( 384 rewriter, loc, contractOp, slicesOperands, targetType); 385 386 SmallVector<int64_t> dstOffets = 387 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 388 // Save the accumulated value untill all the loops are unrolled since 389 // reduction loop keep updating the accumulator. 390 accCache[dstOffets] = newOp->getResult(0); 391 } 392 // Assemble back the accumulator into a single vector. 393 Value result = rewriter.create<arith::ConstantOp>( 394 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 395 for (const auto &it : accCache) { 396 SmallVector<int64_t> dstStrides(it.first.size(), 1); 397 result = rewriter.create<vector::InsertStridedSliceOp>( 398 loc, it.second, result, it.first, dstStrides); 399 } 400 rewriter.replaceOp(contractOp, result); 401 return success(); 402 } 403 404 private: 405 vector::UnrollVectorOptions options; 406 }; 407 408 struct UnrollMultiReductionPattern 409 : public OpRewritePattern<vector::MultiDimReductionOp> { 410 UnrollMultiReductionPattern(MLIRContext *context, 411 const vector::UnrollVectorOptions &options) 412 : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1), 413 options(options) {} 414 415 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 416 PatternRewriter &rewriter) const override { 417 Optional<SmallVector<int64_t, 4>> targetShape = 418 getTargetShape(options, reductionOp); 419 if (!targetShape) 420 return failure(); 421 SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll(); 422 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 423 llvm::MapVector< 424 SmallVector<int64_t>, Value, 425 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 426 accCache; 427 // Compute shape ratio of 'shape' and 'sizes'. 428 int64_t sliceCount = computeMaxLinearIndex(ratio); 429 Location loc = reductionOp.getLoc(); 430 for (int64_t i = 0; i < sliceCount; i++) { 431 SmallVector<int64_t, 4> offsets = 432 getVectorOffset(originalSize, *targetShape, i); 433 434 SmallVector<int64_t, 4> operandStrides(offsets.size(), 1); 435 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 436 loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); 437 438 SmallVector<int64_t> dstShape; 439 SmallVector<int64_t> destOffset; 440 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 441 if (!reductionOp.isReducedDim(i)) { 442 destOffset.push_back(offsets[i]); 443 dstShape.push_back((*targetShape)[i]); 444 } 445 } 446 auto targetType = VectorType::get( 447 dstShape, reductionOp.getSourceVectorType().getElementType()); 448 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 449 slicedOperand, targetType); 450 Value result = newOp->getResult(0); 451 // Save the accumulated value until all the loops are unrolled since 452 // reduction loop keeps updating the accumulator. 453 auto accIt = accCache.find(destOffset); 454 if (accIt != accCache.end()) 455 result = makeArithReduction(rewriter, loc, reductionOp.getKind(), 456 result, accIt->second); 457 accCache[destOffset] = result; 458 } 459 // Assemble back the accumulator into a single vector. 460 Value result = rewriter.create<arith::ConstantOp>( 461 loc, reductionOp.getDestType(), 462 rewriter.getZeroAttr(reductionOp.getDestType())); 463 for (const auto &it : accCache) { 464 SmallVector<int64_t> dstStrides(it.first.size(), 1); 465 result = rewriter.create<vector::InsertStridedSliceOp>( 466 loc, it.second, result, it.first, dstStrides); 467 } 468 rewriter.replaceOp(reductionOp, result); 469 return success(); 470 } 471 472 private: 473 vector::UnrollVectorOptions options; 474 }; 475 476 struct UnrollElementwisePattern : public RewritePattern { 477 UnrollElementwisePattern(MLIRContext *context, 478 const vector::UnrollVectorOptions &options) 479 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), 480 options(options) {} 481 LogicalResult matchAndRewrite(Operation *op, 482 PatternRewriter &rewriter) const override { 483 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 484 return failure(); 485 auto targetShape = getTargetShape(options, op); 486 if (!targetShape) 487 return failure(); 488 auto dstVecType = op->getResult(0).getType().cast<VectorType>(); 489 SmallVector<int64_t, 4> originalSize = 490 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 491 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 492 int64_t sliceCount = computeMaxLinearIndex(ratio); 493 Location loc = op->getLoc(); 494 // Prepare the result vector. 495 Value result = rewriter.create<arith::ConstantOp>( 496 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 497 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 498 VectorType newVecType = 499 VectorType::get(*targetShape, dstVecType.getElementType()); 500 for (int64_t i = 0; i < sliceCount; i++) { 501 SmallVector<int64_t, 4> offsets = 502 getVectorOffset(originalSize, *targetShape, i); 503 SmallVector<Value, 4> extractOperands; 504 for (OpOperand &operand : op->getOpOperands()) { 505 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 506 if (!vecType) { 507 extractOperands.push_back(operand.get()); 508 continue; 509 } 510 extractOperands.push_back( 511 rewriter.create<vector::ExtractStridedSliceOp>( 512 loc, operand.get(), offsets, *targetShape, strides)); 513 } 514 Operation *newOp = cloneOpWithOperandsAndTypes( 515 rewriter, loc, op, extractOperands, newVecType); 516 result = rewriter.create<vector::InsertStridedSliceOp>( 517 loc, newOp->getResult(0), result, offsets, strides); 518 } 519 rewriter.replaceOp(op, result); 520 return success(); 521 } 522 523 private: 524 vector::UnrollVectorOptions options; 525 }; 526 527 /// Canonicalize an extract_map using the result of a pointwise operation. 528 /// Transforms: 529 /// %v = arith.addf %a, %b : vector32xf32> 530 /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> 531 /// to: 532 /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 533 /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> 534 /// %dv = arith.addf %da, %db : vector<1xf32> 535 struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 536 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 537 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 538 PatternRewriter &rewriter) const override { 539 Operation *definedOp = extract.getVector().getDefiningOp(); 540 if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || 541 definedOp->getNumResults() != 1) 542 return failure(); 543 Location loc = extract.getLoc(); 544 SmallVector<Value, 4> extractOperands; 545 for (OpOperand &operand : definedOp->getOpOperands()) { 546 auto vecType = operand.get().getType().template dyn_cast<VectorType>(); 547 if (!vecType) { 548 extractOperands.push_back(operand.get()); 549 continue; 550 } 551 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 552 loc, 553 VectorType::get(extract.getResultType().getShape(), 554 vecType.getElementType()), 555 operand.get(), extract.getIds())); 556 } 557 Operation *newOp = cloneOpWithOperandsAndTypes( 558 rewriter, loc, definedOp, extractOperands, extract.getResultType()); 559 rewriter.replaceOp(extract, newOp->getResult(0)); 560 return success(); 561 } 562 }; 563 564 /// Canonicalize an extract_map using the result of a contract operation. 565 /// This propagate the extract_map to operands. 566 struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> { 567 using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern; 568 LogicalResult matchAndRewrite(vector::ExtractMapOp extract, 569 PatternRewriter &rewriter) const override { 570 Operation *definedOp = extract.getVector().getDefiningOp(); 571 auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp); 572 if (!contract) 573 return failure(); 574 Location loc = contract.getLoc(); 575 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 576 AffineMap affineMap = contract.getIndexingMaps()[accIndex]; 577 // Create a map of the dimensions distributed based on the acc affine map. 578 // Only parallel dimensions are being distributed, reduction dimensions are 579 // untouched. 580 DenseMap<int64_t, int64_t> map; 581 for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) 582 map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); 583 SmallVector<Value, 4> extractOperands; 584 for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) { 585 // For each operands calculate the new vector type after distribution. 586 Value operand = contract->getOperand(it.index()); 587 auto vecType = operand.getType().cast<VectorType>(); 588 SmallVector<int64_t> operandShape(vecType.getShape().begin(), 589 vecType.getShape().end()); 590 for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { 591 unsigned dim = it.value().getDimPosition(i); 592 auto distributedDim = map.find(dim); 593 // If the dimension is not in the map it means it is a reduction and 594 // doesn't get distributed. 595 if (distributedDim == map.end()) 596 continue; 597 operandShape[i] = distributedDim->second; 598 } 599 VectorType newVecType = 600 VectorType::get(operandShape, vecType.getElementType()); 601 extractOperands.push_back(rewriter.create<vector::ExtractMapOp>( 602 loc, newVecType, operand, extract.getIds())); 603 } 604 Operation *newOp = 605 cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, 606 extract.getResult().getType()); 607 rewriter.replaceOp(extract, newOp->getResult(0)); 608 return success(); 609 } 610 }; 611 612 /// Converts TransferRead op used by ExtractMap op into a smaller dimension 613 /// TransferRead. 614 /// Example: 615 /// ``` 616 /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: 617 /// memref<64x64x64xf32>, vector<64x4x32xf32> 618 /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> 619 /// ``` 620 /// to: 621 /// ``` 622 /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) 623 /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : 624 /// memref<64x64x64xf32>, vector<2x4x1xf32> 625 /// ``` 626 struct TransferReadExtractPattern 627 : public OpRewritePattern<vector::TransferReadOp> { 628 TransferReadExtractPattern(MLIRContext *context) 629 : OpRewritePattern<vector::TransferReadOp>(context) {} 630 LogicalResult matchAndRewrite(vector::TransferReadOp read, 631 PatternRewriter &rewriter) const override { 632 // TODO: support 0-d corner case. 633 if (read.getTransferRank() == 0) 634 return failure(); 635 636 if (!read.getResult().hasOneUse()) 637 return failure(); 638 auto extract = 639 dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin()); 640 if (!extract) 641 return failure(); 642 if (read.getMask()) 643 return failure(); 644 645 SmallVector<Value, 4> indices(read.getIndices().begin(), 646 read.getIndices().end()); 647 AffineMap indexMap = extract.map().compose(read.getPermutationMap()); 648 unsigned idCount = 0; 649 ImplicitLocOpBuilder lb(read.getLoc(), rewriter); 650 for (auto it : 651 llvm::zip(indexMap.getResults(), extract.map().getResults())) { 652 AffineExpr d0, d1; 653 bindDims(read.getContext(), d0, d1); 654 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 655 if (!indexExpr) 656 continue; 657 unsigned indexPos = indexExpr.getPosition(); 658 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 659 auto scale = getAffineConstantExpr( 660 extract.getResultType().getDimSize(vectorPos), read.getContext()); 661 indices[indexPos] = makeComposedAffineApply( 662 rewriter, read.getLoc(), d0 + scale * d1, 663 {indices[indexPos], extract.getIds()[idCount++]}); 664 } 665 Value newRead = lb.create<vector::TransferReadOp>( 666 extract.getType(), read.getSource(), indices, 667 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 668 read.getInBoundsAttr()); 669 Value dest = lb.create<arith::ConstantOp>( 670 read.getType(), rewriter.getZeroAttr(read.getType())); 671 newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds()); 672 rewriter.replaceOp(read, newRead); 673 return success(); 674 } 675 }; 676 677 struct TransferWriteInsertPattern 678 : public OpRewritePattern<vector::TransferWriteOp> { 679 TransferWriteInsertPattern(MLIRContext *context) 680 : OpRewritePattern<vector::TransferWriteOp>(context) {} 681 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 682 PatternRewriter &rewriter) const override { 683 // TODO: support 0-d corner case. 684 if (write.getTransferRank() == 0) 685 return failure(); 686 687 auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>(); 688 if (!insert) 689 return failure(); 690 if (write.getMask()) 691 return failure(); 692 SmallVector<Value, 4> indices(write.getIndices().begin(), 693 write.getIndices().end()); 694 AffineMap indexMap = insert.map().compose(write.getPermutationMap()); 695 unsigned idCount = 0; 696 Location loc = write.getLoc(); 697 for (auto it : 698 llvm::zip(indexMap.getResults(), insert.map().getResults())) { 699 AffineExpr d0, d1; 700 bindDims(write.getContext(), d0, d1); 701 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 702 if (!indexExpr) 703 continue; 704 unsigned indexPos = indexExpr.getPosition(); 705 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 706 auto scale = getAffineConstantExpr( 707 insert.getSourceVectorType().getDimSize(vectorPos), 708 write.getContext()); 709 indices[indexPos] = makeComposedAffineApply( 710 rewriter, loc, d0 + scale * d1, 711 {indices[indexPos], insert.getIds()[idCount++]}); 712 } 713 rewriter.create<vector::TransferWriteOp>( 714 loc, insert.getVector(), write.getSource(), indices, 715 write.getPermutationMapAttr(), write.getInBoundsAttr()); 716 rewriter.eraseOp(write); 717 return success(); 718 } 719 }; 720 721 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 722 UnrollReductionPattern(MLIRContext *context, 723 const vector::UnrollVectorOptions &options) 724 : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1), 725 options(options) {} 726 727 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 728 PatternRewriter &rewriter) const override { 729 Optional<SmallVector<int64_t, 4>> targetShape = 730 getTargetShape(options, reductionOp); 731 if (!targetShape) 732 return failure(); 733 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 734 int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; 735 736 // Create unrolled vector reduction. 737 Location loc = reductionOp.getLoc(); 738 Value accumulator = nullptr; 739 for (int64_t i = 0; i < ratio; ++i) { 740 SmallVector<int64_t> offsets = 741 getVectorOffset(originalSize, *targetShape, i); 742 SmallVector<int64_t> strides(offsets.size(), 1); 743 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 744 loc, reductionOp.getVector(), offsets, *targetShape, strides); 745 Operation *newOp = cloneOpWithOperandsAndTypes( 746 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 747 Value result = newOp->getResult(0); 748 749 if (!accumulator) { 750 // This is the first reduction. 751 accumulator = result; 752 } else { 753 // On subsequent reduction, combine with the accumulator. 754 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 755 accumulator, result); 756 } 757 } 758 759 rewriter.replaceOp(reductionOp, accumulator); 760 return success(); 761 } 762 763 private: 764 const vector::UnrollVectorOptions options; 765 }; 766 767 struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> { 768 UnrollTranposePattern(MLIRContext *context, 769 const vector::UnrollVectorOptions &options) 770 : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1), 771 options(options) {} 772 LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp, 773 PatternRewriter &rewriter) const override { 774 if (tranposeOp.getResultType().getRank() == 0) 775 return failure(); 776 auto targetShape = getTargetShape(options, tranposeOp); 777 if (!targetShape) 778 return failure(); 779 auto originalVectorType = tranposeOp.getResultType(); 780 SmallVector<int64_t, 4> strides(targetShape->size(), 1); 781 Location loc = tranposeOp.getLoc(); 782 ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 783 SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape); 784 int64_t sliceCount = computeMaxLinearIndex(ratio); 785 // Prepare the result vector; 786 Value result = rewriter.create<arith::ConstantOp>( 787 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 788 SmallVector<int64_t> permutation; 789 tranposeOp.getTransp(permutation); 790 for (int64_t i = 0; i < sliceCount; i++) { 791 SmallVector<int64_t, 4> elementOffsets = 792 getVectorOffset(originalSize, *targetShape, i); 793 SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size()); 794 SmallVector<int64_t, 4> permutedShape(elementOffsets.size()); 795 // Compute the source offsets and shape. 796 for (auto &indices : llvm::enumerate(permutation)) { 797 permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 798 permutedShape[indices.value()] = (*targetShape)[indices.index()]; 799 } 800 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 801 loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides); 802 Value tranposedSlice = 803 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 804 result = rewriter.create<vector::InsertStridedSliceOp>( 805 loc, tranposedSlice, result, elementOffsets, strides); 806 } 807 rewriter.replaceOp(tranposeOp, result); 808 return success(); 809 } 810 811 private: 812 vector::UnrollVectorOptions options; 813 }; 814 815 } // namespace 816 817 void mlir::vector::populateVectorUnrollPatterns( 818 RewritePatternSet &patterns, const UnrollVectorOptions &options) { 819 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 820 UnrollContractionPattern, UnrollElementwisePattern, 821 UnrollReductionPattern, UnrollMultiReductionPattern, 822 UnrollTranposePattern>(patterns.getContext(), options); 823 } 824 825 void mlir::vector::populatePropagateVectorDistributionPatterns( 826 RewritePatternSet &patterns) { 827 patterns.add<PointwiseExtractPattern, ContractExtractPattern, 828 TransferReadExtractPattern, TransferWriteInsertPattern>( 829 patterns.getContext()); 830 } 831