1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/SCF/Transforms.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/StaticValueUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/VectorOps.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <type_traits>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     FilterFunction f, ArrayRef<StringAttr> matchDisjunction,
58     Optional<StringAttr> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 replacement.getValue());
105   else
106     op->removeAttr(
107         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
108 }
109 
110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
111     Operation *op) const {
112   if (!replacement)
113     return false;
114   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
115                   .dyn_cast<StringAttr>();
116   return attr && attr == replacement.getValue();
117 }
118 
119 LinalgTilingOptions &
120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
121   assert(!tileSizeComputationFunction && "tile sizes already set");
122   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124     OpBuilder::InsertionGuard guard(b);
125     b.setInsertionPointToStart(
126         &op->getParentOfType<FuncOp>().getBody().front());
127     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129       return v;
130     }));
131   };
132   return *this;
133 }
134 
135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
136   assert(!tileSizeComputationFunction && "tile sizes already set");
137   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138     SmallVector<Value, 4> tileSizes;
139     auto linalgOp = dyn_cast<LinalgOp>(op);
140     if (!linalgOp)
141       return tileSizes;
142     Location loc = linalgOp.getLoc();
143     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144     AffineMap map = linalgOp.getShapesToLoopsMap();
145     if (!map)
146       return tileSizes;
147     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149     // size 0).
150     for (Value shapeSize : shapeSizes)
151       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152                               ? b.create<arith::ConstantIndexOp>(loc, 0)
153                               : b.create<arith::ConstantIndexOp>(loc, 1));
154     return tileSizes;
155   };
156   return *this;
157 }
158 
159 /// Helper function that tries to pad `opOperand`. Exit early for scalar
160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162 /// has a static shape. Set `result` to the result of the created PadTensorOp or
163 /// and return success if the operand either has been padded to a static shape
164 /// or already had a static shape and failure otherwise.
165 static LogicalResult padOperandToSmallestStaticBoundingBox(
166     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167     const PaddingValueComputationFunction &paddingFunc,
168     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169   // Get the shape of the operand and check if it has a dynamic shape. Only
170   // return failure if the operand is not a scalar and has a dynamic shape.
171   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172   bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173 
174   // Cannot pad scalar operands.
175   if (shape.empty())
176     return success();
177 
178   // Cannot pad if the padding value is unknown.
179   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180   if (failed(paddingValue))
181     return failure(hasDynamicShape);
182 
183   // Cannot construct a static bounding box if the operand is not defined by an
184   // ExtractSliceOp.
185   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186   if (!sliceOp)
187     return failure(hasDynamicShape);
188 
189   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
190   llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
191 
192   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
193   SmallVector<int64_t> staticSizes;
194   staticSizes.reserve(shape.size());
195   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
196   for (auto en : enumerate(shapedOp.getMixedSizes())) {
197     // Skip dropped dimensions.
198     if (droppedDims.contains(en.index()))
199       continue;
200     // If the size is an attribute add it directly to `staticSizes`.
201     if (en.value().is<Attribute>()) {
202       staticSizes.push_back(
203           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
204       continue;
205     }
206     // Otherwise, try to compute a constant upper bound for the size value.
207     FailureOr<int64_t> upperBound =
208         getConstantUpperBoundForIndex(en.value().get<Value>());
209     if (failed(upperBound)) {
210       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
211       return failure();
212     }
213     staticSizes.push_back(upperBound.getValue());
214   }
215   assert(staticSizes.size() == shape.size() &&
216          "expect the dynamic and static ranks to match");
217 
218   // Pad the operand to the bounding box defined by `staticSizes`.
219   auto staticTensorType = RankedTensorType::get(
220       staticSizes, getElementTypeOrSelf(opOperand->get()));
221   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
222   result =
223       makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
224                             opOperand->get(), paddingValue.getValue(), nofold);
225   return success();
226 }
227 
228 FailureOr<SmallVector<Value>>
229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
230                           const PaddingValueComputationFunction &paddingFunc,
231                           const PaddingNoFoldComputationFunction &nofoldFunc,
232                           LinalgOp &paddedOp) {
233   Location loc = opToPad->getLoc();
234 
235   // TODO: there are cases where we may still want to pad to larger sizes.
236   assert(opToPad.hasTensorSemantics() &&
237          "expected operation to have tensor semantics");
238 
239   OpBuilder::InsertionGuard g(b);
240   // Set IP after op because we also take the dims of the original output.
241   b.setInsertionPointAfter(opToPad);
242   // Make a copy of the shaped operands and update it.
243   SmallVector<Value> newOperands;
244   newOperands.reserve(opToPad.getNumInputsAndOutputs());
245   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
246     Value paddedOperand;
247     // If padding was requested but the shape cannot be bounded statically then
248     // the pattern fails to apply.
249     if (failed(padOperandToSmallestStaticBoundingBox(
250             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
251       return failure();
252     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
253   }
254 
255   SmallVector<SmallVector<Value>> reifiedResultShapes;
256   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
257                  .reifyResultShapes(b, reifiedResultShapes)))
258     return failure();
259   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
260          "expected same number of results");
261 
262   // Clone `opToPad` to operate on the statically padded shapes.
263   auto resultTensorTypes =
264       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
265   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
266 
267   // Recover the slice out of the new static results. This keeps the original
268   // linalg op around because it uses the dims of the original results.
269   SmallVector<Value> paddedSubviewResults;
270   paddedSubviewResults.reserve(opToPad->getNumResults());
271   for (auto en : llvm::enumerate(paddedOp->getResults())) {
272     Value paddedResult = en.value();
273     int64_t resultNumber = en.index();
274     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
275     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
276     SmallVector<OpFoldResult> sizes;
277     for (Value v : reifiedResultShapes[resultNumber])
278       sizes.push_back(getAsOpFoldResult(v));
279     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
280     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
281         loc, paddedResult, offsets, sizes, strides));
282   }
283   return paddedSubviewResults;
284 }
285 
286 /// Linalg base tiling pattern.
287 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
288     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
289     LinalgTransformationFilter filter, PatternBenefit benefit)
290     : RewritePattern(opName, benefit, context), filter(filter),
291       options(options) {}
292 
293 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
294     MLIRContext *context, LinalgTilingOptions options,
295     LinalgTransformationFilter filter, PatternBenefit benefit)
296     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
297       options(options) {}
298 
299 /// Try to peel a loop `op` and return the new result.
300 // TODO: Add support for scf.parallel and affine.for loops.
301 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
302   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
303       .Case<scf::ForOp>([&](scf::ForOp forOp) {
304         scf::ForOp partialIteration;
305         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
306                                                       partialIteration)))
307           return partialIteration->getResults();
308         assert(!partialIteration && "expected that loop was not peeled");
309         return forOp->getResults();
310       })
311       .Default([&](Operation *op) { return op->getResults(); });
312 }
313 
314 /// Try to peel a TiledLoopOp and return the new result.
315 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
316                                       TiledLoopOp tiledLoop, int64_t idx) {
317   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
318          "requested peeling of non-existing loop");
319   TiledLoopOp result;
320   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
321     return result->getResults();
322   assert(!result && "expected that loop was not peeled");
323   return tiledLoop->getResults();
324 }
325 
326 /// Peel loops after tiling.
327 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
328                       const LinalgTilingOptions &options) {
329   for (int64_t loop : options.peeledLoops) {
330     assert(loop < static_cast<int64_t>(res.loops.size()) &&
331            "requested peeling of non-existing loop");
332     SmallVector<Value, 4> loopResults;
333     Operation *loopOp = res.loops[loop];
334     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
335       assert(llvm::all_of(
336                  res.loops,
337                  [&](Operation *op) { return op == res.loops.front(); }) &&
338              "expected that all loop ops are the same TiledLoopOp");
339       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
340       assert(tiledLoopOp && "expected TiledLoopOp");
341       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
342     } else {
343       loopResults = peelLoop(rewriter, loopOp);
344     }
345 
346     // The result of the loop nest may change with peeling.
347     if (res.tensorResults.size() == loopOp->getNumResults() &&
348         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
349                    loopOp->getResults().begin()))
350       res.tensorResults = loopResults;
351   }
352 }
353 
354 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
355     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
356   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
357   if (!linalgOp)
358     return failure();
359   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
360     return failure();
361 
362   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
363 
364   if (!res)
365     return failure();
366   // Clear filter to stop recursive pattern application.
367   filter.replaceLinalgTransformationFilter(rewriter, res->op);
368 
369   // Peel loops.
370   peelLoops(rewriter, *res, options);
371 
372   result = *res;
373   return success();
374 }
375 
376 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
377   if (tiledOp.loops.empty())
378     return tiledOp.op.getOperation()->getResults();
379   return tiledOp.loops.front()->getResults();
380 }
381 
382 static ValueRange
383 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
384   if (tiledAndFusedOp.fusedLoops.empty())
385     return tiledAndFusedOp.op.getOperation()->getResults();
386   return tiledAndFusedOp.fusedLoops.front()->getResults();
387 }
388 
389 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
390     StringRef opName, MLIRContext *context,
391     const LinalgDependenceGraph &dependenceGraph,
392     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
393     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
394     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
395     : RewritePattern(opName, benefit, context, {}),
396       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
397       fusionOptions(fusionOptions), filter(filter),
398       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
399 
400 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
401     Operation *op, PatternRewriter &rewriter) const {
402   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
403   // TODO: remove hasIndexSemantics check once index ops are supported.
404   if (!linalgOp || linalgOp.hasIndexSemantics())
405     return failure();
406   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
407     return failure();
408 
409   DenseSet<Operation *> producers;
410   producers.insert(linalgOp);
411   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
412     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
413     // When looking at dependences into, indexingOp is always OpOperand. We
414     // could assert, but continue if this is not the case.
415     if (!operandNumber)
416       continue;
417     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
418       continue;
419     if (isa<LinalgOp>(dependence.getDependentOp()))
420       producers.insert(dependence.getDependentOp());
421   }
422 
423   SmallVector<LinalgOp, 1> fusionOps;
424   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
425        ++it) {
426     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
427     if (producerLinalgOp && producers.count(producerLinalgOp))
428       fusionOps.push_back(producerLinalgOp);
429   }
430   fusionOps.push_back(linalgOp);
431 
432   SmallVector<Value, 4> tileSizes =
433       tilingOptions.tileSizeComputationFunction(rewriter, op);
434   LinalgTilingOptions instanceTilingOptions = tilingOptions;
435   instanceTilingOptions.setTileSizes(tileSizes);
436   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
437       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
438   if (!tiledAndFusedOps)
439     return failure();
440 
441   // Tile the unfused loops;
442   SmallVector<Value, 4> unfusedLoopTileSizes;
443   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
444   for (auto tileSize : enumerate(tileSizes)) {
445     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
446       unfusedLoopTileSizes.push_back(zero);
447     else
448       unfusedLoopTileSizes.push_back(tileSize.value());
449   }
450   // Tile the loop only if there is a non-zero tile size.
451   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
452     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
453   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
454         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
455           return cst.value() != 0;
456         return true;
457       })) {
458     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
459     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
460     Optional<TiledLinalgOp> unfusedTiledOp =
461         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
462     if (!unfusedTiledOp)
463       return failure();
464     rewriter.replaceOp(tiledAndFusedOps->op,
465                        getTiledOpResult(unfusedTiledOp.getValue()));
466     tiledAndFusedOps->op = unfusedTiledOp->op;
467   }
468   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
469 
470   filter.replaceLinalgTransformationFilter(rewriter,
471                                            tiledAndFusedOps->op.getOperation());
472   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
473     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
474                                                     fusedOp.getOperation());
475   }
476   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
477     originalOpMarker.replaceLinalgTransformationFilter(
478         rewriter, origProducerOp.getOperation());
479   }
480   rewriter.updateRootInPlace(op, [&]() {
481     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
482   });
483   return success();
484 }
485 
486 /// Linalg padding pattern.
487 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
488     MLIRContext *context, LinalgPaddingOptions options,
489     LinalgTransformationFilter filter, PatternBenefit benefit)
490     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
491       options(options) {}
492 
493 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
494     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
495     LinalgTransformationFilter filter, PatternBenefit benefit)
496     : RewritePattern(opName, benefit, context, {}), filter(filter),
497       options(options) {}
498 
499 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
500     Operation *op, PatternRewriter &rewriter) const {
501   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
502   if (!linalgOp)
503     return failure();
504   if (!linalgOp.hasTensorSemantics())
505     return failure();
506   if (failed(filter.checkAndNotify(rewriter, op)))
507     return failure();
508 
509   // Pad the operation.
510   LinalgOp paddedOp;
511   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
512       rewriter, linalgOp, options.paddingValueComputationFunction,
513       options.paddingNoFoldComputationFunction, paddedOp);
514   if (failed(newResults))
515     return failure();
516 
517   // Compute the desired hoisting depths.
518   SmallVector<int64_t> depths;
519   if (options.paddingHoistComputationFunction) {
520     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
521       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
522   }
523 
524   // Hoist the padding.
525   for (auto en : enumerate(depths)) {
526     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
527     auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
528     if (!padTensorOp || en.value() == 0)
529       continue;
530     PadTensorOp hoistedOp;
531     FailureOr<Value> newResult =
532         hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
533     if (failed(newResult))
534       continue;
535     rewriter.replaceOp(padTensorOp, newResult.getValue());
536   }
537 
538   // Replace the original operation to pad.
539   rewriter.replaceOp(op, newResults.getValue());
540   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
541   return success();
542 }
543 
544 /// Linalg tile and fuse tensor ops pattern.
545 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
546     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
547                                       LinalgTilingAndFusionOptions options,
548                                       LinalgTransformationFilter filter,
549                                       PatternBenefit benefit)
550     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
551       options(options) {}
552 
553 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
554     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
555                                       LinalgTilingAndFusionOptions options,
556                                       LinalgTransformationFilter filter,
557                                       PatternBenefit benefit)
558     : RewritePattern(opName, benefit, context), filter(filter),
559       options(options) {}
560 
561 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
562     Operation *op, PatternRewriter &rewriter) const {
563   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
564   if (!rootOp)
565     return failure();
566   if (failed(filter.checkAndNotify(rewriter, op)))
567     return failure();
568 
569   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
570   if (options.tileSizes.size() < rootOp.getNumLoops())
571     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
572 
573   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
574   if (!options.tileInterchange.empty() &&
575       options.tileInterchange.size() != options.tileSizes.size())
576     return rewriter.notifyMatchFailure(
577         op, "expect the number of tile sizes and interchange dims to match");
578 
579   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
580   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
581                                      options.tileSizes.begin() +
582                                          rootOp.getNumLoops());
583   SmallVector<int64_t> rootInterchange =
584       options.tileInterchange.empty()
585           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
586           : SmallVector<int64_t>(options.tileInterchange.begin(),
587                                  options.tileInterchange.begin() +
588                                      rootOp.getNumLoops());
589 
590   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
591   // It has to be a permutation since the tiling cannot tile the same loop
592   // dimension multiple times.
593   if (!isPermutation(rootInterchange))
594     return rewriter.notifyMatchFailure(
595         op, "expect the tile interchange permutes the root loops");
596 
597   // Tile `rootOp` and fuse its producers.
598   FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
599       rewriter, rootOp, rootTileSizes, rootInterchange);
600   if (failed(tileLoopNest))
601     return rewriter.notifyMatchFailure(
602         op, "tileConsumerAndFuseProducers failed unexpectedly");
603 
604   // Replace all uses of the tiled loop operation.
605   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
606 
607   // Apply the filter if specified.
608   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
609     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
610   return failure();
611 }
612 
613 /// Linalg generic interchange pattern.
614 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
615     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
616     LinalgTransformationFilter filter, PatternBenefit benefit)
617     : OpRewritePattern(context, benefit), filter(filter),
618       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
619 
620 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
621     GenericOp genericOp, PatternRewriter &rewriter) const {
622   if (failed(filter.checkAndNotify(rewriter, genericOp)))
623     return failure();
624   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
625     return failure();
626 
627   // TODO: figure out how this interplays with named ops. In particular this
628   // should break the named op property.
629   rewriter.updateRootInPlace(genericOp, [&]() {
630     interchangeGenericOp(rewriter, genericOp, interchangeVector);
631     // New filter if specified.
632     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
633   });
634   return success();
635 }
636 
637 /// Linalg generalization pattern.
638 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
639     MLIRContext *context, LinalgTransformationFilter filter,
640     PatternBenefit benefit)
641     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
642 
643 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
644     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
645     PatternBenefit benefit)
646     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
647 
648 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
649     Operation *op, PatternRewriter &rewriter) const {
650   if (failed(filter.checkAndNotify(rewriter, op)))
651     return failure();
652   if (failed(generalizeNamedOpPrecondition(op)))
653     return failure();
654 
655   GenericOp genericOp = generalizeNamedOp(rewriter, op);
656   rewriter.replaceOp(op, genericOp.getResults());
657   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
658   return success();
659 }
660 
661 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
662     MLIRContext *context, LinalgTransformationFilter filter,
663     LinalgPromotionOptions options, PatternBenefit benefit)
664     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
665       options(options) {}
666 
667 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
668     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
669     LinalgTransformationFilter filter, PatternBenefit benefit)
670     : RewritePattern(opName, benefit, context, {}), filter(filter),
671       options(options) {}
672 
673 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
674     Operation *op, PatternRewriter &rewriter) const {
675   if (failed(filter.checkAndNotify(rewriter, op)))
676     return failure();
677   if (failed(promoteSubviewsPrecondition(op, options)))
678     return failure();
679 
680   // TODO: We cannot use root update here. This pattern is creating other ops,
681   // so if the promotion fails, those need to be cleaned up, which doesnt seem
682   // to be happening here. So to fail properly, we should be cloning the op and
683   // deleting the previous op. This needs more investigation.
684   rewriter.startRootUpdate(op);
685   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
686   if (!promotedOp) {
687     rewriter.cancelRootUpdate(op);
688     return op->emitError("subview promotion failed");
689   }
690   rewriter.finalizeRootUpdate(op);
691   filter.replaceLinalgTransformationFilter(rewriter, op);
692   return success();
693 }
694 
695 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
696     MLIRContext *context, LinalgTransformationFilter filter,
697     PatternBenefit benefit)
698     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
699 
700 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
701     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
702     PatternBenefit benefit)
703     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
704 
705 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
706     Operation *op, PatternRewriter &rewriter) const {
707   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
708   if (!linalgOp)
709     return failure();
710   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
711     return failure();
712   SmallVector<Value> newResults;
713   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
714     return failure();
715   if (!newResults.empty())
716     rewriter.replaceOp(op, newResults);
717   else
718     rewriter.eraseOp(op);
719   return success();
720 }
721 
722 LogicalResult mlir::linalg::applyStagedPatterns(
723     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
724     const FrozenRewritePatternSet &stage2Patterns,
725     function_ref<LogicalResult(Operation *)> stage3Lambda) {
726   unsigned iteration = 0;
727   (void)iteration;
728   for (const auto &patterns : stage1Patterns) {
729     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
730                       << *op);
731     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
732       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
733       return failure();
734     }
735     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
736                       << *op);
737     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
738       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
739       return failure();
740     }
741     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
742                       << *op);
743     if (stage3Lambda) {
744       if (failed(stage3Lambda(op)))
745         return failure();
746       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
747                         << *op);
748     }
749   }
750   return success();
751 }
752 
753 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
754   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
755 }
756 
757 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
758 /// with pad_val) and GenericOp (to copy contents).
759 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
760     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
761 
762   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
763   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
764 
765   // Bail on non-static shapes.
766   if (!inputShapedType.hasStaticShape())
767     return failure();
768   if (!resultShapedType.hasStaticShape())
769     return failure();
770 
771   // Only support padding with a constant for now, i.e. either:
772   //   1. A BBarg from a different block.
773   //   2. A value defined outside of the current block.
774   Block &block = padOp.region().front();
775   auto yieldOp = cast<YieldOp>(block.getTerminator());
776   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
777   Value padValue = yieldOp.values().front();
778   Operation *definingOp = padValue.getDefiningOp();
779   if (definingOp && definingOp->getBlock() == &block)
780     return failure();
781   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
782     return failure();
783 
784   // Create tensor with the padded shape
785   Location loc = padOp.getLoc();
786   SmallVector<Value> indices(resultShapedType.getRank(),
787                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
788   Value initTensor = rewriter.create<InitTensorOp>(
789       loc, resultShapedType.getShape(), resultShapedType.getElementType());
790 
791   // Initialize tensor with the pad value
792   Value tmpTensor =
793       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
794 
795   // Copy original contents into new tensor
796   // Uses linalg.generic, but could be done with tensor.insert_slice
797   SmallVector<AffineExpr, 4> outputExprs;
798   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
799     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
800                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
801   }
802 
803   SmallVector<AffineMap, 2> transferMaps = {
804       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
805       AffineMap::get(resultShapedType.getRank(),
806                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
807 
808   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
809       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
810       getNParallelLoopsAttrs(resultShapedType.getRank()),
811       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
812         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
813       });
814 
815   return success();
816 }
817 
818 /// Filling `dest` using FillOp constant padding value if possible.
819 /// Otherwise, generate a tensor::GenerateOp.
820 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
821     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
822     const SmallVector<Value> &dynSizes) const {
823   auto padValue = padOp.getConstantPaddingValue();
824   if (padValue)
825     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
826 
827   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
828   auto generateOp = rewriter.create<tensor::GenerateOp>(
829       padOp.getLoc(), padOp.getResultType(), dynSizes);
830   // Copy region to new op.
831   BlockAndValueMapping bvm;
832   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
833   // Rewrite linalg::YieldOp to tensor::YieldOp.
834   OpBuilder::InsertionGuard guard(rewriter);
835   auto yieldOp =
836       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
837   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
838   assert(yieldOp.values().size() == 1);
839   rewriter.setInsertionPoint(yieldOp);
840   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
841   return generateOp;
842 }
843 
844 LogicalResult
845 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
846                                               PatternRewriter &rewriter) const {
847   // Given an OpFoldResult, return an index-typed value.
848   auto getIdxValue = [&](OpFoldResult ofr) {
849     if (auto val = ofr.dyn_cast<Value>())
850       return val;
851     return rewriter
852         .create<arith::ConstantIndexOp>(
853             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
854         .getResult();
855   };
856 
857   auto resultType = padOp.getResultType();
858   // Compute size of InitTensorOp. Any combination of static/dynamic is
859   // supported.
860   SmallVector<Value> dynSizes;
861   SmallVector<int64_t> staticSizes;
862   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
863     if (resultType.isDynamicDim(dim)) {
864       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
865                                                           padOp.source(), dim);
866       // Add low and high padding value.
867       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
868           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
869       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
870           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
871       dynSizes.push_back(plusHigh);
872     }
873     staticSizes.push_back(resultType.getDimSize(dim));
874   }
875 
876   // Init tensor and fill it with padding.
877   Value init = rewriter.create<InitTensorOp>(
878       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
879   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
880 
881   // Try optimize the copy of source.
882   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
883     return success();
884 
885   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
886   // for copying the PadOp source.
887   auto sourceType = padOp.getSourceType();
888   // Compute size of source of PadTensorOp.
889   SmallVector<OpFoldResult> srcSizes;
890   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
891     if (sourceType.isDynamicDim(dim)) {
892       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
893           padOp.getLoc(), padOp.source(), dim));
894     } else {
895       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
896     }
897   }
898   // Strides of InsertSliceOp are all 1.
899   SmallVector<OpFoldResult> strides(sourceType.getRank(),
900                                     rewriter.getIndexAttr(1));
901   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
902       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
903 
904   return success();
905 }
906 
907 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
908     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
909   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
910   if (!padOp)
911     return failure();
912   // Only unit stride supported.
913   if (!sliceOp.hasUnitStride())
914     return failure();
915 
916   Operation *tiledPadOp =
917       padOp
918           .getTiledImplementation(
919               rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
920               sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
921           .front();
922   // All shapes are static and the data source is actually used. Rewrite into
923   // pad_tensor(subtensor(x)).
924   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
925   return success();
926 }
927 
928 namespace {
929 // The following are patterns for downscaling convolution ops with size-1
930 // window dimensions.
931 //
932 // Note that we'd eventually want to write such transformations in a generic
933 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
934 // and then turning back to named ops. But for now it's fine to have a few
935 // patterns matching special ops to get started.
936 
937 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
938 /// convolution ops.
939 struct DownscaleSizeOneWindowed2DConvolution final
940     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
941   DownscaleSizeOneWindowed2DConvolution(
942       MLIRContext *context,
943       LinalgTransformationFilter filter = LinalgTransformationFilter(),
944       PatternBenefit benefit = 1)
945       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {}
946 
947   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
948                                 PatternRewriter &rewriter) const override {
949     if (failed(filter.checkAndNotify(rewriter, convOp)))
950       return failure();
951     if (convOp.hasBufferSemantics())
952       return failure(); // To be implemented
953 
954     Value input = convOp.inputs().front();
955     Value kernel = convOp.inputs().back();
956     Value output = convOp.outputs().front();
957 
958     auto inputType = input.getType().dyn_cast<RankedTensorType>();
959     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
960     auto outputType = output.getType().dyn_cast<RankedTensorType>();
961 
962     auto kernelShape = kernelType.getShape();
963     auto outputShape = outputType.getShape();
964 
965     // Only handle the case where at least one of the window dimensions is
966     // of size 1. Other cases can rely on tiling to reduce to such cases.
967     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
968     int64_t ohSize = outputShape[1], owSize = outputShape[2];
969     bool removeH = (khSize == 1 && ohSize == 1);
970     bool removeW = (kwSize == 1 && owSize == 1);
971     if (!removeH && !removeW)
972       return failure();
973 
974     // Get new shapes and types for all operands by removing the size-1
975     // dimension.
976     using RTTBuilder = RankedTensorType::Builder;
977     RankedTensorType newInputType =
978         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
979     RankedTensorType newKernelType =
980         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
981     RankedTensorType newOutputType =
982         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
983 
984     // Rank-reduce operands.
985     Location loc = convOp.getLoc();
986     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
987         rewriter, loc, input, newInputType);
988     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
989         rewriter, loc, kernel, newKernelType);
990     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
991         rewriter, loc, output, newOutputType);
992 
993     // Rank-reduce strides and dilations too.
994     // TODO: dropDim 1-liner helper.
995     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
996     strides.erase(strides.begin() + (removeH ? 0 : 1));
997     auto stridesAttr = rewriter.getI64VectorAttr(strides);
998 
999     auto dilations =
1000         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1001     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1002     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1003 
1004     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1005         loc, newOutputType, ValueRange{newInput, newKernel},
1006         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1007 
1008     // Insert back.
1009     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1010         rewriter, loc, conv1DOp.getResult(0), output);
1011     rewriter.replaceOp(convOp, inserted);
1012 
1013     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1014     return success();
1015   };
1016 
1017 private:
1018   /// LinalgTransformMarker handles special attribute manipulations.
1019   LinalgTransformationFilter filter;
1020 };
1021 
1022 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1023 /// dimensions into 1-D depthwise convolution ops.
1024 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1025     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1026   DownscaleDepthwiseConv2DNhwcHwcOp(
1027       MLIRContext *context,
1028       LinalgTransformationFilter filter = LinalgTransformationFilter(),
1029       PatternBenefit benefit = 1)
1030       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1031         filter(filter) {}
1032 
1033   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1034                                 PatternRewriter &rewriter) const override {
1035     if (failed(filter.checkAndNotify(rewriter, convOp)))
1036       return failure();
1037     if (convOp.hasBufferSemantics())
1038       return failure(); // To be implemented
1039 
1040     Value input = convOp.inputs().front();
1041     Value kernel = convOp.inputs().back();
1042     Value output = convOp.outputs().front();
1043 
1044     auto inputType = input.getType().dyn_cast<RankedTensorType>();
1045     auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1046     auto outputType = output.getType().dyn_cast<RankedTensorType>();
1047 
1048     auto kernelShape = kernelType.getShape();
1049     auto outputShape = outputType.getShape();
1050 
1051     // Only handle the case where at least one of the window dimensions is
1052     // of size 1. Other cases can rely on tiling to reduce to such cases.
1053     int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1054     int64_t ohSize = outputShape[1], owSize = outputShape[2];
1055     bool removeH = (khSize == 1 && ohSize == 1);
1056     bool removeW = (kwSize == 1 && owSize == 1);
1057     if (!removeH && !removeW)
1058       return failure();
1059 
1060     // Get new shapes and types for all operands by removing the size-1
1061     // dimension.
1062     using RTTBuilder = RankedTensorType::Builder;
1063     RankedTensorType newInputType =
1064         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1065     RankedTensorType newKernelType =
1066         RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1067     RankedTensorType newOutputType =
1068         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1069 
1070     // Rank-reduce operands.
1071     Location loc = convOp.getLoc();
1072     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
1073         rewriter, loc, input, newInputType);
1074     Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
1075         rewriter, loc, kernel, newKernelType);
1076     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
1077         rewriter, loc, output, newOutputType);
1078 
1079     // Rank-reduce strides and dilations too.
1080     // TODO: dropDim 1-liner helper.
1081     auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1082     strides.erase(strides.begin() + (removeH ? 0 : 1));
1083     auto stridesAttr = rewriter.getI64VectorAttr(strides);
1084 
1085     auto dilations =
1086         llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1087     dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1088     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1089 
1090     auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1091         loc, newOutputType, ValueRange{newInput, newKernel},
1092         ValueRange{newOutput}, stridesAttr, dilationsAttr);
1093 
1094     // Insert back.
1095     Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
1096         rewriter, loc, conv1DOp.getResult(0), output);
1097     rewriter.replaceOp(convOp, inserted);
1098 
1099     filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1100     return success();
1101   };
1102 
1103 private:
1104   /// LinalgTransformMarker handles special attribute manipulations.
1105   LinalgTransformationFilter filter;
1106 };
1107 
1108 } // namespace
1109 
1110 void linalg::populateDecomposeConvolutionPatterns(
1111     RewritePatternSet &patterns, LinalgTransformationFilter filter,
1112     PatternBenefit benefit) {
1113   patterns.add<DownscaleSizeOneWindowed2DConvolution,
1114                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1115                                                   benefit);
1116 }
1117