1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 
47 //===----------------------------------------------------------------------===//
48 // Transformations exposed as rewrite patterns.
49 //===----------------------------------------------------------------------===//
50 // Marker used as attribute name in generated Linalg rewriting transformations.
51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
52     "__internal_linalg_transform__";
53 
54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
55     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
56     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement), matchByDefault(false) {}
58 
59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
60     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
61     Optional<StringAttr> replacement)
62     : filters(),
63       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
64       replacement(replacement), matchByDefault(false) {
65   if (f)
66     filters.push_back(f);
67 }
68 
69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
70     PatternRewriter &rewriter, Operation *op) const {
71   if (llvm::any_of(filters,
72                    [&](const FilterFunction &f) { return failed(f(op)); }))
73     return failure();
74 
75   auto attr = op->template getAttrOfType<StringAttr>(
76       LinalgTransforms::kLinalgTransformMarker);
77 
78   if (!attr) {
79     // 1. Has no filter case and matchDisjunction is empty.
80     if (matchDisjunction.empty() || matchByDefault)
81       return success();
82 
83     // 2. Has no filter but was expecting a filter.
84     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
85       diag << " does not have any filter from list: ";
86       interleaveComma(matchDisjunction, diag);
87     });
88   }
89 
90   // 4. Match explicit filter.
91   for (auto filter : matchDisjunction)
92     if (attr.getValue() == filter)
93       return success();
94 
95   // 5. Fail to match.
96   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
97     diag << " does not have any filter from list: ";
98     interleaveComma(matchDisjunction, diag);
99   });
100 }
101 
102 void mlir::linalg::LinalgTransformationFilter::
103     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
104                                       Operation *op) const {
105   if (replacement.hasValue())
106     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
107                 replacement.getValue());
108   else
109     op->removeAttr(
110         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
111 }
112 
113 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
114     Operation *op) const {
115   if (!replacement)
116     return false;
117   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
118                   .dyn_cast<StringAttr>();
119   return attr && attr == replacement.getValue();
120 }
121 
122 LinalgTilingOptions &
123 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
124   assert(!tileSizeComputationFunction && "tile sizes already set");
125   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
126   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
127     OpBuilder::InsertionGuard guard(b);
128     b.setInsertionPointToStart(
129         &op->getParentOfType<func::FuncOp>().getBody().front());
130     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
131       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
132       return v;
133     }));
134   };
135   return *this;
136 }
137 
138 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
139   assert(!tileSizeComputationFunction && "tile sizes already set");
140   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
141     SmallVector<Value, 4> tileSizes;
142     auto linalgOp = dyn_cast<LinalgOp>(op);
143     if (!linalgOp)
144       return tileSizes;
145     Location loc = linalgOp.getLoc();
146     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
147     AffineMap map = linalgOp.getShapesToLoopsMap();
148     if (!map)
149       return tileSizes;
150     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
151     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
152     // size 0).
153     for (Value shapeSize : shapeSizes)
154       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
155                               ? b.create<arith::ConstantIndexOp>(loc, 0)
156                               : b.create<arith::ConstantIndexOp>(loc, 1));
157     return tileSizes;
158   };
159   return *this;
160 }
161 
162 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
163 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
164 /// Exit early and return the `opOperand` value if the shape dimensions that
165 /// match `paddingDimensions` have a static size and the nofold flag is not set.
166 /// Otherwise, try to pad the shape dimensions that match the iterator
167 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
168 /// padding succeeds or failure otherwise.
169 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
170     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
171     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
172     ArrayRef<bool> packPaddings) {
173   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
174   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
175 
176   // Collect the shape dimension that are a function of the `paddingDimensions`.
177   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
178   for (int64_t dim : paddingDimensions)
179     for (const auto &en : enumerate(indexingMap.getResults()))
180       if (en.value().isFunctionOfDim(dim))
181         shapeDimsToPad.insert(en.index());
182 
183   // Return the unpadded operand if padding to a static shape is not needed and
184   // if the nofold flag is not set.
185   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
186                     ? packPaddings[opOperand->getOperandNumber()]
187                     : false;
188   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
189     return ShapedType::isDynamic(shape[dim]);
190   });
191   if (!nofold && hasStaticShape)
192     return opOperand->get();
193 
194   // Fail if `paddingValues` specifies no padding value.
195   if (opOperand->getOperandNumber() >= paddingValues.size())
196     return failure();
197   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
198   Value paddingValue = b.create<arith::ConstantOp>(
199       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
200 
201   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
202   OpOperand *currOpOperand = opOperand;
203   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
204     OpResult result = currOpOperand->get().cast<OpResult>();
205     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
206   }
207 
208   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
209   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
210   if (!sliceOp)
211     return failure();
212 
213   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
214   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
215   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
216 
217   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
218   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
219   int64_t shapeIdx = 0;
220   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
221     // Skip dropped dimensions.
222     if (droppedDims.test(en.index()))
223       continue;
224     // Skip dimensions that do not require padding.
225     if (!shapeDimsToPad.contains(shapeIdx)) {
226       shapeIdx++;
227       continue;
228     }
229     // If the size is an attribute add it directly to `paddedShape`.
230     if (en.value().is<Attribute>()) {
231       paddedShape[shapeIdx++] =
232           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
233       continue;
234     }
235     // Otherwise, try to compute a constant upper bound for the size value.
236     FailureOr<int64_t> upperBound =
237         getConstantUpperBoundForIndex(en.value().get<Value>());
238     if (failed(upperBound)) {
239       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
240       return failure();
241     }
242     paddedShape[shapeIdx++] = upperBound.getValue();
243   }
244   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
245          "expect the dynamic and static ranks to match");
246 
247   // Pad the operand to the bounding box defined by `paddedShape`.
248   auto paddedTensorType = RankedTensorType::get(
249       paddedShape, getElementTypeOrSelf(opOperand->get()));
250   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
251                                opOperand->get(), paddingValue, nofold);
252 }
253 
254 FailureOr<SmallVector<Value>>
255 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
256                           ArrayRef<int64_t> paddingDimensions,
257                           ArrayRef<Attribute> paddingValues,
258                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
259   Location loc = opToPad->getLoc();
260 
261   // TODO: there are cases where we may still want to pad to larger sizes.
262   assert(opToPad.hasTensorSemantics() &&
263          "expected operation to have tensor semantics");
264 
265   OpBuilder::InsertionGuard g(b);
266   // Set IP after op because we also take the dims of the original output.
267   b.setInsertionPointAfter(opToPad);
268   // Make a copy of the shaped operands and update it.
269   SmallVector<Value> newOperands;
270   newOperands.reserve(opToPad.getNumInputsAndOutputs());
271   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
272     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
273         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
274     // Exit if `paddingDimensions` cannot be bounded statically.
275     if (failed(paddedOperand))
276       return failure();
277     newOperands.push_back(*paddedOperand);
278   }
279 
280   SmallVector<SmallVector<Value>> reifiedResultShapes;
281   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
282                  .reifyResultShapes(b, reifiedResultShapes)))
283     return failure();
284   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
285          "expected same number of results");
286 
287   // Clone `opToPad` to operate on the statically padded shapes.
288   auto resultTensorTypes =
289       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
290   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
291 
292   // Recover the slice out of the new static results. This keeps the original
293   // linalg op around because it uses the dims of the original results.
294   SmallVector<Value> paddedSubviewResults;
295   paddedSubviewResults.reserve(opToPad->getNumResults());
296   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
297     Value paddedResult = en.value();
298     int64_t resultNumber = en.index();
299     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
300     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
301     SmallVector<OpFoldResult> sizes;
302     for (Value v : reifiedResultShapes[resultNumber])
303       sizes.push_back(getAsOpFoldResult(v));
304     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
305     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
306         loc, paddedResult, offsets, sizes, strides));
307   }
308   return paddedSubviewResults;
309 }
310 
311 /// Try to peel a loop `op` and return the new result.
312 // TODO: Add support for scf.parallel and affine.for loops.
313 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
314   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
315       .Case<scf::ForOp>([&](scf::ForOp forOp) {
316         scf::ForOp partialIteration;
317         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
318                                                       partialIteration)))
319           return partialIteration->getResults();
320         assert(!partialIteration && "expected that loop was not peeled");
321         return forOp->getResults();
322       })
323       .Default([&](Operation *op) { return op->getResults(); });
324 }
325 
326 /// Peel and canonicalize 'loops'.
327 void mlir::linalg::peelLoops(RewriterBase &rewriter,
328                              ArrayRef<scf::ForOp> loops) {
329   for (auto loopOp : loops) {
330     SmallVector<Value, 4> loopResults;
331     loopResults = peelLoop(rewriter, loopOp);
332   }
333 }
334 
335 /// Peel loops after tiling.
336 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
337                                      ArrayRef<int64_t> peeledLoops,
338                                      LinalgTilingLoopType loopType) {
339   for (int64_t loop : peeledLoops) {
340     assert(loop < static_cast<int64_t>(res.loops.size()) &&
341            "requested peeling of non-existing loop");
342     SmallVector<Value, 4> loopResults;
343     Operation *loopOp = res.loops[loop];
344     loopResults = peelLoop(rewriter, loopOp);
345 
346     // The result of the loop nest may change with peeling.
347     if (res.tensorResults.size() == loopOp->getNumResults() &&
348         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
349                    loopOp->getResults().begin()))
350       res.tensorResults = loopResults;
351   }
352 }
353 
354 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
355   if (tiledOp.loops.empty())
356     return tiledOp.op.getOperation()->getResults();
357   return tiledOp.loops.front()->getResults();
358 }
359 
360 static ValueRange
361 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
362   if (tiledAndFusedOp.fusedLoops.empty())
363     return tiledAndFusedOp.op.getOperation()->getResults();
364   return tiledAndFusedOp.fusedLoops.front()->getResults();
365 }
366 
367 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
368     StringRef opName, MLIRContext *context,
369     const LinalgDependenceGraph &dependenceGraph,
370     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
371     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
372     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
373     : RewritePattern(opName, benefit, context, {}),
374       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
375       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
376       fusedOpMarker(std::move(fusedOpMarker)),
377       originalOpMarker(std::move(originalOpMarker)) {}
378 
379 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
380     Operation *op, PatternRewriter &rewriter) const {
381   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
382   // TODO: remove hasIndexSemantics check once index ops are supported.
383   if (!linalgOp || linalgOp.hasIndexSemantics())
384     return failure();
385   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
386     return failure();
387 
388   DenseSet<Operation *> producers;
389   producers.insert(linalgOp);
390   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
391     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
392     // When looking at dependences into, indexingOp is always OpOperand. We
393     // could assert, but continue if this is not the case.
394     if (!operandNumber)
395       continue;
396     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
397       continue;
398     if (isa<LinalgOp>(dependence.getDependentOp()))
399       producers.insert(dependence.getDependentOp());
400   }
401 
402   SmallVector<LinalgOp, 1> fusionOps;
403   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
404        ++it) {
405     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
406     if (producerLinalgOp && producers.count(producerLinalgOp))
407       fusionOps.push_back(producerLinalgOp);
408   }
409   fusionOps.push_back(linalgOp);
410 
411   SmallVector<Value, 4> tileSizes =
412       tilingOptions.tileSizeComputationFunction(rewriter, op);
413   LinalgTilingOptions instanceTilingOptions = tilingOptions;
414   instanceTilingOptions.setTileSizes(tileSizes);
415   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
416       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
417   if (!tiledAndFusedOps)
418     return failure();
419 
420   // Tile the unfused loops;
421   SmallVector<Value, 4> unfusedLoopTileSizes;
422   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
423   for (const auto &tileSize : enumerate(tileSizes)) {
424     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
425       unfusedLoopTileSizes.push_back(zero);
426     else
427       unfusedLoopTileSizes.push_back(tileSize.value());
428   }
429   // Tile the loop only if there is a non-zero tile size.
430   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
431     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
432   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
433         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
434           return cst.value() != 0;
435         return true;
436       })) {
437     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
438     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
439     FailureOr<TiledLinalgOp> unfusedTiledOp =
440         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
441     if (failed(unfusedTiledOp))
442       return failure();
443     rewriter.replaceOp(tiledAndFusedOps->op,
444                        getTiledOpResult(unfusedTiledOp.getValue()));
445     tiledAndFusedOps->op = unfusedTiledOp->op;
446   }
447   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
448 
449   filter.replaceLinalgTransformationFilter(rewriter,
450                                            tiledAndFusedOps->op.getOperation());
451   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
452     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
453                                                     fusedOp.getOperation());
454   }
455   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
456     originalOpMarker.replaceLinalgTransformationFilter(
457         rewriter, origProducerOp.getOperation());
458   }
459   rewriter.updateRootInPlace(op, [&]() {
460     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
461   });
462   return success();
463 }
464 
465 /// Linalg tiling pattern.
466 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
467     MLIRContext *context, LinalgTilingOptions options,
468     LinalgTransformationFilter f, PatternBenefit benefit)
469     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
470       filter(std::move(f)), options(std::move(options)) {}
471 
472 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
473     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
474     LinalgTransformationFilter f, PatternBenefit benefit)
475     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
476       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
477 
478 FailureOr<TiledLinalgOp>
479 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
480     LinalgOp op, PatternRewriter &rewriter) const {
481   if (failed(filter.checkAndNotify(rewriter, op)))
482     return failure();
483 
484   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
485   if (failed(res))
486     return failure();
487 
488   // Clear filter to stop recursive pattern application.
489   // This must be done here to properly propagate to peeling branches.
490   filter.replaceLinalgTransformationFilter(rewriter, res->op);
491 
492   // Peel the loops of the TiledLinalgOp.
493   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
494 
495   if (res->tensorResults.empty())
496     rewriter.eraseOp(op);
497   else
498     rewriter.replaceOp(op, res->tensorResults);
499 
500   return res;
501 }
502 
503 /// Linalg padding pattern.
504 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
505     MLIRContext *context, LinalgPaddingOptions options,
506     LinalgTransformationFilter f, PatternBenefit benefit)
507     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
508       filter(std::move(f)), options(std::move(options)) {}
509 
510 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
511     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
512     LinalgTransformationFilter f, PatternBenefit benefit)
513     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
514       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
515 
516 FailureOr<LinalgOp>
517 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
518     LinalgOp linalgOp, PatternRewriter &rewriter) const {
519   if (!linalgOp.hasTensorSemantics())
520     return failure();
521   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
522     return failure();
523 
524   // Pad the operation.
525   LinalgOp paddedOp;
526   FailureOr<SmallVector<Value>> newResults =
527       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
528                         options.paddingValues, options.packPaddings, paddedOp);
529   if (failed(newResults))
530     return failure();
531 
532   // Hoist the padding.
533   for (const auto &en : enumerate(options.hoistPaddings)) {
534     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
535       break;
536     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
537     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
538     if (!padOp || en.value() == 0)
539       continue;
540 
541     // Fail hoisting if the operand shape is not fully static.
542     if (llvm::any_of(paddedOp.getShape(opOperand),
543                      [](int64_t size) { return ShapedType::isDynamic(size); }))
544       return failure();
545 
546     tensor::PadOp hoistedOp;
547     SmallVector<GenericOp> transposeOps;
548     SmallVector<int64_t> transposeVector =
549         en.index() < options.transposePaddings.size()
550             ? options.transposePaddings[en.index()]
551             : SmallVector<int64_t>{};
552 
553     FailureOr<Value> newResult = hoistPaddingOnTensors(
554         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
555     if (failed(newResult))
556       continue;
557     rewriter.replaceOp(padOp, newResult.getValue());
558 
559     // Do not apply hoist padding to the newly introduced transpose operations.
560     for (GenericOp transposeOp : transposeOps)
561       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
562   }
563 
564   // Replace the original operation to pad.
565   rewriter.replaceOp(linalgOp, newResults.getValue());
566   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
567 
568   return paddedOp;
569 }
570 
571 /// Linalg tile and fuse tensor ops pattern.
572 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
573     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
574                                       LinalgTilingAndFusionOptions options,
575                                       LinalgTransformationFilter f,
576                                       PatternBenefit benefit)
577     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
578       filter(std::move(f)), options(std::move(options)) {}
579 
580 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
581     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
582                                       LinalgTilingAndFusionOptions options,
583                                       LinalgTransformationFilter f,
584                                       PatternBenefit benefit)
585     : RewritePattern(opName, benefit, context), filter(std::move(f)),
586       options(std::move(options)) {}
587 
588 FailureOr<mlir::linalg::TileLoopNest>
589 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
590     Operation *op, PatternRewriter &rewriter) const {
591   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
592   if (!rootOp)
593     return failure();
594   if (failed(filter.checkAndNotify(rewriter, op)))
595     return failure();
596 
597   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
598   if (options.tileSizes.size() < rootOp.getNumLoops())
599     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
600 
601   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
602   if (!options.tileInterchange.empty() &&
603       options.tileInterchange.size() != options.tileSizes.size())
604     return rewriter.notifyMatchFailure(
605         op, "expect the number of tile sizes and interchange dims to match");
606 
607   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
608   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
609                                      options.tileSizes.begin() +
610                                          rootOp.getNumLoops());
611   SmallVector<int64_t> rootInterchange =
612       options.tileInterchange.empty()
613           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
614           : SmallVector<int64_t>(options.tileInterchange.begin(),
615                                  options.tileInterchange.begin() +
616                                      rootOp.getNumLoops());
617 
618   // Check `rootTileSizes` contains non-zero tile sizes.
619   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
620     return rewriter.notifyMatchFailure(
621         op, "expect at least one non-zero tile size");
622 
623   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
624   // It has to be a permutation since the tiling cannot tile the same loop
625   // dimension multiple times.
626   if (!isPermutation(rootInterchange))
627     return rewriter.notifyMatchFailure(
628         op, "expect the tile interchange permutes the root loops");
629 
630   // Tile `rootOp` and fuse its producers.
631   FailureOr<TileLoopNest> tileLoopNest =
632       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
633                                    rootInterchange, options.tileDistribution);
634   if (failed(tileLoopNest))
635     return rewriter.notifyMatchFailure(
636         op, "tileConsumerAndFuseProducers failed unexpectedly");
637 
638   // Replace all uses of the tiled loop operation.
639   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
640 
641   // Apply the filter if specified.
642   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
643     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
644   return tileLoopNest;
645 }
646 
647 /// Linalg generic interchange pattern.
648 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
649     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
650     LinalgTransformationFilter f, PatternBenefit benefit)
651     : OpRewritePattern(context, benefit), filter(std::move(f)),
652       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
653 
654 FailureOr<GenericOp>
655 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
656     GenericOp genericOp, PatternRewriter &rewriter) const {
657   if (failed(filter.checkAndNotify(rewriter, genericOp)))
658     return failure();
659 
660   FailureOr<GenericOp> transformedOp =
661       interchangeGenericOp(rewriter, genericOp, interchangeVector);
662   if (failed(transformedOp))
663     return failure();
664 
665   // New filter if specified.
666   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
667   return transformedOp;
668 }
669 
670 /// Linalg generalization pattern.
671 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
672     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
673     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
674       filter(std::move(f)) {}
675 
676 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
677     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
678     PatternBenefit benefit)
679     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
680       filter(f.addOpNameFilter(opName)) {}
681 
682 FailureOr<GenericOp>
683 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
684     LinalgOp linalgOp, PatternRewriter &rewriter) const {
685   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
686     return failure();
687   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
688   if (failed(genericOp))
689     return failure();
690   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
691   return genericOp;
692 }
693 
694 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
695     MLIRContext *context, LinalgTransformationFilter f,
696     LinalgPromotionOptions options, PatternBenefit benefit)
697     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
698       filter(std::move(f)), options(std::move(options)) {}
699 
700 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
701     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
702     LinalgTransformationFilter f, PatternBenefit benefit)
703     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
704       options(std::move(options)) {}
705 
706 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
707     Operation *op, PatternRewriter &rewriter) const {
708   if (failed(filter.checkAndNotify(rewriter, op)))
709     return failure();
710   if (failed(promoteSubviewsPrecondition(op, options)))
711     return failure();
712 
713   // TODO: We cannot use root update here. This pattern is creating other ops,
714   // so if the promotion fails, those need to be cleaned up, which doesnt seem
715   // to be happening here. So to fail properly, we should be cloning the op and
716   // deleting the previous op. This needs more investigation.
717   rewriter.startRootUpdate(op);
718   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
719   if (!promotedOp) {
720     rewriter.cancelRootUpdate(op);
721     return op->emitError("subview promotion failed");
722   }
723   rewriter.finalizeRootUpdate(op);
724   filter.replaceLinalgTransformationFilter(rewriter, op);
725   return success();
726 }
727 
728 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
729     MLIRContext *context, LinalgTransformationFilter f,
730     LinalgPeelOptions options, PatternBenefit benefit)
731     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
732       filter(std::move(f)), options(std::move(options)) {}
733 
734 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
735     StringRef opName, MLIRContext *context, LinalgPeelOptions options,
736     LinalgTransformationFilter f, PatternBenefit benefit)
737     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
738       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
739 
740 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
741     LinalgOp linalgOp, PatternRewriter &rewriter) const {
742   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
743     return failure();
744 
745   // Increase marker counter even if peeling doesn't happen for this op.
746   filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
747 
748   if (!options.loopsToPeelComputationFunction)
749     return failure();
750 
751   SmallVector<scf::ForOp, 4> loopsToPeel;
752   options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
753   peelLoops(rewriter, loopsToPeel);
754   return success();
755 }
756 
757 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
758     MLIRContext *context, LinalgTransformationFilter f,
759     LinalgVectorizationOptions options, PatternBenefit benefit)
760     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
761       filter(std::move(f)) {}
762 
763 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
764     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
765     LinalgTransformationFilter f, PatternBenefit benefit)
766     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
767       filter(f.addOpNameFilter(opName)) {}
768 
769 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
770     LinalgOp linalgOp, PatternRewriter &rewriter) const {
771   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
772     return failure();
773   return vectorize(rewriter, linalgOp);
774 }
775 
776 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
777     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
778   return vectorizeCopy(rewriter, copyOp);
779 }
780 
781 LogicalResult mlir::linalg::applyStagedPatterns(
782     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
783     const FrozenRewritePatternSet &stage2Patterns,
784     function_ref<LogicalResult(Operation *)> stage3Lambda) {
785   unsigned iteration = 0;
786   (void)iteration;
787   for (const auto &patterns : stage1Patterns) {
788     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
789                       << *op);
790     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
791       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
792       return failure();
793     }
794     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
795                       << *op);
796     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
797       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
798       return failure();
799     }
800     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
801                       << *op);
802     if (stage3Lambda) {
803       if (failed(stage3Lambda(op)))
804         return failure();
805       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
806                         << *op);
807     }
808   }
809   return success();
810 }
811 
812 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
813   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
814 }
815 
816 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
817 /// initialize with pad_val) and GenericOp (to copy contents).
818 LogicalResult
819 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
820                                             PatternRewriter &rewriter) const {
821 
822   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
823   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
824 
825   // Bail on non-static shapes.
826   if (!inputShapedType.hasStaticShape())
827     return failure();
828   if (!resultShapedType.hasStaticShape())
829     return failure();
830 
831   // Only support padding with a constant for now, i.e. either:
832   //   1. A BBarg from a different block.
833   //   2. A value defined outside of the current block.
834   Block &block = padOp.region().front();
835   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
836   Value padValue = yieldOp.value();
837   Operation *definingOp = padValue.getDefiningOp();
838   if (definingOp && definingOp->getBlock() == &block)
839     return failure();
840   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
841     return failure();
842 
843   // Create tensor with the padded shape
844   Location loc = padOp.getLoc();
845   SmallVector<Value> indices(resultShapedType.getRank(),
846                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
847   Value initTensor = rewriter.create<InitTensorOp>(
848       loc, resultShapedType.getShape(), resultShapedType.getElementType());
849 
850   // Initialize tensor with the pad value
851   Value tmpTensor = rewriter
852                         .create<linalg::FillOp>(loc, ValueRange{padValue},
853                                                 ValueRange{initTensor})
854                         .result();
855 
856   // Copy original contents into new tensor
857   // Uses linalg.generic, but could be done with tensor.insert_slice
858   SmallVector<AffineExpr, 4> outputExprs;
859   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
860     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
861                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
862   }
863 
864   SmallVector<AffineMap, 2> transferMaps = {
865       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
866       AffineMap::get(resultShapedType.getRank(),
867                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
868 
869   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
870       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
871       getNParallelLoopsAttrs(resultShapedType.getRank()),
872       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
873         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
874       });
875 
876   return success();
877 }
878 
879 /// Filling `dest` using FillOp constant padding value if possible.
880 /// Otherwise, generate a tensor::GenerateOp.
881 Value GeneralizePadOpPattern::createFillOrGenerateOp(
882     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
883     const SmallVector<Value> &dynSizes) const {
884   auto padValue = padOp.getConstantPaddingValue();
885   if (padValue)
886     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
887 
888   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
889   auto generateOp = rewriter.create<tensor::GenerateOp>(
890       padOp.getLoc(), padOp.getResultType(), dynSizes);
891   // Copy region to new op.
892   BlockAndValueMapping bvm;
893   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
894   return generateOp;
895 }
896 
897 LogicalResult
898 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
899                                         PatternRewriter &rewriter) const {
900   // Given an OpFoldResult, return an index-typed value.
901   auto getIdxValue = [&](OpFoldResult ofr) {
902     if (auto val = ofr.dyn_cast<Value>())
903       return val;
904     return rewriter
905         .create<arith::ConstantIndexOp>(
906             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
907         .getResult();
908   };
909 
910   auto resultType = padOp.getResultType();
911   // Compute size of InitTensorOp. Any combination of static/dynamic is
912   // supported.
913   SmallVector<Value> dynSizes;
914   SmallVector<int64_t> staticSizes;
915   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
916     if (resultType.isDynamicDim(dim)) {
917       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
918                                                           padOp.source(), dim);
919       // Add low and high padding value.
920       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
921           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
922       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
923           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
924       dynSizes.push_back(plusHigh);
925     }
926     staticSizes.push_back(resultType.getDimSize(dim));
927   }
928 
929   // Init tensor and fill it with padding.
930   Value init = rewriter.create<InitTensorOp>(
931       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
932   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
933 
934   // Try optimize the copy of source.
935   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
936     return success();
937 
938   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
939   // for copying the PadOp source.
940   auto sourceType = padOp.getSourceType();
941   // Compute size of source of tensor::PadOp.
942   SmallVector<OpFoldResult> srcSizes;
943   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
944     if (sourceType.isDynamicDim(dim)) {
945       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
946           padOp.getLoc(), padOp.source(), dim));
947     } else {
948       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
949     }
950   }
951   // Strides of InsertSliceOp are all 1.
952   SmallVector<OpFoldResult> strides(sourceType.getRank(),
953                                     rewriter.getIndexAttr(1));
954   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
955       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
956 
957   return success();
958 }
959 
960 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
961     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
962   if (!sliceOp.hasUnitStride())
963     return failure();
964 
965   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
966   if (!padOp)
967     return failure();
968 
969   bool zeroSliceGuard = true;
970   if (controlFn) {
971     if (Optional<bool> control = controlFn(sliceOp))
972       zeroSliceGuard = control.getValue();
973     else
974       return failure();
975   }
976 
977   Operation *tiledPadOp =
978       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
979                                sliceOp.getMixedSizes(), zeroSliceGuard);
980   // All shapes are static and the data source is actually used. Rewrite into
981   // pad(extract_slice(x)).
982   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
983   return success();
984 }
985 
986 // The following are patterns for downscaling convolution ops with size-1
987 // window dimensions.
988 //
989 // Note that we'd eventually want to write such transformations in a generic
990 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
991 // and then turning back to named ops. But for now it's fine to have a few
992 // patterns matching special ops to get started.
993 
994 FailureOr<Conv1DNwcWcfOp>
995 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
996     linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
997   if (failed(filter.checkAndNotify(rewriter, convOp)))
998     return failure();
999   if (convOp.hasBufferSemantics())
1000     return failure(); // To be implemented.
1001 
1002   Value input = convOp.inputs().front();
1003   Value kernel = convOp.inputs().back();
1004   Value output = convOp.outputs().front();
1005 
1006   auto inputType = input.getType().dyn_cast<RankedTensorType>();
1007   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1008   auto outputType = output.getType().dyn_cast<RankedTensorType>();
1009 
1010   auto kernelShape = kernelType.getShape();
1011   auto outputShape = outputType.getShape();
1012 
1013   // Only handle the case where at least one of the window dimensions is
1014   // of size 1. Other cases can rely on tiling to reduce to such cases.
1015   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1016   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1017   bool removeH = (khSize == 1 && ohSize == 1);
1018   bool removeW = (kwSize == 1 && owSize == 1);
1019   if (!removeH && !removeW)
1020     return failure();
1021 
1022   // Get new shapes and types for all operands by removing the size-1
1023   // dimension.
1024   using RTTBuilder = RankedTensorType::Builder;
1025   RankedTensorType newInputType =
1026       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1027   RankedTensorType newKernelType =
1028       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1029   RankedTensorType newOutputType =
1030       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1031 
1032   // Rank-reduce operands.
1033   Location loc = convOp.getLoc();
1034   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1035       rewriter, loc, input, newInputType);
1036   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1037       rewriter, loc, kernel, newKernelType);
1038   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1039       rewriter, loc, output, newOutputType);
1040 
1041   // Rank-reduce strides and dilations too.
1042   // TODO: dropDim 1-liner helper.
1043   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1044   strides.erase(strides.begin() + (removeH ? 0 : 1));
1045   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1046 
1047   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1048   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1049   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1050 
1051   auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1052       loc, newOutputType, ValueRange{newInput, newKernel},
1053       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1054 
1055   // Insert back.
1056   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1057       rewriter, loc, conv1DOp.getResult(0), output);
1058   rewriter.replaceOp(convOp, inserted);
1059 
1060   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1061   return conv1DOp;
1062 }
1063 
1064 FailureOr<DepthwiseConv1DNwcWcOp>
1065 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1066     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1067   if (failed(filter.checkAndNotify(rewriter, convOp)))
1068     return failure();
1069   if (convOp.hasBufferSemantics())
1070     return failure(); // To be implemented.
1071 
1072   Value input = convOp.inputs().front();
1073   Value kernel = convOp.inputs().back();
1074   Value output = convOp.outputs().front();
1075 
1076   auto inputType = input.getType().dyn_cast<RankedTensorType>();
1077   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1078   auto outputType = output.getType().dyn_cast<RankedTensorType>();
1079 
1080   auto kernelShape = kernelType.getShape();
1081   auto outputShape = outputType.getShape();
1082 
1083   // Only handle the case where at least one of the window dimensions is
1084   // of size 1. Other cases can rely on tiling to reduce to such cases.
1085   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1086   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1087   bool removeH = (khSize == 1 && ohSize == 1);
1088   bool removeW = (kwSize == 1 && owSize == 1);
1089   if (!removeH && !removeW)
1090     return failure();
1091 
1092   // Get new shapes and types for all operands by removing the size-1
1093   // dimension.
1094   using RTTBuilder = RankedTensorType::Builder;
1095   RankedTensorType newInputType =
1096       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1097   RankedTensorType newKernelType =
1098       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1099   RankedTensorType newOutputType =
1100       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1101 
1102   // Rank-reduce operands.
1103   Location loc = convOp.getLoc();
1104   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1105       rewriter, loc, input, newInputType);
1106   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1107       rewriter, loc, kernel, newKernelType);
1108   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1109       rewriter, loc, output, newOutputType);
1110 
1111   // Rank-reduce strides and dilations too.
1112   // TODO: dropDim 1-liner helper.
1113   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1114   strides.erase(strides.begin() + (removeH ? 0 : 1));
1115   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1116 
1117   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1118   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1119   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1120 
1121   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1122       loc, newOutputType, ValueRange{newInput, newKernel},
1123       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1124 
1125   // Insert back.
1126   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1127       rewriter, loc, conv1DOp.getResult(0), output);
1128   rewriter.replaceOp(convOp, inserted);
1129 
1130   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1131   return conv1DOp;
1132 }
1133 
1134 void linalg::populateDecomposeConvolutionPatterns(
1135     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1136     PatternBenefit benefit) {
1137   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1138                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1139                                                   benefit);
1140 }
1141