199ef9eebSMatthias Springer //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements patterns to do vector unrolling and vector distribution.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer 
1399ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1499ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1699ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
1799ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
189f122152SChristopher Bate #include "mlir/Support/MathExtras.h"
1999ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
209f122152SChristopher Bate #include "llvm/ADT/STLExtras.h"
219f122152SChristopher Bate #include <numeric>
2299ef9eebSMatthias Springer 
2399ef9eebSMatthias Springer #define DEBUG_TYPE "vector-unrolling"
2499ef9eebSMatthias Springer 
2599ef9eebSMatthias Springer using namespace mlir;
2699ef9eebSMatthias Springer using namespace mlir::vector;
2799ef9eebSMatthias Springer 
2899ef9eebSMatthias Springer /// During unrolling from `originalShape` to `targetShape` return the offset for
2999ef9eebSMatthias Springer /// the slice `index`.
getVectorOffset(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,int64_t index)3099ef9eebSMatthias Springer static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
3199ef9eebSMatthias Springer                                                ArrayRef<int64_t> targetShape,
3299ef9eebSMatthias Springer                                                int64_t index) {
3399ef9eebSMatthias Springer   SmallVector<int64_t, 4> dstSliceStrides =
3499ef9eebSMatthias Springer       computeStrides(originalShape, targetShape);
3599ef9eebSMatthias Springer   SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
3699ef9eebSMatthias Springer   SmallVector<int64_t, 4> elementOffsets =
3799ef9eebSMatthias Springer       computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
3899ef9eebSMatthias Springer   return elementOffsets;
3999ef9eebSMatthias Springer }
4099ef9eebSMatthias Springer 
419f122152SChristopher Bate /// A functor that accomplishes the same thing as `getVectorOffset` but allows
429f122152SChristopher Bate /// for reordering the traversal of the dimensions. The order of traversal is
439f122152SChristopher Bate /// given in "for loop order" (outer to inner).
449f122152SChristopher Bate namespace {
459f122152SChristopher Bate class DecomposeShapeIterator {
469f122152SChristopher Bate private:
479f122152SChristopher Bate   SmallVector<int64_t, 4> vectorShape;
489f122152SChristopher Bate   SmallVector<int64_t> loopOrder;
499f122152SChristopher Bate   SmallVector<int64_t> sliceStrides;
509f122152SChristopher Bate   int64_t maxIndexVal{1};
519f122152SChristopher Bate 
529f122152SChristopher Bate public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,ArrayRef<int64_t> loopOrder)539f122152SChristopher Bate   DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
549f122152SChristopher Bate                          ArrayRef<int64_t> targetShape,
559f122152SChristopher Bate                          ArrayRef<int64_t> loopOrder)
569f122152SChristopher Bate       : vectorShape(targetShape.begin(), targetShape.end()),
579f122152SChristopher Bate         loopOrder(loopOrder.begin(), loopOrder.end()),
589f122152SChristopher Bate         sliceStrides(originalShape.size()) {
599f122152SChristopher Bate     assert(originalShape.size() == targetShape.size());
609f122152SChristopher Bate     assert(loopOrder.size() == targetShape.size());
619f122152SChristopher Bate 
629f122152SChristopher Bate     // Compute the count for each dimension.
639f122152SChristopher Bate     SmallVector<int64_t> sliceDimCounts(originalShape.size());
649f122152SChristopher Bate     for (unsigned r = 0; r < originalShape.size(); ++r) {
659f122152SChristopher Bate       sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
669f122152SChristopher Bate       maxIndexVal *= sliceDimCounts[r];
679f122152SChristopher Bate     }
689f122152SChristopher Bate 
699f122152SChristopher Bate     // Reversing "loop order" gives dimensions from fastest varying to slowest
709f122152SChristopher Bate     // varying (smallest stride to largest stride).
719f122152SChristopher Bate     int64_t accum = 1;
729f122152SChristopher Bate     for (auto idx : llvm::reverse(loopOrder)) {
739f122152SChristopher Bate       sliceStrides[idx] = accum;
749f122152SChristopher Bate       accum *= sliceDimCounts[idx];
759f122152SChristopher Bate     }
769f122152SChristopher Bate   }
779f122152SChristopher Bate 
789f122152SChristopher Bate   // Turn the linear index into a d-tuple based on units of vectors of size
799f122152SChristopher Bate   // `vectorShape`. The linear index is assumed to represent traversal of the
809f122152SChristopher Bate   // dimensions based on `order`.
delinearize(int64_t index) const819f122152SChristopher Bate   SmallVector<int64_t> delinearize(int64_t index) const {
829f122152SChristopher Bate     // Traverse in for loop order (largest stride to smallest stride).
839f122152SChristopher Bate     SmallVector<int64_t> vectorOffsets(sliceStrides.size());
849f122152SChristopher Bate     for (auto idx : loopOrder) {
859f122152SChristopher Bate       vectorOffsets[idx] = index / sliceStrides[idx];
869f122152SChristopher Bate       index %= sliceStrides[idx];
879f122152SChristopher Bate     }
889f122152SChristopher Bate     return vectorOffsets;
899f122152SChristopher Bate   }
909f122152SChristopher Bate 
maxIndex() const919f122152SChristopher Bate   int64_t maxIndex() const { return maxIndexVal; }
929f122152SChristopher Bate 
939f122152SChristopher Bate   /// Return the offset within d-tuple based on the ordering given by
949f122152SChristopher Bate   /// `loopOrder`.
getVectorOffset(int64_t index) const959f122152SChristopher Bate   SmallVector<int64_t> getVectorOffset(int64_t index) const {
969f122152SChristopher Bate     SmallVector<int64_t> vectorOffsets = delinearize(index);
979f122152SChristopher Bate     SmallVector<int64_t> elementOffsets =
989f122152SChristopher Bate         computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
999f122152SChristopher Bate     return elementOffsets;
1009f122152SChristopher Bate   }
1019f122152SChristopher Bate };
1029f122152SChristopher Bate } // namespace
1039f122152SChristopher Bate 
10499ef9eebSMatthias Springer /// Compute the indices of the slice `index` for a tranfer op.
sliceTransferIndices(ArrayRef<int64_t> elementOffsets,ArrayRef<Value> indices,AffineMap permutationMap,Location loc,OpBuilder & builder)1059f122152SChristopher Bate static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
1069f122152SChristopher Bate                                                ArrayRef<Value> indices,
1079f122152SChristopher Bate                                                AffineMap permutationMap,
1089f122152SChristopher Bate                                                Location loc,
10999ef9eebSMatthias Springer                                                OpBuilder &builder) {
11099ef9eebSMatthias Springer   MLIRContext *ctx = builder.getContext();
11199ef9eebSMatthias Springer   auto isBroadcast = [](AffineExpr expr) {
11299ef9eebSMatthias Springer     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
11399ef9eebSMatthias Springer       return constExpr.getValue() == 0;
11499ef9eebSMatthias Springer     return false;
11599ef9eebSMatthias Springer   };
11699ef9eebSMatthias Springer   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
11799ef9eebSMatthias Springer   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
11899ef9eebSMatthias Springer   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
11999ef9eebSMatthias Springer     if (isBroadcast(dim.value()))
12099ef9eebSMatthias Springer       continue;
12199ef9eebSMatthias Springer     unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
12299ef9eebSMatthias Springer     auto expr = getAffineDimExpr(0, builder.getContext()) +
12399ef9eebSMatthias Springer                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
12499ef9eebSMatthias Springer     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
12599ef9eebSMatthias Springer     slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
12699ef9eebSMatthias Springer   }
12799ef9eebSMatthias Springer   return slicedIndices;
12899ef9eebSMatthias Springer }
12999ef9eebSMatthias Springer 
13099ef9eebSMatthias Springer // Clones `op` into a new operations that takes `operands` and returns
13199ef9eebSMatthias Springer // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)13299ef9eebSMatthias Springer static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
13399ef9eebSMatthias Springer                                               Operation *op,
13499ef9eebSMatthias Springer                                               ArrayRef<Value> operands,
13599ef9eebSMatthias Springer                                               ArrayRef<Type> resultTypes) {
13614ecafd0SChia-hung Duan   return builder.create(loc, op->getName().getIdentifier(), operands,
13714ecafd0SChia-hung Duan                         resultTypes, op->getAttrs());
13899ef9eebSMatthias Springer }
13999ef9eebSMatthias Springer 
14099ef9eebSMatthias Springer /// Return the target shape for unrolling for the given `op`. Return llvm::None
14199ef9eebSMatthias Springer /// if the op shouldn't be or cannot be unrolled.
14299ef9eebSMatthias Springer static Optional<SmallVector<int64_t, 4>>
getTargetShape(const vector::UnrollVectorOptions & options,Operation * op)14399ef9eebSMatthias Springer getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
14499ef9eebSMatthias Springer   if (options.filterConstraint && failed(options.filterConstraint(op)))
14599ef9eebSMatthias Springer     return llvm::None;
14699ef9eebSMatthias Springer   assert(options.nativeShape &&
14799ef9eebSMatthias Springer          "vector unrolling expects the native shape or native"
14899ef9eebSMatthias Springer          "shape call back function to be set");
14999ef9eebSMatthias Springer   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
15099ef9eebSMatthias Springer   if (!unrollableVectorOp)
15199ef9eebSMatthias Springer     return llvm::None;
15299ef9eebSMatthias Springer   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
15399ef9eebSMatthias Springer   if (!maybeUnrollShape)
15499ef9eebSMatthias Springer     return llvm::None;
15599ef9eebSMatthias Springer   Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
15699ef9eebSMatthias Springer   if (!targetShape)
15799ef9eebSMatthias Springer     return llvm::None;
15899ef9eebSMatthias Springer   auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
15999ef9eebSMatthias Springer   if (!maybeShapeRatio ||
16099ef9eebSMatthias Springer       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
16199ef9eebSMatthias Springer     return llvm::None;
16299ef9eebSMatthias Springer   return targetShape;
16399ef9eebSMatthias Springer }
16499ef9eebSMatthias Springer 
1659f122152SChristopher Bate static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops,Operation * op,const vector::UnrollVectorOptions & options)1669f122152SChristopher Bate getUnrollOrder(unsigned numLoops, Operation *op,
1679f122152SChristopher Bate                const vector::UnrollVectorOptions &options) {
1689f122152SChristopher Bate   SmallVector<int64_t> loopOrder =
1699f122152SChristopher Bate       llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
1709f122152SChristopher Bate   if (options.traversalOrderCallback != nullptr) {
1719f122152SChristopher Bate     Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172037f0995SKazu Hirata     if (order) {
1739f122152SChristopher Bate       loopOrder = std::move(*order);
1749f122152SChristopher Bate     }
1759f122152SChristopher Bate   }
1769f122152SChristopher Bate   return loopOrder;
1779f122152SChristopher Bate }
1789f122152SChristopher Bate 
17999ef9eebSMatthias Springer namespace {
18099ef9eebSMatthias Springer 
18199ef9eebSMatthias Springer struct UnrollTransferReadPattern
18299ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern__anon77a0e1ec0411::UnrollTransferReadPattern18399ef9eebSMatthias Springer   UnrollTransferReadPattern(MLIRContext *context,
18499ef9eebSMatthias Springer                             const vector::UnrollVectorOptions &options)
18599ef9eebSMatthias Springer       : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
18699ef9eebSMatthias Springer         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferReadPattern18799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
18899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
18999ef9eebSMatthias Springer     // TODO: support 0-d corner case.
19099ef9eebSMatthias Springer     if (readOp.getTransferRank() == 0)
19199ef9eebSMatthias Springer       return failure();
1927c38fd60SJacques Pienaar     if (readOp.getMask())
19399ef9eebSMatthias Springer       return failure();
19499ef9eebSMatthias Springer     auto targetShape = getTargetShape(options, readOp);
19599ef9eebSMatthias Springer     if (!targetShape)
19699ef9eebSMatthias Springer       return failure();
19799ef9eebSMatthias Springer     auto sourceVectorType = readOp.getVectorType();
19899ef9eebSMatthias Springer     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
19999ef9eebSMatthias Springer     Location loc = readOp.getLoc();
20099ef9eebSMatthias Springer     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
2019f122152SChristopher Bate 
20299ef9eebSMatthias Springer     // Prepare the result vector;
20399ef9eebSMatthias Springer     Value result = rewriter.create<arith::ConstantOp>(
20499ef9eebSMatthias Springer         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
20599ef9eebSMatthias Springer     auto targetType =
20699ef9eebSMatthias Springer         VectorType::get(*targetShape, sourceVectorType.getElementType());
2077c38fd60SJacques Pienaar     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
2087c38fd60SJacques Pienaar                                           readOp.getIndices().end());
2099f122152SChristopher Bate 
2109f122152SChristopher Bate     SmallVector<int64_t> loopOrder =
2119f122152SChristopher Bate         getUnrollOrder(originalSize.size(), readOp, options);
2129f122152SChristopher Bate     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
2139f122152SChristopher Bate                                           loopOrder);
2149f122152SChristopher Bate     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
2159f122152SChristopher Bate       SmallVector<int64_t, 4> elementOffsets =
2169f122152SChristopher Bate           indexToOffsets.getVectorOffset(i);
21799ef9eebSMatthias Springer       SmallVector<Value, 4> indices =
2189f122152SChristopher Bate           sliceTransferIndices(elementOffsets, originalIndices,
2197c38fd60SJacques Pienaar                                readOp.getPermutationMap(), loc, rewriter);
22099ef9eebSMatthias Springer       auto slicedRead = rewriter.create<vector::TransferReadOp>(
2217c38fd60SJacques Pienaar           loc, targetType, readOp.getSource(), indices,
2227c38fd60SJacques Pienaar           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
2237c38fd60SJacques Pienaar           readOp.getInBoundsAttr());
22499ef9eebSMatthias Springer 
22599ef9eebSMatthias Springer       result = rewriter.create<vector::InsertStridedSliceOp>(
22699ef9eebSMatthias Springer           loc, slicedRead, result, elementOffsets, strides);
22799ef9eebSMatthias Springer     }
22899ef9eebSMatthias Springer     rewriter.replaceOp(readOp, result);
22999ef9eebSMatthias Springer     return success();
23099ef9eebSMatthias Springer   }
23199ef9eebSMatthias Springer 
23299ef9eebSMatthias Springer private:
23399ef9eebSMatthias Springer   vector::UnrollVectorOptions options;
23499ef9eebSMatthias Springer };
23599ef9eebSMatthias Springer 
23699ef9eebSMatthias Springer struct UnrollTransferWritePattern
23799ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern__anon77a0e1ec0411::UnrollTransferWritePattern23899ef9eebSMatthias Springer   UnrollTransferWritePattern(MLIRContext *context,
23999ef9eebSMatthias Springer                              const vector::UnrollVectorOptions &options)
24099ef9eebSMatthias Springer       : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
24199ef9eebSMatthias Springer         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferWritePattern24299ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
24399ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
24499ef9eebSMatthias Springer     // TODO: support 0-d corner case.
24599ef9eebSMatthias Springer     if (writeOp.getTransferRank() == 0)
24699ef9eebSMatthias Springer       return failure();
24799ef9eebSMatthias Springer 
2487c38fd60SJacques Pienaar     if (writeOp.getMask())
24999ef9eebSMatthias Springer       return failure();
25099ef9eebSMatthias Springer     auto targetShape = getTargetShape(options, writeOp);
25199ef9eebSMatthias Springer     if (!targetShape)
25299ef9eebSMatthias Springer       return failure();
25399ef9eebSMatthias Springer     auto sourceVectorType = writeOp.getVectorType();
25499ef9eebSMatthias Springer     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
25599ef9eebSMatthias Springer     Location loc = writeOp.getLoc();
25699ef9eebSMatthias Springer     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
2577c38fd60SJacques Pienaar     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
2587c38fd60SJacques Pienaar                                           writeOp.getIndices().end());
2599f122152SChristopher Bate 
2609f122152SChristopher Bate     SmallVector<int64_t> loopOrder =
2619f122152SChristopher Bate         getUnrollOrder(originalSize.size(), writeOp, options);
2629f122152SChristopher Bate     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
2639f122152SChristopher Bate                                           loopOrder);
26499ef9eebSMatthias Springer     Value resultTensor;
2659f122152SChristopher Bate     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
26699ef9eebSMatthias Springer       SmallVector<int64_t, 4> elementOffsets =
2679f122152SChristopher Bate           indexToOffsets.getVectorOffset(i);
26899ef9eebSMatthias Springer       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
2697c38fd60SJacques Pienaar           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
27099ef9eebSMatthias Springer       SmallVector<Value, 4> indices =
2719f122152SChristopher Bate           sliceTransferIndices(elementOffsets, originalIndices,
2727c38fd60SJacques Pienaar                                writeOp.getPermutationMap(), loc, rewriter);
27399ef9eebSMatthias Springer       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
2747c38fd60SJacques Pienaar           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
2757c38fd60SJacques Pienaar           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
27699ef9eebSMatthias Springer       // For the tensor case update the destination for the next transfer write.
27799ef9eebSMatthias Springer       if (!slicedWrite->getResults().empty())
27899ef9eebSMatthias Springer         resultTensor = slicedWrite->getResult(0);
27999ef9eebSMatthias Springer     }
28099ef9eebSMatthias Springer     if (resultTensor)
28199ef9eebSMatthias Springer       rewriter.replaceOp(writeOp, resultTensor);
28299ef9eebSMatthias Springer     else
28399ef9eebSMatthias Springer       rewriter.eraseOp(writeOp);
28499ef9eebSMatthias Springer     return success();
28599ef9eebSMatthias Springer   }
28699ef9eebSMatthias Springer 
28799ef9eebSMatthias Springer private:
28899ef9eebSMatthias Springer   vector::UnrollVectorOptions options;
28999ef9eebSMatthias Springer };
29099ef9eebSMatthias Springer 
29199ef9eebSMatthias Springer struct OffsetMapInfo {
getEmptyKey__anon77a0e1ec0411::OffsetMapInfo29299ef9eebSMatthias Springer   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
29399ef9eebSMatthias Springer 
getTombstoneKey__anon77a0e1ec0411::OffsetMapInfo29499ef9eebSMatthias Springer   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
29599ef9eebSMatthias Springer 
getHashValue__anon77a0e1ec0411::OffsetMapInfo29699ef9eebSMatthias Springer   static unsigned getHashValue(const SmallVector<int64_t> &v) {
297f69175b1SThomas Raoux     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
29899ef9eebSMatthias Springer   }
29999ef9eebSMatthias Springer 
isEqual__anon77a0e1ec0411::OffsetMapInfo30099ef9eebSMatthias Springer   static bool isEqual(const SmallVector<int64_t> &lhs,
30199ef9eebSMatthias Springer                       const SmallVector<int64_t> &rhs) {
30299ef9eebSMatthias Springer     return lhs == rhs;
30399ef9eebSMatthias Springer   }
30499ef9eebSMatthias Springer };
305f69175b1SThomas Raoux 
306f69175b1SThomas Raoux struct UnrollContractionPattern
307f69175b1SThomas Raoux     : public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern__anon77a0e1ec0411::UnrollContractionPattern30899ef9eebSMatthias Springer   UnrollContractionPattern(MLIRContext *context,
30999ef9eebSMatthias Springer                            const vector::UnrollVectorOptions &options)
31099ef9eebSMatthias Springer       : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
31199ef9eebSMatthias Springer         options(options) {}
31299ef9eebSMatthias Springer 
matchAndRewrite__anon77a0e1ec0411::UnrollContractionPattern31399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
31499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
31599ef9eebSMatthias Springer     auto targetShape = getTargetShape(options, contractOp);
31699ef9eebSMatthias Springer     if (!targetShape)
31799ef9eebSMatthias Springer       return failure();
31899ef9eebSMatthias Springer     auto dstVecType = contractOp.getResultType().cast<VectorType>();
31999ef9eebSMatthias Springer     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
32099ef9eebSMatthias Springer 
32199ef9eebSMatthias Springer     Location loc = contractOp.getLoc();
32299ef9eebSMatthias Springer     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323*d2c0572bSJacques Pienaar     AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
32499ef9eebSMatthias Springer     llvm::MapVector<
32599ef9eebSMatthias Springer         SmallVector<int64_t>, Value,
32699ef9eebSMatthias Springer         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
32799ef9eebSMatthias Springer         accCache;
3289f122152SChristopher Bate 
3299f122152SChristopher Bate     SmallVector<int64_t> loopOrder = getUnrollOrder(
3309f122152SChristopher Bate         contractOp.getIteratorTypes().size(), contractOp, options);
3319f122152SChristopher Bate     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
3329f122152SChristopher Bate                                           loopOrder);
3339f122152SChristopher Bate     const int64_t sliceCount = indexToOffsets.maxIndex();
33499ef9eebSMatthias Springer     for (int64_t i = 0; i < sliceCount; i++) {
3359f122152SChristopher Bate       SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
33699ef9eebSMatthias Springer       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
33799ef9eebSMatthias Springer 
33899ef9eebSMatthias Springer       // Helper to coompute the new shape of each operand and extract the slice.
33999ef9eebSMatthias Springer       auto extractOperand = [&](unsigned index, Value operand,
34099ef9eebSMatthias Springer                                 AffineMap permutationMap,
34199ef9eebSMatthias Springer                                 ArrayRef<int64_t> operandOffets) {
34299ef9eebSMatthias Springer         SmallVector<int64_t> operandShape = applyPermutationMap(
34399ef9eebSMatthias Springer             permutationMap, ArrayRef<int64_t>(*targetShape));
34499ef9eebSMatthias Springer         SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
34599ef9eebSMatthias Springer         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
34699ef9eebSMatthias Springer             loc, operand, operandOffets, operandShape, operandStrides);
34799ef9eebSMatthias Springer       };
34899ef9eebSMatthias Springer 
34999ef9eebSMatthias Springer       // Extract the new lhs operand.
350*d2c0572bSJacques Pienaar       AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
35199ef9eebSMatthias Springer       SmallVector<int64_t> lhsOffets =
35299ef9eebSMatthias Springer           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
3537c38fd60SJacques Pienaar       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
35499ef9eebSMatthias Springer       // If there is a mask associated to lhs, extract it as well.
35599ef9eebSMatthias Springer       if (slicesOperands.size() > 3)
3567c38fd60SJacques Pienaar         extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
3577c38fd60SJacques Pienaar                        lhsOffets);
35899ef9eebSMatthias Springer 
35999ef9eebSMatthias Springer       // Extract the new rhs operand.
360*d2c0572bSJacques Pienaar       AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
36199ef9eebSMatthias Springer       SmallVector<int64_t> rhsOffets =
36299ef9eebSMatthias Springer           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
3637c38fd60SJacques Pienaar       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
36499ef9eebSMatthias Springer       // If there is a mask associated to rhs, extract it as well.
36599ef9eebSMatthias Springer       if (slicesOperands.size() > 4)
3667c38fd60SJacques Pienaar         extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
3677c38fd60SJacques Pienaar                        rhsOffets);
36899ef9eebSMatthias Springer 
369*d2c0572bSJacques Pienaar       AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
37099ef9eebSMatthias Springer       SmallVector<int64_t> accOffets =
37199ef9eebSMatthias Springer           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
37299ef9eebSMatthias Springer       // If a version of the accumulator has already been computed, use it
37399ef9eebSMatthias Springer       // otherwise extract the first version from the original operand.
37499ef9eebSMatthias Springer       auto accIt = accCache.find(accOffets);
37599ef9eebSMatthias Springer       if (accIt != accCache.end())
37699ef9eebSMatthias Springer         slicesOperands[2] = accIt->second;
37799ef9eebSMatthias Springer       else
3787c38fd60SJacques Pienaar         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
37999ef9eebSMatthias Springer 
38099ef9eebSMatthias Springer       SmallVector<int64_t> dstShape =
38199ef9eebSMatthias Springer           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
38299ef9eebSMatthias Springer       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
38399ef9eebSMatthias Springer       Operation *newOp = cloneOpWithOperandsAndTypes(
38499ef9eebSMatthias Springer           rewriter, loc, contractOp, slicesOperands, targetType);
38599ef9eebSMatthias Springer 
38699ef9eebSMatthias Springer       SmallVector<int64_t> dstOffets =
38799ef9eebSMatthias Springer           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
38899ef9eebSMatthias Springer       // Save the accumulated value untill all the loops are unrolled since
38999ef9eebSMatthias Springer       // reduction loop keep updating the accumulator.
39099ef9eebSMatthias Springer       accCache[dstOffets] = newOp->getResult(0);
39199ef9eebSMatthias Springer     }
39299ef9eebSMatthias Springer     // Assemble back the accumulator into a single vector.
39399ef9eebSMatthias Springer     Value result = rewriter.create<arith::ConstantOp>(
39499ef9eebSMatthias Springer         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
39599ef9eebSMatthias Springer     for (const auto &it : accCache) {
39699ef9eebSMatthias Springer       SmallVector<int64_t> dstStrides(it.first.size(), 1);
39799ef9eebSMatthias Springer       result = rewriter.create<vector::InsertStridedSliceOp>(
39899ef9eebSMatthias Springer           loc, it.second, result, it.first, dstStrides);
39999ef9eebSMatthias Springer     }
40099ef9eebSMatthias Springer     rewriter.replaceOp(contractOp, result);
40199ef9eebSMatthias Springer     return success();
40299ef9eebSMatthias Springer   }
40399ef9eebSMatthias Springer 
40499ef9eebSMatthias Springer private:
40599ef9eebSMatthias Springer   vector::UnrollVectorOptions options;
40699ef9eebSMatthias Springer };
40799ef9eebSMatthias Springer 
408f69175b1SThomas Raoux struct UnrollMultiReductionPattern
409f69175b1SThomas Raoux     : public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern__anon77a0e1ec0411::UnrollMultiReductionPattern410f69175b1SThomas Raoux   UnrollMultiReductionPattern(MLIRContext *context,
411f69175b1SThomas Raoux                               const vector::UnrollVectorOptions &options)
412f69175b1SThomas Raoux       : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413f69175b1SThomas Raoux         options(options) {}
414f69175b1SThomas Raoux 
matchAndRewrite__anon77a0e1ec0411::UnrollMultiReductionPattern415f69175b1SThomas Raoux   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416f69175b1SThomas Raoux                                 PatternRewriter &rewriter) const override {
417f69175b1SThomas Raoux     Optional<SmallVector<int64_t, 4>> targetShape =
418f69175b1SThomas Raoux         getTargetShape(options, reductionOp);
419f69175b1SThomas Raoux     if (!targetShape)
420f69175b1SThomas Raoux       return failure();
421f69175b1SThomas Raoux     SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422f69175b1SThomas Raoux     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423f69175b1SThomas Raoux     llvm::MapVector<
424f69175b1SThomas Raoux         SmallVector<int64_t>, Value,
425f69175b1SThomas Raoux         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
426f69175b1SThomas Raoux         accCache;
427f69175b1SThomas Raoux     // Compute shape ratio of 'shape' and 'sizes'.
428f69175b1SThomas Raoux     int64_t sliceCount = computeMaxLinearIndex(ratio);
429f69175b1SThomas Raoux     Location loc = reductionOp.getLoc();
430f69175b1SThomas Raoux     for (int64_t i = 0; i < sliceCount; i++) {
431f69175b1SThomas Raoux       SmallVector<int64_t, 4> offsets =
432f69175b1SThomas Raoux           getVectorOffset(originalSize, *targetShape, i);
433f69175b1SThomas Raoux 
434051b36baSThomas Raoux       SmallVector<Value> operands;
435f69175b1SThomas Raoux       SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
436f69175b1SThomas Raoux       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
437051b36baSThomas Raoux           loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
438051b36baSThomas Raoux       operands.push_back(slicedOperand);
439f69175b1SThomas Raoux       SmallVector<int64_t> dstShape;
440f69175b1SThomas Raoux       SmallVector<int64_t> destOffset;
441f69175b1SThomas Raoux       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
442f69175b1SThomas Raoux         if (!reductionOp.isReducedDim(i)) {
443f69175b1SThomas Raoux           destOffset.push_back(offsets[i]);
444f69175b1SThomas Raoux           dstShape.push_back((*targetShape)[i]);
445f69175b1SThomas Raoux         }
446f69175b1SThomas Raoux       }
447051b36baSThomas Raoux       Value acc;
448051b36baSThomas Raoux       SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
449051b36baSThomas Raoux       // If a version of the accumulator has already been computed, use it
450051b36baSThomas Raoux       // otherwise extract the first version from the original operand.
451051b36baSThomas Raoux       auto accIt = accCache.find(destOffset);
452051b36baSThomas Raoux       if (accIt != accCache.end())
453051b36baSThomas Raoux         acc = accIt->second;
454051b36baSThomas Raoux       else
455051b36baSThomas Raoux         acc = rewriter.create<vector::ExtractStridedSliceOp>(
456051b36baSThomas Raoux             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
457051b36baSThomas Raoux       operands.push_back(acc);
458f69175b1SThomas Raoux       auto targetType = VectorType::get(
459f69175b1SThomas Raoux           dstShape, reductionOp.getSourceVectorType().getElementType());
460f69175b1SThomas Raoux       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
461051b36baSThomas Raoux                                                      operands, targetType);
462f69175b1SThomas Raoux       Value result = newOp->getResult(0);
463f69175b1SThomas Raoux       accCache[destOffset] = result;
464f69175b1SThomas Raoux     }
465f69175b1SThomas Raoux     // Assemble back the accumulator into a single vector.
466f69175b1SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
467f69175b1SThomas Raoux         loc, reductionOp.getDestType(),
468f69175b1SThomas Raoux         rewriter.getZeroAttr(reductionOp.getDestType()));
469f69175b1SThomas Raoux     for (const auto &it : accCache) {
470f69175b1SThomas Raoux       SmallVector<int64_t> dstStrides(it.first.size(), 1);
471f69175b1SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
472f69175b1SThomas Raoux           loc, it.second, result, it.first, dstStrides);
473f69175b1SThomas Raoux     }
474f69175b1SThomas Raoux     rewriter.replaceOp(reductionOp, result);
475f69175b1SThomas Raoux     return success();
476f69175b1SThomas Raoux   }
477f69175b1SThomas Raoux 
478f69175b1SThomas Raoux private:
479f69175b1SThomas Raoux   vector::UnrollVectorOptions options;
480f69175b1SThomas Raoux };
481f69175b1SThomas Raoux 
48299ef9eebSMatthias Springer struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern__anon77a0e1ec0411::UnrollElementwisePattern48399ef9eebSMatthias Springer   UnrollElementwisePattern(MLIRContext *context,
48499ef9eebSMatthias Springer                            const vector::UnrollVectorOptions &options)
48599ef9eebSMatthias Springer       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
48699ef9eebSMatthias Springer         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollElementwisePattern48799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(Operation *op,
48899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
48999ef9eebSMatthias Springer     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
49099ef9eebSMatthias Springer       return failure();
49199ef9eebSMatthias Springer     auto targetShape = getTargetShape(options, op);
49299ef9eebSMatthias Springer     if (!targetShape)
49399ef9eebSMatthias Springer       return failure();
49499ef9eebSMatthias Springer     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
49599ef9eebSMatthias Springer     SmallVector<int64_t, 4> originalSize =
49699ef9eebSMatthias Springer         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
49799ef9eebSMatthias Springer     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
49899ef9eebSMatthias Springer     int64_t sliceCount = computeMaxLinearIndex(ratio);
49999ef9eebSMatthias Springer     Location loc = op->getLoc();
50099ef9eebSMatthias Springer     // Prepare the result vector.
50199ef9eebSMatthias Springer     Value result = rewriter.create<arith::ConstantOp>(
50299ef9eebSMatthias Springer         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
50399ef9eebSMatthias Springer     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
50499ef9eebSMatthias Springer     VectorType newVecType =
50599ef9eebSMatthias Springer         VectorType::get(*targetShape, dstVecType.getElementType());
50699ef9eebSMatthias Springer     for (int64_t i = 0; i < sliceCount; i++) {
50799ef9eebSMatthias Springer       SmallVector<int64_t, 4> offsets =
50899ef9eebSMatthias Springer           getVectorOffset(originalSize, *targetShape, i);
50999ef9eebSMatthias Springer       SmallVector<Value, 4> extractOperands;
51099ef9eebSMatthias Springer       for (OpOperand &operand : op->getOpOperands()) {
51199ef9eebSMatthias Springer         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
51299ef9eebSMatthias Springer         if (!vecType) {
51399ef9eebSMatthias Springer           extractOperands.push_back(operand.get());
51499ef9eebSMatthias Springer           continue;
51599ef9eebSMatthias Springer         }
51699ef9eebSMatthias Springer         extractOperands.push_back(
51799ef9eebSMatthias Springer             rewriter.create<vector::ExtractStridedSliceOp>(
51899ef9eebSMatthias Springer                 loc, operand.get(), offsets, *targetShape, strides));
51999ef9eebSMatthias Springer       }
52099ef9eebSMatthias Springer       Operation *newOp = cloneOpWithOperandsAndTypes(
52199ef9eebSMatthias Springer           rewriter, loc, op, extractOperands, newVecType);
52299ef9eebSMatthias Springer       result = rewriter.create<vector::InsertStridedSliceOp>(
52399ef9eebSMatthias Springer           loc, newOp->getResult(0), result, offsets, strides);
52499ef9eebSMatthias Springer     }
52599ef9eebSMatthias Springer     rewriter.replaceOp(op, result);
52699ef9eebSMatthias Springer     return success();
52799ef9eebSMatthias Springer   }
52899ef9eebSMatthias Springer 
52999ef9eebSMatthias Springer private:
53099ef9eebSMatthias Springer   vector::UnrollVectorOptions options;
53199ef9eebSMatthias Springer };
53299ef9eebSMatthias Springer 
53399ef9eebSMatthias Springer /// Canonicalize an extract_map using the result of a pointwise operation.
53499ef9eebSMatthias Springer /// Transforms:
53599ef9eebSMatthias Springer /// %v = arith.addf %a, %b : vector32xf32>
53699ef9eebSMatthias Springer /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
53799ef9eebSMatthias Springer /// to:
53899ef9eebSMatthias Springer /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
53999ef9eebSMatthias Springer /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
54099ef9eebSMatthias Springer /// %dv = arith.addf %da, %db : vector<1xf32>
54199ef9eebSMatthias Springer struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
54299ef9eebSMatthias Springer   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::PointwiseExtractPattern54399ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
54499ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
5457c38fd60SJacques Pienaar     Operation *definedOp = extract.getVector().getDefiningOp();
54699ef9eebSMatthias Springer     if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
54799ef9eebSMatthias Springer         definedOp->getNumResults() != 1)
54899ef9eebSMatthias Springer       return failure();
54999ef9eebSMatthias Springer     Location loc = extract.getLoc();
55099ef9eebSMatthias Springer     SmallVector<Value, 4> extractOperands;
55199ef9eebSMatthias Springer     for (OpOperand &operand : definedOp->getOpOperands()) {
55299ef9eebSMatthias Springer       auto vecType = operand.get().getType().template dyn_cast<VectorType>();
55399ef9eebSMatthias Springer       if (!vecType) {
55499ef9eebSMatthias Springer         extractOperands.push_back(operand.get());
55599ef9eebSMatthias Springer         continue;
55699ef9eebSMatthias Springer       }
55799ef9eebSMatthias Springer       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
55899ef9eebSMatthias Springer           loc,
55999ef9eebSMatthias Springer           VectorType::get(extract.getResultType().getShape(),
56099ef9eebSMatthias Springer                           vecType.getElementType()),
5617c38fd60SJacques Pienaar           operand.get(), extract.getIds()));
56299ef9eebSMatthias Springer     }
56399ef9eebSMatthias Springer     Operation *newOp = cloneOpWithOperandsAndTypes(
56499ef9eebSMatthias Springer         rewriter, loc, definedOp, extractOperands, extract.getResultType());
56599ef9eebSMatthias Springer     rewriter.replaceOp(extract, newOp->getResult(0));
56699ef9eebSMatthias Springer     return success();
56799ef9eebSMatthias Springer   }
56899ef9eebSMatthias Springer };
56999ef9eebSMatthias Springer 
57099ef9eebSMatthias Springer /// Canonicalize an extract_map using the result of a contract operation.
57199ef9eebSMatthias Springer /// This propagate the extract_map to operands.
57299ef9eebSMatthias Springer struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
57399ef9eebSMatthias Springer   using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::ContractExtractPattern57499ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
57599ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
5767c38fd60SJacques Pienaar     Operation *definedOp = extract.getVector().getDefiningOp();
57799ef9eebSMatthias Springer     auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
57899ef9eebSMatthias Springer     if (!contract)
57999ef9eebSMatthias Springer       return failure();
58099ef9eebSMatthias Springer     Location loc = contract.getLoc();
58199ef9eebSMatthias Springer     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
582*d2c0572bSJacques Pienaar     AffineMap affineMap = contract.getIndexingMapsArray()[accIndex];
58399ef9eebSMatthias Springer     // Create a map of the dimensions distributed based on the acc affine map.
58499ef9eebSMatthias Springer     // Only parallel dimensions are being distributed, reduction dimensions are
58599ef9eebSMatthias Springer     // untouched.
58699ef9eebSMatthias Springer     DenseMap<int64_t, int64_t> map;
58799ef9eebSMatthias Springer     for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
58899ef9eebSMatthias Springer       map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
58999ef9eebSMatthias Springer     SmallVector<Value, 4> extractOperands;
590*d2c0572bSJacques Pienaar     for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) {
59199ef9eebSMatthias Springer       // For each operands calculate the new vector type after distribution.
59299ef9eebSMatthias Springer       Value operand = contract->getOperand(it.index());
59399ef9eebSMatthias Springer       auto vecType = operand.getType().cast<VectorType>();
59499ef9eebSMatthias Springer       SmallVector<int64_t> operandShape(vecType.getShape().begin(),
59599ef9eebSMatthias Springer                                         vecType.getShape().end());
59699ef9eebSMatthias Springer       for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
59799ef9eebSMatthias Springer         unsigned dim = it.value().getDimPosition(i);
59899ef9eebSMatthias Springer         auto distributedDim = map.find(dim);
59999ef9eebSMatthias Springer         // If the dimension is not in the map it means it is a reduction and
60099ef9eebSMatthias Springer         // doesn't get distributed.
60199ef9eebSMatthias Springer         if (distributedDim == map.end())
60299ef9eebSMatthias Springer           continue;
60399ef9eebSMatthias Springer         operandShape[i] = distributedDim->second;
60499ef9eebSMatthias Springer       }
60599ef9eebSMatthias Springer       VectorType newVecType =
60699ef9eebSMatthias Springer           VectorType::get(operandShape, vecType.getElementType());
60799ef9eebSMatthias Springer       extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
6087c38fd60SJacques Pienaar           loc, newVecType, operand, extract.getIds()));
60999ef9eebSMatthias Springer     }
61099ef9eebSMatthias Springer     Operation *newOp =
61199ef9eebSMatthias Springer         cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
61299ef9eebSMatthias Springer                                     extract.getResult().getType());
61399ef9eebSMatthias Springer     rewriter.replaceOp(extract, newOp->getResult(0));
61499ef9eebSMatthias Springer     return success();
61599ef9eebSMatthias Springer   }
61699ef9eebSMatthias Springer };
61799ef9eebSMatthias Springer 
61899ef9eebSMatthias Springer /// Converts TransferRead op used by ExtractMap op into a smaller dimension
61999ef9eebSMatthias Springer /// TransferRead.
62099ef9eebSMatthias Springer /// Example:
62199ef9eebSMatthias Springer /// ```
62299ef9eebSMatthias Springer /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
62399ef9eebSMatthias Springer ///   memref<64x64x64xf32>, vector<64x4x32xf32>
62499ef9eebSMatthias Springer /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
62599ef9eebSMatthias Springer /// ```
62699ef9eebSMatthias Springer /// to:
62799ef9eebSMatthias Springer /// ```
62899ef9eebSMatthias Springer /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
62999ef9eebSMatthias Springer /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
63099ef9eebSMatthias Springer ///   memref<64x64x64xf32>, vector<2x4x1xf32>
63199ef9eebSMatthias Springer /// ```
63299ef9eebSMatthias Springer struct TransferReadExtractPattern
63399ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern__anon77a0e1ec0411::TransferReadExtractPattern63499ef9eebSMatthias Springer   TransferReadExtractPattern(MLIRContext *context)
63599ef9eebSMatthias Springer       : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferReadExtractPattern63699ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferReadOp read,
63799ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
63899ef9eebSMatthias Springer     // TODO: support 0-d corner case.
63999ef9eebSMatthias Springer     if (read.getTransferRank() == 0)
64099ef9eebSMatthias Springer       return failure();
64199ef9eebSMatthias Springer 
64299ef9eebSMatthias Springer     if (!read.getResult().hasOneUse())
64399ef9eebSMatthias Springer       return failure();
64499ef9eebSMatthias Springer     auto extract =
64599ef9eebSMatthias Springer         dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
64699ef9eebSMatthias Springer     if (!extract)
64799ef9eebSMatthias Springer       return failure();
6487c38fd60SJacques Pienaar     if (read.getMask())
64999ef9eebSMatthias Springer       return failure();
65099ef9eebSMatthias Springer 
6517c38fd60SJacques Pienaar     SmallVector<Value, 4> indices(read.getIndices().begin(),
6527c38fd60SJacques Pienaar                                   read.getIndices().end());
6537c38fd60SJacques Pienaar     AffineMap indexMap = extract.map().compose(read.getPermutationMap());
65499ef9eebSMatthias Springer     unsigned idCount = 0;
65599ef9eebSMatthias Springer     ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
65699ef9eebSMatthias Springer     for (auto it :
65799ef9eebSMatthias Springer          llvm::zip(indexMap.getResults(), extract.map().getResults())) {
65899ef9eebSMatthias Springer       AffineExpr d0, d1;
65999ef9eebSMatthias Springer       bindDims(read.getContext(), d0, d1);
66099ef9eebSMatthias Springer       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
66199ef9eebSMatthias Springer       if (!indexExpr)
66299ef9eebSMatthias Springer         continue;
66399ef9eebSMatthias Springer       unsigned indexPos = indexExpr.getPosition();
66499ef9eebSMatthias Springer       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
66599ef9eebSMatthias Springer       auto scale = getAffineConstantExpr(
66699ef9eebSMatthias Springer           extract.getResultType().getDimSize(vectorPos), read.getContext());
66799ef9eebSMatthias Springer       indices[indexPos] = makeComposedAffineApply(
66899ef9eebSMatthias Springer           rewriter, read.getLoc(), d0 + scale * d1,
6697c38fd60SJacques Pienaar           {indices[indexPos], extract.getIds()[idCount++]});
67099ef9eebSMatthias Springer     }
67199ef9eebSMatthias Springer     Value newRead = lb.create<vector::TransferReadOp>(
6727c38fd60SJacques Pienaar         extract.getType(), read.getSource(), indices,
6737c38fd60SJacques Pienaar         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
6747c38fd60SJacques Pienaar         read.getInBoundsAttr());
67599ef9eebSMatthias Springer     Value dest = lb.create<arith::ConstantOp>(
67699ef9eebSMatthias Springer         read.getType(), rewriter.getZeroAttr(read.getType()));
6777c38fd60SJacques Pienaar     newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
67899ef9eebSMatthias Springer     rewriter.replaceOp(read, newRead);
67999ef9eebSMatthias Springer     return success();
68099ef9eebSMatthias Springer   }
68199ef9eebSMatthias Springer };
68299ef9eebSMatthias Springer 
68399ef9eebSMatthias Springer struct TransferWriteInsertPattern
68499ef9eebSMatthias Springer     : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPattern__anon77a0e1ec0411::TransferWriteInsertPattern68599ef9eebSMatthias Springer   TransferWriteInsertPattern(MLIRContext *context)
68699ef9eebSMatthias Springer       : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferWriteInsertPattern68799ef9eebSMatthias Springer   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
68899ef9eebSMatthias Springer                                 PatternRewriter &rewriter) const override {
68999ef9eebSMatthias Springer     // TODO: support 0-d corner case.
69099ef9eebSMatthias Springer     if (write.getTransferRank() == 0)
69199ef9eebSMatthias Springer       return failure();
69299ef9eebSMatthias Springer 
6937c38fd60SJacques Pienaar     auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
69499ef9eebSMatthias Springer     if (!insert)
69599ef9eebSMatthias Springer       return failure();
6967c38fd60SJacques Pienaar     if (write.getMask())
69799ef9eebSMatthias Springer       return failure();
6987c38fd60SJacques Pienaar     SmallVector<Value, 4> indices(write.getIndices().begin(),
6997c38fd60SJacques Pienaar                                   write.getIndices().end());
7007c38fd60SJacques Pienaar     AffineMap indexMap = insert.map().compose(write.getPermutationMap());
70199ef9eebSMatthias Springer     unsigned idCount = 0;
70299ef9eebSMatthias Springer     Location loc = write.getLoc();
70399ef9eebSMatthias Springer     for (auto it :
70499ef9eebSMatthias Springer          llvm::zip(indexMap.getResults(), insert.map().getResults())) {
70599ef9eebSMatthias Springer       AffineExpr d0, d1;
70699ef9eebSMatthias Springer       bindDims(write.getContext(), d0, d1);
70799ef9eebSMatthias Springer       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
70899ef9eebSMatthias Springer       if (!indexExpr)
70999ef9eebSMatthias Springer         continue;
71099ef9eebSMatthias Springer       unsigned indexPos = indexExpr.getPosition();
71199ef9eebSMatthias Springer       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
71299ef9eebSMatthias Springer       auto scale = getAffineConstantExpr(
71399ef9eebSMatthias Springer           insert.getSourceVectorType().getDimSize(vectorPos),
71499ef9eebSMatthias Springer           write.getContext());
7157c38fd60SJacques Pienaar       indices[indexPos] = makeComposedAffineApply(
7167c38fd60SJacques Pienaar           rewriter, loc, d0 + scale * d1,
7177c38fd60SJacques Pienaar           {indices[indexPos], insert.getIds()[idCount++]});
71899ef9eebSMatthias Springer     }
71999ef9eebSMatthias Springer     rewriter.create<vector::TransferWriteOp>(
7207c38fd60SJacques Pienaar         loc, insert.getVector(), write.getSource(), indices,
7217c38fd60SJacques Pienaar         write.getPermutationMapAttr(), write.getInBoundsAttr());
72299ef9eebSMatthias Springer     rewriter.eraseOp(write);
72399ef9eebSMatthias Springer     return success();
72499ef9eebSMatthias Springer   }
72599ef9eebSMatthias Springer };
72699ef9eebSMatthias Springer 
727de5022c7SMatthias Springer struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern__anon77a0e1ec0411::UnrollReductionPattern728de5022c7SMatthias Springer   UnrollReductionPattern(MLIRContext *context,
729de5022c7SMatthias Springer                          const vector::UnrollVectorOptions &options)
730de5022c7SMatthias Springer       : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
731de5022c7SMatthias Springer         options(options) {}
732de5022c7SMatthias Springer 
matchAndRewrite__anon77a0e1ec0411::UnrollReductionPattern733de5022c7SMatthias Springer   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
734de5022c7SMatthias Springer                                 PatternRewriter &rewriter) const override {
735de5022c7SMatthias Springer     Optional<SmallVector<int64_t, 4>> targetShape =
736de5022c7SMatthias Springer         getTargetShape(options, reductionOp);
737de5022c7SMatthias Springer     if (!targetShape)
738de5022c7SMatthias Springer       return failure();
739de5022c7SMatthias Springer     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
740de5022c7SMatthias Springer     int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
741de5022c7SMatthias Springer 
742de5022c7SMatthias Springer     // Create unrolled vector reduction.
743de5022c7SMatthias Springer     Location loc = reductionOp.getLoc();
744de5022c7SMatthias Springer     Value accumulator = nullptr;
745de5022c7SMatthias Springer     for (int64_t i = 0; i < ratio; ++i) {
746de5022c7SMatthias Springer       SmallVector<int64_t> offsets =
747de5022c7SMatthias Springer           getVectorOffset(originalSize, *targetShape, i);
748de5022c7SMatthias Springer       SmallVector<int64_t> strides(offsets.size(), 1);
749de5022c7SMatthias Springer       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
7507c38fd60SJacques Pienaar           loc, reductionOp.getVector(), offsets, *targetShape, strides);
751de5022c7SMatthias Springer       Operation *newOp = cloneOpWithOperandsAndTypes(
752de5022c7SMatthias Springer           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
753de5022c7SMatthias Springer       Value result = newOp->getResult(0);
754de5022c7SMatthias Springer 
755de5022c7SMatthias Springer       if (!accumulator) {
756de5022c7SMatthias Springer         // This is the first reduction.
757de5022c7SMatthias Springer         accumulator = result;
758de5022c7SMatthias Springer       } else {
759de5022c7SMatthias Springer         // On subsequent reduction, combine with the accumulator.
7607c38fd60SJacques Pienaar         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
761de5022c7SMatthias Springer                                          accumulator, result);
762de5022c7SMatthias Springer       }
763de5022c7SMatthias Springer     }
764de5022c7SMatthias Springer 
765de5022c7SMatthias Springer     rewriter.replaceOp(reductionOp, accumulator);
766de5022c7SMatthias Springer     return success();
767de5022c7SMatthias Springer   }
768de5022c7SMatthias Springer 
769de5022c7SMatthias Springer private:
770de5022c7SMatthias Springer   const vector::UnrollVectorOptions options;
771de5022c7SMatthias Springer };
772de5022c7SMatthias Springer 
7735b1b7108SThomas Raoux struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern__anon77a0e1ec0411::UnrollTranposePattern7745b1b7108SThomas Raoux   UnrollTranposePattern(MLIRContext *context,
7755b1b7108SThomas Raoux                         const vector::UnrollVectorOptions &options)
7765b1b7108SThomas Raoux       : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
7775b1b7108SThomas Raoux         options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTranposePattern7785b1b7108SThomas Raoux   LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
7795b1b7108SThomas Raoux                                 PatternRewriter &rewriter) const override {
7805b1b7108SThomas Raoux     if (tranposeOp.getResultType().getRank() == 0)
7815b1b7108SThomas Raoux       return failure();
7825b1b7108SThomas Raoux     auto targetShape = getTargetShape(options, tranposeOp);
7835b1b7108SThomas Raoux     if (!targetShape)
7845b1b7108SThomas Raoux       return failure();
7855b1b7108SThomas Raoux     auto originalVectorType = tranposeOp.getResultType();
7865b1b7108SThomas Raoux     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
7875b1b7108SThomas Raoux     Location loc = tranposeOp.getLoc();
7885b1b7108SThomas Raoux     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
7895b1b7108SThomas Raoux     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
7905b1b7108SThomas Raoux     int64_t sliceCount = computeMaxLinearIndex(ratio);
7915b1b7108SThomas Raoux     // Prepare the result vector;
7925b1b7108SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
7935b1b7108SThomas Raoux         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
7945b1b7108SThomas Raoux     SmallVector<int64_t> permutation;
7955b1b7108SThomas Raoux     tranposeOp.getTransp(permutation);
7965b1b7108SThomas Raoux     for (int64_t i = 0; i < sliceCount; i++) {
7975b1b7108SThomas Raoux       SmallVector<int64_t, 4> elementOffsets =
7985b1b7108SThomas Raoux           getVectorOffset(originalSize, *targetShape, i);
7995b1b7108SThomas Raoux       SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
8005b1b7108SThomas Raoux       SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
8015b1b7108SThomas Raoux       // Compute the source offsets and shape.
8025b1b7108SThomas Raoux       for (auto &indices : llvm::enumerate(permutation)) {
8035b1b7108SThomas Raoux         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
8045b1b7108SThomas Raoux         permutedShape[indices.value()] = (*targetShape)[indices.index()];
8055b1b7108SThomas Raoux       }
8065b1b7108SThomas Raoux       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
8075b1b7108SThomas Raoux           loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
8085b1b7108SThomas Raoux       Value tranposedSlice =
8095b1b7108SThomas Raoux           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
8105b1b7108SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
8115b1b7108SThomas Raoux           loc, tranposedSlice, result, elementOffsets, strides);
8125b1b7108SThomas Raoux     }
8135b1b7108SThomas Raoux     rewriter.replaceOp(tranposeOp, result);
8145b1b7108SThomas Raoux     return success();
8155b1b7108SThomas Raoux   }
8165b1b7108SThomas Raoux 
8175b1b7108SThomas Raoux private:
8185b1b7108SThomas Raoux   vector::UnrollVectorOptions options;
8195b1b7108SThomas Raoux };
8205b1b7108SThomas Raoux 
82199ef9eebSMatthias Springer } // namespace
82299ef9eebSMatthias Springer 
populateVectorUnrollPatterns(RewritePatternSet & patterns,const UnrollVectorOptions & options)82399ef9eebSMatthias Springer void mlir::vector::populateVectorUnrollPatterns(
82499ef9eebSMatthias Springer     RewritePatternSet &patterns, const UnrollVectorOptions &options) {
82599ef9eebSMatthias Springer   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
826f69175b1SThomas Raoux                UnrollContractionPattern, UnrollElementwisePattern,
8275b1b7108SThomas Raoux                UnrollReductionPattern, UnrollMultiReductionPattern,
8285b1b7108SThomas Raoux                UnrollTranposePattern>(patterns.getContext(), options);
82999ef9eebSMatthias Springer }
83099ef9eebSMatthias Springer 
populatePropagateVectorDistributionPatterns(RewritePatternSet & patterns)83199ef9eebSMatthias Springer void mlir::vector::populatePropagateVectorDistributionPatterns(
83299ef9eebSMatthias Springer     RewritePatternSet &patterns) {
83399ef9eebSMatthias Springer   patterns.add<PointwiseExtractPattern, ContractExtractPattern,
83499ef9eebSMatthias Springer                TransferReadExtractPattern, TransferWriteInsertPattern>(
83599ef9eebSMatthias Springer       patterns.getContext());
83699ef9eebSMatthias Springer }
837