1307cfdf5SNicolas Vasilache //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2307cfdf5SNicolas Vasilache //
3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6307cfdf5SNicolas Vasilache //
7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
8307cfdf5SNicolas Vasilache //
9307cfdf5SNicolas Vasilache // This file implements logic and helpers to expose Linalg transforms as rewrite
10307cfdf5SNicolas Vasilache // patterns.
11307cfdf5SNicolas Vasilache //
12307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
13307cfdf5SNicolas Vasilache 
14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
153afbfb41SThomas Raoux #include "mlir/Dialect/Affine/Utils.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19d0ec4a8eSTobias Gysi #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
218faf35c0SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h"
22060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
23d624c1b5SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
24307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25307cfdf5SNicolas Vasilache #include "mlir/Dialect/Vector/VectorOps.h"
26307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
27307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h"
28307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h"
29307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h"
30b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
313747eb9cSNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
328faf35c0SMatthias Springer #include "llvm/ADT/TypeSwitch.h"
33307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h"
34307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
35307cfdf5SNicolas Vasilache #include <type_traits>
36307cfdf5SNicolas Vasilache 
37307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-transforms"
38307cfdf5SNicolas Vasilache 
39307cfdf5SNicolas Vasilache using namespace mlir;
40307cfdf5SNicolas Vasilache using namespace mlir::linalg;
41307cfdf5SNicolas Vasilache 
4256ce65e2SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
433110e7b0SNicolas Vasilache 
44307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
45307cfdf5SNicolas Vasilache // Transformations exposed as rewrite patterns.
46307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
47307cfdf5SNicolas Vasilache // Marker used as attribute name in generated Linalg rewriting transformations.
48307cfdf5SNicolas Vasilache const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49307cfdf5SNicolas Vasilache     "__internal_linalg_transform__";
50307cfdf5SNicolas Vasilache 
51299cc5daSNicolas Vasilache mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52195730a6SRiver Riddle     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53e4a503a2SNicolas Vasilache     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
543ecc2a63SMaheshRavishankar       replacement(replacement), matchByDefault(false) {}
55299cc5daSNicolas Vasilache 
56299cc5daSNicolas Vasilache mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57195730a6SRiver Riddle     FilterFunction f, ArrayRef<StringAttr> matchDisjunction,
58195730a6SRiver Riddle     Optional<StringAttr> replacement)
59e4a503a2SNicolas Vasilache     : filters(),
60299cc5daSNicolas Vasilache       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
613ecc2a63SMaheshRavishankar       replacement(replacement), matchByDefault(false) {
62e4a503a2SNicolas Vasilache   if (f)
63e4a503a2SNicolas Vasilache     filters.push_back(f);
64e4a503a2SNicolas Vasilache }
65307cfdf5SNicolas Vasilache 
66299cc5daSNicolas Vasilache LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67299cc5daSNicolas Vasilache     PatternRewriter &rewriter, Operation *op) const {
68e4a503a2SNicolas Vasilache   if (llvm::any_of(filters,
69e4a503a2SNicolas Vasilache                    [&](const FilterFunction &f) { return failed(f(op)); }))
70299cc5daSNicolas Vasilache     return failure();
71299cc5daSNicolas Vasilache 
72307cfdf5SNicolas Vasilache   auto attr = op->template getAttrOfType<StringAttr>(
73307cfdf5SNicolas Vasilache       LinalgTransforms::kLinalgTransformMarker);
74307cfdf5SNicolas Vasilache 
75307cfdf5SNicolas Vasilache   if (!attr) {
76e4a503a2SNicolas Vasilache     // 1. Has no filter case and matchDisjunction is empty.
773ecc2a63SMaheshRavishankar     if (matchDisjunction.empty() || matchByDefault)
78307cfdf5SNicolas Vasilache       return success();
79307cfdf5SNicolas Vasilache 
80e4a503a2SNicolas Vasilache     // 2. Has no filter but was expecting a filter.
81307cfdf5SNicolas Vasilache     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82e4a503a2SNicolas Vasilache       diag << " does not have any filter from list: ";
8391beb517SNicolas Vasilache       interleaveComma(matchDisjunction, diag);
84307cfdf5SNicolas Vasilache     });
85307cfdf5SNicolas Vasilache   }
86307cfdf5SNicolas Vasilache 
87e4a503a2SNicolas Vasilache   // 4. Match explicit filter.
88e4a503a2SNicolas Vasilache   for (auto filter : matchDisjunction)
89e4a503a2SNicolas Vasilache     if (attr.getValue() == filter)
90307cfdf5SNicolas Vasilache       return success();
91307cfdf5SNicolas Vasilache 
92307cfdf5SNicolas Vasilache   // 5. Fail to match.
93307cfdf5SNicolas Vasilache   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94e4a503a2SNicolas Vasilache     diag << " does not have any filter from list: ";
9591beb517SNicolas Vasilache     interleaveComma(matchDisjunction, diag);
96307cfdf5SNicolas Vasilache   });
97307cfdf5SNicolas Vasilache }
98307cfdf5SNicolas Vasilache 
99299cc5daSNicolas Vasilache void mlir::linalg::LinalgTransformationFilter::
100299cc5daSNicolas Vasilache     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101307cfdf5SNicolas Vasilache                                       Operation *op) const {
102307cfdf5SNicolas Vasilache   if (replacement.hasValue())
103307cfdf5SNicolas Vasilache     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104195730a6SRiver Riddle                 replacement.getValue());
105307cfdf5SNicolas Vasilache   else
106195730a6SRiver Riddle     op->removeAttr(
107195730a6SRiver Riddle         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
108307cfdf5SNicolas Vasilache }
109307cfdf5SNicolas Vasilache 
110d26beb0bSMaheshRavishankar bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
111d26beb0bSMaheshRavishankar     Operation *op) const {
112d26beb0bSMaheshRavishankar   if (!replacement)
113d26beb0bSMaheshRavishankar     return false;
114d26beb0bSMaheshRavishankar   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
115d26beb0bSMaheshRavishankar                   .dyn_cast<StringAttr>();
116d26beb0bSMaheshRavishankar   return attr && attr == replacement.getValue();
117d26beb0bSMaheshRavishankar }
118d26beb0bSMaheshRavishankar 
119004a3d4fSNicolas Vasilache LinalgTilingOptions &
120004a3d4fSNicolas Vasilache mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
121fb1def9cSMatthias Springer   assert(!tileSizeComputationFunction && "tile sizes already set");
122004a3d4fSNicolas Vasilache   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123004a3d4fSNicolas Vasilache   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124004a3d4fSNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
125004a3d4fSNicolas Vasilache     b.setInsertionPointToStart(
126004a3d4fSNicolas Vasilache         &op->getParentOfType<FuncOp>().getBody().front());
12791beb517SNicolas Vasilache     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128a54f4eaeSMogball       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129004a3d4fSNicolas Vasilache       return v;
130004a3d4fSNicolas Vasilache     }));
131004a3d4fSNicolas Vasilache   };
132004a3d4fSNicolas Vasilache   return *this;
133e9ac7927SAlexander Belyaev }
134004a3d4fSNicolas Vasilache 
135fb1def9cSMatthias Springer LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
136fb1def9cSMatthias Springer   assert(!tileSizeComputationFunction && "tile sizes already set");
137fb1def9cSMatthias Springer   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138fb1def9cSMatthias Springer     SmallVector<Value, 4> tileSizes;
139fb1def9cSMatthias Springer     auto linalgOp = dyn_cast<LinalgOp>(op);
140fb1def9cSMatthias Springer     if (!linalgOp)
141fb1def9cSMatthias Springer       return tileSizes;
142fb1def9cSMatthias Springer     Location loc = linalgOp.getLoc();
143fb1def9cSMatthias Springer     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144fb1def9cSMatthias Springer     AffineMap map = linalgOp.getShapesToLoopsMap();
145fb1def9cSMatthias Springer     if (!map)
146fb1def9cSMatthias Springer       return tileSizes;
147fb1def9cSMatthias Springer     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148fb1def9cSMatthias Springer     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149fb1def9cSMatthias Springer     // size 0).
150fb1def9cSMatthias Springer     for (Value shapeSize : shapeSizes)
151fb1def9cSMatthias Springer       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152a54f4eaeSMogball                               ? b.create<arith::ConstantIndexOp>(loc, 0)
153a54f4eaeSMogball                               : b.create<arith::ConstantIndexOp>(loc, 1));
154fb1def9cSMatthias Springer     return tileSizes;
155fb1def9cSMatthias Springer   };
156fb1def9cSMatthias Springer   return *this;
157fb1def9cSMatthias Springer }
158fb1def9cSMatthias Springer 
159a4fd8cb7STobias Gysi /// Helper function that tries to pad `opOperand`. Exit early for scalar
160a4fd8cb7STobias Gysi /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161a4fd8cb7STobias Gysi /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162a4fd8cb7STobias Gysi /// has a static shape. Set `result` to the result of the created PadTensorOp or
163a4fd8cb7STobias Gysi /// and return success if the operand either has been padded to a static shape
164a4fd8cb7STobias Gysi /// or already had a static shape and failure otherwise.
1653747eb9cSNicolas Vasilache static LogicalResult padOperandToSmallestStaticBoundingBox(
1661eae247aSTobias Gysi     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167a8f69be6STobias Gysi     const PaddingValueComputationFunction &paddingFunc,
168a8f69be6STobias Gysi     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169a4fd8cb7STobias Gysi   // Get the shape of the operand and check if it has a dynamic shape. Only
170a4fd8cb7STobias Gysi   // return failure if the operand is not a scalar and has a dynamic shape.
171a4fd8cb7STobias Gysi   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172a4fd8cb7STobias Gysi   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173a4fd8cb7STobias Gysi 
174a4fd8cb7STobias Gysi   // Cannot pad scalar operands.
175a4fd8cb7STobias Gysi   if (shape.empty())
1763747eb9cSNicolas Vasilache     return success();
177a4fd8cb7STobias Gysi 
178a4fd8cb7STobias Gysi   // Cannot pad if the padding value is unknown.
1791eae247aSTobias Gysi   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180d20d0e14STobias Gysi   if (failed(paddingValue))
181a4fd8cb7STobias Gysi     return failure(hasDynamicShape);
182a4fd8cb7STobias Gysi 
183a4fd8cb7STobias Gysi   // Cannot construct a static bounding box if the operand is not defined by an
184a4fd8cb7STobias Gysi   // ExtractSliceOp.
185060208b4SMatthias Springer   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186060208b4SMatthias Springer   if (!sliceOp)
187a4fd8cb7STobias Gysi     return failure(hasDynamicShape);
188a4fd8cb7STobias Gysi 
189a4fd8cb7STobias Gysi   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
1903747eb9cSNicolas Vasilache   SmallVector<int64_t> staticSizes;
191d2661c6cSTobias Gysi   staticSizes.reserve(opToPad.getRank(opOperand));
192060208b4SMatthias Springer   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
1933747eb9cSNicolas Vasilache   for (auto size : shapedOp.getMixedSizes()) {
194ea53a693STobias Gysi     // If the size is an attribute add it directly to `staticSizes`.
195ea53a693STobias Gysi     if (size.is<Attribute>()) {
196ea53a693STobias Gysi       staticSizes.push_back(
197ea53a693STobias Gysi           size.get<Attribute>().dyn_cast<IntegerAttr>().getInt());
198ea53a693STobias Gysi       continue;
199ea53a693STobias Gysi     }
200ea53a693STobias Gysi     // Otherwise, try to compute a constant upper bound for the size value.
201ea53a693STobias Gysi     FailureOr<int64_t> upperBound =
202ea53a693STobias Gysi         getConstantUpperBoundForIndex(size.get<Value>());
203ea53a693STobias Gysi     if (failed(upperBound)) {
2041eae247aSTobias Gysi       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
2051eae247aSTobias Gysi       return failure();
2061eae247aSTobias Gysi     }
207ea53a693STobias Gysi     staticSizes.push_back(upperBound.getValue());
2083747eb9cSNicolas Vasilache   }
209a4fd8cb7STobias Gysi 
210a4fd8cb7STobias Gysi   // Pad the operand to the bounding box defined by `staticSizes`.
211d2661c6cSTobias Gysi   auto staticTensorType = RankedTensorType::get(
212046922e1STobias Gysi       staticSizes, getElementTypeOrSelf(opOperand->get()));
213a8f69be6STobias Gysi   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
21486f186efSTobias Gysi   result =
21586f186efSTobias Gysi       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
21686f186efSTobias Gysi                             opOperand->get(), paddingValue.getValue(), nofold);
2173747eb9cSNicolas Vasilache   return success();
2183747eb9cSNicolas Vasilache }
2193747eb9cSNicolas Vasilache 
2201eae247aSTobias Gysi FailureOr<SmallVector<Value>>
2211eae247aSTobias Gysi linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
222aa2210a8SAlexander Belyaev                           const PaddingValueComputationFunction &paddingFunc,
223a8f69be6STobias Gysi                           const PaddingNoFoldComputationFunction &nofoldFunc,
224aa2210a8SAlexander Belyaev                           LinalgOp &paddedOp) {
2253747eb9cSNicolas Vasilache   Location loc = opToPad->getLoc();
2263747eb9cSNicolas Vasilache 
2273747eb9cSNicolas Vasilache   // TODO: there are cases where we may still want to pad to larger sizes.
2289f815cb5STobias Gysi   assert(opToPad.hasTensorSemantics() &&
2299f815cb5STobias Gysi          "expected operation to have tensor semantics");
2303747eb9cSNicolas Vasilache 
2311eae247aSTobias Gysi   OpBuilder::InsertionGuard g(b);
2323747eb9cSNicolas Vasilache   // Set IP after op because we also take the dims of the original output.
2331eae247aSTobias Gysi   b.setInsertionPointAfter(opToPad);
2343747eb9cSNicolas Vasilache   // Make a copy of the shaped operands and update it.
235d57a305fSNicolas Vasilache   SmallVector<Value> newOperands;
2369f815cb5STobias Gysi   newOperands.reserve(opToPad.getNumInputsAndOutputs());
2379f815cb5STobias Gysi   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
2383747eb9cSNicolas Vasilache     Value paddedOperand;
2393747eb9cSNicolas Vasilache     // If padding was requested but the shape cannot be bounded statically then
2403747eb9cSNicolas Vasilache     // the pattern fails to apply.
2419f815cb5STobias Gysi     if (failed(padOperandToSmallestStaticBoundingBox(
2421eae247aSTobias Gysi             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
2433747eb9cSNicolas Vasilache       return failure();
2449f815cb5STobias Gysi     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
2453747eb9cSNicolas Vasilache   }
2463747eb9cSNicolas Vasilache 
247b01d223fSNicolas Vasilache   SmallVector<SmallVector<Value>> reifiedResultShapes;
248b01d223fSNicolas Vasilache   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
2491eae247aSTobias Gysi                  .reifyResultShapes(b, reifiedResultShapes)))
250b01d223fSNicolas Vasilache     return failure();
251b01d223fSNicolas Vasilache   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
252b01d223fSNicolas Vasilache          "expected same number of results");
253b01d223fSNicolas Vasilache 
2543747eb9cSNicolas Vasilache   // Clone `opToPad` to operate on the statically padded shapes.
2553747eb9cSNicolas Vasilache   auto resultTensorTypes =
256d57a305fSNicolas Vasilache       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
2571eae247aSTobias Gysi   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
2583747eb9cSNicolas Vasilache 
259060208b4SMatthias Springer   // Recover the slice out of the new static results. This keeps the original
260060208b4SMatthias Springer   // linalg op around because it uses the dims of the original results.
2613747eb9cSNicolas Vasilache   SmallVector<Value> paddedSubviewResults;
2623747eb9cSNicolas Vasilache   paddedSubviewResults.reserve(opToPad->getNumResults());
263b01d223fSNicolas Vasilache   for (auto en : llvm::enumerate(paddedOp->getResults())) {
264b01d223fSNicolas Vasilache     Value paddedResult = en.value();
265b01d223fSNicolas Vasilache     int64_t resultNumber = en.index();
266b01d223fSNicolas Vasilache     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
2671eae247aSTobias Gysi     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
268b01d223fSNicolas Vasilache     SmallVector<OpFoldResult> sizes;
269b01d223fSNicolas Vasilache     for (Value v : reifiedResultShapes[resultNumber])
270247a1a55STobias Gysi       sizes.push_back(getAsOpFoldResult(v));
2711eae247aSTobias Gysi     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2721eae247aSTobias Gysi     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
273b01d223fSNicolas Vasilache         loc, paddedResult, offsets, sizes, strides));
2743747eb9cSNicolas Vasilache   }
2751eae247aSTobias Gysi   return paddedSubviewResults;
2763747eb9cSNicolas Vasilache }
2773747eb9cSNicolas Vasilache 
278307cfdf5SNicolas Vasilache /// Linalg base tiling pattern.
279307cfdf5SNicolas Vasilache mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
280307cfdf5SNicolas Vasilache     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
281e4a503a2SNicolas Vasilache     LinalgTransformationFilter filter, PatternBenefit benefit)
28276f3c2f3SRiver Riddle     : RewritePattern(opName, benefit, context), filter(filter),
283307cfdf5SNicolas Vasilache       options(options) {}
284307cfdf5SNicolas Vasilache 
285a1cd559cSNicolas Vasilache mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
28676f3c2f3SRiver Riddle     MLIRContext *context, LinalgTilingOptions options,
28776f3c2f3SRiver Riddle     LinalgTransformationFilter filter, PatternBenefit benefit)
28876f3c2f3SRiver Riddle     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
289a1cd559cSNicolas Vasilache       options(options) {}
290a1cd559cSNicolas Vasilache 
2918faf35c0SMatthias Springer /// Try to peel a loop `op` and return the new result.
2922190f8a8SMatthias Springer // TODO: Add support for scf.parallel and affine.for loops.
2938faf35c0SMatthias Springer static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
2948faf35c0SMatthias Springer   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
2958faf35c0SMatthias Springer       .Case<scf::ForOp>([&](scf::ForOp forOp) {
2968faf35c0SMatthias Springer         scf::ForOp partialIteration;
2978faf35c0SMatthias Springer         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
2988faf35c0SMatthias Springer                                                       partialIteration)))
2998faf35c0SMatthias Springer           return partialIteration->getResults();
3008faf35c0SMatthias Springer         assert(!partialIteration && "expected that loop was not peeled");
3018faf35c0SMatthias Springer         return forOp->getResults();
3028faf35c0SMatthias Springer       })
3038faf35c0SMatthias Springer       .Default([&](Operation *op) { return op->getResults(); });
3048faf35c0SMatthias Springer }
3058faf35c0SMatthias Springer 
3062190f8a8SMatthias Springer /// Try to peel a TiledLoopOp and return the new result.
3072190f8a8SMatthias Springer static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
3082190f8a8SMatthias Springer                                       TiledLoopOp tiledLoop, int64_t idx) {
3092190f8a8SMatthias Springer   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
3102190f8a8SMatthias Springer          "requested peeling of non-existing loop");
3112190f8a8SMatthias Springer   TiledLoopOp result;
3122190f8a8SMatthias Springer   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
3132190f8a8SMatthias Springer     return result->getResults();
3142190f8a8SMatthias Springer   assert(!result && "expected that loop was not peeled");
3152190f8a8SMatthias Springer   return tiledLoop->getResults();
3162190f8a8SMatthias Springer }
3172190f8a8SMatthias Springer 
3182190f8a8SMatthias Springer /// Peel loops after tiling.
3192190f8a8SMatthias Springer static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
3202190f8a8SMatthias Springer                       const LinalgTilingOptions &options) {
3212190f8a8SMatthias Springer   for (int64_t loop : options.peeledLoops) {
3222190f8a8SMatthias Springer     assert(loop < static_cast<int64_t>(res.loops.size()) &&
3232190f8a8SMatthias Springer            "requested peeling of non-existing loop");
3242190f8a8SMatthias Springer     SmallVector<Value, 4> loopResults;
3252190f8a8SMatthias Springer     Operation *loopOp = res.loops[loop];
3262190f8a8SMatthias Springer     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
3272190f8a8SMatthias Springer       assert(llvm::all_of(
3282190f8a8SMatthias Springer                  res.loops,
3292190f8a8SMatthias Springer                  [&](Operation *op) { return op == res.loops.front(); }) &&
3302190f8a8SMatthias Springer              "expected that all loop ops are the same TiledLoopOp");
3312190f8a8SMatthias Springer       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
3322190f8a8SMatthias Springer       assert(tiledLoopOp && "expected TiledLoopOp");
3332190f8a8SMatthias Springer       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
3342190f8a8SMatthias Springer     } else {
3352190f8a8SMatthias Springer       loopResults = peelLoop(rewriter, loopOp);
3362190f8a8SMatthias Springer     }
3372190f8a8SMatthias Springer 
3382190f8a8SMatthias Springer     // The result of the loop nest may change with peeling.
3392190f8a8SMatthias Springer     if (res.tensorResults.size() == loopOp->getNumResults() &&
3402190f8a8SMatthias Springer         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
3412190f8a8SMatthias Springer                    loopOp->getResults().begin()))
3422190f8a8SMatthias Springer       res.tensorResults = loopResults;
3432190f8a8SMatthias Springer   }
3442190f8a8SMatthias Springer }
3452190f8a8SMatthias Springer 
346a3adcba6SNicolas Vasilache LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
34780f07854SNicolas Vasilache     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
348307cfdf5SNicolas Vasilache   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3498ea5d190STobias Gysi   if (!linalgOp)
350307cfdf5SNicolas Vasilache     return failure();
351e4a503a2SNicolas Vasilache   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
352307cfdf5SNicolas Vasilache     return failure();
35356ce65e2SNicolas Vasilache 
354004a3d4fSNicolas Vasilache   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
355307cfdf5SNicolas Vasilache 
356307cfdf5SNicolas Vasilache   if (!res)
357307cfdf5SNicolas Vasilache     return failure();
3588faf35c0SMatthias Springer   // Clear filter to stop recursive pattern application.
3598faf35c0SMatthias Springer   filter.replaceLinalgTransformationFilter(rewriter, res->op);
360307cfdf5SNicolas Vasilache 
3618faf35c0SMatthias Springer   // Peel loops.
3622190f8a8SMatthias Springer   peelLoops(rewriter, *res, options);
363a3adcba6SNicolas Vasilache 
3648faf35c0SMatthias Springer   result = *res;
3653747eb9cSNicolas Vasilache   return success();
3668faf35c0SMatthias Springer }
3673747eb9cSNicolas Vasilache 
36898835e3dSMaheshRavishankar static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
36998835e3dSMaheshRavishankar   if (tiledOp.loops.empty())
37098835e3dSMaheshRavishankar     return tiledOp.op.getOperation()->getResults();
37198835e3dSMaheshRavishankar   return tiledOp.loops.front()->getResults();
37298835e3dSMaheshRavishankar }
37398835e3dSMaheshRavishankar 
37498835e3dSMaheshRavishankar static ValueRange
37598835e3dSMaheshRavishankar getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
37698835e3dSMaheshRavishankar   if (tiledAndFusedOp.fusedLoops.empty())
37798835e3dSMaheshRavishankar     return tiledAndFusedOp.op.getOperation()->getResults();
37898835e3dSMaheshRavishankar   return tiledAndFusedOp.fusedLoops.front()->getResults();
37998835e3dSMaheshRavishankar }
38098835e3dSMaheshRavishankar 
381c694588fSMaheshRavishankar mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
382c694588fSMaheshRavishankar     StringRef opName, MLIRContext *context,
383c694588fSMaheshRavishankar     const LinalgDependenceGraph &dependenceGraph,
384c694588fSMaheshRavishankar     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
385e4a503a2SNicolas Vasilache     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
386299cc5daSNicolas Vasilache     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
38776f3c2f3SRiver Riddle     : RewritePattern(opName, benefit, context, {}),
388c694588fSMaheshRavishankar       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
389e4a503a2SNicolas Vasilache       fusionOptions(fusionOptions), filter(filter),
390c694588fSMaheshRavishankar       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
391c694588fSMaheshRavishankar 
392c694588fSMaheshRavishankar LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
393c694588fSMaheshRavishankar     Operation *op, PatternRewriter &rewriter) const {
394c694588fSMaheshRavishankar   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
39593f9922dSTobias Gysi   // TODO: remove hasIndexSemantics check once index ops are supported.
39693f9922dSTobias Gysi   if (!linalgOp || linalgOp.hasIndexSemantics())
397c694588fSMaheshRavishankar     return failure();
398e4a503a2SNicolas Vasilache   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
399c694588fSMaheshRavishankar     return failure();
400c694588fSMaheshRavishankar 
401e65a5e5bSMaheshRavishankar   DenseSet<Operation *> producers;
402e65a5e5bSMaheshRavishankar   producers.insert(linalgOp);
403bce318f5SMaheshRavishankar   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
404bce318f5SMaheshRavishankar     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
405bce318f5SMaheshRavishankar     // When looking at dependences into, indexingOp is always OpOperand. We
406bce318f5SMaheshRavishankar     // could assert, but continue if this is not the case.
407bce318f5SMaheshRavishankar     if (!operandNumber)
408e65a5e5bSMaheshRavishankar       continue;
409bce318f5SMaheshRavishankar     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
410bce318f5SMaheshRavishankar       continue;
411bce318f5SMaheshRavishankar     if (isa<LinalgOp>(dependence.getDependentOp()))
412bce318f5SMaheshRavishankar       producers.insert(dependence.getDependentOp());
413e65a5e5bSMaheshRavishankar   }
414e65a5e5bSMaheshRavishankar 
415e65a5e5bSMaheshRavishankar   SmallVector<LinalgOp, 1> fusionOps;
416e65a5e5bSMaheshRavishankar   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
417e65a5e5bSMaheshRavishankar        ++it) {
418e65a5e5bSMaheshRavishankar     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
419e65a5e5bSMaheshRavishankar     if (producerLinalgOp && producers.count(producerLinalgOp))
420e65a5e5bSMaheshRavishankar       fusionOps.push_back(producerLinalgOp);
421e65a5e5bSMaheshRavishankar   }
422e65a5e5bSMaheshRavishankar   fusionOps.push_back(linalgOp);
423e65a5e5bSMaheshRavishankar 
424e65a5e5bSMaheshRavishankar   SmallVector<Value, 4> tileSizes =
425e65a5e5bSMaheshRavishankar       tilingOptions.tileSizeComputationFunction(rewriter, op);
426e65a5e5bSMaheshRavishankar   LinalgTilingOptions instanceTilingOptions = tilingOptions;
427e65a5e5bSMaheshRavishankar   instanceTilingOptions.setTileSizes(tileSizes);
428c694588fSMaheshRavishankar   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
429e65a5e5bSMaheshRavishankar       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
430c694588fSMaheshRavishankar   if (!tiledAndFusedOps)
431c694588fSMaheshRavishankar     return failure();
432e65a5e5bSMaheshRavishankar 
433e65a5e5bSMaheshRavishankar   // Tile the unfused loops;
434e65a5e5bSMaheshRavishankar   SmallVector<Value, 4> unfusedLoopTileSizes;
435a54f4eaeSMogball   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
436e65a5e5bSMaheshRavishankar   for (auto tileSize : enumerate(tileSizes)) {
437e65a5e5bSMaheshRavishankar     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
438e65a5e5bSMaheshRavishankar       unfusedLoopTileSizes.push_back(zero);
439e65a5e5bSMaheshRavishankar     else
440e65a5e5bSMaheshRavishankar       unfusedLoopTileSizes.push_back(tileSize.value());
441e65a5e5bSMaheshRavishankar   }
442e65a5e5bSMaheshRavishankar   // Tile the loop only if there is a non-zero tile size.
443e65a5e5bSMaheshRavishankar   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
444e65a5e5bSMaheshRavishankar     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
445e65a5e5bSMaheshRavishankar   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
446a54f4eaeSMogball         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
447a54f4eaeSMogball           return cst.value() != 0;
448e65a5e5bSMaheshRavishankar         return true;
449e65a5e5bSMaheshRavishankar       })) {
450e65a5e5bSMaheshRavishankar     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
451e65a5e5bSMaheshRavishankar     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
452e65a5e5bSMaheshRavishankar     Optional<TiledLinalgOp> unfusedTiledOp =
453e65a5e5bSMaheshRavishankar         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
454e65a5e5bSMaheshRavishankar     if (!unfusedTiledOp)
455e65a5e5bSMaheshRavishankar       return failure();
45698835e3dSMaheshRavishankar     rewriter.replaceOp(tiledAndFusedOps->op,
45798835e3dSMaheshRavishankar                        getTiledOpResult(unfusedTiledOp.getValue()));
458e65a5e5bSMaheshRavishankar     tiledAndFusedOps->op = unfusedTiledOp->op;
459e65a5e5bSMaheshRavishankar   }
46098835e3dSMaheshRavishankar   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
461e65a5e5bSMaheshRavishankar 
462e4a503a2SNicolas Vasilache   filter.replaceLinalgTransformationFilter(rewriter,
463299cc5daSNicolas Vasilache                                            tiledAndFusedOps->op.getOperation());
464c694588fSMaheshRavishankar   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
465299cc5daSNicolas Vasilache     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
466299cc5daSNicolas Vasilache                                                     fusedOp.getOperation());
467c694588fSMaheshRavishankar   }
468e65a5e5bSMaheshRavishankar   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
469299cc5daSNicolas Vasilache     originalOpMarker.replaceLinalgTransformationFilter(
470299cc5daSNicolas Vasilache         rewriter, origProducerOp.getOperation());
471e65a5e5bSMaheshRavishankar   }
472299cc5daSNicolas Vasilache   rewriter.updateRootInPlace(op, [&]() {
473299cc5daSNicolas Vasilache     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
474299cc5daSNicolas Vasilache   });
475c694588fSMaheshRavishankar   return success();
476c694588fSMaheshRavishankar }
477c694588fSMaheshRavishankar 
478d0ec4a8eSTobias Gysi /// Linalg padding pattern.
479d0ec4a8eSTobias Gysi mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
480d0ec4a8eSTobias Gysi     MLIRContext *context, LinalgPaddingOptions options,
481d0ec4a8eSTobias Gysi     LinalgTransformationFilter filter, PatternBenefit benefit)
482d0ec4a8eSTobias Gysi     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
483d0ec4a8eSTobias Gysi       options(options) {}
484d0ec4a8eSTobias Gysi 
485d0ec4a8eSTobias Gysi mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
486d0ec4a8eSTobias Gysi     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
487d0ec4a8eSTobias Gysi     LinalgTransformationFilter filter, PatternBenefit benefit)
488d0ec4a8eSTobias Gysi     : RewritePattern(opName, benefit, context, {}), filter(filter),
489d0ec4a8eSTobias Gysi       options(options) {}
490d0ec4a8eSTobias Gysi 
491d0ec4a8eSTobias Gysi LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
492d0ec4a8eSTobias Gysi     Operation *op, PatternRewriter &rewriter) const {
493d0ec4a8eSTobias Gysi   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
494d0ec4a8eSTobias Gysi   if (!linalgOp)
495d0ec4a8eSTobias Gysi     return failure();
496d0ec4a8eSTobias Gysi   if (!linalgOp.hasTensorSemantics())
497d0ec4a8eSTobias Gysi     return failure();
498d0ec4a8eSTobias Gysi   if (failed(filter.checkAndNotify(rewriter, op)))
499d0ec4a8eSTobias Gysi     return failure();
500d0ec4a8eSTobias Gysi 
501d0ec4a8eSTobias Gysi   // Pad the operation.
502d0ec4a8eSTobias Gysi   LinalgOp paddedOp;
503d0ec4a8eSTobias Gysi   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
504d0ec4a8eSTobias Gysi       rewriter, linalgOp, options.paddingValueComputationFunction,
505d0ec4a8eSTobias Gysi       options.paddingNoFoldComputationFunction, paddedOp);
50669bcff46Sgysit   if (failed(newResults))
507d0ec4a8eSTobias Gysi     return failure();
508d0ec4a8eSTobias Gysi 
509d0ec4a8eSTobias Gysi   // Compute the desired hoisting depths.
510d0ec4a8eSTobias Gysi   SmallVector<int64_t> depths;
511d0ec4a8eSTobias Gysi   if (options.paddingHoistComputationFunction) {
512d0ec4a8eSTobias Gysi     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
513d0ec4a8eSTobias Gysi       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
514d0ec4a8eSTobias Gysi   }
515d0ec4a8eSTobias Gysi 
516d0ec4a8eSTobias Gysi   // Hoist the padding.
517d0ec4a8eSTobias Gysi   for (auto en : enumerate(depths)) {
518d0ec4a8eSTobias Gysi     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
519d0ec4a8eSTobias Gysi     auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
520d0ec4a8eSTobias Gysi     if (!padTensorOp || en.value() == 0)
521d0ec4a8eSTobias Gysi       continue;
522d0ec4a8eSTobias Gysi     PadTensorOp hoistedOp;
523d0ec4a8eSTobias Gysi     FailureOr<Value> newResult =
524d0ec4a8eSTobias Gysi         hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
525d0ec4a8eSTobias Gysi     if (failed(newResult))
526d0ec4a8eSTobias Gysi       continue;
527d0ec4a8eSTobias Gysi     rewriter.replaceOp(padTensorOp, newResult.getValue());
528d0ec4a8eSTobias Gysi   }
529d0ec4a8eSTobias Gysi 
530d0ec4a8eSTobias Gysi   // Replace the original operation to pad.
531d0ec4a8eSTobias Gysi   rewriter.replaceOp(op, newResults.getValue());
532d0ec4a8eSTobias Gysi   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
533d0ec4a8eSTobias Gysi   return success();
534d0ec4a8eSTobias Gysi }
535d0ec4a8eSTobias Gysi 
536e3d386eaSTobias Gysi /// Linalg tile and fuse tensor ops pattern.
537e3d386eaSTobias Gysi mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
538e3d386eaSTobias Gysi     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
539e3d386eaSTobias Gysi                                       LinalgTilingAndFusionOptions options,
540e3d386eaSTobias Gysi                                       LinalgTransformationFilter filter,
541e3d386eaSTobias Gysi                                       PatternBenefit benefit)
542e3d386eaSTobias Gysi     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
543e3d386eaSTobias Gysi       options(options) {}
544e3d386eaSTobias Gysi 
545e3d386eaSTobias Gysi mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
546e3d386eaSTobias Gysi     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
547e3d386eaSTobias Gysi                                       LinalgTilingAndFusionOptions options,
548e3d386eaSTobias Gysi                                       LinalgTransformationFilter filter,
549e3d386eaSTobias Gysi                                       PatternBenefit benefit)
550e3d386eaSTobias Gysi     : RewritePattern(opName, benefit, context), filter(filter),
551e3d386eaSTobias Gysi       options(options) {}
552e3d386eaSTobias Gysi 
553e3d386eaSTobias Gysi LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
554e3d386eaSTobias Gysi     Operation *op, PatternRewriter &rewriter) const {
555e3d386eaSTobias Gysi   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
556e3d386eaSTobias Gysi   if (!rootOp)
557e3d386eaSTobias Gysi     return failure();
558e3d386eaSTobias Gysi   if (failed(filter.checkAndNotify(rewriter, op)))
559e3d386eaSTobias Gysi     return failure();
560e3d386eaSTobias Gysi 
561e3d386eaSTobias Gysi   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
562e3d386eaSTobias Gysi   if (options.tileSizes.size() < rootOp.getNumLoops())
563e3d386eaSTobias Gysi     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
564e3d386eaSTobias Gysi 
565e3d386eaSTobias Gysi   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
566e3d386eaSTobias Gysi   if (!options.tileInterchange.empty() &&
567e3d386eaSTobias Gysi       options.tileInterchange.size() != options.tileSizes.size())
568e3d386eaSTobias Gysi     return rewriter.notifyMatchFailure(
569e3d386eaSTobias Gysi         op, "expect the number of tile sizes and interchange dims to match");
570e3d386eaSTobias Gysi 
571e3d386eaSTobias Gysi   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
572e3d386eaSTobias Gysi   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
573e3d386eaSTobias Gysi                                      options.tileSizes.begin() +
574e3d386eaSTobias Gysi                                          rootOp.getNumLoops());
575e3d386eaSTobias Gysi   SmallVector<int64_t> rootInterchange =
576e3d386eaSTobias Gysi       options.tileInterchange.empty()
577e3d386eaSTobias Gysi           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
578e3d386eaSTobias Gysi           : SmallVector<int64_t>(options.tileInterchange.begin(),
579e3d386eaSTobias Gysi                                  options.tileInterchange.begin() +
580e3d386eaSTobias Gysi                                      rootOp.getNumLoops());
581e3d386eaSTobias Gysi 
582e3d386eaSTobias Gysi   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
583e3d386eaSTobias Gysi   // It has to be a permutation since the tiling cannot tile the same loop
584e3d386eaSTobias Gysi   // dimension multiple times.
585e3d386eaSTobias Gysi   if (!isPermutation(rootInterchange))
586e3d386eaSTobias Gysi     return rewriter.notifyMatchFailure(
587e3d386eaSTobias Gysi         op, "expect the tile interchange permutes the root loops");
588e3d386eaSTobias Gysi 
589e3d386eaSTobias Gysi   // Tile `rootOp` and fuse its producers.
590e3d386eaSTobias Gysi   FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
591e3d386eaSTobias Gysi       rewriter, rootOp, rootTileSizes, rootInterchange);
592e3d386eaSTobias Gysi   if (failed(tileLoopNest))
593e3d386eaSTobias Gysi     return rewriter.notifyMatchFailure(
594e3d386eaSTobias Gysi         op, "tileConsumerAndFuseProducers failed unexpectedly");
595e3d386eaSTobias Gysi 
596e3d386eaSTobias Gysi   // Replace all uses of the tiled loop operation.
597e3d386eaSTobias Gysi   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
598e3d386eaSTobias Gysi 
599e3d386eaSTobias Gysi   // Apply the filter if specified.
600e3d386eaSTobias Gysi   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
601e3d386eaSTobias Gysi     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
602e3d386eaSTobias Gysi   return failure();
603e3d386eaSTobias Gysi }
604e3d386eaSTobias Gysi 
60506bb9cf3STobias Gysi /// Linalg generic interchange pattern.
60606bb9cf3STobias Gysi mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
60706bb9cf3STobias Gysi     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
60806bb9cf3STobias Gysi     LinalgTransformationFilter filter, PatternBenefit benefit)
60906bb9cf3STobias Gysi     : OpRewritePattern(context, benefit), filter(filter),
610307cfdf5SNicolas Vasilache       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
611307cfdf5SNicolas Vasilache 
61206bb9cf3STobias Gysi LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
61306bb9cf3STobias Gysi     GenericOp genericOp, PatternRewriter &rewriter) const {
61406bb9cf3STobias Gysi   if (failed(filter.checkAndNotify(rewriter, genericOp)))
615307cfdf5SNicolas Vasilache     return failure();
61606bb9cf3STobias Gysi   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
617307cfdf5SNicolas Vasilache     return failure();
618307cfdf5SNicolas Vasilache 
619307cfdf5SNicolas Vasilache   // TODO: figure out how this interplays with named ops. In particular this
620307cfdf5SNicolas Vasilache   // should break the named op property.
62106bb9cf3STobias Gysi   rewriter.updateRootInPlace(genericOp, [&]() {
62206bb9cf3STobias Gysi     interchangeGenericOp(rewriter, genericOp, interchangeVector);
623e4a503a2SNicolas Vasilache     // New filter if specified.
62406bb9cf3STobias Gysi     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
625307cfdf5SNicolas Vasilache   });
626307cfdf5SNicolas Vasilache   return success();
627307cfdf5SNicolas Vasilache }
628307cfdf5SNicolas Vasilache 
629e826db62STobias Gysi /// Linalg generalization pattern.
630e826db62STobias Gysi mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
631e826db62STobias Gysi     MLIRContext *context, LinalgTransformationFilter filter,
632e826db62STobias Gysi     PatternBenefit benefit)
633e826db62STobias Gysi     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
634e826db62STobias Gysi 
635e826db62STobias Gysi mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
636e826db62STobias Gysi     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
637e826db62STobias Gysi     PatternBenefit benefit)
638e826db62STobias Gysi     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
639e826db62STobias Gysi 
640e826db62STobias Gysi LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
641e826db62STobias Gysi     Operation *op, PatternRewriter &rewriter) const {
642e826db62STobias Gysi   if (failed(filter.checkAndNotify(rewriter, op)))
643e826db62STobias Gysi     return failure();
644e826db62STobias Gysi   if (failed(generalizeNamedOpPrecondition(op)))
645e826db62STobias Gysi     return failure();
646e826db62STobias Gysi 
647e826db62STobias Gysi   GenericOp genericOp = generalizeNamedOp(rewriter, op);
648e826db62STobias Gysi   rewriter.replaceOp(op, genericOp.getResults());
649e826db62STobias Gysi   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
650e826db62STobias Gysi   return success();
651e826db62STobias Gysi }
652e826db62STobias Gysi 
653307cfdf5SNicolas Vasilache mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
65492ea624aSNicolas Vasilache     MLIRContext *context, LinalgTransformationFilter filter,
65592ea624aSNicolas Vasilache     LinalgPromotionOptions options, PatternBenefit benefit)
65692ea624aSNicolas Vasilache     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
65792ea624aSNicolas Vasilache       options(options) {}
65892ea624aSNicolas Vasilache 
65992ea624aSNicolas Vasilache mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
6608dbbb223SNicolas Vasilache     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
661e4a503a2SNicolas Vasilache     LinalgTransformationFilter filter, PatternBenefit benefit)
66276f3c2f3SRiver Riddle     : RewritePattern(opName, benefit, context, {}), filter(filter),
6638dbbb223SNicolas Vasilache       options(options) {}
664307cfdf5SNicolas Vasilache 
665307cfdf5SNicolas Vasilache LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
666307cfdf5SNicolas Vasilache     Operation *op, PatternRewriter &rewriter) const {
667e4a503a2SNicolas Vasilache   if (failed(filter.checkAndNotify(rewriter, op)))
668307cfdf5SNicolas Vasilache     return failure();
6698dbbb223SNicolas Vasilache   if (failed(promoteSubviewsPrecondition(op, options)))
670307cfdf5SNicolas Vasilache     return failure();
6710ed2d4c7SMaheshRavishankar 
6720ed2d4c7SMaheshRavishankar   // TODO: We cannot use root update here. This pattern is creating other ops,
6730ed2d4c7SMaheshRavishankar   // so if the promotion fails, those need to be cleaned up, which doesnt seem
6740ed2d4c7SMaheshRavishankar   // to be happening here. So to fail properly, we should be cloning the op and
6750ed2d4c7SMaheshRavishankar   // deleting the previous op. This needs more investigation.
6760ed2d4c7SMaheshRavishankar   rewriter.startRootUpdate(op);
6770ed2d4c7SMaheshRavishankar   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
6780ed2d4c7SMaheshRavishankar   if (!promotedOp) {
6790ed2d4c7SMaheshRavishankar     rewriter.cancelRootUpdate(op);
6800ed2d4c7SMaheshRavishankar     return op->emitError("subview promotion failed");
6810ed2d4c7SMaheshRavishankar   }
6820ed2d4c7SMaheshRavishankar   rewriter.finalizeRootUpdate(op);
683e4a503a2SNicolas Vasilache   filter.replaceLinalgTransformationFilter(rewriter, op);
684307cfdf5SNicolas Vasilache   return success();
685307cfdf5SNicolas Vasilache }
686307cfdf5SNicolas Vasilache 
687307cfdf5SNicolas Vasilache mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
68876f3c2f3SRiver Riddle     MLIRContext *context, LinalgTransformationFilter filter,
68976f3c2f3SRiver Riddle     PatternBenefit benefit)
69076f3c2f3SRiver Riddle     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
691e4a503a2SNicolas Vasilache 
692e4a503a2SNicolas Vasilache mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
693e4a503a2SNicolas Vasilache     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
694307cfdf5SNicolas Vasilache     PatternBenefit benefit)
69576f3c2f3SRiver Riddle     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
696307cfdf5SNicolas Vasilache 
697307cfdf5SNicolas Vasilache LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
698307cfdf5SNicolas Vasilache     Operation *op, PatternRewriter &rewriter) const {
699307cfdf5SNicolas Vasilache   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
700b9715156STobias Gysi   if (!linalgOp)
701307cfdf5SNicolas Vasilache     return failure();
702e4a503a2SNicolas Vasilache   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
703307cfdf5SNicolas Vasilache     return failure();
704c1a4cd55STobias Gysi   SmallVector<Value> newResults;
705c1a4cd55STobias Gysi   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
706307cfdf5SNicolas Vasilache     return failure();
707c1a4cd55STobias Gysi   if (!newResults.empty())
708c1a4cd55STobias Gysi     rewriter.replaceOp(op, newResults);
7090fcbbde2SNicolas Vasilache   else
710307cfdf5SNicolas Vasilache     rewriter.eraseOp(op);
711307cfdf5SNicolas Vasilache   return success();
712307cfdf5SNicolas Vasilache }
713d12d05a7SNicolas Vasilache 
714d12d05a7SNicolas Vasilache LogicalResult mlir::linalg::applyStagedPatterns(
71579d7f618SChris Lattner     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
71679d7f618SChris Lattner     const FrozenRewritePatternSet &stage2Patterns,
71791beb517SNicolas Vasilache     function_ref<LogicalResult(Operation *)> stage3Lambda) {
71891beb517SNicolas Vasilache   unsigned iteration = 0;
71991beb517SNicolas Vasilache   (void)iteration;
720d12d05a7SNicolas Vasilache   for (const auto &patterns : stage1Patterns) {
72156ce65e2SNicolas Vasilache     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
72256ce65e2SNicolas Vasilache                       << *op);
7233e98fbf4SRiver Riddle     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
72456ce65e2SNicolas Vasilache       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
725d12d05a7SNicolas Vasilache       return failure();
726d12d05a7SNicolas Vasilache     }
72756ce65e2SNicolas Vasilache     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
72891beb517SNicolas Vasilache                       << *op);
7293e98fbf4SRiver Riddle     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
73056ce65e2SNicolas Vasilache       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
731d12d05a7SNicolas Vasilache       return failure();
732d12d05a7SNicolas Vasilache     }
73356ce65e2SNicolas Vasilache     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
73491beb517SNicolas Vasilache                       << *op);
735d12d05a7SNicolas Vasilache     if (stage3Lambda) {
736d12d05a7SNicolas Vasilache       if (failed(stage3Lambda(op)))
737d12d05a7SNicolas Vasilache         return failure();
73856ce65e2SNicolas Vasilache       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
73991beb517SNicolas Vasilache                         << *op);
740d12d05a7SNicolas Vasilache     }
741d12d05a7SNicolas Vasilache   }
742d12d05a7SNicolas Vasilache   return success();
743d12d05a7SNicolas Vasilache }
7443110e7b0SNicolas Vasilache 
7450804a88eSNicolas Agostini static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
7460804a88eSNicolas Agostini   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
7470804a88eSNicolas Agostini }
7480804a88eSNicolas Agostini 
7490804a88eSNicolas Agostini /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
7500804a88eSNicolas Agostini /// with pad_val) and GenericOp (to copy contents).
7510804a88eSNicolas Agostini LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
7520804a88eSNicolas Agostini     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
7530804a88eSNicolas Agostini 
7540804a88eSNicolas Agostini   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
7550804a88eSNicolas Agostini   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
7560804a88eSNicolas Agostini 
7570804a88eSNicolas Agostini   // Bail on non-static shapes.
7580804a88eSNicolas Agostini   if (!inputShapedType.hasStaticShape())
7590804a88eSNicolas Agostini     return failure();
7600804a88eSNicolas Agostini   if (!resultShapedType.hasStaticShape())
7610804a88eSNicolas Agostini     return failure();
7620804a88eSNicolas Agostini 
7630804a88eSNicolas Agostini   // Only support padding with a constant for now, i.e. either:
7640804a88eSNicolas Agostini   //   1. A BBarg from a different block.
7650804a88eSNicolas Agostini   //   2. A value defined outside of the current block.
7660804a88eSNicolas Agostini   Block &block = padOp.region().front();
7670804a88eSNicolas Agostini   auto yieldOp = cast<YieldOp>(block.getTerminator());
7680804a88eSNicolas Agostini   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
7690804a88eSNicolas Agostini   Value padValue = yieldOp.values().front();
7700804a88eSNicolas Agostini   Operation *definingOp = padValue.getDefiningOp();
7710804a88eSNicolas Agostini   if (definingOp && definingOp->getBlock() == &block)
7720804a88eSNicolas Agostini     return failure();
7730804a88eSNicolas Agostini   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
7740804a88eSNicolas Agostini     return failure();
7750804a88eSNicolas Agostini 
7760804a88eSNicolas Agostini   // Create tensor with the padded shape
7770804a88eSNicolas Agostini   Location loc = padOp.getLoc();
7780804a88eSNicolas Agostini   SmallVector<Value> indices(resultShapedType.getRank(),
779a54f4eaeSMogball                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
7800804a88eSNicolas Agostini   Value initTensor = rewriter.create<InitTensorOp>(
7810804a88eSNicolas Agostini       loc, resultShapedType.getShape(), resultShapedType.getElementType());
7820804a88eSNicolas Agostini 
7830804a88eSNicolas Agostini   // Initialize tensor with the pad value
7840804a88eSNicolas Agostini   Value tmpTensor =
7857cef24eeSTobias Gysi       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
7860804a88eSNicolas Agostini 
7870804a88eSNicolas Agostini   // Copy original contents into new tensor
788060208b4SMatthias Springer   // Uses linalg.generic, but could be done with tensor.insert_slice
7890804a88eSNicolas Agostini   SmallVector<AffineExpr, 4> outputExprs;
7900804a88eSNicolas Agostini   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
7910804a88eSNicolas Agostini     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
7920804a88eSNicolas Agostini                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
7930804a88eSNicolas Agostini   }
7940804a88eSNicolas Agostini 
7950804a88eSNicolas Agostini   SmallVector<AffineMap, 2> transferMaps = {
7960804a88eSNicolas Agostini       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
7970804a88eSNicolas Agostini       AffineMap::get(resultShapedType.getRank(),
7980804a88eSNicolas Agostini                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
7990804a88eSNicolas Agostini 
8000804a88eSNicolas Agostini   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
8010804a88eSNicolas Agostini       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
8020804a88eSNicolas Agostini       getNParallelLoopsAttrs(resultShapedType.getRank()),
8030804a88eSNicolas Agostini       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
8040804a88eSNicolas Agostini         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
8050804a88eSNicolas Agostini       });
8060804a88eSNicolas Agostini 
8070804a88eSNicolas Agostini   return success();
8080804a88eSNicolas Agostini }
80924199f53SMatthias Springer 
81035df2f6fSYi Zhang /// Filling `dest` using FillOp constant padding value if possible.
81135df2f6fSYi Zhang /// Otherwise, generate a tensor::GenerateOp.
81235df2f6fSYi Zhang Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
81335df2f6fSYi Zhang     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
81435df2f6fSYi Zhang     const SmallVector<Value> &dynSizes) const {
81535df2f6fSYi Zhang   auto padValue = padOp.getConstantPaddingValue();
81635df2f6fSYi Zhang   if (padValue)
81735df2f6fSYi Zhang     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
81835df2f6fSYi Zhang 
81935df2f6fSYi Zhang   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
82035df2f6fSYi Zhang   auto generateOp = rewriter.create<tensor::GenerateOp>(
82135df2f6fSYi Zhang       padOp.getLoc(), padOp.getResultType(), dynSizes);
82235df2f6fSYi Zhang   // Copy region to new op.
82335df2f6fSYi Zhang   BlockAndValueMapping bvm;
82435df2f6fSYi Zhang   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
82535df2f6fSYi Zhang   // Rewrite linalg::YieldOp to tensor::YieldOp.
82635df2f6fSYi Zhang   OpBuilder::InsertionGuard guard(rewriter);
82735df2f6fSYi Zhang   auto yieldOp =
82835df2f6fSYi Zhang       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
82935df2f6fSYi Zhang   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
83035df2f6fSYi Zhang   assert(yieldOp.values().size() == 1);
83135df2f6fSYi Zhang   rewriter.setInsertionPoint(yieldOp);
83235df2f6fSYi Zhang   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
83335df2f6fSYi Zhang   return generateOp;
83435df2f6fSYi Zhang }
83535df2f6fSYi Zhang 
83635df2f6fSYi Zhang LogicalResult
83735df2f6fSYi Zhang GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
83835df2f6fSYi Zhang                                               PatternRewriter &rewriter) const {
83935df2f6fSYi Zhang   // Given an OpFoldResult, return an index-typed value.
84035df2f6fSYi Zhang   auto getIdxValue = [&](OpFoldResult ofr) {
84135df2f6fSYi Zhang     if (auto val = ofr.dyn_cast<Value>())
84235df2f6fSYi Zhang       return val;
84335df2f6fSYi Zhang     return rewriter
844a54f4eaeSMogball         .create<arith::ConstantIndexOp>(
84535df2f6fSYi Zhang             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
84635df2f6fSYi Zhang         .getResult();
84735df2f6fSYi Zhang   };
84835df2f6fSYi Zhang 
84935df2f6fSYi Zhang   auto resultType = padOp.getResultType();
85035df2f6fSYi Zhang   // Compute size of InitTensorOp. Any combination of static/dynamic is
85135df2f6fSYi Zhang   // supported.
85235df2f6fSYi Zhang   SmallVector<Value> dynSizes;
85335df2f6fSYi Zhang   SmallVector<int64_t> staticSizes;
85435df2f6fSYi Zhang   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
85535df2f6fSYi Zhang     if (resultType.isDynamicDim(dim)) {
85635df2f6fSYi Zhang       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
85735df2f6fSYi Zhang                                                           padOp.source(), dim);
85835df2f6fSYi Zhang       // Add low and high padding value.
859a54f4eaeSMogball       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
86035df2f6fSYi Zhang           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
861a54f4eaeSMogball       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
86235df2f6fSYi Zhang           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
86335df2f6fSYi Zhang       dynSizes.push_back(plusHigh);
86435df2f6fSYi Zhang     }
86535df2f6fSYi Zhang     staticSizes.push_back(resultType.getDimSize(dim));
86635df2f6fSYi Zhang   }
86735df2f6fSYi Zhang 
86835df2f6fSYi Zhang   // Init tensor and fill it with padding.
86935df2f6fSYi Zhang   Value init = rewriter.create<InitTensorOp>(
87035df2f6fSYi Zhang       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
87135df2f6fSYi Zhang   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
87235df2f6fSYi Zhang 
87335df2f6fSYi Zhang   // Try optimize the copy of source.
87435df2f6fSYi Zhang   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
87535df2f6fSYi Zhang     return success();
87635df2f6fSYi Zhang 
87735df2f6fSYi Zhang   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
87835df2f6fSYi Zhang   // for copying the PadOp source.
87935df2f6fSYi Zhang   auto sourceType = padOp.getSourceType();
88035df2f6fSYi Zhang   // Compute size of source of PadTensorOp.
88135df2f6fSYi Zhang   SmallVector<OpFoldResult> srcSizes;
88235df2f6fSYi Zhang   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
88335df2f6fSYi Zhang     if (sourceType.isDynamicDim(dim)) {
88435df2f6fSYi Zhang       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
88535df2f6fSYi Zhang           padOp.getLoc(), padOp.source(), dim));
88635df2f6fSYi Zhang     } else {
88735df2f6fSYi Zhang       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
88835df2f6fSYi Zhang     }
88935df2f6fSYi Zhang   }
89035df2f6fSYi Zhang   // Strides of InsertSliceOp are all 1.
89135df2f6fSYi Zhang   SmallVector<OpFoldResult> strides(sourceType.getRank(),
89235df2f6fSYi Zhang                                     rewriter.getIndexAttr(1));
89335df2f6fSYi Zhang   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
89435df2f6fSYi Zhang       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
89535df2f6fSYi Zhang 
89635df2f6fSYi Zhang   return success();
89735df2f6fSYi Zhang }
89835df2f6fSYi Zhang 
899060208b4SMatthias Springer LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
900060208b4SMatthias Springer     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
901060208b4SMatthias Springer   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
90224199f53SMatthias Springer   if (!padOp)
90324199f53SMatthias Springer     return failure();
90424199f53SMatthias Springer   // Only unit stride supported.
905060208b4SMatthias Springer   if (!sliceOp.hasUnitStride())
90624199f53SMatthias Springer     return failure();
90724199f53SMatthias Springer 
908*61ba9f91SNicolas Vasilache   Operation *tiledPadOp =
909*61ba9f91SNicolas Vasilache       padOp
910*61ba9f91SNicolas Vasilache           .getTiledImplementation(
911ba72cfe7SMaheshRavishankar               rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
912*61ba9f91SNicolas Vasilache               sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
913*61ba9f91SNicolas Vasilache           .front();
91424199f53SMatthias Springer   // All shapes are static and the data source is actually used. Rewrite into
91524199f53SMatthias Springer   // pad_tensor(subtensor(x)).
916ba72cfe7SMaheshRavishankar   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
91724199f53SMatthias Springer   return success();
91824199f53SMatthias Springer }
9197b615a87SLei Zhang 
9207b615a87SLei Zhang namespace {
9217b615a87SLei Zhang // The following are patterns for downscaling convolution ops with size-1
9227b615a87SLei Zhang // window dimensions.
9237b615a87SLei Zhang //
9247b615a87SLei Zhang // Note that we'd eventually want to write such transformations in a generic
9257b615a87SLei Zhang // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
9267b615a87SLei Zhang // and then turning back to named ops. But for now it's fine to have a few
9277b615a87SLei Zhang // patterns matching special ops to get started.
9287b615a87SLei Zhang 
9297b615a87SLei Zhang /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
9307b615a87SLei Zhang /// convolution ops.
9317b615a87SLei Zhang struct DownscaleSizeOneWindowed2DConvolution final
9327b615a87SLei Zhang     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
93398dbcff1Sgysit   DownscaleSizeOneWindowed2DConvolution(
93498dbcff1Sgysit       MLIRContext *context,
93598dbcff1Sgysit       LinalgTransformationFilter filter = LinalgTransformationFilter(),
93698dbcff1Sgysit       PatternBenefit benefit = 1)
93798dbcff1Sgysit       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {}
9387b615a87SLei Zhang 
9397b615a87SLei Zhang   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
9407b615a87SLei Zhang                                 PatternRewriter &rewriter) const override {
94198dbcff1Sgysit     if (failed(filter.checkAndNotify(rewriter, convOp)))
94298dbcff1Sgysit       return failure();
94398dbcff1Sgysit     if (convOp.hasBufferSemantics())
9447b615a87SLei Zhang       return failure(); // To be implemented
9457b615a87SLei Zhang 
9467b615a87SLei Zhang     Value input = convOp.inputs().front();
94798dbcff1Sgysit     Value kernel = convOp.inputs().back();
9487b615a87SLei Zhang     Value output = convOp.outputs().front();
9497b615a87SLei Zhang 
9507b615a87SLei Zhang     auto inputType = input.getType().dyn_cast<RankedTensorType>();
95198dbcff1Sgysit     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
9527b615a87SLei Zhang     auto outputType = output.getType().dyn_cast<RankedTensorType>();
9537b615a87SLei Zhang 
95498dbcff1Sgysit     auto kernelShape = kernelType.getShape();
9557b615a87SLei Zhang     auto outputShape = outputType.getShape();
9567b615a87SLei Zhang 
9577b615a87SLei Zhang     // Only handle the case where at least one of the window dimensions is
9587b615a87SLei Zhang     // of size 1. Other cases can rely on tiling to reduce to such cases.
95998dbcff1Sgysit     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
9607b615a87SLei Zhang     int64_t ohSize = outputShape[1], owSize = outputShape[2];
96198dbcff1Sgysit     bool removeH = (khSize == 1 && ohSize == 1);
96298dbcff1Sgysit     bool removeW = (kwSize == 1 && owSize == 1);
963aa373180SNicolas Vasilache     if (!removeH && !removeW)
9647b615a87SLei Zhang       return failure();
9657b615a87SLei Zhang 
9667b615a87SLei Zhang     // Get new shapes and types for all operands by removing the size-1
9677b615a87SLei Zhang     // dimension.
968aa373180SNicolas Vasilache     using RTTBuilder = RankedTensorType::Builder;
969789c88e8SNicolas Vasilache     RankedTensorType newInputType =
970789c88e8SNicolas Vasilache         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
97198dbcff1Sgysit     RankedTensorType newKernelType =
97298dbcff1Sgysit         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
973789c88e8SNicolas Vasilache     RankedTensorType newOutputType =
974789c88e8SNicolas Vasilache         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
9757b615a87SLei Zhang 
976aa373180SNicolas Vasilache     // Rank-reduce operands.
9777b615a87SLei Zhang     Location loc = convOp.getLoc();
978aa373180SNicolas Vasilache     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
979aa373180SNicolas Vasilache         rewriter, loc, input, newInputType);
98098dbcff1Sgysit     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
98198dbcff1Sgysit         rewriter, loc, kernel, newKernelType);
982aa373180SNicolas Vasilache     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
983aa373180SNicolas Vasilache         rewriter, loc, output, newOutputType);
9847b615a87SLei Zhang 
985aa373180SNicolas Vasilache     // Rank-reduce strides and dilations too.
986aa373180SNicolas Vasilache     // TODO: dropDim 1-liner helper.
987aa373180SNicolas Vasilache     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
988aa373180SNicolas Vasilache     strides.erase(strides.begin() + (removeH ? 0 : 1));
989aa373180SNicolas Vasilache     auto stridesAttr = rewriter.getI64VectorAttr(strides);
990aa373180SNicolas Vasilache 
991aa373180SNicolas Vasilache     auto dilations =
992aa373180SNicolas Vasilache         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
993aa373180SNicolas Vasilache     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
994aa373180SNicolas Vasilache     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
9957b615a87SLei Zhang 
9967b615a87SLei Zhang     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
99798dbcff1Sgysit         loc, newOutputType, ValueRange{newInput, newKernel},
9987b615a87SLei Zhang         ValueRange{newOutput}, stridesAttr, dilationsAttr);
9997b615a87SLei Zhang 
1000aa373180SNicolas Vasilache     // Insert back.
1001aa373180SNicolas Vasilache     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1002aa373180SNicolas Vasilache         rewriter, loc, conv1DOp.getResult(0), output);
1003aa373180SNicolas Vasilache     rewriter.replaceOp(convOp, inserted);
1004aa373180SNicolas Vasilache 
100598dbcff1Sgysit     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
10067b615a87SLei Zhang     return success();
10077b615a87SLei Zhang   };
100898dbcff1Sgysit 
100998dbcff1Sgysit private:
101098dbcff1Sgysit   /// LinalgTransformMarker handles special attribute manipulations.
101198dbcff1Sgysit   LinalgTransformationFilter filter;
10127b615a87SLei Zhang };
10137b615a87SLei Zhang 
1014b828506eSNicolas Vasilache /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1015b828506eSNicolas Vasilache /// dimensions into 1-D depthwise convolution ops.
1016b828506eSNicolas Vasilache struct DownscaleDepthwiseConv2DNhwcHwcOp final
1017b828506eSNicolas Vasilache     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
101898dbcff1Sgysit   DownscaleDepthwiseConv2DNhwcHwcOp(
101998dbcff1Sgysit       MLIRContext *context,
102098dbcff1Sgysit       LinalgTransformationFilter filter = LinalgTransformationFilter(),
102198dbcff1Sgysit       PatternBenefit benefit = 1)
102298dbcff1Sgysit       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
102398dbcff1Sgysit         filter(filter) {}
1024b828506eSNicolas Vasilache 
1025b828506eSNicolas Vasilache   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1026b828506eSNicolas Vasilache                                 PatternRewriter &rewriter) const override {
102798dbcff1Sgysit     if (failed(filter.checkAndNotify(rewriter, convOp)))
102898dbcff1Sgysit       return failure();
102998dbcff1Sgysit     if (convOp.hasBufferSemantics())
1030b828506eSNicolas Vasilache       return failure(); // To be implemented
1031b828506eSNicolas Vasilache 
1032b828506eSNicolas Vasilache     Value input = convOp.inputs().front();
1033b828506eSNicolas Vasilache     Value kernel = convOp.inputs().back();
1034b828506eSNicolas Vasilache     Value output = convOp.outputs().front();
1035b828506eSNicolas Vasilache 
1036b828506eSNicolas Vasilache     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1037b828506eSNicolas Vasilache     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1038b828506eSNicolas Vasilache     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1039b828506eSNicolas Vasilache 
1040b828506eSNicolas Vasilache     auto kernelShape = kernelType.getShape();
1041b828506eSNicolas Vasilache     auto outputShape = outputType.getShape();
1042b828506eSNicolas Vasilache 
1043b828506eSNicolas Vasilache     // Only handle the case where at least one of the window dimensions is
1044b828506eSNicolas Vasilache     // of size 1. Other cases can rely on tiling to reduce to such cases.
1045b828506eSNicolas Vasilache     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1046b828506eSNicolas Vasilache     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1047b828506eSNicolas Vasilache     bool removeH = (khSize == 1 && ohSize == 1);
1048b828506eSNicolas Vasilache     bool removeW = (kwSize == 1 && owSize == 1);
1049b828506eSNicolas Vasilache     if (!removeH && !removeW)
1050b828506eSNicolas Vasilache       return failure();
1051b828506eSNicolas Vasilache 
1052b828506eSNicolas Vasilache     // Get new shapes and types for all operands by removing the size-1
1053b828506eSNicolas Vasilache     // dimension.
1054b828506eSNicolas Vasilache     using RTTBuilder = RankedTensorType::Builder;
1055789c88e8SNicolas Vasilache     RankedTensorType newInputType =
1056789c88e8SNicolas Vasilache         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1057789c88e8SNicolas Vasilache     RankedTensorType newKernelType =
1058789c88e8SNicolas Vasilache         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1059789c88e8SNicolas Vasilache     RankedTensorType newOutputType =
1060789c88e8SNicolas Vasilache         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1061b828506eSNicolas Vasilache 
1062b828506eSNicolas Vasilache     // Rank-reduce operands.
1063b828506eSNicolas Vasilache     Location loc = convOp.getLoc();
1064b828506eSNicolas Vasilache     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1065b828506eSNicolas Vasilache         rewriter, loc, input, newInputType);
1066b828506eSNicolas Vasilache     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1067b828506eSNicolas Vasilache         rewriter, loc, kernel, newKernelType);
1068b828506eSNicolas Vasilache     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1069b828506eSNicolas Vasilache         rewriter, loc, output, newOutputType);
1070b828506eSNicolas Vasilache 
1071b828506eSNicolas Vasilache     // Rank-reduce strides and dilations too.
1072b828506eSNicolas Vasilache     // TODO: dropDim 1-liner helper.
1073b828506eSNicolas Vasilache     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1074b828506eSNicolas Vasilache     strides.erase(strides.begin() + (removeH ? 0 : 1));
1075b828506eSNicolas Vasilache     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1076b828506eSNicolas Vasilache 
1077b828506eSNicolas Vasilache     auto dilations =
1078b828506eSNicolas Vasilache         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1079b828506eSNicolas Vasilache     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1080b828506eSNicolas Vasilache     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1081b828506eSNicolas Vasilache 
1082b828506eSNicolas Vasilache     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1083b828506eSNicolas Vasilache         loc, newOutputType, ValueRange{newInput, newKernel},
1084b828506eSNicolas Vasilache         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1085b828506eSNicolas Vasilache 
1086b828506eSNicolas Vasilache     // Insert back.
1087b828506eSNicolas Vasilache     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1088b828506eSNicolas Vasilache         rewriter, loc, conv1DOp.getResult(0), output);
1089b828506eSNicolas Vasilache     rewriter.replaceOp(convOp, inserted);
1090b828506eSNicolas Vasilache 
109198dbcff1Sgysit     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1092b828506eSNicolas Vasilache     return success();
1093b828506eSNicolas Vasilache   };
109498dbcff1Sgysit 
109598dbcff1Sgysit private:
109698dbcff1Sgysit   /// LinalgTransformMarker handles special attribute manipulations.
109798dbcff1Sgysit   LinalgTransformationFilter filter;
1098b828506eSNicolas Vasilache };
1099b828506eSNicolas Vasilache 
11007b615a87SLei Zhang } // namespace
11017b615a87SLei Zhang 
110298dbcff1Sgysit void linalg::populateDecomposeConvolutionPatterns(
110398dbcff1Sgysit     RewritePatternSet &patterns, LinalgTransformationFilter filter,
11047b615a87SLei Zhang     PatternBenefit benefit) {
1105b828506eSNicolas Vasilache   patterns.add<DownscaleSizeOneWindowed2DConvolution,
110698dbcff1Sgysit                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
11077b615a87SLei Zhang                                                   benefit);
11087b615a87SLei Zhang }
1109