199ef9eebSMatthias Springer //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements patterns to do vector unrolling and vector distribution.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer
1399ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1499ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1699ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
1799ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
189f122152SChristopher Bate #include "mlir/Support/MathExtras.h"
1999ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
209f122152SChristopher Bate #include "llvm/ADT/STLExtras.h"
219f122152SChristopher Bate #include <numeric>
2299ef9eebSMatthias Springer
2399ef9eebSMatthias Springer #define DEBUG_TYPE "vector-unrolling"
2499ef9eebSMatthias Springer
2599ef9eebSMatthias Springer using namespace mlir;
2699ef9eebSMatthias Springer using namespace mlir::vector;
2799ef9eebSMatthias Springer
2899ef9eebSMatthias Springer /// During unrolling from `originalShape` to `targetShape` return the offset for
2999ef9eebSMatthias Springer /// the slice `index`.
getVectorOffset(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,int64_t index)3099ef9eebSMatthias Springer static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
3199ef9eebSMatthias Springer ArrayRef<int64_t> targetShape,
3299ef9eebSMatthias Springer int64_t index) {
3399ef9eebSMatthias Springer SmallVector<int64_t, 4> dstSliceStrides =
3499ef9eebSMatthias Springer computeStrides(originalShape, targetShape);
3599ef9eebSMatthias Springer SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
3699ef9eebSMatthias Springer SmallVector<int64_t, 4> elementOffsets =
3799ef9eebSMatthias Springer computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
3899ef9eebSMatthias Springer return elementOffsets;
3999ef9eebSMatthias Springer }
4099ef9eebSMatthias Springer
419f122152SChristopher Bate /// A functor that accomplishes the same thing as `getVectorOffset` but allows
429f122152SChristopher Bate /// for reordering the traversal of the dimensions. The order of traversal is
439f122152SChristopher Bate /// given in "for loop order" (outer to inner).
449f122152SChristopher Bate namespace {
459f122152SChristopher Bate class DecomposeShapeIterator {
469f122152SChristopher Bate private:
479f122152SChristopher Bate SmallVector<int64_t, 4> vectorShape;
489f122152SChristopher Bate SmallVector<int64_t> loopOrder;
499f122152SChristopher Bate SmallVector<int64_t> sliceStrides;
509f122152SChristopher Bate int64_t maxIndexVal{1};
519f122152SChristopher Bate
529f122152SChristopher Bate public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,ArrayRef<int64_t> targetShape,ArrayRef<int64_t> loopOrder)539f122152SChristopher Bate DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
549f122152SChristopher Bate ArrayRef<int64_t> targetShape,
559f122152SChristopher Bate ArrayRef<int64_t> loopOrder)
569f122152SChristopher Bate : vectorShape(targetShape.begin(), targetShape.end()),
579f122152SChristopher Bate loopOrder(loopOrder.begin(), loopOrder.end()),
589f122152SChristopher Bate sliceStrides(originalShape.size()) {
599f122152SChristopher Bate assert(originalShape.size() == targetShape.size());
609f122152SChristopher Bate assert(loopOrder.size() == targetShape.size());
619f122152SChristopher Bate
629f122152SChristopher Bate // Compute the count for each dimension.
639f122152SChristopher Bate SmallVector<int64_t> sliceDimCounts(originalShape.size());
649f122152SChristopher Bate for (unsigned r = 0; r < originalShape.size(); ++r) {
659f122152SChristopher Bate sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
669f122152SChristopher Bate maxIndexVal *= sliceDimCounts[r];
679f122152SChristopher Bate }
689f122152SChristopher Bate
699f122152SChristopher Bate // Reversing "loop order" gives dimensions from fastest varying to slowest
709f122152SChristopher Bate // varying (smallest stride to largest stride).
719f122152SChristopher Bate int64_t accum = 1;
729f122152SChristopher Bate for (auto idx : llvm::reverse(loopOrder)) {
739f122152SChristopher Bate sliceStrides[idx] = accum;
749f122152SChristopher Bate accum *= sliceDimCounts[idx];
759f122152SChristopher Bate }
769f122152SChristopher Bate }
779f122152SChristopher Bate
789f122152SChristopher Bate // Turn the linear index into a d-tuple based on units of vectors of size
799f122152SChristopher Bate // `vectorShape`. The linear index is assumed to represent traversal of the
809f122152SChristopher Bate // dimensions based on `order`.
delinearize(int64_t index) const819f122152SChristopher Bate SmallVector<int64_t> delinearize(int64_t index) const {
829f122152SChristopher Bate // Traverse in for loop order (largest stride to smallest stride).
839f122152SChristopher Bate SmallVector<int64_t> vectorOffsets(sliceStrides.size());
849f122152SChristopher Bate for (auto idx : loopOrder) {
859f122152SChristopher Bate vectorOffsets[idx] = index / sliceStrides[idx];
869f122152SChristopher Bate index %= sliceStrides[idx];
879f122152SChristopher Bate }
889f122152SChristopher Bate return vectorOffsets;
899f122152SChristopher Bate }
909f122152SChristopher Bate
maxIndex() const919f122152SChristopher Bate int64_t maxIndex() const { return maxIndexVal; }
929f122152SChristopher Bate
939f122152SChristopher Bate /// Return the offset within d-tuple based on the ordering given by
949f122152SChristopher Bate /// `loopOrder`.
getVectorOffset(int64_t index) const959f122152SChristopher Bate SmallVector<int64_t> getVectorOffset(int64_t index) const {
969f122152SChristopher Bate SmallVector<int64_t> vectorOffsets = delinearize(index);
979f122152SChristopher Bate SmallVector<int64_t> elementOffsets =
989f122152SChristopher Bate computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
999f122152SChristopher Bate return elementOffsets;
1009f122152SChristopher Bate }
1019f122152SChristopher Bate };
1029f122152SChristopher Bate } // namespace
1039f122152SChristopher Bate
10499ef9eebSMatthias Springer /// Compute the indices of the slice `index` for a tranfer op.
sliceTransferIndices(ArrayRef<int64_t> elementOffsets,ArrayRef<Value> indices,AffineMap permutationMap,Location loc,OpBuilder & builder)1059f122152SChristopher Bate static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
1069f122152SChristopher Bate ArrayRef<Value> indices,
1079f122152SChristopher Bate AffineMap permutationMap,
1089f122152SChristopher Bate Location loc,
10999ef9eebSMatthias Springer OpBuilder &builder) {
11099ef9eebSMatthias Springer MLIRContext *ctx = builder.getContext();
11199ef9eebSMatthias Springer auto isBroadcast = [](AffineExpr expr) {
11299ef9eebSMatthias Springer if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
11399ef9eebSMatthias Springer return constExpr.getValue() == 0;
11499ef9eebSMatthias Springer return false;
11599ef9eebSMatthias Springer };
11699ef9eebSMatthias Springer // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
11799ef9eebSMatthias Springer SmallVector<Value> slicedIndices(indices.begin(), indices.end());
11899ef9eebSMatthias Springer for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
11999ef9eebSMatthias Springer if (isBroadcast(dim.value()))
12099ef9eebSMatthias Springer continue;
12199ef9eebSMatthias Springer unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
12299ef9eebSMatthias Springer auto expr = getAffineDimExpr(0, builder.getContext()) +
12399ef9eebSMatthias Springer getAffineConstantExpr(elementOffsets[dim.index()], ctx);
12499ef9eebSMatthias Springer auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
12599ef9eebSMatthias Springer slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
12699ef9eebSMatthias Springer }
12799ef9eebSMatthias Springer return slicedIndices;
12899ef9eebSMatthias Springer }
12999ef9eebSMatthias Springer
13099ef9eebSMatthias Springer // Clones `op` into a new operations that takes `operands` and returns
13199ef9eebSMatthias Springer // `resultTypes`.
cloneOpWithOperandsAndTypes(OpBuilder & builder,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)13299ef9eebSMatthias Springer static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
13399ef9eebSMatthias Springer Operation *op,
13499ef9eebSMatthias Springer ArrayRef<Value> operands,
13599ef9eebSMatthias Springer ArrayRef<Type> resultTypes) {
13614ecafd0SChia-hung Duan return builder.create(loc, op->getName().getIdentifier(), operands,
13714ecafd0SChia-hung Duan resultTypes, op->getAttrs());
13899ef9eebSMatthias Springer }
13999ef9eebSMatthias Springer
14099ef9eebSMatthias Springer /// Return the target shape for unrolling for the given `op`. Return llvm::None
14199ef9eebSMatthias Springer /// if the op shouldn't be or cannot be unrolled.
14299ef9eebSMatthias Springer static Optional<SmallVector<int64_t, 4>>
getTargetShape(const vector::UnrollVectorOptions & options,Operation * op)14399ef9eebSMatthias Springer getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
14499ef9eebSMatthias Springer if (options.filterConstraint && failed(options.filterConstraint(op)))
14599ef9eebSMatthias Springer return llvm::None;
14699ef9eebSMatthias Springer assert(options.nativeShape &&
14799ef9eebSMatthias Springer "vector unrolling expects the native shape or native"
14899ef9eebSMatthias Springer "shape call back function to be set");
14999ef9eebSMatthias Springer auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
15099ef9eebSMatthias Springer if (!unrollableVectorOp)
15199ef9eebSMatthias Springer return llvm::None;
15299ef9eebSMatthias Springer auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
15399ef9eebSMatthias Springer if (!maybeUnrollShape)
15499ef9eebSMatthias Springer return llvm::None;
15599ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
15699ef9eebSMatthias Springer if (!targetShape)
15799ef9eebSMatthias Springer return llvm::None;
15899ef9eebSMatthias Springer auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
15999ef9eebSMatthias Springer if (!maybeShapeRatio ||
16099ef9eebSMatthias Springer llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
16199ef9eebSMatthias Springer return llvm::None;
16299ef9eebSMatthias Springer return targetShape;
16399ef9eebSMatthias Springer }
16499ef9eebSMatthias Springer
1659f122152SChristopher Bate static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops,Operation * op,const vector::UnrollVectorOptions & options)1669f122152SChristopher Bate getUnrollOrder(unsigned numLoops, Operation *op,
1679f122152SChristopher Bate const vector::UnrollVectorOptions &options) {
1689f122152SChristopher Bate SmallVector<int64_t> loopOrder =
1699f122152SChristopher Bate llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
1709f122152SChristopher Bate if (options.traversalOrderCallback != nullptr) {
1719f122152SChristopher Bate Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
172037f0995SKazu Hirata if (order) {
1739f122152SChristopher Bate loopOrder = std::move(*order);
1749f122152SChristopher Bate }
1759f122152SChristopher Bate }
1769f122152SChristopher Bate return loopOrder;
1779f122152SChristopher Bate }
1789f122152SChristopher Bate
17999ef9eebSMatthias Springer namespace {
18099ef9eebSMatthias Springer
18199ef9eebSMatthias Springer struct UnrollTransferReadPattern
18299ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern__anon77a0e1ec0411::UnrollTransferReadPattern18399ef9eebSMatthias Springer UnrollTransferReadPattern(MLIRContext *context,
18499ef9eebSMatthias Springer const vector::UnrollVectorOptions &options)
18599ef9eebSMatthias Springer : OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
18699ef9eebSMatthias Springer options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferReadPattern18799ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
18899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
18999ef9eebSMatthias Springer // TODO: support 0-d corner case.
19099ef9eebSMatthias Springer if (readOp.getTransferRank() == 0)
19199ef9eebSMatthias Springer return failure();
1927c38fd60SJacques Pienaar if (readOp.getMask())
19399ef9eebSMatthias Springer return failure();
19499ef9eebSMatthias Springer auto targetShape = getTargetShape(options, readOp);
19599ef9eebSMatthias Springer if (!targetShape)
19699ef9eebSMatthias Springer return failure();
19799ef9eebSMatthias Springer auto sourceVectorType = readOp.getVectorType();
19899ef9eebSMatthias Springer SmallVector<int64_t, 4> strides(targetShape->size(), 1);
19999ef9eebSMatthias Springer Location loc = readOp.getLoc();
20099ef9eebSMatthias Springer ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
2019f122152SChristopher Bate
20299ef9eebSMatthias Springer // Prepare the result vector;
20399ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
20499ef9eebSMatthias Springer loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
20599ef9eebSMatthias Springer auto targetType =
20699ef9eebSMatthias Springer VectorType::get(*targetShape, sourceVectorType.getElementType());
2077c38fd60SJacques Pienaar SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
2087c38fd60SJacques Pienaar readOp.getIndices().end());
2099f122152SChristopher Bate
2109f122152SChristopher Bate SmallVector<int64_t> loopOrder =
2119f122152SChristopher Bate getUnrollOrder(originalSize.size(), readOp, options);
2129f122152SChristopher Bate DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
2139f122152SChristopher Bate loopOrder);
2149f122152SChristopher Bate for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
2159f122152SChristopher Bate SmallVector<int64_t, 4> elementOffsets =
2169f122152SChristopher Bate indexToOffsets.getVectorOffset(i);
21799ef9eebSMatthias Springer SmallVector<Value, 4> indices =
2189f122152SChristopher Bate sliceTransferIndices(elementOffsets, originalIndices,
2197c38fd60SJacques Pienaar readOp.getPermutationMap(), loc, rewriter);
22099ef9eebSMatthias Springer auto slicedRead = rewriter.create<vector::TransferReadOp>(
2217c38fd60SJacques Pienaar loc, targetType, readOp.getSource(), indices,
2227c38fd60SJacques Pienaar readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
2237c38fd60SJacques Pienaar readOp.getInBoundsAttr());
22499ef9eebSMatthias Springer
22599ef9eebSMatthias Springer result = rewriter.create<vector::InsertStridedSliceOp>(
22699ef9eebSMatthias Springer loc, slicedRead, result, elementOffsets, strides);
22799ef9eebSMatthias Springer }
22899ef9eebSMatthias Springer rewriter.replaceOp(readOp, result);
22999ef9eebSMatthias Springer return success();
23099ef9eebSMatthias Springer }
23199ef9eebSMatthias Springer
23299ef9eebSMatthias Springer private:
23399ef9eebSMatthias Springer vector::UnrollVectorOptions options;
23499ef9eebSMatthias Springer };
23599ef9eebSMatthias Springer
23699ef9eebSMatthias Springer struct UnrollTransferWritePattern
23799ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern__anon77a0e1ec0411::UnrollTransferWritePattern23899ef9eebSMatthias Springer UnrollTransferWritePattern(MLIRContext *context,
23999ef9eebSMatthias Springer const vector::UnrollVectorOptions &options)
24099ef9eebSMatthias Springer : OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
24199ef9eebSMatthias Springer options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTransferWritePattern24299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
24399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
24499ef9eebSMatthias Springer // TODO: support 0-d corner case.
24599ef9eebSMatthias Springer if (writeOp.getTransferRank() == 0)
24699ef9eebSMatthias Springer return failure();
24799ef9eebSMatthias Springer
2487c38fd60SJacques Pienaar if (writeOp.getMask())
24999ef9eebSMatthias Springer return failure();
25099ef9eebSMatthias Springer auto targetShape = getTargetShape(options, writeOp);
25199ef9eebSMatthias Springer if (!targetShape)
25299ef9eebSMatthias Springer return failure();
25399ef9eebSMatthias Springer auto sourceVectorType = writeOp.getVectorType();
25499ef9eebSMatthias Springer SmallVector<int64_t, 4> strides(targetShape->size(), 1);
25599ef9eebSMatthias Springer Location loc = writeOp.getLoc();
25699ef9eebSMatthias Springer ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
2577c38fd60SJacques Pienaar SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
2587c38fd60SJacques Pienaar writeOp.getIndices().end());
2599f122152SChristopher Bate
2609f122152SChristopher Bate SmallVector<int64_t> loopOrder =
2619f122152SChristopher Bate getUnrollOrder(originalSize.size(), writeOp, options);
2629f122152SChristopher Bate DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
2639f122152SChristopher Bate loopOrder);
26499ef9eebSMatthias Springer Value resultTensor;
2659f122152SChristopher Bate for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
26699ef9eebSMatthias Springer SmallVector<int64_t, 4> elementOffsets =
2679f122152SChristopher Bate indexToOffsets.getVectorOffset(i);
26899ef9eebSMatthias Springer Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
2697c38fd60SJacques Pienaar loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
27099ef9eebSMatthias Springer SmallVector<Value, 4> indices =
2719f122152SChristopher Bate sliceTransferIndices(elementOffsets, originalIndices,
2727c38fd60SJacques Pienaar writeOp.getPermutationMap(), loc, rewriter);
27399ef9eebSMatthias Springer Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
2747c38fd60SJacques Pienaar loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
2757c38fd60SJacques Pienaar indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
27699ef9eebSMatthias Springer // For the tensor case update the destination for the next transfer write.
27799ef9eebSMatthias Springer if (!slicedWrite->getResults().empty())
27899ef9eebSMatthias Springer resultTensor = slicedWrite->getResult(0);
27999ef9eebSMatthias Springer }
28099ef9eebSMatthias Springer if (resultTensor)
28199ef9eebSMatthias Springer rewriter.replaceOp(writeOp, resultTensor);
28299ef9eebSMatthias Springer else
28399ef9eebSMatthias Springer rewriter.eraseOp(writeOp);
28499ef9eebSMatthias Springer return success();
28599ef9eebSMatthias Springer }
28699ef9eebSMatthias Springer
28799ef9eebSMatthias Springer private:
28899ef9eebSMatthias Springer vector::UnrollVectorOptions options;
28999ef9eebSMatthias Springer };
29099ef9eebSMatthias Springer
29199ef9eebSMatthias Springer struct OffsetMapInfo {
getEmptyKey__anon77a0e1ec0411::OffsetMapInfo29299ef9eebSMatthias Springer static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
29399ef9eebSMatthias Springer
getTombstoneKey__anon77a0e1ec0411::OffsetMapInfo29499ef9eebSMatthias Springer static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
29599ef9eebSMatthias Springer
getHashValue__anon77a0e1ec0411::OffsetMapInfo29699ef9eebSMatthias Springer static unsigned getHashValue(const SmallVector<int64_t> &v) {
297f69175b1SThomas Raoux return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
29899ef9eebSMatthias Springer }
29999ef9eebSMatthias Springer
isEqual__anon77a0e1ec0411::OffsetMapInfo30099ef9eebSMatthias Springer static bool isEqual(const SmallVector<int64_t> &lhs,
30199ef9eebSMatthias Springer const SmallVector<int64_t> &rhs) {
30299ef9eebSMatthias Springer return lhs == rhs;
30399ef9eebSMatthias Springer }
30499ef9eebSMatthias Springer };
305f69175b1SThomas Raoux
306f69175b1SThomas Raoux struct UnrollContractionPattern
307f69175b1SThomas Raoux : public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern__anon77a0e1ec0411::UnrollContractionPattern30899ef9eebSMatthias Springer UnrollContractionPattern(MLIRContext *context,
30999ef9eebSMatthias Springer const vector::UnrollVectorOptions &options)
31099ef9eebSMatthias Springer : OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
31199ef9eebSMatthias Springer options(options) {}
31299ef9eebSMatthias Springer
matchAndRewrite__anon77a0e1ec0411::UnrollContractionPattern31399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
31499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
31599ef9eebSMatthias Springer auto targetShape = getTargetShape(options, contractOp);
31699ef9eebSMatthias Springer if (!targetShape)
31799ef9eebSMatthias Springer return failure();
31899ef9eebSMatthias Springer auto dstVecType = contractOp.getResultType().cast<VectorType>();
31999ef9eebSMatthias Springer SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
32099ef9eebSMatthias Springer
32199ef9eebSMatthias Springer Location loc = contractOp.getLoc();
32299ef9eebSMatthias Springer unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323*d2c0572bSJacques Pienaar AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
32499ef9eebSMatthias Springer llvm::MapVector<
32599ef9eebSMatthias Springer SmallVector<int64_t>, Value,
32699ef9eebSMatthias Springer llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
32799ef9eebSMatthias Springer accCache;
3289f122152SChristopher Bate
3299f122152SChristopher Bate SmallVector<int64_t> loopOrder = getUnrollOrder(
3309f122152SChristopher Bate contractOp.getIteratorTypes().size(), contractOp, options);
3319f122152SChristopher Bate DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
3329f122152SChristopher Bate loopOrder);
3339f122152SChristopher Bate const int64_t sliceCount = indexToOffsets.maxIndex();
33499ef9eebSMatthias Springer for (int64_t i = 0; i < sliceCount; i++) {
3359f122152SChristopher Bate SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
33699ef9eebSMatthias Springer SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
33799ef9eebSMatthias Springer
33899ef9eebSMatthias Springer // Helper to coompute the new shape of each operand and extract the slice.
33999ef9eebSMatthias Springer auto extractOperand = [&](unsigned index, Value operand,
34099ef9eebSMatthias Springer AffineMap permutationMap,
34199ef9eebSMatthias Springer ArrayRef<int64_t> operandOffets) {
34299ef9eebSMatthias Springer SmallVector<int64_t> operandShape = applyPermutationMap(
34399ef9eebSMatthias Springer permutationMap, ArrayRef<int64_t>(*targetShape));
34499ef9eebSMatthias Springer SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
34599ef9eebSMatthias Springer slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
34699ef9eebSMatthias Springer loc, operand, operandOffets, operandShape, operandStrides);
34799ef9eebSMatthias Springer };
34899ef9eebSMatthias Springer
34999ef9eebSMatthias Springer // Extract the new lhs operand.
350*d2c0572bSJacques Pienaar AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
35199ef9eebSMatthias Springer SmallVector<int64_t> lhsOffets =
35299ef9eebSMatthias Springer applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
3537c38fd60SJacques Pienaar extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
35499ef9eebSMatthias Springer // If there is a mask associated to lhs, extract it as well.
35599ef9eebSMatthias Springer if (slicesOperands.size() > 3)
3567c38fd60SJacques Pienaar extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
3577c38fd60SJacques Pienaar lhsOffets);
35899ef9eebSMatthias Springer
35999ef9eebSMatthias Springer // Extract the new rhs operand.
360*d2c0572bSJacques Pienaar AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
36199ef9eebSMatthias Springer SmallVector<int64_t> rhsOffets =
36299ef9eebSMatthias Springer applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
3637c38fd60SJacques Pienaar extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
36499ef9eebSMatthias Springer // If there is a mask associated to rhs, extract it as well.
36599ef9eebSMatthias Springer if (slicesOperands.size() > 4)
3667c38fd60SJacques Pienaar extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
3677c38fd60SJacques Pienaar rhsOffets);
36899ef9eebSMatthias Springer
369*d2c0572bSJacques Pienaar AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
37099ef9eebSMatthias Springer SmallVector<int64_t> accOffets =
37199ef9eebSMatthias Springer applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
37299ef9eebSMatthias Springer // If a version of the accumulator has already been computed, use it
37399ef9eebSMatthias Springer // otherwise extract the first version from the original operand.
37499ef9eebSMatthias Springer auto accIt = accCache.find(accOffets);
37599ef9eebSMatthias Springer if (accIt != accCache.end())
37699ef9eebSMatthias Springer slicesOperands[2] = accIt->second;
37799ef9eebSMatthias Springer else
3787c38fd60SJacques Pienaar extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
37999ef9eebSMatthias Springer
38099ef9eebSMatthias Springer SmallVector<int64_t> dstShape =
38199ef9eebSMatthias Springer applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
38299ef9eebSMatthias Springer auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
38399ef9eebSMatthias Springer Operation *newOp = cloneOpWithOperandsAndTypes(
38499ef9eebSMatthias Springer rewriter, loc, contractOp, slicesOperands, targetType);
38599ef9eebSMatthias Springer
38699ef9eebSMatthias Springer SmallVector<int64_t> dstOffets =
38799ef9eebSMatthias Springer applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
38899ef9eebSMatthias Springer // Save the accumulated value untill all the loops are unrolled since
38999ef9eebSMatthias Springer // reduction loop keep updating the accumulator.
39099ef9eebSMatthias Springer accCache[dstOffets] = newOp->getResult(0);
39199ef9eebSMatthias Springer }
39299ef9eebSMatthias Springer // Assemble back the accumulator into a single vector.
39399ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
39499ef9eebSMatthias Springer loc, dstVecType, rewriter.getZeroAttr(dstVecType));
39599ef9eebSMatthias Springer for (const auto &it : accCache) {
39699ef9eebSMatthias Springer SmallVector<int64_t> dstStrides(it.first.size(), 1);
39799ef9eebSMatthias Springer result = rewriter.create<vector::InsertStridedSliceOp>(
39899ef9eebSMatthias Springer loc, it.second, result, it.first, dstStrides);
39999ef9eebSMatthias Springer }
40099ef9eebSMatthias Springer rewriter.replaceOp(contractOp, result);
40199ef9eebSMatthias Springer return success();
40299ef9eebSMatthias Springer }
40399ef9eebSMatthias Springer
40499ef9eebSMatthias Springer private:
40599ef9eebSMatthias Springer vector::UnrollVectorOptions options;
40699ef9eebSMatthias Springer };
40799ef9eebSMatthias Springer
408f69175b1SThomas Raoux struct UnrollMultiReductionPattern
409f69175b1SThomas Raoux : public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern__anon77a0e1ec0411::UnrollMultiReductionPattern410f69175b1SThomas Raoux UnrollMultiReductionPattern(MLIRContext *context,
411f69175b1SThomas Raoux const vector::UnrollVectorOptions &options)
412f69175b1SThomas Raoux : OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
413f69175b1SThomas Raoux options(options) {}
414f69175b1SThomas Raoux
matchAndRewrite__anon77a0e1ec0411::UnrollMultiReductionPattern415f69175b1SThomas Raoux LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
416f69175b1SThomas Raoux PatternRewriter &rewriter) const override {
417f69175b1SThomas Raoux Optional<SmallVector<int64_t, 4>> targetShape =
418f69175b1SThomas Raoux getTargetShape(options, reductionOp);
419f69175b1SThomas Raoux if (!targetShape)
420f69175b1SThomas Raoux return failure();
421f69175b1SThomas Raoux SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
422f69175b1SThomas Raoux SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
423f69175b1SThomas Raoux llvm::MapVector<
424f69175b1SThomas Raoux SmallVector<int64_t>, Value,
425f69175b1SThomas Raoux llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
426f69175b1SThomas Raoux accCache;
427f69175b1SThomas Raoux // Compute shape ratio of 'shape' and 'sizes'.
428f69175b1SThomas Raoux int64_t sliceCount = computeMaxLinearIndex(ratio);
429f69175b1SThomas Raoux Location loc = reductionOp.getLoc();
430f69175b1SThomas Raoux for (int64_t i = 0; i < sliceCount; i++) {
431f69175b1SThomas Raoux SmallVector<int64_t, 4> offsets =
432f69175b1SThomas Raoux getVectorOffset(originalSize, *targetShape, i);
433f69175b1SThomas Raoux
434051b36baSThomas Raoux SmallVector<Value> operands;
435f69175b1SThomas Raoux SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
436f69175b1SThomas Raoux Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
437051b36baSThomas Raoux loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
438051b36baSThomas Raoux operands.push_back(slicedOperand);
439f69175b1SThomas Raoux SmallVector<int64_t> dstShape;
440f69175b1SThomas Raoux SmallVector<int64_t> destOffset;
441f69175b1SThomas Raoux for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
442f69175b1SThomas Raoux if (!reductionOp.isReducedDim(i)) {
443f69175b1SThomas Raoux destOffset.push_back(offsets[i]);
444f69175b1SThomas Raoux dstShape.push_back((*targetShape)[i]);
445f69175b1SThomas Raoux }
446f69175b1SThomas Raoux }
447051b36baSThomas Raoux Value acc;
448051b36baSThomas Raoux SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
449051b36baSThomas Raoux // If a version of the accumulator has already been computed, use it
450051b36baSThomas Raoux // otherwise extract the first version from the original operand.
451051b36baSThomas Raoux auto accIt = accCache.find(destOffset);
452051b36baSThomas Raoux if (accIt != accCache.end())
453051b36baSThomas Raoux acc = accIt->second;
454051b36baSThomas Raoux else
455051b36baSThomas Raoux acc = rewriter.create<vector::ExtractStridedSliceOp>(
456051b36baSThomas Raoux loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
457051b36baSThomas Raoux operands.push_back(acc);
458f69175b1SThomas Raoux auto targetType = VectorType::get(
459f69175b1SThomas Raoux dstShape, reductionOp.getSourceVectorType().getElementType());
460f69175b1SThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
461051b36baSThomas Raoux operands, targetType);
462f69175b1SThomas Raoux Value result = newOp->getResult(0);
463f69175b1SThomas Raoux accCache[destOffset] = result;
464f69175b1SThomas Raoux }
465f69175b1SThomas Raoux // Assemble back the accumulator into a single vector.
466f69175b1SThomas Raoux Value result = rewriter.create<arith::ConstantOp>(
467f69175b1SThomas Raoux loc, reductionOp.getDestType(),
468f69175b1SThomas Raoux rewriter.getZeroAttr(reductionOp.getDestType()));
469f69175b1SThomas Raoux for (const auto &it : accCache) {
470f69175b1SThomas Raoux SmallVector<int64_t> dstStrides(it.first.size(), 1);
471f69175b1SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>(
472f69175b1SThomas Raoux loc, it.second, result, it.first, dstStrides);
473f69175b1SThomas Raoux }
474f69175b1SThomas Raoux rewriter.replaceOp(reductionOp, result);
475f69175b1SThomas Raoux return success();
476f69175b1SThomas Raoux }
477f69175b1SThomas Raoux
478f69175b1SThomas Raoux private:
479f69175b1SThomas Raoux vector::UnrollVectorOptions options;
480f69175b1SThomas Raoux };
481f69175b1SThomas Raoux
48299ef9eebSMatthias Springer struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern__anon77a0e1ec0411::UnrollElementwisePattern48399ef9eebSMatthias Springer UnrollElementwisePattern(MLIRContext *context,
48499ef9eebSMatthias Springer const vector::UnrollVectorOptions &options)
48599ef9eebSMatthias Springer : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
48699ef9eebSMatthias Springer options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollElementwisePattern48799ef9eebSMatthias Springer LogicalResult matchAndRewrite(Operation *op,
48899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
48999ef9eebSMatthias Springer if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
49099ef9eebSMatthias Springer return failure();
49199ef9eebSMatthias Springer auto targetShape = getTargetShape(options, op);
49299ef9eebSMatthias Springer if (!targetShape)
49399ef9eebSMatthias Springer return failure();
49499ef9eebSMatthias Springer auto dstVecType = op->getResult(0).getType().cast<VectorType>();
49599ef9eebSMatthias Springer SmallVector<int64_t, 4> originalSize =
49699ef9eebSMatthias Springer *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
49799ef9eebSMatthias Springer SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
49899ef9eebSMatthias Springer int64_t sliceCount = computeMaxLinearIndex(ratio);
49999ef9eebSMatthias Springer Location loc = op->getLoc();
50099ef9eebSMatthias Springer // Prepare the result vector.
50199ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
50299ef9eebSMatthias Springer loc, dstVecType, rewriter.getZeroAttr(dstVecType));
50399ef9eebSMatthias Springer SmallVector<int64_t, 4> strides(targetShape->size(), 1);
50499ef9eebSMatthias Springer VectorType newVecType =
50599ef9eebSMatthias Springer VectorType::get(*targetShape, dstVecType.getElementType());
50699ef9eebSMatthias Springer for (int64_t i = 0; i < sliceCount; i++) {
50799ef9eebSMatthias Springer SmallVector<int64_t, 4> offsets =
50899ef9eebSMatthias Springer getVectorOffset(originalSize, *targetShape, i);
50999ef9eebSMatthias Springer SmallVector<Value, 4> extractOperands;
51099ef9eebSMatthias Springer for (OpOperand &operand : op->getOpOperands()) {
51199ef9eebSMatthias Springer auto vecType = operand.get().getType().template dyn_cast<VectorType>();
51299ef9eebSMatthias Springer if (!vecType) {
51399ef9eebSMatthias Springer extractOperands.push_back(operand.get());
51499ef9eebSMatthias Springer continue;
51599ef9eebSMatthias Springer }
51699ef9eebSMatthias Springer extractOperands.push_back(
51799ef9eebSMatthias Springer rewriter.create<vector::ExtractStridedSliceOp>(
51899ef9eebSMatthias Springer loc, operand.get(), offsets, *targetShape, strides));
51999ef9eebSMatthias Springer }
52099ef9eebSMatthias Springer Operation *newOp = cloneOpWithOperandsAndTypes(
52199ef9eebSMatthias Springer rewriter, loc, op, extractOperands, newVecType);
52299ef9eebSMatthias Springer result = rewriter.create<vector::InsertStridedSliceOp>(
52399ef9eebSMatthias Springer loc, newOp->getResult(0), result, offsets, strides);
52499ef9eebSMatthias Springer }
52599ef9eebSMatthias Springer rewriter.replaceOp(op, result);
52699ef9eebSMatthias Springer return success();
52799ef9eebSMatthias Springer }
52899ef9eebSMatthias Springer
52999ef9eebSMatthias Springer private:
53099ef9eebSMatthias Springer vector::UnrollVectorOptions options;
53199ef9eebSMatthias Springer };
53299ef9eebSMatthias Springer
53399ef9eebSMatthias Springer /// Canonicalize an extract_map using the result of a pointwise operation.
53499ef9eebSMatthias Springer /// Transforms:
53599ef9eebSMatthias Springer /// %v = arith.addf %a, %b : vector32xf32>
53699ef9eebSMatthias Springer /// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
53799ef9eebSMatthias Springer /// to:
53899ef9eebSMatthias Springer /// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
53999ef9eebSMatthias Springer /// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
54099ef9eebSMatthias Springer /// %dv = arith.addf %da, %db : vector<1xf32>
54199ef9eebSMatthias Springer struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
54299ef9eebSMatthias Springer using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::PointwiseExtractPattern54399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
54499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
5457c38fd60SJacques Pienaar Operation *definedOp = extract.getVector().getDefiningOp();
54699ef9eebSMatthias Springer if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
54799ef9eebSMatthias Springer definedOp->getNumResults() != 1)
54899ef9eebSMatthias Springer return failure();
54999ef9eebSMatthias Springer Location loc = extract.getLoc();
55099ef9eebSMatthias Springer SmallVector<Value, 4> extractOperands;
55199ef9eebSMatthias Springer for (OpOperand &operand : definedOp->getOpOperands()) {
55299ef9eebSMatthias Springer auto vecType = operand.get().getType().template dyn_cast<VectorType>();
55399ef9eebSMatthias Springer if (!vecType) {
55499ef9eebSMatthias Springer extractOperands.push_back(operand.get());
55599ef9eebSMatthias Springer continue;
55699ef9eebSMatthias Springer }
55799ef9eebSMatthias Springer extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
55899ef9eebSMatthias Springer loc,
55999ef9eebSMatthias Springer VectorType::get(extract.getResultType().getShape(),
56099ef9eebSMatthias Springer vecType.getElementType()),
5617c38fd60SJacques Pienaar operand.get(), extract.getIds()));
56299ef9eebSMatthias Springer }
56399ef9eebSMatthias Springer Operation *newOp = cloneOpWithOperandsAndTypes(
56499ef9eebSMatthias Springer rewriter, loc, definedOp, extractOperands, extract.getResultType());
56599ef9eebSMatthias Springer rewriter.replaceOp(extract, newOp->getResult(0));
56699ef9eebSMatthias Springer return success();
56799ef9eebSMatthias Springer }
56899ef9eebSMatthias Springer };
56999ef9eebSMatthias Springer
57099ef9eebSMatthias Springer /// Canonicalize an extract_map using the result of a contract operation.
57199ef9eebSMatthias Springer /// This propagate the extract_map to operands.
57299ef9eebSMatthias Springer struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
57399ef9eebSMatthias Springer using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
matchAndRewrite__anon77a0e1ec0411::ContractExtractPattern57499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
57599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
5767c38fd60SJacques Pienaar Operation *definedOp = extract.getVector().getDefiningOp();
57799ef9eebSMatthias Springer auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
57899ef9eebSMatthias Springer if (!contract)
57999ef9eebSMatthias Springer return failure();
58099ef9eebSMatthias Springer Location loc = contract.getLoc();
58199ef9eebSMatthias Springer unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
582*d2c0572bSJacques Pienaar AffineMap affineMap = contract.getIndexingMapsArray()[accIndex];
58399ef9eebSMatthias Springer // Create a map of the dimensions distributed based on the acc affine map.
58499ef9eebSMatthias Springer // Only parallel dimensions are being distributed, reduction dimensions are
58599ef9eebSMatthias Springer // untouched.
58699ef9eebSMatthias Springer DenseMap<int64_t, int64_t> map;
58799ef9eebSMatthias Springer for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
58899ef9eebSMatthias Springer map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
58999ef9eebSMatthias Springer SmallVector<Value, 4> extractOperands;
590*d2c0572bSJacques Pienaar for (const auto &it : llvm::enumerate(contract.getIndexingMapsArray())) {
59199ef9eebSMatthias Springer // For each operands calculate the new vector type after distribution.
59299ef9eebSMatthias Springer Value operand = contract->getOperand(it.index());
59399ef9eebSMatthias Springer auto vecType = operand.getType().cast<VectorType>();
59499ef9eebSMatthias Springer SmallVector<int64_t> operandShape(vecType.getShape().begin(),
59599ef9eebSMatthias Springer vecType.getShape().end());
59699ef9eebSMatthias Springer for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
59799ef9eebSMatthias Springer unsigned dim = it.value().getDimPosition(i);
59899ef9eebSMatthias Springer auto distributedDim = map.find(dim);
59999ef9eebSMatthias Springer // If the dimension is not in the map it means it is a reduction and
60099ef9eebSMatthias Springer // doesn't get distributed.
60199ef9eebSMatthias Springer if (distributedDim == map.end())
60299ef9eebSMatthias Springer continue;
60399ef9eebSMatthias Springer operandShape[i] = distributedDim->second;
60499ef9eebSMatthias Springer }
60599ef9eebSMatthias Springer VectorType newVecType =
60699ef9eebSMatthias Springer VectorType::get(operandShape, vecType.getElementType());
60799ef9eebSMatthias Springer extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
6087c38fd60SJacques Pienaar loc, newVecType, operand, extract.getIds()));
60999ef9eebSMatthias Springer }
61099ef9eebSMatthias Springer Operation *newOp =
61199ef9eebSMatthias Springer cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
61299ef9eebSMatthias Springer extract.getResult().getType());
61399ef9eebSMatthias Springer rewriter.replaceOp(extract, newOp->getResult(0));
61499ef9eebSMatthias Springer return success();
61599ef9eebSMatthias Springer }
61699ef9eebSMatthias Springer };
61799ef9eebSMatthias Springer
61899ef9eebSMatthias Springer /// Converts TransferRead op used by ExtractMap op into a smaller dimension
61999ef9eebSMatthias Springer /// TransferRead.
62099ef9eebSMatthias Springer /// Example:
62199ef9eebSMatthias Springer /// ```
62299ef9eebSMatthias Springer /// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
62399ef9eebSMatthias Springer /// memref<64x64x64xf32>, vector<64x4x32xf32>
62499ef9eebSMatthias Springer /// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
62599ef9eebSMatthias Springer /// ```
62699ef9eebSMatthias Springer /// to:
62799ef9eebSMatthias Springer /// ```
62899ef9eebSMatthias Springer /// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
62999ef9eebSMatthias Springer /// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
63099ef9eebSMatthias Springer /// memref<64x64x64xf32>, vector<2x4x1xf32>
63199ef9eebSMatthias Springer /// ```
63299ef9eebSMatthias Springer struct TransferReadExtractPattern
63399ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern__anon77a0e1ec0411::TransferReadExtractPattern63499ef9eebSMatthias Springer TransferReadExtractPattern(MLIRContext *context)
63599ef9eebSMatthias Springer : OpRewritePattern<vector::TransferReadOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferReadExtractPattern63699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp read,
63799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
63899ef9eebSMatthias Springer // TODO: support 0-d corner case.
63999ef9eebSMatthias Springer if (read.getTransferRank() == 0)
64099ef9eebSMatthias Springer return failure();
64199ef9eebSMatthias Springer
64299ef9eebSMatthias Springer if (!read.getResult().hasOneUse())
64399ef9eebSMatthias Springer return failure();
64499ef9eebSMatthias Springer auto extract =
64599ef9eebSMatthias Springer dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
64699ef9eebSMatthias Springer if (!extract)
64799ef9eebSMatthias Springer return failure();
6487c38fd60SJacques Pienaar if (read.getMask())
64999ef9eebSMatthias Springer return failure();
65099ef9eebSMatthias Springer
6517c38fd60SJacques Pienaar SmallVector<Value, 4> indices(read.getIndices().begin(),
6527c38fd60SJacques Pienaar read.getIndices().end());
6537c38fd60SJacques Pienaar AffineMap indexMap = extract.map().compose(read.getPermutationMap());
65499ef9eebSMatthias Springer unsigned idCount = 0;
65599ef9eebSMatthias Springer ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
65699ef9eebSMatthias Springer for (auto it :
65799ef9eebSMatthias Springer llvm::zip(indexMap.getResults(), extract.map().getResults())) {
65899ef9eebSMatthias Springer AffineExpr d0, d1;
65999ef9eebSMatthias Springer bindDims(read.getContext(), d0, d1);
66099ef9eebSMatthias Springer auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
66199ef9eebSMatthias Springer if (!indexExpr)
66299ef9eebSMatthias Springer continue;
66399ef9eebSMatthias Springer unsigned indexPos = indexExpr.getPosition();
66499ef9eebSMatthias Springer unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
66599ef9eebSMatthias Springer auto scale = getAffineConstantExpr(
66699ef9eebSMatthias Springer extract.getResultType().getDimSize(vectorPos), read.getContext());
66799ef9eebSMatthias Springer indices[indexPos] = makeComposedAffineApply(
66899ef9eebSMatthias Springer rewriter, read.getLoc(), d0 + scale * d1,
6697c38fd60SJacques Pienaar {indices[indexPos], extract.getIds()[idCount++]});
67099ef9eebSMatthias Springer }
67199ef9eebSMatthias Springer Value newRead = lb.create<vector::TransferReadOp>(
6727c38fd60SJacques Pienaar extract.getType(), read.getSource(), indices,
6737c38fd60SJacques Pienaar read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
6747c38fd60SJacques Pienaar read.getInBoundsAttr());
67599ef9eebSMatthias Springer Value dest = lb.create<arith::ConstantOp>(
67699ef9eebSMatthias Springer read.getType(), rewriter.getZeroAttr(read.getType()));
6777c38fd60SJacques Pienaar newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
67899ef9eebSMatthias Springer rewriter.replaceOp(read, newRead);
67999ef9eebSMatthias Springer return success();
68099ef9eebSMatthias Springer }
68199ef9eebSMatthias Springer };
68299ef9eebSMatthias Springer
68399ef9eebSMatthias Springer struct TransferWriteInsertPattern
68499ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPattern__anon77a0e1ec0411::TransferWriteInsertPattern68599ef9eebSMatthias Springer TransferWriteInsertPattern(MLIRContext *context)
68699ef9eebSMatthias Springer : OpRewritePattern<vector::TransferWriteOp>(context) {}
matchAndRewrite__anon77a0e1ec0411::TransferWriteInsertPattern68799ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp write,
68899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
68999ef9eebSMatthias Springer // TODO: support 0-d corner case.
69099ef9eebSMatthias Springer if (write.getTransferRank() == 0)
69199ef9eebSMatthias Springer return failure();
69299ef9eebSMatthias Springer
6937c38fd60SJacques Pienaar auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
69499ef9eebSMatthias Springer if (!insert)
69599ef9eebSMatthias Springer return failure();
6967c38fd60SJacques Pienaar if (write.getMask())
69799ef9eebSMatthias Springer return failure();
6987c38fd60SJacques Pienaar SmallVector<Value, 4> indices(write.getIndices().begin(),
6997c38fd60SJacques Pienaar write.getIndices().end());
7007c38fd60SJacques Pienaar AffineMap indexMap = insert.map().compose(write.getPermutationMap());
70199ef9eebSMatthias Springer unsigned idCount = 0;
70299ef9eebSMatthias Springer Location loc = write.getLoc();
70399ef9eebSMatthias Springer for (auto it :
70499ef9eebSMatthias Springer llvm::zip(indexMap.getResults(), insert.map().getResults())) {
70599ef9eebSMatthias Springer AffineExpr d0, d1;
70699ef9eebSMatthias Springer bindDims(write.getContext(), d0, d1);
70799ef9eebSMatthias Springer auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
70899ef9eebSMatthias Springer if (!indexExpr)
70999ef9eebSMatthias Springer continue;
71099ef9eebSMatthias Springer unsigned indexPos = indexExpr.getPosition();
71199ef9eebSMatthias Springer unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
71299ef9eebSMatthias Springer auto scale = getAffineConstantExpr(
71399ef9eebSMatthias Springer insert.getSourceVectorType().getDimSize(vectorPos),
71499ef9eebSMatthias Springer write.getContext());
7157c38fd60SJacques Pienaar indices[indexPos] = makeComposedAffineApply(
7167c38fd60SJacques Pienaar rewriter, loc, d0 + scale * d1,
7177c38fd60SJacques Pienaar {indices[indexPos], insert.getIds()[idCount++]});
71899ef9eebSMatthias Springer }
71999ef9eebSMatthias Springer rewriter.create<vector::TransferWriteOp>(
7207c38fd60SJacques Pienaar loc, insert.getVector(), write.getSource(), indices,
7217c38fd60SJacques Pienaar write.getPermutationMapAttr(), write.getInBoundsAttr());
72299ef9eebSMatthias Springer rewriter.eraseOp(write);
72399ef9eebSMatthias Springer return success();
72499ef9eebSMatthias Springer }
72599ef9eebSMatthias Springer };
72699ef9eebSMatthias Springer
727de5022c7SMatthias Springer struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern__anon77a0e1ec0411::UnrollReductionPattern728de5022c7SMatthias Springer UnrollReductionPattern(MLIRContext *context,
729de5022c7SMatthias Springer const vector::UnrollVectorOptions &options)
730de5022c7SMatthias Springer : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
731de5022c7SMatthias Springer options(options) {}
732de5022c7SMatthias Springer
matchAndRewrite__anon77a0e1ec0411::UnrollReductionPattern733de5022c7SMatthias Springer LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
734de5022c7SMatthias Springer PatternRewriter &rewriter) const override {
735de5022c7SMatthias Springer Optional<SmallVector<int64_t, 4>> targetShape =
736de5022c7SMatthias Springer getTargetShape(options, reductionOp);
737de5022c7SMatthias Springer if (!targetShape)
738de5022c7SMatthias Springer return failure();
739de5022c7SMatthias Springer SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
740de5022c7SMatthias Springer int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
741de5022c7SMatthias Springer
742de5022c7SMatthias Springer // Create unrolled vector reduction.
743de5022c7SMatthias Springer Location loc = reductionOp.getLoc();
744de5022c7SMatthias Springer Value accumulator = nullptr;
745de5022c7SMatthias Springer for (int64_t i = 0; i < ratio; ++i) {
746de5022c7SMatthias Springer SmallVector<int64_t> offsets =
747de5022c7SMatthias Springer getVectorOffset(originalSize, *targetShape, i);
748de5022c7SMatthias Springer SmallVector<int64_t> strides(offsets.size(), 1);
749de5022c7SMatthias Springer Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
7507c38fd60SJacques Pienaar loc, reductionOp.getVector(), offsets, *targetShape, strides);
751de5022c7SMatthias Springer Operation *newOp = cloneOpWithOperandsAndTypes(
752de5022c7SMatthias Springer rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
753de5022c7SMatthias Springer Value result = newOp->getResult(0);
754de5022c7SMatthias Springer
755de5022c7SMatthias Springer if (!accumulator) {
756de5022c7SMatthias Springer // This is the first reduction.
757de5022c7SMatthias Springer accumulator = result;
758de5022c7SMatthias Springer } else {
759de5022c7SMatthias Springer // On subsequent reduction, combine with the accumulator.
7607c38fd60SJacques Pienaar accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
761de5022c7SMatthias Springer accumulator, result);
762de5022c7SMatthias Springer }
763de5022c7SMatthias Springer }
764de5022c7SMatthias Springer
765de5022c7SMatthias Springer rewriter.replaceOp(reductionOp, accumulator);
766de5022c7SMatthias Springer return success();
767de5022c7SMatthias Springer }
768de5022c7SMatthias Springer
769de5022c7SMatthias Springer private:
770de5022c7SMatthias Springer const vector::UnrollVectorOptions options;
771de5022c7SMatthias Springer };
772de5022c7SMatthias Springer
7735b1b7108SThomas Raoux struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern__anon77a0e1ec0411::UnrollTranposePattern7745b1b7108SThomas Raoux UnrollTranposePattern(MLIRContext *context,
7755b1b7108SThomas Raoux const vector::UnrollVectorOptions &options)
7765b1b7108SThomas Raoux : OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
7775b1b7108SThomas Raoux options(options) {}
matchAndRewrite__anon77a0e1ec0411::UnrollTranposePattern7785b1b7108SThomas Raoux LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
7795b1b7108SThomas Raoux PatternRewriter &rewriter) const override {
7805b1b7108SThomas Raoux if (tranposeOp.getResultType().getRank() == 0)
7815b1b7108SThomas Raoux return failure();
7825b1b7108SThomas Raoux auto targetShape = getTargetShape(options, tranposeOp);
7835b1b7108SThomas Raoux if (!targetShape)
7845b1b7108SThomas Raoux return failure();
7855b1b7108SThomas Raoux auto originalVectorType = tranposeOp.getResultType();
7865b1b7108SThomas Raoux SmallVector<int64_t, 4> strides(targetShape->size(), 1);
7875b1b7108SThomas Raoux Location loc = tranposeOp.getLoc();
7885b1b7108SThomas Raoux ArrayRef<int64_t> originalSize = originalVectorType.getShape();
7895b1b7108SThomas Raoux SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
7905b1b7108SThomas Raoux int64_t sliceCount = computeMaxLinearIndex(ratio);
7915b1b7108SThomas Raoux // Prepare the result vector;
7925b1b7108SThomas Raoux Value result = rewriter.create<arith::ConstantOp>(
7935b1b7108SThomas Raoux loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
7945b1b7108SThomas Raoux SmallVector<int64_t> permutation;
7955b1b7108SThomas Raoux tranposeOp.getTransp(permutation);
7965b1b7108SThomas Raoux for (int64_t i = 0; i < sliceCount; i++) {
7975b1b7108SThomas Raoux SmallVector<int64_t, 4> elementOffsets =
7985b1b7108SThomas Raoux getVectorOffset(originalSize, *targetShape, i);
7995b1b7108SThomas Raoux SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
8005b1b7108SThomas Raoux SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
8015b1b7108SThomas Raoux // Compute the source offsets and shape.
8025b1b7108SThomas Raoux for (auto &indices : llvm::enumerate(permutation)) {
8035b1b7108SThomas Raoux permutedOffsets[indices.value()] = elementOffsets[indices.index()];
8045b1b7108SThomas Raoux permutedShape[indices.value()] = (*targetShape)[indices.index()];
8055b1b7108SThomas Raoux }
8065b1b7108SThomas Raoux Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
8075b1b7108SThomas Raoux loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
8085b1b7108SThomas Raoux Value tranposedSlice =
8095b1b7108SThomas Raoux rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
8105b1b7108SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>(
8115b1b7108SThomas Raoux loc, tranposedSlice, result, elementOffsets, strides);
8125b1b7108SThomas Raoux }
8135b1b7108SThomas Raoux rewriter.replaceOp(tranposeOp, result);
8145b1b7108SThomas Raoux return success();
8155b1b7108SThomas Raoux }
8165b1b7108SThomas Raoux
8175b1b7108SThomas Raoux private:
8185b1b7108SThomas Raoux vector::UnrollVectorOptions options;
8195b1b7108SThomas Raoux };
8205b1b7108SThomas Raoux
82199ef9eebSMatthias Springer } // namespace
82299ef9eebSMatthias Springer
populateVectorUnrollPatterns(RewritePatternSet & patterns,const UnrollVectorOptions & options)82399ef9eebSMatthias Springer void mlir::vector::populateVectorUnrollPatterns(
82499ef9eebSMatthias Springer RewritePatternSet &patterns, const UnrollVectorOptions &options) {
82599ef9eebSMatthias Springer patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
826f69175b1SThomas Raoux UnrollContractionPattern, UnrollElementwisePattern,
8275b1b7108SThomas Raoux UnrollReductionPattern, UnrollMultiReductionPattern,
8285b1b7108SThomas Raoux UnrollTranposePattern>(patterns.getContext(), options);
82999ef9eebSMatthias Springer }
83099ef9eebSMatthias Springer
populatePropagateVectorDistributionPatterns(RewritePatternSet & patterns)83199ef9eebSMatthias Springer void mlir::vector::populatePropagateVectorDistributionPatterns(
83299ef9eebSMatthias Springer RewritePatternSet &patterns) {
83399ef9eebSMatthias Springer patterns.add<PointwiseExtractPattern, ContractExtractPattern,
83499ef9eebSMatthias Springer TransferReadExtractPattern, TransferWriteInsertPattern>(
83599ef9eebSMatthias Springer patterns.getContext());
83699ef9eebSMatthias Springer }
837