1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 
47 //===----------------------------------------------------------------------===//
48 // Transformations exposed as rewrite patterns.
49 //===----------------------------------------------------------------------===//
50 // Marker used as attribute name in generated Linalg rewriting transformations.
51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
52     "__internal_linalg_transform__";
53 
54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
55     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
56     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement), matchByDefault(false) {}
58 
59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
60     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
61     Optional<StringAttr> replacement)
62     : filters(),
63       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
64       replacement(replacement), matchByDefault(false) {
65   if (f)
66     filters.push_back(f);
67 }
68 
69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
70     PatternRewriter &rewriter, Operation *op) const {
71   if (llvm::any_of(filters,
72                    [&](const FilterFunction &f) { return failed(f(op)); }))
73     return failure();
74 
75   auto attr = op->template getAttrOfType<StringAttr>(
76       LinalgTransforms::kLinalgTransformMarker);
77 
78   if (!attr) {
79     // 1. Has no filter case and matchDisjunction is empty.
80     if (matchDisjunction.empty() || matchByDefault)
81       return success();
82 
83     // 2. Has no filter but was expecting a filter.
84     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
85       diag << " does not have any filter from list: ";
86       interleaveComma(matchDisjunction, diag);
87     });
88   }
89 
90   // 4. Match explicit filter.
91   for (auto filter : matchDisjunction)
92     if (attr.getValue() == filter)
93       return success();
94 
95   // 5. Fail to match.
96   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
97     diag << " does not have any filter from list: ";
98     interleaveComma(matchDisjunction, diag);
99   });
100 }
101 
102 void mlir::linalg::LinalgTransformationFilter::
103     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
104                                       Operation *op) const {
105   if (replacement)
106     op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
107   else
108     op->removeAttr(
109         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
110 }
111 
112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
113     Operation *op) const {
114   if (!replacement)
115     return false;
116   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
117                   .dyn_cast<StringAttr>();
118   return attr && attr == *replacement;
119 }
120 
121 LinalgTilingOptions &
122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
123   assert(!tileSizeComputationFunction && "tile sizes already set");
124   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
125   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
126     OpBuilder::InsertionGuard guard(b);
127     b.setInsertionPointToStart(
128         &op->getParentOfType<func::FuncOp>().getBody().front());
129     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
130       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
131       return v;
132     }));
133   };
134   return *this;
135 }
136 
137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
138   assert(!tileSizeComputationFunction && "tile sizes already set");
139   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
140     SmallVector<Value, 4> tileSizes;
141     auto linalgOp = dyn_cast<LinalgOp>(op);
142     if (!linalgOp)
143       return tileSizes;
144     Location loc = linalgOp.getLoc();
145     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
146     AffineMap map = linalgOp.getShapesToLoopsMap();
147     if (!map)
148       return tileSizes;
149     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
150     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
151     // size 0).
152     for (Value shapeSize : shapeSizes)
153       tileSizes.push_back(getConstantIntValue(shapeSize)
154                               ? b.create<arith::ConstantIndexOp>(loc, 0)
155                               : b.create<arith::ConstantIndexOp>(loc, 1));
156     return tileSizes;
157   };
158   return *this;
159 }
160 
161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
163 /// Exit early and return the `opOperand` value if the shape dimensions that
164 /// match `paddingDimensions` have a static size and the nofold flag is not set.
165 /// Otherwise, try to pad the shape dimensions that match the iterator
166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
167 /// padding succeeds or failure otherwise.
168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
169     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
171     ArrayRef<bool> packPaddings) {
172   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
173   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
174 
175   // Collect the shape dimension that are a function of the `paddingDimensions`.
176   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
177   for (int64_t dim : paddingDimensions)
178     for (const auto &en : enumerate(indexingMap.getResults()))
179       if (en.value().isFunctionOfDim(dim))
180         shapeDimsToPad.insert(en.index());
181 
182   // Return the unpadded operand if padding to a static shape is not needed and
183   // if the nofold flag is not set.
184   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
185                     ? packPaddings[opOperand->getOperandNumber()]
186                     : false;
187   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
188     return ShapedType::isDynamic(shape[dim]);
189   });
190   if (!nofold && hasStaticShape)
191     return opOperand->get();
192 
193   // Fail if `paddingValues` specifies no padding value.
194   if (opOperand->getOperandNumber() >= paddingValues.size())
195     return failure();
196   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
197   Value paddingValue = b.create<arith::ConstantOp>(
198       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
199 
200   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
201   OpOperand *currOpOperand = opOperand;
202   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
203     OpResult result = currOpOperand->get().cast<OpResult>();
204     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
205   }
206 
207   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
208   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
209   if (!sliceOp)
210     return failure();
211 
212   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
213   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
214   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
215 
216   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
217   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
218   int64_t shapeIdx = 0;
219   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
220     // Skip dropped dimensions.
221     if (droppedDims.test(en.index()))
222       continue;
223     // Skip dimensions that do not require padding.
224     if (!shapeDimsToPad.contains(shapeIdx)) {
225       shapeIdx++;
226       continue;
227     }
228     // If the size is an attribute add it directly to `paddedShape`.
229     if (en.value().is<Attribute>()) {
230       paddedShape[shapeIdx++] =
231           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
232       continue;
233     }
234     // Otherwise, try to compute a constant upper bound for the size value.
235     FailureOr<int64_t> upperBound =
236         getConstantUpperBoundForIndex(en.value().get<Value>());
237     if (failed(upperBound)) {
238       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
239       return failure();
240     }
241     paddedShape[shapeIdx++] = *upperBound;
242   }
243   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
244          "expect the dynamic and static ranks to match");
245 
246   // Pad the operand to the bounding box defined by `paddedShape`.
247   auto paddedTensorType = RankedTensorType::get(
248       paddedShape, getElementTypeOrSelf(opOperand->get()));
249   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
250                                opOperand->get(), paddingValue, nofold);
251 }
252 
253 FailureOr<SmallVector<Value>>
254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255                           ArrayRef<int64_t> paddingDimensions,
256                           ArrayRef<Attribute> paddingValues,
257                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
258   Location loc = opToPad->getLoc();
259 
260   // TODO: there are cases where we may still want to pad to larger sizes.
261   assert(opToPad.hasTensorSemantics() &&
262          "expected operation to have tensor semantics");
263 
264   OpBuilder::InsertionGuard g(b);
265   // Set IP after op because we also take the dims of the original output.
266   b.setInsertionPointAfter(opToPad);
267   // Make a copy of the shaped operands and update it.
268   SmallVector<Value> newOperands;
269   newOperands.reserve(opToPad.getNumInputsAndOutputs());
270   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
271     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
272         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273     // Exit if `paddingDimensions` cannot be bounded statically.
274     if (failed(paddedOperand))
275       return failure();
276     newOperands.push_back(*paddedOperand);
277   }
278 
279   SmallVector<SmallVector<Value>> reifiedResultShapes;
280   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
281                  .reifyResultShapes(b, reifiedResultShapes)))
282     return failure();
283   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
284          "expected same number of results");
285 
286   // Clone `opToPad` to operate on the statically padded shapes.
287   auto resultTensorTypes =
288       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
289   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
290 
291   // Recover the slice out of the new static results. This keeps the original
292   // linalg op around because it uses the dims of the original results.
293   SmallVector<Value> paddedSubviewResults;
294   paddedSubviewResults.reserve(opToPad->getNumResults());
295   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
296     Value paddedResult = en.value();
297     int64_t resultNumber = en.index();
298     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
299     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
300     SmallVector<OpFoldResult> sizes;
301     for (Value v : reifiedResultShapes[resultNumber])
302       sizes.push_back(getAsOpFoldResult(v));
303     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
304     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
305         loc, paddedResult, offsets, sizes, strides));
306   }
307   return paddedSubviewResults;
308 }
309 
310 /// Try to peel a loop `op` and return the new result.
311 // TODO: Add support for scf.parallel and affine.for loops.
312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
313   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
314       .Case<scf::ForOp>([&](scf::ForOp forOp) {
315         scf::ForOp partialIteration;
316         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
317                                                       partialIteration)))
318           return partialIteration->getResults();
319         assert(!partialIteration && "expected that loop was not peeled");
320         return forOp->getResults();
321       })
322       .Default([&](Operation *op) { return op->getResults(); });
323 }
324 
325 /// Peel and canonicalize 'loops'.
326 void mlir::linalg::peelLoops(RewriterBase &rewriter,
327                              ArrayRef<scf::ForOp> loops) {
328   for (auto loopOp : loops) {
329     SmallVector<Value, 4> loopResults;
330     loopResults = peelLoop(rewriter, loopOp);
331   }
332 }
333 
334 /// Peel loops after tiling.
335 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
336                                      ArrayRef<int64_t> peeledLoops,
337                                      LinalgTilingLoopType loopType) {
338   for (int64_t loop : peeledLoops) {
339     assert(loop < static_cast<int64_t>(res.loops.size()) &&
340            "requested peeling of non-existing loop");
341     SmallVector<Value, 4> loopResults;
342     Operation *loopOp = res.loops[loop];
343     loopResults = peelLoop(rewriter, loopOp);
344 
345     // The result of the loop nest may change with peeling.
346     if (res.tensorResults.size() == loopOp->getNumResults() &&
347         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
348                    loopOp->getResults().begin()))
349       res.tensorResults = loopResults;
350   }
351 }
352 
353 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
354   if (tiledOp.loops.empty())
355     return tiledOp.op.getOperation()->getResults();
356   return tiledOp.loops.front()->getResults();
357 }
358 
359 static ValueRange
360 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
361   if (tiledAndFusedOp.fusedLoops.empty())
362     return tiledAndFusedOp.op.getOperation()->getResults();
363   return tiledAndFusedOp.fusedLoops.front()->getResults();
364 }
365 
366 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
367     StringRef opName, MLIRContext *context,
368     const LinalgDependenceGraph &dependenceGraph,
369     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
370     LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
371     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
372     : RewritePattern(opName, benefit, context, {}),
373       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
374       fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
375       fusedOpMarker(std::move(fusedOpMarker)),
376       originalOpMarker(std::move(originalOpMarker)) {}
377 
378 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
379     Operation *op, PatternRewriter &rewriter) const {
380   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
381   // TODO: remove hasIndexSemantics check once index ops are supported.
382   if (!linalgOp || linalgOp.hasIndexSemantics())
383     return failure();
384   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
385     return failure();
386 
387   DenseSet<Operation *> producers;
388   producers.insert(linalgOp);
389   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
390     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
391     // When looking at dependences into, indexingOp is always OpOperand. We
392     // could assert, but continue if this is not the case.
393     if (!operandNumber)
394       continue;
395     if (!fusionOptions.indicesToFuse.count(*operandNumber))
396       continue;
397     if (isa<LinalgOp>(dependence.getDependentOp()))
398       producers.insert(dependence.getDependentOp());
399   }
400 
401   SmallVector<LinalgOp, 1> fusionOps;
402   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
403        ++it) {
404     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
405     if (producerLinalgOp && producers.count(producerLinalgOp))
406       fusionOps.push_back(producerLinalgOp);
407   }
408   fusionOps.push_back(linalgOp);
409 
410   SmallVector<Value, 4> tileSizes =
411       tilingOptions.tileSizeComputationFunction(rewriter, op);
412   LinalgTilingOptions instanceTilingOptions = tilingOptions;
413   instanceTilingOptions.setTileSizes(tileSizes);
414   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
415       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
416   if (!tiledAndFusedOps)
417     return failure();
418 
419   // Tile the unfused loops;
420   SmallVector<Value, 4> unfusedLoopTileSizes;
421   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
422   for (const auto &tileSize : enumerate(tileSizes)) {
423     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
424       unfusedLoopTileSizes.push_back(zero);
425     else
426       unfusedLoopTileSizes.push_back(tileSize.value());
427   }
428   // Tile the loop only if there is a non-zero tile size.
429   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
430     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
431   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
432         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
433           return cst.value() != 0;
434         return true;
435       })) {
436     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
437     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
438     FailureOr<TiledLinalgOp> unfusedTiledOp =
439         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
440     if (failed(unfusedTiledOp))
441       return failure();
442     rewriter.replaceOp(tiledAndFusedOps->op,
443                        getTiledOpResult(unfusedTiledOp.value()));
444     tiledAndFusedOps->op = unfusedTiledOp->op;
445   }
446   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.value()));
447 
448   filter.replaceLinalgTransformationFilter(rewriter,
449                                            tiledAndFusedOps->op.getOperation());
450   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
451     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
452                                                     fusedOp.getOperation());
453   }
454   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
455     originalOpMarker.replaceLinalgTransformationFilter(
456         rewriter, origProducerOp.getOperation());
457   }
458   rewriter.updateRootInPlace(op, [&]() {
459     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
460   });
461   return success();
462 }
463 
464 /// Linalg tiling pattern.
465 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
466     MLIRContext *context, LinalgTilingOptions options,
467     LinalgTransformationFilter f, PatternBenefit benefit)
468     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
469       filter(std::move(f)), options(std::move(options)) {}
470 
471 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
472     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
473     LinalgTransformationFilter f, PatternBenefit benefit)
474     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
475       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
476 
477 FailureOr<TiledLinalgOp>
478 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
479     LinalgOp op, PatternRewriter &rewriter) const {
480   if (failed(filter.checkAndNotify(rewriter, op)))
481     return failure();
482 
483   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
484   if (failed(res))
485     return failure();
486 
487   // Clear filter to stop recursive pattern application.
488   // This must be done here to properly propagate to peeling branches.
489   filter.replaceLinalgTransformationFilter(rewriter, res->op);
490 
491   // Peel the loops of the TiledLinalgOp.
492   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
493 
494   if (res->tensorResults.empty())
495     rewriter.eraseOp(op);
496   else
497     rewriter.replaceOp(op, res->tensorResults);
498 
499   return res;
500 }
501 
502 /// Linalg padding pattern.
503 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
504     MLIRContext *context, LinalgPaddingOptions options,
505     LinalgTransformationFilter f, PatternBenefit benefit)
506     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
507       filter(std::move(f)), options(std::move(options)) {}
508 
509 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
510     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
511     LinalgTransformationFilter f, PatternBenefit benefit)
512     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
513       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
514 
515 FailureOr<LinalgOp>
516 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
517     LinalgOp linalgOp, PatternRewriter &rewriter) const {
518   if (!linalgOp.hasTensorSemantics())
519     return failure();
520   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
521     return failure();
522 
523   // Pad the operation.
524   LinalgOp paddedOp;
525   FailureOr<SmallVector<Value>> newResults =
526       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
527                         options.paddingValues, options.packPaddings, paddedOp);
528   if (failed(newResults))
529     return failure();
530 
531   // Hoist the padding.
532   for (const auto &en : enumerate(options.hoistPaddings)) {
533     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
534       break;
535     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
536     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
537     if (!padOp || en.value() == 0)
538       continue;
539 
540     // Fail hoisting if the operand shape is not fully static.
541     if (llvm::any_of(paddedOp.getShape(opOperand),
542                      [](int64_t size) { return ShapedType::isDynamic(size); }))
543       return failure();
544 
545     tensor::PadOp hoistedOp;
546     SmallVector<GenericOp> transposeOps;
547     SmallVector<int64_t> transposeVector =
548         en.index() < options.transposePaddings.size()
549             ? options.transposePaddings[en.index()]
550             : SmallVector<int64_t>{};
551 
552     FailureOr<Value> newResult = hoistPaddingOnTensors(
553         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
554     if (failed(newResult))
555       continue;
556     rewriter.replaceOp(padOp, *newResult);
557 
558     // Do not apply hoist padding to the newly introduced transpose operations.
559     for (GenericOp transposeOp : transposeOps)
560       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
561   }
562 
563   // Replace the original operation to pad.
564   rewriter.replaceOp(linalgOp, *newResults);
565   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
566 
567   return paddedOp;
568 }
569 
570 /// Linalg tile and fuse tensor ops pattern.
571 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
572     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
573                                       LinalgTilingAndFusionOptions options,
574                                       LinalgTransformationFilter f,
575                                       PatternBenefit benefit)
576     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
577       filter(std::move(f)), options(std::move(options)) {}
578 
579 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
580     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
581                                       LinalgTilingAndFusionOptions options,
582                                       LinalgTransformationFilter f,
583                                       PatternBenefit benefit)
584     : RewritePattern(opName, benefit, context), filter(std::move(f)),
585       options(std::move(options)) {}
586 
587 FailureOr<mlir::linalg::TileLoopNest>
588 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
589     Operation *op, PatternRewriter &rewriter) const {
590   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
591   if (!rootOp)
592     return failure();
593   if (failed(filter.checkAndNotify(rewriter, op)))
594     return failure();
595 
596   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
597   if (options.tileSizes.size() < rootOp.getNumLoops())
598     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
599 
600   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
601   if (!options.tileInterchange.empty() &&
602       options.tileInterchange.size() != options.tileSizes.size())
603     return rewriter.notifyMatchFailure(
604         op, "expect the number of tile sizes and interchange dims to match");
605 
606   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
607   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
608                                      options.tileSizes.begin() +
609                                          rootOp.getNumLoops());
610   SmallVector<int64_t> rootInterchange =
611       options.tileInterchange.empty()
612           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
613           : SmallVector<int64_t>(options.tileInterchange.begin(),
614                                  options.tileInterchange.begin() +
615                                      rootOp.getNumLoops());
616 
617   // Check `rootTileSizes` contains non-zero tile sizes.
618   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
619     return rewriter.notifyMatchFailure(
620         op, "expect at least one non-zero tile size");
621 
622   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
623   // It has to be a permutation since the tiling cannot tile the same loop
624   // dimension multiple times.
625   if (!isPermutation(rootInterchange))
626     return rewriter.notifyMatchFailure(
627         op, "expect the tile interchange permutes the root loops");
628 
629   // Tile `rootOp` and fuse its producers.
630   FailureOr<TileLoopNest> tileLoopNest =
631       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
632                                    rootInterchange, options.tileDistribution);
633   if (failed(tileLoopNest))
634     return rewriter.notifyMatchFailure(
635         op, "tileConsumerAndFuseProducers failed unexpectedly");
636 
637   // Replace all uses of the tiled loop operation.
638   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
639 
640   // Apply the filter if specified.
641   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
642     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
643   return tileLoopNest;
644 }
645 
646 /// Linalg generic interchange pattern.
647 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
648     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
649     LinalgTransformationFilter f, PatternBenefit benefit)
650     : OpRewritePattern(context, benefit), filter(std::move(f)),
651       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
652 
653 FailureOr<GenericOp>
654 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
655     GenericOp genericOp, PatternRewriter &rewriter) const {
656   if (failed(filter.checkAndNotify(rewriter, genericOp)))
657     return failure();
658 
659   FailureOr<GenericOp> transformedOp =
660       interchangeGenericOp(rewriter, genericOp, interchangeVector);
661   if (failed(transformedOp))
662     return failure();
663 
664   // New filter if specified.
665   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
666   return transformedOp;
667 }
668 
669 /// Linalg generalization pattern.
670 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
671     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
672     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
673       filter(std::move(f)) {}
674 
675 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
676     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
677     PatternBenefit benefit)
678     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
679       filter(f.addOpNameFilter(opName)) {}
680 
681 FailureOr<GenericOp>
682 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
683     LinalgOp linalgOp, PatternRewriter &rewriter) const {
684   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
685     return failure();
686   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
687   if (failed(genericOp))
688     return failure();
689   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
690   return genericOp;
691 }
692 
693 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
694     MLIRContext *context, LinalgTransformationFilter f,
695     LinalgPromotionOptions options, PatternBenefit benefit)
696     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
697       filter(std::move(f)), options(std::move(options)) {}
698 
699 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
700     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
701     LinalgTransformationFilter f, PatternBenefit benefit)
702     : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
703       options(std::move(options)) {}
704 
705 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
706     Operation *op, PatternRewriter &rewriter) const {
707   if (failed(filter.checkAndNotify(rewriter, op)))
708     return failure();
709   if (failed(promoteSubviewsPrecondition(op, options)))
710     return failure();
711 
712   // TODO: We cannot use root update here. This pattern is creating other ops,
713   // so if the promotion fails, those need to be cleaned up, which doesnt seem
714   // to be happening here. So to fail properly, we should be cloning the op and
715   // deleting the previous op. This needs more investigation.
716   rewriter.startRootUpdate(op);
717   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
718   if (!promotedOp) {
719     rewriter.cancelRootUpdate(op);
720     return op->emitError("subview promotion failed");
721   }
722   rewriter.finalizeRootUpdate(op);
723   filter.replaceLinalgTransformationFilter(rewriter, op);
724   return success();
725 }
726 
727 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
728     MLIRContext *context, LinalgTransformationFilter f,
729     LinalgPeelOptions options, PatternBenefit benefit)
730     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
731       filter(std::move(f)), options(std::move(options)) {}
732 
733 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
734     StringRef opName, MLIRContext *context, LinalgPeelOptions options,
735     LinalgTransformationFilter f, PatternBenefit benefit)
736     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
737       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
738 
739 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
740     LinalgOp linalgOp, PatternRewriter &rewriter) const {
741   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
742     return failure();
743 
744   // Increase marker counter even if peeling doesn't happen for this op.
745   filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
746 
747   if (!options.loopsToPeelComputationFunction)
748     return failure();
749 
750   SmallVector<scf::ForOp, 4> loopsToPeel;
751   options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
752   peelLoops(rewriter, loopsToPeel);
753   return success();
754 }
755 
756 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
757     MLIRContext *context, LinalgTransformationFilter f,
758     LinalgVectorizationOptions options, PatternBenefit benefit)
759     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
760       filter(std::move(f)) {}
761 
762 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
763     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
764     LinalgTransformationFilter f, PatternBenefit benefit)
765     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
766       filter(f.addOpNameFilter(opName)) {}
767 
768 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
769     LinalgOp linalgOp, PatternRewriter &rewriter) const {
770   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
771     return failure();
772   return vectorize(rewriter, linalgOp);
773 }
774 
775 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
776     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
777   return vectorizeCopy(rewriter, copyOp);
778 }
779 
780 LogicalResult mlir::linalg::applyStagedPatterns(
781     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
782     const FrozenRewritePatternSet &stage2Patterns,
783     function_ref<LogicalResult(Operation *)> stage3Lambda) {
784   unsigned iteration = 0;
785   (void)iteration;
786   for (const auto &patterns : stage1Patterns) {
787     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
788                       << *op);
789     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
790       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
791       return failure();
792     }
793     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
794                       << *op);
795     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
796       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
797       return failure();
798     }
799     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
800                       << *op);
801     if (stage3Lambda) {
802       if (failed(stage3Lambda(op)))
803         return failure();
804       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
805                         << *op);
806     }
807   }
808   return success();
809 }
810 
811 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
812   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
813 }
814 
815 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
816 /// initialize with pad_val) and GenericOp (to copy contents).
817 LogicalResult
818 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
819                                             PatternRewriter &rewriter) const {
820 
821   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
822   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
823 
824   // Bail on non-static shapes.
825   if (!inputShapedType.hasStaticShape())
826     return failure();
827   if (!resultShapedType.hasStaticShape())
828     return failure();
829 
830   // Only support padding with a constant for now, i.e. either:
831   //   1. A BBarg from a different block.
832   //   2. A value defined outside of the current block.
833   Block &block = padOp.region().front();
834   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
835   Value padValue = yieldOp.value();
836   Operation *definingOp = padValue.getDefiningOp();
837   if (definingOp && definingOp->getBlock() == &block)
838     return failure();
839   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
840     return failure();
841 
842   // Create tensor with the padded shape
843   Location loc = padOp.getLoc();
844   SmallVector<Value> indices(resultShapedType.getRank(),
845                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
846   Value initTensor = rewriter.create<InitTensorOp>(
847       loc, resultShapedType.getShape(), resultShapedType.getElementType());
848 
849   // Initialize tensor with the pad value
850   Value tmpTensor = rewriter
851                         .create<linalg::FillOp>(loc, ValueRange{padValue},
852                                                 ValueRange{initTensor})
853                         .result();
854 
855   // Copy original contents into new tensor
856   // Uses linalg.generic, but could be done with tensor.insert_slice
857   SmallVector<AffineExpr, 4> outputExprs;
858   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
859     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
860                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
861   }
862 
863   SmallVector<AffineMap, 2> transferMaps = {
864       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
865       AffineMap::get(resultShapedType.getRank(),
866                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
867 
868   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
869       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
870       getNParallelLoopsAttrs(resultShapedType.getRank()),
871       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
872         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
873       });
874 
875   return success();
876 }
877 
878 /// Filling `dest` using FillOp constant padding value if possible.
879 /// Otherwise, generate a tensor::GenerateOp.
880 Value GeneralizePadOpPattern::createFillOrGenerateOp(
881     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
882     const SmallVector<Value> &dynSizes) const {
883   auto padValue = padOp.getConstantPaddingValue();
884   if (padValue)
885     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
886 
887   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
888   auto generateOp = rewriter.create<tensor::GenerateOp>(
889       padOp.getLoc(), padOp.getResultType(), dynSizes);
890   // Copy region to new op.
891   BlockAndValueMapping bvm;
892   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
893   return generateOp;
894 }
895 
896 LogicalResult
897 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
898                                         PatternRewriter &rewriter) const {
899   // Given an OpFoldResult, return an index-typed value.
900   auto getIdxValue = [&](OpFoldResult ofr) {
901     if (auto val = ofr.dyn_cast<Value>())
902       return val;
903     return rewriter
904         .create<arith::ConstantIndexOp>(
905             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
906         .getResult();
907   };
908 
909   auto resultType = padOp.getResultType();
910   // Compute size of InitTensorOp. Any combination of static/dynamic is
911   // supported.
912   SmallVector<Value> dynSizes;
913   SmallVector<int64_t> staticSizes;
914   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
915     if (resultType.isDynamicDim(dim)) {
916       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
917                                                           padOp.source(), dim);
918       // Add low and high padding value.
919       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
920           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
921       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
922           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
923       dynSizes.push_back(plusHigh);
924     }
925     staticSizes.push_back(resultType.getDimSize(dim));
926   }
927 
928   // Init tensor and fill it with padding.
929   Value init = rewriter.create<InitTensorOp>(
930       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
931   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
932 
933   // Try optimize the copy of source.
934   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
935     return success();
936 
937   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
938   // for copying the PadOp source.
939   auto sourceType = padOp.getSourceType();
940   // Compute size of source of tensor::PadOp.
941   SmallVector<OpFoldResult> srcSizes;
942   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
943     if (sourceType.isDynamicDim(dim)) {
944       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
945           padOp.getLoc(), padOp.source(), dim));
946     } else {
947       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
948     }
949   }
950   // Strides of InsertSliceOp are all 1.
951   SmallVector<OpFoldResult> strides(sourceType.getRank(),
952                                     rewriter.getIndexAttr(1));
953   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
954       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
955 
956   return success();
957 }
958 
959 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
960     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
961   if (!sliceOp.hasUnitStride())
962     return failure();
963 
964   auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
965   if (!padOp)
966     return failure();
967 
968   bool zeroSliceGuard = true;
969   if (controlFn) {
970     if (Optional<bool> control = controlFn(sliceOp))
971       zeroSliceGuard = *control;
972     else
973       return failure();
974   }
975 
976   Operation *tiledPadOp =
977       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
978                                sliceOp.getMixedSizes(), zeroSliceGuard);
979   // All shapes are static and the data source is actually used. Rewrite into
980   // pad(extract_slice(x)).
981   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
982   return success();
983 }
984 
985 // The following are patterns for downscaling convolution ops with size-1
986 // window dimensions.
987 //
988 // Note that we'd eventually want to write such transformations in a generic
989 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
990 // and then turning back to named ops. But for now it's fine to have a few
991 // patterns matching special ops to get started.
992 
993 FailureOr<Conv1DNwcWcfOp>
994 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
995     linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
996   if (failed(filter.checkAndNotify(rewriter, convOp)))
997     return failure();
998   if (convOp.hasBufferSemantics())
999     return failure(); // To be implemented.
1000 
1001   Value input = convOp.inputs().front();
1002   Value kernel = convOp.inputs().back();
1003   Value output = convOp.outputs().front();
1004 
1005   auto inputType = input.getType().dyn_cast<RankedTensorType>();
1006   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1007   auto outputType = output.getType().dyn_cast<RankedTensorType>();
1008 
1009   auto kernelShape = kernelType.getShape();
1010   auto outputShape = outputType.getShape();
1011 
1012   // Only handle the case where at least one of the window dimensions is
1013   // of size 1. Other cases can rely on tiling to reduce to such cases.
1014   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1015   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1016   bool removeH = (khSize == 1 && ohSize == 1);
1017   bool removeW = (kwSize == 1 && owSize == 1);
1018   if (!removeH && !removeW)
1019     return failure();
1020 
1021   // Get new shapes and types for all operands by removing the size-1
1022   // dimension.
1023   using RTTBuilder = RankedTensorType::Builder;
1024   RankedTensorType newInputType =
1025       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1026   RankedTensorType newKernelType =
1027       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1028   RankedTensorType newOutputType =
1029       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1030 
1031   // Rank-reduce operands.
1032   Location loc = convOp.getLoc();
1033   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1034       rewriter, loc, input, newInputType);
1035   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1036       rewriter, loc, kernel, newKernelType);
1037   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1038       rewriter, loc, output, newOutputType);
1039 
1040   // Rank-reduce strides and dilations too.
1041   // TODO: dropDim 1-liner helper.
1042   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1043   strides.erase(strides.begin() + (removeH ? 0 : 1));
1044   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1045 
1046   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1047   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1048   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1049 
1050   auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1051       loc, newOutputType, ValueRange{newInput, newKernel},
1052       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1053 
1054   // Insert back.
1055   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1056       rewriter, loc, conv1DOp.getResult(0), output);
1057   rewriter.replaceOp(convOp, inserted);
1058 
1059   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1060   return conv1DOp;
1061 }
1062 
1063 FailureOr<DepthwiseConv1DNwcWcOp>
1064 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
1065     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1066   if (failed(filter.checkAndNotify(rewriter, convOp)))
1067     return failure();
1068   if (convOp.hasBufferSemantics())
1069     return failure(); // To be implemented.
1070 
1071   Value input = convOp.inputs().front();
1072   Value kernel = convOp.inputs().back();
1073   Value output = convOp.outputs().front();
1074 
1075   auto inputType = input.getType().dyn_cast<RankedTensorType>();
1076   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1077   auto outputType = output.getType().dyn_cast<RankedTensorType>();
1078 
1079   auto kernelShape = kernelType.getShape();
1080   auto outputShape = outputType.getShape();
1081 
1082   // Only handle the case where at least one of the window dimensions is
1083   // of size 1. Other cases can rely on tiling to reduce to such cases.
1084   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1085   int64_t ohSize = outputShape[1], owSize = outputShape[2];
1086   bool removeH = (khSize == 1 && ohSize == 1);
1087   bool removeW = (kwSize == 1 && owSize == 1);
1088   if (!removeH && !removeW)
1089     return failure();
1090 
1091   // Get new shapes and types for all operands by removing the size-1
1092   // dimension.
1093   using RTTBuilder = RankedTensorType::Builder;
1094   RankedTensorType newInputType =
1095       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1096   RankedTensorType newKernelType =
1097       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1098   RankedTensorType newOutputType =
1099       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1100 
1101   // Rank-reduce operands.
1102   Location loc = convOp.getLoc();
1103   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1104       rewriter, loc, input, newInputType);
1105   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1106       rewriter, loc, kernel, newKernelType);
1107   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1108       rewriter, loc, output, newOutputType);
1109 
1110   // Rank-reduce strides and dilations too.
1111   // TODO: dropDim 1-liner helper.
1112   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1113   strides.erase(strides.begin() + (removeH ? 0 : 1));
1114   auto stridesAttr = rewriter.getI64VectorAttr(strides);
1115 
1116   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1117   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1118   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1119 
1120   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1121       loc, newOutputType, ValueRange{newInput, newKernel},
1122       ValueRange{newOutput}, stridesAttr, dilationsAttr);
1123 
1124   // Insert back.
1125   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1126       rewriter, loc, conv1DOp.getResult(0), output);
1127   rewriter.replaceOp(convOp, inserted);
1128 
1129   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1130   return conv1DOp;
1131 }
1132 
1133 void linalg::populateDecomposeConvolutionPatterns(
1134     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1135     PatternBenefit benefit) {
1136   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1137                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1138                                                   benefit);
1139 }
1140