1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 
47 //===----------------------------------------------------------------------===//
48 // Transformations exposed as rewrite patterns.
49 //===----------------------------------------------------------------------===//
50 // Marker used as attribute name in generated Linalg rewriting transformations.
51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
52     "__internal_linalg_transform__";
53 
54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
55     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
56     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement), matchByDefault(false) {}
58 
59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
60     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
61     Optional<StringAttr> replacement)
62     : filters(),
63       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
64       replacement(replacement), matchByDefault(false) {
65   if (f)
66     filters.push_back(f);
67 }
68 
69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
70     PatternRewriter &rewriter, Operation *op) const {
71   if (llvm::any_of(filters,
72                    [&](const FilterFunction &f) { return failed(f(op)); }))
73     return failure();
74 
75   auto attr = op->template getAttrOfType<StringAttr>(
76       LinalgTransforms::kLinalgTransformMarker);
77 
78   if (!attr) {
79     // 1. Has no filter case and matchDisjunction is empty.
80     if (matchDisjunction.empty() || matchByDefault)
81       return success();
82 
83     // 2. Has no filter but was expecting a filter.
84     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
85       diag << " does not have any filter from list: ";
86       interleaveComma(matchDisjunction, diag);
87     });
88   }
89 
90   // 4. Match explicit filter.
91   for (auto filter : matchDisjunction)
92     if (attr.getValue() == filter)
93       return success();
94 
95   // 5. Fail to match.
96   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
97     diag << " does not have any filter from list: ";
98     interleaveComma(matchDisjunction, diag);
99   });
100 }
101 
102 void mlir::linalg::LinalgTransformationFilter::
103     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
104                                       Operation *op) const {
105   if (replacement.hasValue())
106     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
107                 replacement.getValue());
108   else
109     op->removeAttr(
110         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
111 }
112 
113 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
114     Operation *op) const {
115   if (!replacement)
116     return false;
117   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
118                   .dyn_cast<StringAttr>();
119   return attr && attr == replacement.getValue();
120 }
121 
122 LinalgTilingOptions &
123 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
124   assert(!tileSizeComputationFunction && "tile sizes already set");
125   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
126   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
127     OpBuilder::InsertionGuard guard(b);
128     b.setInsertionPointToStart(
129         &op->getParentOfType<func::FuncOp>().getBody().front());
130     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
131       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
132       return v;
133     }));
134   };
135   return *this;
136 }
137 
138 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
139   assert(!tileSizeComputationFunction && "tile sizes already set");
140   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
141     SmallVector<Value, 4> tileSizes;
142     auto linalgOp = dyn_cast<LinalgOp>(op);
143     if (!linalgOp)
144       return tileSizes;
145     Location loc = linalgOp.getLoc();
146     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
147     AffineMap map = linalgOp.getShapesToLoopsMap();
148     if (!map)
149       return tileSizes;
150     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
151     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
152     // size 0).
153     for (Value shapeSize : shapeSizes)
154       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
155                               ? b.create<arith::ConstantIndexOp>(loc, 0)
156                               : b.create<arith::ConstantIndexOp>(loc, 1));
157     return tileSizes;
158   };
159   return *this;
160 }
161 
162 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
163 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
164 /// Exit early and return the `opOperand` value if the shape dimensions that
165 /// match `paddingDimensions` have a static size and the nofold flag is not set.
166 /// Otherwise, try to pad the shape dimensions that match the iterator
167 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
168 /// padding succeeds or failure otherwise.
169 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
170     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
171     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
172     ArrayRef<bool> packPaddings) {
173   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
174   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
175 
176   // Collect the shape dimension that are a function of the `paddingDimensions`.
177   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
178   for (int64_t dim : paddingDimensions)
179     for (const auto &en : enumerate(indexingMap.getResults()))
180       if (en.value().isFunctionOfDim(dim))
181         shapeDimsToPad.insert(en.index());
182 
183   // Return the unpadded operand if padding to a static shape is not needed and
184   // if the nofold flag is not set.
185   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
186                     ? packPaddings[opOperand->getOperandNumber()]
187                     : false;
188   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
189     return ShapedType::isDynamic(shape[dim]);
190   });
191   if (!nofold && hasStaticShape)
192     return opOperand->get();
193 
194   // Fail if `paddingValues` specifies no padding value.
195   if (opOperand->getOperandNumber() >= paddingValues.size())
196     return failure();
197   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
198   Value paddingValue = b.create<arith::ConstantOp>(
199       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
200 
201   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
202   OpOperand *currOpOperand = opOperand;
203   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
204     OpResult result = currOpOperand->get().cast<OpResult>();
205     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
206   }
207 
208   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
209   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
210   if (!sliceOp)
211     return failure();
212 
213   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
214   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
215   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
216 
217   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
218   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
219   int64_t shapeIdx = 0;
220   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
221     // Skip dropped dimensions.
222     if (droppedDims.test(en.index()))
223       continue;
224     // Skip dimensions that do not require padding.
225     if (!shapeDimsToPad.contains(shapeIdx)) {
226       shapeIdx++;
227       continue;
228     }
229     // If the size is an attribute add it directly to `paddedShape`.
230     if (en.value().is<Attribute>()) {
231       paddedShape[shapeIdx++] =
232           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
233       continue;
234     }
235     // Otherwise, try to compute a constant upper bound for the size value.
236     FailureOr<int64_t> upperBound =
237         getConstantUpperBoundForIndex(en.value().get<Value>());
238     if (failed(upperBound)) {
239       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
240       return failure();
241     }
242     paddedShape[shapeIdx++] = upperBound.getValue();
243   }
244   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
245          "expect the dynamic and static ranks to match");
246 
247   // Pad the operand to the bounding box defined by `paddedShape`.
248   auto paddedTensorType = RankedTensorType::get(
249       paddedShape, getElementTypeOrSelf(opOperand->get()));
250   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
251                                opOperand->get(), paddingValue, nofold);
252 }
253 
254 FailureOr<SmallVector<Value>>
255 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
256                           ArrayRef<int64_t> paddingDimensions,
257                           ArrayRef<Attribute> paddingValues,
258                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
259   Location loc = opToPad->getLoc();
260 
261   // TODO: there are cases where we may still want to pad to larger sizes.
262   assert(opToPad.hasTensorSemantics() &&
263          "expected operation to have tensor semantics");
264 
265   OpBuilder::InsertionGuard g(b);
266   // Set IP after op because we also take the dims of the original output.
267   b.setInsertionPointAfter(opToPad);
268   // Make a copy of the shaped operands and update it.
269   SmallVector<Value> newOperands;
270   newOperands.reserve(opToPad.getNumInputsAndOutputs());
271   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
272     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
273         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
274     // Exit if `paddingDimensions` cannot be bounded statically.
275     if (failed(paddedOperand))
276       return failure();
277     newOperands.push_back(*paddedOperand);
278   }
279 
280   SmallVector<SmallVector<Value>> reifiedResultShapes;
281   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
282                  .reifyResultShapes(b, reifiedResultShapes)))
283     return failure();
284   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
285          "expected same number of results");
286 
287   // Clone `opToPad` to operate on the statically padded shapes.
288   auto resultTensorTypes =
289       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
290   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
291 
292   // Recover the slice out of the new static results. This keeps the original
293   // linalg op around because it uses the dims of the original results.
294   SmallVector<Value> paddedSubviewResults;
295   paddedSubviewResults.reserve(opToPad->getNumResults());
296   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
297     Value paddedResult = en.value();
298     int64_t resultNumber = en.index();
299     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
300     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
301     SmallVector<OpFoldResult> sizes;
302     for (Value v : reifiedResultShapes[resultNumber])
303       sizes.push_back(getAsOpFoldResult(v));
304     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
305     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
306         loc, paddedResult, offsets, sizes, strides));
307   }
308   return paddedSubviewResults;
309 }
310 
311 /// Try to peel a loop `op` and return the new result.
312 // TODO: Add support for scf.parallel and affine.for loops.
313 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
314   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
315       .Case<scf::ForOp>([&](scf::ForOp forOp) {
316         scf::ForOp partialIteration;
317         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
318                                                       partialIteration)))
319           return partialIteration->getResults();
320         assert(!partialIteration && "expected that loop was not peeled");
321         return forOp->getResults();
322       })
323       .Default([&](Operation *op) { return op->getResults(); });
324 }
325 
326 /// Peel loops after tiling.
327 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
328                                      ArrayRef<int64_t> peeledLoops,
329                                      LinalgTilingLoopType loopType) {
330   for (int64_t loop : peeledLoops) {
331     assert(loop < static_cast<int64_t>(res.loops.size()) &&
332            "requested peeling of non-existing loop");
333     SmallVector<Value, 4> loopResults;
334     Operation *loopOp = res.loops[loop];
335     loopResults = peelLoop(rewriter, loopOp);
336 
337     // The result of the loop nest may change with peeling.
338     if (res.tensorResults.size() == loopOp->getNumResults() &&
339         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
340                    loopOp->getResults().begin()))
341       res.tensorResults = loopResults;
342   }
343 }
344 
345 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
346   if (tiledOp.loops.empty())
347     return tiledOp.op.getOperation()->getResults();
348   return tiledOp.loops.front()->getResults();
349 }
350 
351 static ValueRange
352 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
353   if (tiledAndFusedOp.fusedLoops.empty())
354     return tiledAndFusedOp.op.getOperation()->getResults();
355   return tiledAndFusedOp.fusedLoops.front()->getResults();
356 }
357 
358 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
359     StringRef opName, MLIRContext *context,
360     const LinalgDependenceGraph &dependenceGraph,
361     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
362     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
363     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
364     : RewritePattern(opName, benefit, context, {}),
365       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
366       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
367       fusedOpMarker(std::move(fusedOpMarker)),
368       originalOpMarker(std::move(originalOpMarker)) {}
369 
370 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
371     Operation *op, PatternRewriter &rewriter) const {
372   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
373   // TODO: remove hasIndexSemantics check once index ops are supported.
374   if (!linalgOp || linalgOp.hasIndexSemantics())
375     return failure();
376   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
377     return failure();
378 
379   DenseSet<Operation *> producers;
380   producers.insert(linalgOp);
381   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
382     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
383     // When looking at dependences into, indexingOp is always OpOperand. We
384     // could assert, but continue if this is not the case.
385     if (!operandNumber)
386       continue;
387     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
388       continue;
389     if (isa<LinalgOp>(dependence.getDependentOp()))
390       producers.insert(dependence.getDependentOp());
391   }
392 
393   SmallVector<LinalgOp, 1> fusionOps;
394   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
395        ++it) {
396     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
397     if (producerLinalgOp && producers.count(producerLinalgOp))
398       fusionOps.push_back(producerLinalgOp);
399   }
400   fusionOps.push_back(linalgOp);
401 
402   SmallVector<Value, 4> tileSizes =
403       tilingOptions.tileSizeComputationFunction(rewriter, op);
404   LinalgTilingOptions instanceTilingOptions = tilingOptions;
405   instanceTilingOptions.setTileSizes(tileSizes);
406   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
407       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
408   if (!tiledAndFusedOps)
409     return failure();
410 
411   // Tile the unfused loops;
412   SmallVector<Value, 4> unfusedLoopTileSizes;
413   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
414   for (const auto &tileSize : enumerate(tileSizes)) {
415     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
416       unfusedLoopTileSizes.push_back(zero);
417     else
418       unfusedLoopTileSizes.push_back(tileSize.value());
419   }
420   // Tile the loop only if there is a non-zero tile size.
421   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
422     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
423   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
424         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
425           return cst.value() != 0;
426         return true;
427       })) {
428     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
429     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
430     FailureOr<TiledLinalgOp> unfusedTiledOp =
431         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
432     if (failed(unfusedTiledOp))
433       return failure();
434     rewriter.replaceOp(tiledAndFusedOps->op,
435                        getTiledOpResult(unfusedTiledOp.getValue()));
436     tiledAndFusedOps->op = unfusedTiledOp->op;
437   }
438   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
439 
440   filter.replaceLinalgTransformationFilter(rewriter,
441                                            tiledAndFusedOps->op.getOperation());
442   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
443     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
444                                                     fusedOp.getOperation());
445   }
446   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
447     originalOpMarker.replaceLinalgTransformationFilter(
448         rewriter, origProducerOp.getOperation());
449   }
450   rewriter.updateRootInPlace(op, [&]() {
451     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
452   });
453   return success();
454 }
455 
456 /// Linalg tiling pattern.
457 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
458     MLIRContext *context, LinalgTilingOptions options,
459     LinalgTransformationFilter f, PatternBenefit benefit)
460     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
461       filter(std::move(f)), options(std::move(options)) {}
462 
463 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
464     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
465     LinalgTransformationFilter f, PatternBenefit benefit)
466     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
467       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
468 
469 FailureOr<TiledLinalgOp>
470 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
471     LinalgOp op, PatternRewriter &rewriter) const {
472   if (failed(filter.checkAndNotify(rewriter, op)))
473     return failure();
474 
475   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
476   if (failed(res))
477     return failure();
478 
479   // Clear filter to stop recursive pattern application.
480   // This must be done here to properly propagate to peeling branches.
481   filter.replaceLinalgTransformationFilter(rewriter, res->op);
482 
483   // Peel the loops of the TiledLinalgOp.
484   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
485 
486   if (res->tensorResults.empty())
487     rewriter.eraseOp(op);
488   else
489     rewriter.replaceOp(op, res->tensorResults);
490 
491   return res;
492 }
493 
494 /// Linalg padding pattern.
495 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
496     MLIRContext *context, LinalgPaddingOptions options,
497     LinalgTransformationFilter f, PatternBenefit benefit)
498     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
499       filter(std::move(f)), options(std::move(options)) {}
500 
501 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
502     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
503     LinalgTransformationFilter f, PatternBenefit benefit)
504     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
505       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
506 
507 FailureOr<LinalgOp>
508 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
509     LinalgOp linalgOp, PatternRewriter &rewriter) const {
510   if (!linalgOp.hasTensorSemantics())
511     return failure();
512   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
513     return failure();
514 
515   // Pad the operation.
516   LinalgOp paddedOp;
517   FailureOr<SmallVector<Value>> newResults =
518       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
519                         options.paddingValues, options.packPaddings, paddedOp);
520   if (failed(newResults))
521     return failure();
522 
523   // Hoist the padding.
524   for (const auto &en : enumerate(options.hoistPaddings)) {
525     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
526       break;
527     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
528     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
529     if (!padOp || en.value() == 0)
530       continue;
531 
532     // Fail hoisting if the operand shape is not fully static.
533     if (llvm::any_of(paddedOp.getShape(opOperand),
534                      [](int64_t size) { return ShapedType::isDynamic(size); }))
535       return failure();
536 
537     tensor::PadOp hoistedOp;
538     SmallVector<GenericOp> transposeOps;
539     SmallVector<int64_t> transposeVector =
540         en.index() < options.transposePaddings.size()
541             ? options.transposePaddings[en.index()]
542             : SmallVector<int64_t>{};
543 
544     FailureOr<Value> newResult = hoistPaddingOnTensors(
545         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
546     if (failed(newResult))
547       continue;
548     rewriter.replaceOp(padOp, newResult.getValue());
549 
550     // Do not apply hoist padding to the newly introduced transpose operations.
551     for (GenericOp transposeOp : transposeOps)
552       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
553   }
554 
555   // Replace the original operation to pad.
556   rewriter.replaceOp(linalgOp, newResults.getValue());
557   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
558 
559   return paddedOp;
560 }
561 
562 /// Linalg tile and fuse tensor ops pattern.
563 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
564     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
565                                       LinalgTilingAndFusionOptions options,
566                                       LinalgTransformationFilter f,
567                                       PatternBenefit benefit)
568     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
569       filter(std::move(f)), options(std::move(options)) {}
570 
571 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
572     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
573                                       LinalgTilingAndFusionOptions options,
574                                       LinalgTransformationFilter f,
575                                       PatternBenefit benefit)
576     : RewritePattern(opName, benefit, context), filter(std::move(f)),
577       options(std::move(options)) {}
578 
579 FailureOr<mlir::linalg::TileLoopNest>
580 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
581     Operation *op, PatternRewriter &rewriter) const {
582   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
583   if (!rootOp)
584     return failure();
585   if (failed(filter.checkAndNotify(rewriter, op)))
586     return failure();
587 
588   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
589   if (options.tileSizes.size() < rootOp.getNumLoops())
590     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
591 
592   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
593   if (!options.tileInterchange.empty() &&
594       options.tileInterchange.size() != options.tileSizes.size())
595     return rewriter.notifyMatchFailure(
596         op, "expect the number of tile sizes and interchange dims to match");
597 
598   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
599   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
600                                      options.tileSizes.begin() +
601                                          rootOp.getNumLoops());
602   SmallVector<int64_t> rootInterchange =
603       options.tileInterchange.empty()
604           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
605           : SmallVector<int64_t>(options.tileInterchange.begin(),
606                                  options.tileInterchange.begin() +
607                                      rootOp.getNumLoops());
608 
609   // Check `rootTileSizes` contains non-zero tile sizes.
610   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
611     return rewriter.notifyMatchFailure(
612         op, "expect at least one non-zero tile size");
613 
614   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
615   // It has to be a permutation since the tiling cannot tile the same loop
616   // dimension multiple times.
617   if (!isPermutation(rootInterchange))
618     return rewriter.notifyMatchFailure(
619         op, "expect the tile interchange permutes the root loops");
620 
621   // Tile `rootOp` and fuse its producers.
622   FailureOr<TileLoopNest> tileLoopNest =
623       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
624                                    rootInterchange, options.tileDistribution);
625   if (failed(tileLoopNest))
626     return rewriter.notifyMatchFailure(
627         op, "tileConsumerAndFuseProducers failed unexpectedly");
628 
629   // Replace all uses of the tiled loop operation.
630   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
631 
632   // Apply the filter if specified.
633   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
634     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
635   return tileLoopNest;
636 }
637 
638 /// Linalg generic interchange pattern.
639 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
640     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
641     LinalgTransformationFilter f, PatternBenefit benefit)
642     : OpRewritePattern(context, benefit), filter(std::move(f)),
643       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
644 
645 FailureOr<GenericOp>
646 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
647     GenericOp genericOp, PatternRewriter &rewriter) const {
648   if (failed(filter.checkAndNotify(rewriter, genericOp)))
649     return failure();
650 
651   FailureOr<GenericOp> transformedOp =
652       interchangeGenericOp(rewriter, genericOp, interchangeVector);
653   if (failed(transformedOp))
654     return failure();
655 
656   // New filter if specified.
657   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
658   return transformedOp;
659 }
660 
661 /// Linalg generalization pattern.
662 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
663     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
664     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
665       filter(std::move(f)) {}
666 
667 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
668     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
669     PatternBenefit benefit)
670     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
671       filter(f.addOpNameFilter(opName)) {}
672 
673 FailureOr<GenericOp>
674 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
675     LinalgOp linalgOp, PatternRewriter &rewriter) const {
676   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
677     return failure();
678   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
679   if (failed(genericOp))
680     return failure();
681   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
682   return genericOp;
683 }
684 
685 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
686     MLIRContext *context, LinalgTransformationFilter f,
687     LinalgPromotionOptions options, PatternBenefit benefit)
688     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
689       filter(std::move(f)), options(std::move(options)) {}
690 
691 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
692     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
693     LinalgTransformationFilter f, PatternBenefit benefit)
694     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
695       options(std::move(options)) {}
696 
697 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
698     Operation *op, PatternRewriter &rewriter) const {
699   if (failed(filter.checkAndNotify(rewriter, op)))
700     return failure();
701   if (failed(promoteSubviewsPrecondition(op, options)))
702     return failure();
703 
704   // TODO: We cannot use root update here. This pattern is creating other ops,
705   // so if the promotion fails, those need to be cleaned up, which doesnt seem
706   // to be happening here. So to fail properly, we should be cloning the op and
707   // deleting the previous op. This needs more investigation.
708   rewriter.startRootUpdate(op);
709   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
710   if (!promotedOp) {
711     rewriter.cancelRootUpdate(op);
712     return op->emitError("subview promotion failed");
713   }
714   rewriter.finalizeRootUpdate(op);
715   filter.replaceLinalgTransformationFilter(rewriter, op);
716   return success();
717 }
718 
719 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
720     MLIRContext *context, LinalgTransformationFilter f,
721     LinalgVectorizationOptions options, PatternBenefit benefit)
722     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
723       filter(std::move(f)) {}
724 
725 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
726     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
727     LinalgTransformationFilter f, PatternBenefit benefit)
728     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
729       filter(f.addOpNameFilter(opName)) {}
730 
731 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
732     LinalgOp linalgOp, PatternRewriter &rewriter) const {
733   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
734     return failure();
735   return vectorize(rewriter, linalgOp);
736 }
737 
738 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
739     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
740   return vectorizeCopy(rewriter, copyOp);
741 }
742 
743 LogicalResult mlir::linalg::applyStagedPatterns(
744     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
745     const FrozenRewritePatternSet &stage2Patterns,
746     function_ref<LogicalResult(Operation *)> stage3Lambda) {
747   unsigned iteration = 0;
748   (void)iteration;
749   for (const auto &patterns : stage1Patterns) {
750     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
751                       << *op);
752     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
753       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
754       return failure();
755     }
756     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
757                       << *op);
758     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
759       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
760       return failure();
761     }
762     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
763                       << *op);
764     if (stage3Lambda) {
765       if (failed(stage3Lambda(op)))
766         return failure();
767       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
768                         << *op);
769     }
770   }
771   return success();
772 }
773 
774 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
775   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
776 }
777 
778 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
779 /// initialize with pad_val) and GenericOp (to copy contents).
780 LogicalResult
781 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
782                                             PatternRewriter &rewriter) const {
783 
784   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
785   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
786 
787   // Bail on non-static shapes.
788   if (!inputShapedType.hasStaticShape())
789     return failure();
790   if (!resultShapedType.hasStaticShape())
791     return failure();
792 
793   // Only support padding with a constant for now, i.e. either:
794   //   1. A BBarg from a different block.
795   //   2. A value defined outside of the current block.
796   Block &block = padOp.region().front();
797   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
798   Value padValue = yieldOp.value();
799   Operation *definingOp = padValue.getDefiningOp();
800   if (definingOp && definingOp->getBlock() == &block)
801     return failure();
802   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
803     return failure();
804 
805   // Create tensor with the padded shape
806   Location loc = padOp.getLoc();
807   SmallVector<Value> indices(resultShapedType.getRank(),
808                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
809   Value initTensor = rewriter.create<InitTensorOp>(
810       loc, resultShapedType.getShape(), resultShapedType.getElementType());
811 
812   // Initialize tensor with the pad value
813   Value tmpTensor = rewriter
814                         .create<linalg::FillOp>(loc, ValueRange{padValue},
815                                                 ValueRange{initTensor})
816                         .result();
817 
818   // Copy original contents into new tensor
819   // Uses linalg.generic, but could be done with tensor.insert_slice
820   SmallVector<AffineExpr, 4> outputExprs;
821   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
822     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
823                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
824   }
825 
826   SmallVector<AffineMap, 2> transferMaps = {
827       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
828       AffineMap::get(resultShapedType.getRank(),
829                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
830 
831   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
832       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
833       getNParallelLoopsAttrs(resultShapedType.getRank()),
834       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
835         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
836       });
837 
838   return success();
839 }
840 
841 /// Filling `dest` using FillOp constant padding value if possible.
842 /// Otherwise, generate a tensor::GenerateOp.
843 Value GeneralizePadOpPattern::createFillOrGenerateOp(
844     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
845     const SmallVector<Value> &dynSizes) const {
846   auto padValue = padOp.getConstantPaddingValue();
847   if (padValue)
848     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
849 
850   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
851   auto generateOp = rewriter.create<tensor::GenerateOp>(
852       padOp.getLoc(), padOp.getResultType(), dynSizes);
853   // Copy region to new op.
854   BlockAndValueMapping bvm;
855   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
856   return generateOp;
857 }
858 
859 LogicalResult
860 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
861                                         PatternRewriter &rewriter) const {
862   // Given an OpFoldResult, return an index-typed value.
863   auto getIdxValue = [&](OpFoldResult ofr) {
864     if (auto val = ofr.dyn_cast<Value>())
865       return val;
866     return rewriter
867         .create<arith::ConstantIndexOp>(
868             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
869         .getResult();
870   };
871 
872   auto resultType = padOp.getResultType();
873   // Compute size of InitTensorOp. Any combination of static/dynamic is
874   // supported.
875   SmallVector<Value> dynSizes;
876   SmallVector<int64_t> staticSizes;
877   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
878     if (resultType.isDynamicDim(dim)) {
879       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
880                                                           padOp.source(), dim);
881       // Add low and high padding value.
882       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
883           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
884       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
885           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
886       dynSizes.push_back(plusHigh);
887     }
888     staticSizes.push_back(resultType.getDimSize(dim));
889   }
890 
891   // Init tensor and fill it with padding.
892   Value init = rewriter.create<InitTensorOp>(
893       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
894   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
895 
896   // Try optimize the copy of source.
897   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
898     return success();
899 
900   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
901   // for copying the PadOp source.
902   auto sourceType = padOp.getSourceType();
903   // Compute size of source of tensor::PadOp.
904   SmallVector<OpFoldResult> srcSizes;
905   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
906     if (sourceType.isDynamicDim(dim)) {
907       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
908           padOp.getLoc(), padOp.source(), dim));
909     } else {
910       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
911     }
912   }
913   // Strides of InsertSliceOp are all 1.
914   SmallVector<OpFoldResult> strides(sourceType.getRank(),
915                                     rewriter.getIndexAttr(1));
916   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
917       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
918 
919   return success();
920 }
921 
922 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
923     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
924   if (!sliceOp.hasUnitStride())
925     return failure();
926 
927   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
928   if (!padOp)
929     return failure();
930 
931   bool zeroSliceGuard = true;
932   if (controlFn) {
933     if (Optional<bool> control = controlFn(sliceOp))
934       zeroSliceGuard = control.getValue();
935     else
936       return failure();
937   }
938 
939   Operation *tiledPadOp =
940       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
941                                sliceOp.getMixedSizes(), zeroSliceGuard);
942   // All shapes are static and the data source is actually used. Rewrite into
943   // pad(extract_slice(x)).
944   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
945   return success();
946 }
947 
948 namespace {
949 // The following are patterns for downscaling convolution ops with size-1
950 // window dimensions.
951 //
952 // Note that we'd eventually want to write such transformations in a generic
953 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
954 // and then turning back to named ops. But for now it's fine to have a few
955 // patterns matching special ops to get started.
956 
957 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
958 /// convolution ops.
959 struct DownscaleSizeOneWindowed2DConvolution final
960     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
961   DownscaleSizeOneWindowed2DConvolution(
962       MLIRContext *context,
963       LinalgTransformationFilter f = LinalgTransformationFilter(),
964       PatternBenefit benefit = 1)
965       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
966         filter(std::move(f)) {}
967 
968   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
969                                 PatternRewriter &rewriter) const override {
970     if (failed(filter.checkAndNotify(rewriter, convOp)))
971       return failure();
972     if (convOp.hasBufferSemantics())
973       return failure(); // To be implemented
974 
975     Value input = convOp.inputs().front();
976     Value kernel = convOp.inputs().back();
977     Value output = convOp.outputs().front();
978 
979     auto inputType = input.getType().dyn_cast<RankedTensorType>();
980     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
981     auto outputType = output.getType().dyn_cast<RankedTensorType>();
982 
983     auto kernelShape = kernelType.getShape();
984     auto outputShape = outputType.getShape();
985 
986     // Only handle the case where at least one of the window dimensions is
987     // of size 1. Other cases can rely on tiling to reduce to such cases.
988     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
989     int64_t ohSize = outputShape[1], owSize = outputShape[2];
990     bool removeH = (khSize == 1 && ohSize == 1);
991     bool removeW = (kwSize == 1 && owSize == 1);
992     if (!removeH && !removeW)
993       return failure();
994 
995     // Get new shapes and types for all operands by removing the size-1
996     // dimension.
997     using RTTBuilder = RankedTensorType::Builder;
998     RankedTensorType newInputType =
999         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1000     RankedTensorType newKernelType =
1001         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1002     RankedTensorType newOutputType =
1003         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1004 
1005     // Rank-reduce operands.
1006     Location loc = convOp.getLoc();
1007     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1008         rewriter, loc, input, newInputType);
1009     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1010         rewriter, loc, kernel, newKernelType);
1011     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1012         rewriter, loc, output, newOutputType);
1013 
1014     // Rank-reduce strides and dilations too.
1015     // TODO: dropDim 1-liner helper.
1016     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1017     strides.erase(strides.begin() + (removeH ? 0 : 1));
1018     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1019 
1020     auto dilations =
1021         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1022     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1023     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1024 
1025     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1026         loc, newOutputType, ValueRange{newInput, newKernel},
1027         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1028 
1029     // Insert back.
1030     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1031         rewriter, loc, conv1DOp.getResult(0), output);
1032     rewriter.replaceOp(convOp, inserted);
1033 
1034     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1035     return success();
1036   };
1037 
1038 private:
1039   /// LinalgTransformMarker handles special attribute manipulations.
1040   LinalgTransformationFilter filter;
1041 };
1042 
1043 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1044 /// dimensions into 1-D depthwise convolution ops.
1045 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1046     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1047   DownscaleDepthwiseConv2DNhwcHwcOp(
1048       MLIRContext *context,
1049       LinalgTransformationFilter f = LinalgTransformationFilter(),
1050       PatternBenefit benefit = 1)
1051       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1052         filter(std::move(f)) {}
1053 
1054   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1055                                 PatternRewriter &rewriter) const override {
1056     if (failed(filter.checkAndNotify(rewriter, convOp)))
1057       return failure();
1058     if (convOp.hasBufferSemantics())
1059       return failure(); // To be implemented
1060 
1061     Value input = convOp.inputs().front();
1062     Value kernel = convOp.inputs().back();
1063     Value output = convOp.outputs().front();
1064 
1065     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1066     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1067     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1068 
1069     auto kernelShape = kernelType.getShape();
1070     auto outputShape = outputType.getShape();
1071 
1072     // Only handle the case where at least one of the window dimensions is
1073     // of size 1. Other cases can rely on tiling to reduce to such cases.
1074     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1075     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1076     bool removeH = (khSize == 1 && ohSize == 1);
1077     bool removeW = (kwSize == 1 && owSize == 1);
1078     if (!removeH && !removeW)
1079       return failure();
1080 
1081     // Get new shapes and types for all operands by removing the size-1
1082     // dimension.
1083     using RTTBuilder = RankedTensorType::Builder;
1084     RankedTensorType newInputType =
1085         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1086     RankedTensorType newKernelType =
1087         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1088     RankedTensorType newOutputType =
1089         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1090 
1091     // Rank-reduce operands.
1092     Location loc = convOp.getLoc();
1093     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1094         rewriter, loc, input, newInputType);
1095     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1096         rewriter, loc, kernel, newKernelType);
1097     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1098         rewriter, loc, output, newOutputType);
1099 
1100     // Rank-reduce strides and dilations too.
1101     // TODO: dropDim 1-liner helper.
1102     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1103     strides.erase(strides.begin() + (removeH ? 0 : 1));
1104     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1105 
1106     auto dilations =
1107         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1108     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1109     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1110 
1111     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1112         loc, newOutputType, ValueRange{newInput, newKernel},
1113         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1114 
1115     // Insert back.
1116     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1117         rewriter, loc, conv1DOp.getResult(0), output);
1118     rewriter.replaceOp(convOp, inserted);
1119 
1120     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1121     return success();
1122   };
1123 
1124 private:
1125   /// LinalgTransformMarker handles special attribute manipulations.
1126   LinalgTransformationFilter filter;
1127 };
1128 
1129 } // namespace
1130 
1131 void linalg::populateDecomposeConvolutionPatterns(
1132     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1133     PatternBenefit benefit) {
1134   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1135                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1136                                                   benefit);
1137 }
1138