1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 using namespace mlir;
24 using namespace mlir::linalg;
25
26 namespace {
27 /// Bubble up extract_slice above Linalg operation.
28 ///
29 /// A sequence of operations
30 ///
31 /// ```mlir
32 /// %0 = linalg.<op> ... arg0, arg1, ...
33 /// %1 = tensor.extract_slice %0 ...
34 /// ```
35 ///
36 /// can be replaced with
37 ///
38 /// ```mlir
39 /// %0 = tensor.extract_slice %arg0
40 /// %1 = tensor.extract_slice %arg1
41 /// %2 = linalg.<op> ... %0, %1, ...
42 /// ```
43 ///
44 /// This results in the reduce computation of the linalg operation.
45 ///
46 struct BubbleUpExtractSliceOpPattern
47 : OpRewritePattern<tensor::ExtractSliceOp> {
48 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
49
matchAndRewrite__anonbb81f59e0111::BubbleUpExtractSliceOpPattern50 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51 PatternRewriter &rewriter) const final {
52 Value source = sliceOp.getSource();
53 auto linalgOp = source.getDefiningOp<LinalgOp>();
54 if (!linalgOp) {
55 return rewriter.notifyMatchFailure(sliceOp,
56 "expected source to be linalg op");
57 }
58
59 // TODO: we might relax this if we want heuristics to detect that all uses
60 // are small portion of the output.
61 if (!linalgOp->hasOneUse()) {
62 return rewriter.notifyMatchFailure(sliceOp,
63 "expected single use of linalg op");
64 }
65
66 if (linalgOp.getNumOutputs() != 1) {
67 return rewriter.notifyMatchFailure(sliceOp,
68 "expected single output of linalg op");
69 }
70
71 if (!linalgOp.hasTensorSemantics()) {
72 return rewriter.notifyMatchFailure(sliceOp,
73 "expected tensor of linalg op");
74 }
75
76 if (!sliceOp.hasUnitStride())
77 return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78
79 if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80 return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81 }
82
83 OpOperand *outOperand = linalgOp.getOutputOperand(0);
84 AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
85 if (!indexingMap.isProjectedPermutation()) {
86 return rewriter.notifyMatchFailure(
87 sliceOp, "expected a projected permutation for output");
88 }
89
90 auto linalgLoc = linalgOp.getLoc();
91 auto allShapeSizes =
92 linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93 AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94 if (!shapeSizesToLoopsMap) {
95 return rewriter.notifyMatchFailure(
96 linalgOp, "failed to get loops map from shape sizes");
97 }
98 auto sizeBounds = applyMapToValues(rewriter, linalgLoc,
99 shapeSizesToLoopsMap, allShapeSizes);
100
101 auto sliceLoc = sliceOp.getLoc();
102 auto offsetVals = getValueOrCreateConstantIndexOp(
103 rewriter, sliceLoc, sliceOp.getMixedOffsets());
104 auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc,
105 sliceOp.getMixedSizes());
106
107 // The offsets and sizes from the slice operation only give you the tile
108 // size of the output. Use that compute the tile sizes and offsets of the
109 // loops. For loops not used to access the output, set the tile sizes to
110 // loop bounds and set the offset to 0.
111 Value zero = rewriter.create<arith::ConstantIndexOp>(linalgLoc, 0);
112 SmallVector<Value, 4> tileOffsets(sizeBounds.size(), zero);
113 SmallVector<Value, 4> tileSizes = sizeBounds;
114 for (auto const &result : enumerate(indexingMap.getResults())) {
115 unsigned position = result.value().cast<AffineDimExpr>().getPosition();
116 tileOffsets[position] = offsetVals[result.index()];
117 tileSizes[position] = sizeVals[result.index()];
118 }
119
120 SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
121
122 SmallVector<Value, 4> tiledOperands = makeTiledShapes(
123 rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes,
124 sizeBounds, /*omitPartialTileCheck=*/true);
125
126 SmallVector<Type, 4> resultTensorTypes;
127 for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
128 resultTensorTypes.push_back(
129 tiledOperands[opOperand->getOperandNumber()].getType());
130
131 Operation *newOp =
132 linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands);
133 rewriter.replaceOp(sliceOp, newOp->getResults());
134 return success();
135 }
136 };
137 } // namespace
138
populateBubbleUpExtractSliceOpPatterns(RewritePatternSet & patterns)139 void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
140 RewritePatternSet &patterns) {
141 auto *context = patterns.getContext();
142 patterns.add<BubbleUpExtractSliceOpPattern>(context);
143 }
144