1 //===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns that transforms linalg.<op> +
10 // tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
11 // the computation for the linalg op.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 namespace {
27 /// Bubble up extract_slice above Linalg operation.
28 ///
29 /// A sequence of operations
30 ///
31 /// ```mlir
32 /// %0 = linalg.<op> ... arg0, arg1, ...
33 /// %1 = tensor.extract_slice %0 ...
34 /// ```
35 ///
36 /// can be replaced with
37 ///
38 /// ```mlir
39 /// %0 = tensor.extract_slice %arg0
40 /// %1 = tensor.extract_slice %arg1
41 /// %2 = linalg.<op> ... %0, %1, ...
42 /// ```
43 ///
44 /// This results in the reduce computation of the linalg operation.
45 ///
46 struct BubbleUpExtractSliceOpPattern
47     : OpRewritePattern<tensor::ExtractSliceOp> {
48   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
49 
matchAndRewrite__anonbb81f59e0111::BubbleUpExtractSliceOpPattern50   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
51                                 PatternRewriter &rewriter) const final {
52     Value source = sliceOp.getSource();
53     auto linalgOp = source.getDefiningOp<LinalgOp>();
54     if (!linalgOp) {
55       return rewriter.notifyMatchFailure(sliceOp,
56                                          "expected source to be linalg op");
57     }
58 
59     // TODO: we might relax this if we want heuristics to detect that all uses
60     // are small portion of the output.
61     if (!linalgOp->hasOneUse()) {
62       return rewriter.notifyMatchFailure(sliceOp,
63                                          "expected single use of linalg op");
64     }
65 
66     if (linalgOp.getNumOutputs() != 1) {
67       return rewriter.notifyMatchFailure(sliceOp,
68                                          "expected single output of linalg op");
69     }
70 
71     if (!linalgOp.hasTensorSemantics()) {
72       return rewriter.notifyMatchFailure(sliceOp,
73                                          "expected tensor of linalg op");
74     }
75 
76     if (!sliceOp.hasUnitStride())
77       return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
78 
79     if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
80       return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
81     }
82 
83     OpOperand *outOperand = linalgOp.getOutputOperand(0);
84     AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
85     if (!indexingMap.isProjectedPermutation()) {
86       return rewriter.notifyMatchFailure(
87           sliceOp, "expected a projected permutation for output");
88     }
89 
90     auto linalgLoc = linalgOp.getLoc();
91     auto allShapeSizes =
92         linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
93     AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
94     if (!shapeSizesToLoopsMap) {
95       return rewriter.notifyMatchFailure(
96           linalgOp, "failed to get loops map from shape sizes");
97     }
98     auto sizeBounds = applyMapToValues(rewriter, linalgLoc,
99                                        shapeSizesToLoopsMap, allShapeSizes);
100 
101     auto sliceLoc = sliceOp.getLoc();
102     auto offsetVals = getValueOrCreateConstantIndexOp(
103         rewriter, sliceLoc, sliceOp.getMixedOffsets());
104     auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc,
105                                                     sliceOp.getMixedSizes());
106 
107     // The offsets and sizes from the slice operation only give you the tile
108     // size of the output. Use that compute the tile sizes and offsets of the
109     // loops. For loops not used to access the output, set the tile sizes to
110     // loop bounds and set the offset to 0.
111     Value zero = rewriter.create<arith::ConstantIndexOp>(linalgLoc, 0);
112     SmallVector<Value, 4> tileOffsets(sizeBounds.size(), zero);
113     SmallVector<Value, 4> tileSizes = sizeBounds;
114     for (auto const &result : enumerate(indexingMap.getResults())) {
115       unsigned position = result.value().cast<AffineDimExpr>().getPosition();
116       tileOffsets[position] = offsetVals[result.index()];
117       tileSizes[position] = sizeVals[result.index()];
118     }
119 
120     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
121 
122     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
123         rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes,
124         sizeBounds, /*omitPartialTileCheck=*/true);
125 
126     SmallVector<Type, 4> resultTensorTypes;
127     for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
128       resultTensorTypes.push_back(
129           tiledOperands[opOperand->getOperandNumber()].getType());
130 
131     Operation *newOp =
132         linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands);
133     rewriter.replaceOp(sliceOp, newOp->getResults());
134     return success();
135   }
136 };
137 } // namespace
138 
populateBubbleUpExtractSliceOpPatterns(RewritePatternSet & patterns)139 void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
140     RewritePatternSet &patterns) {
141   auto *context = patterns.getContext();
142   patterns.add<BubbleUpExtractSliceOpPattern>(context);
143 }
144