1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 
47 //===----------------------------------------------------------------------===//
48 // Transformations exposed as rewrite patterns.
49 //===----------------------------------------------------------------------===//
50 // Marker used as attribute name in generated Linalg rewriting transformations.
51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
52     "__internal_linalg_transform__";
53 
LinalgTransformationFilter(ArrayRef<StringAttr> matchDisjunction,Optional<StringAttr> replacement)54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
55     ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
56     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
57       replacement(replacement), matchByDefault(false) {}
58 
LinalgTransformationFilter(const FilterFunction & f,ArrayRef<StringAttr> matchDisjunction,Optional<StringAttr> replacement)59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
60     const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
61     Optional<StringAttr> replacement)
62     : filters(),
63       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
64       replacement(replacement), matchByDefault(false) {
65   if (f)
66     filters.push_back(f);
67 }
68 
checkAndNotify(PatternRewriter & rewriter,Operation * op) const69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
70     PatternRewriter &rewriter, Operation *op) const {
71   if (llvm::any_of(filters,
72                    [&](const FilterFunction &f) { return failed(f(op)); }))
73     return failure();
74 
75   auto attr = op->template getAttrOfType<StringAttr>(
76       LinalgTransforms::kLinalgTransformMarker);
77 
78   if (!attr) {
79     // 1. Has no filter case and matchDisjunction is empty.
80     if (matchDisjunction.empty() || matchByDefault)
81       return success();
82 
83     // 2. Has no filter but was expecting a filter.
84     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
85       diag << " does not have any filter from list: ";
86       interleaveComma(matchDisjunction, diag);
87     });
88   }
89 
90   // 4. Match explicit filter.
91   for (auto filter : matchDisjunction)
92     if (attr.getValue() == filter)
93       return success();
94 
95   // 5. Fail to match.
96   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
97     diag << " does not have any filter from list: ";
98     interleaveComma(matchDisjunction, diag);
99   });
100 }
101 
102 void mlir::linalg::LinalgTransformationFilter::
replaceLinalgTransformationFilter(PatternRewriter & rewriter,Operation * op) const103     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
104                                       Operation *op) const {
105   if (replacement.has_value())
106     op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
107   else
108     op->removeAttr(
109         rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
110 }
111 
hasReplacementFilter(Operation * op) const112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter(
113     Operation *op) const {
114   if (!replacement)
115     return false;
116   auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
117                   .dyn_cast<StringAttr>();
118   return attr && attr == *replacement;
119 }
120 
121 LinalgTilingOptions &
setTileSizes(ArrayRef<int64_t> ts)122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
123   assert(!tileSizeComputationFunction && "tile sizes already set");
124   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
125   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
126     OpBuilder::InsertionGuard guard(b);
127     b.setInsertionPointToStart(
128         &op->getParentOfType<func::FuncOp>().getBody().front());
129     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
130       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
131       return v;
132     }));
133   };
134   return *this;
135 }
136 
scalarizeDynamicDims()137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
138   assert(!tileSizeComputationFunction && "tile sizes already set");
139   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
140     SmallVector<Value, 4> tileSizes;
141     auto linalgOp = dyn_cast<LinalgOp>(op);
142     if (!linalgOp)
143       return tileSizes;
144     Location loc = linalgOp.getLoc();
145     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
146     AffineMap map = linalgOp.getShapesToLoopsMap();
147     if (!map)
148       return tileSizes;
149     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
150     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
151     // size 0).
152     for (Value shapeSize : shapeSizes)
153       tileSizes.push_back(getConstantIntValue(shapeSize)
154                               ? b.create<arith::ConstantIndexOp>(loc, 0)
155                               : b.create<arith::ConstantIndexOp>(loc, 1));
156     return tileSizes;
157   };
158   return *this;
159 }
160 
161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
163 /// Exit early and return the `opOperand` value if the shape dimensions that
164 /// match `paddingDimensions` have a static size and the nofold flag is not set.
165 /// Otherwise, try to pad the shape dimensions that match the iterator
166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
167 /// padding succeeds or failure otherwise.
padOperandToSmallestStaticBoundingBox(OpBuilder & b,linalg::LinalgOp opToPad,OpOperand * opOperand,ArrayRef<int64_t> paddingDimensions,ArrayRef<Attribute> paddingValues,ArrayRef<bool> packPaddings)168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
169     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170     ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
171     ArrayRef<bool> packPaddings) {
172   AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
173   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
174 
175   // Collect the shape dimension that are a function of the `paddingDimensions`.
176   llvm::SmallDenseSet<int64_t> shapeDimsToPad;
177   for (int64_t dim : paddingDimensions)
178     for (const auto &en : enumerate(indexingMap.getResults()))
179       if (en.value().isFunctionOfDim(dim))
180         shapeDimsToPad.insert(en.index());
181 
182   // Return the unpadded operand if padding to a static shape is not needed and
183   // if the nofold flag is not set.
184   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
185                     ? packPaddings[opOperand->getOperandNumber()]
186                     : false;
187   bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
188     return ShapedType::isDynamic(shape[dim]);
189   });
190   if (!nofold && hasStaticShape)
191     return opOperand->get();
192 
193   // Fail if `paddingValues` specifies no padding value.
194   if (opOperand->getOperandNumber() >= paddingValues.size())
195     return failure();
196   Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
197   Value paddingValue = b.create<arith::ConstantOp>(
198       opToPad.getLoc(), paddingAttr.getType(), paddingAttr);
199 
200   // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
201   OpOperand *currOpOperand = opOperand;
202   while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
203     OpResult result = currOpOperand->get().cast<OpResult>();
204     currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
205   }
206 
207   // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
208   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
209   if (!sliceOp)
210     return failure();
211 
212   // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
213   llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
214   OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
215 
216   // Upper bound the `sliceOp` sizes to obtain a static bounding box.
217   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
218   int64_t shapeIdx = 0;
219   for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
220     // Skip dropped dimensions.
221     if (droppedDims.test(en.index()))
222       continue;
223     // Skip dimensions that do not require padding.
224     if (!shapeDimsToPad.contains(shapeIdx)) {
225       shapeIdx++;
226       continue;
227     }
228     // If the size is an attribute add it directly to `paddedShape`.
229     if (en.value().is<Attribute>()) {
230       paddedShape[shapeIdx++] =
231           en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
232       continue;
233     }
234     // Otherwise, try to compute a constant upper bound for the size value.
235     FailureOr<int64_t> upperBound =
236         getConstantUpperBoundForIndex(en.value().get<Value>());
237     if (failed(upperBound)) {
238       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
239       return failure();
240     }
241     paddedShape[shapeIdx++] = *upperBound;
242   }
243   assert(shapeIdx == static_cast<int64_t>(shape.size()) &&
244          "expect the dynamic and static ranks to match");
245 
246   // Pad the operand to the bounding box defined by `paddedShape`.
247   auto paddedTensorType = RankedTensorType::get(
248       paddedShape, getElementTypeOrSelf(opOperand->get()));
249   return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType,
250                                opOperand->get(), paddingValue, nofold);
251 }
252 
253 FailureOr<SmallVector<Value>>
rewriteAsPaddedOp(OpBuilder & b,LinalgOp opToPad,ArrayRef<int64_t> paddingDimensions,ArrayRef<Attribute> paddingValues,ArrayRef<bool> packPaddings,LinalgOp & paddedOp)254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255                           ArrayRef<int64_t> paddingDimensions,
256                           ArrayRef<Attribute> paddingValues,
257                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
258   Location loc = opToPad->getLoc();
259 
260   // TODO: there are cases where we may still want to pad to larger sizes.
261   assert(opToPad.hasTensorSemantics() &&
262          "expected operation to have tensor semantics");
263 
264   OpBuilder::InsertionGuard g(b);
265   // Set IP after op because we also take the dims of the original output.
266   b.setInsertionPointAfter(opToPad);
267   // Make a copy of the shaped operands and update it.
268   SmallVector<Value> newOperands;
269   newOperands.reserve(opToPad.getNumInputsAndOutputs());
270   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
271     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
272         b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273     // Exit if `paddingDimensions` cannot be bounded statically.
274     if (failed(paddedOperand))
275       return failure();
276     newOperands.push_back(*paddedOperand);
277   }
278 
279   SmallVector<SmallVector<Value>> reifiedResultShapes;
280   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
281                  .reifyResultShapes(b, reifiedResultShapes)))
282     return failure();
283   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
284          "expected same number of results");
285 
286   // Clone `opToPad` to operate on the statically padded shapes.
287   auto resultTensorTypes =
288       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
289   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
290 
291   // Recover the slice out of the new static results. This keeps the original
292   // linalg op around because it uses the dims of the original results.
293   SmallVector<Value> paddedSubviewResults;
294   paddedSubviewResults.reserve(opToPad->getNumResults());
295   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
296     Value paddedResult = en.value();
297     int64_t resultNumber = en.index();
298     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
299     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
300     SmallVector<OpFoldResult> sizes;
301     for (Value v : reifiedResultShapes[resultNumber])
302       sizes.push_back(getAsOpFoldResult(v));
303     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
304     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
305         loc, paddedResult, offsets, sizes, strides));
306   }
307   return paddedSubviewResults;
308 }
309 
310 /// Try to peel a loop `op` and return the new result.
311 // TODO: Add support for scf.parallel and affine.for loops.
peelLoop(RewriterBase & rewriter,Operation * op)312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
313   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
314       .Case<scf::ForOp>([&](scf::ForOp forOp) {
315         scf::ForOp partialIteration;
316         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
317                                                       partialIteration)))
318           return partialIteration->getResults();
319         assert(!partialIteration && "expected that loop was not peeled");
320         return forOp->getResults();
321       })
322       .Default([&](Operation *op) { return op->getResults(); });
323 }
324 
325 /// Peel and canonicalize 'loops'.
peelLoops(RewriterBase & rewriter,ArrayRef<scf::ForOp> loops)326 void mlir::linalg::peelLoops(RewriterBase &rewriter,
327                              ArrayRef<scf::ForOp> loops) {
328   for (auto loopOp : loops) {
329     SmallVector<Value, 4> loopResults;
330     loopResults = peelLoop(rewriter, loopOp);
331   }
332 }
333 
334 /// Peel loops after tiling.
peelTiledLinalgOp(RewriterBase & rewriter,TiledLinalgOp & res,ArrayRef<int64_t> peeledLoops,LinalgTilingLoopType loopType)335 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
336                                      ArrayRef<int64_t> peeledLoops,
337                                      LinalgTilingLoopType loopType) {
338   for (int64_t loop : peeledLoops) {
339     assert(loop < static_cast<int64_t>(res.loops.size()) &&
340            "requested peeling of non-existing loop");
341     SmallVector<Value, 4> loopResults;
342     Operation *loopOp = res.loops[loop];
343     loopResults = peelLoop(rewriter, loopOp);
344 
345     // The result of the loop nest may change with peeling.
346     if (res.tensorResults.size() == loopOp->getNumResults() &&
347         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
348                    loopOp->getResults().begin()))
349       res.tensorResults = loopResults;
350   }
351 }
352 
353 /// Linalg tiling pattern.
LinalgTilingPattern(MLIRContext * context,LinalgTilingOptions options,LinalgTransformationFilter f,PatternBenefit benefit)354 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
355     MLIRContext *context, LinalgTilingOptions options,
356     LinalgTransformationFilter f, PatternBenefit benefit)
357     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
358       filter(std::move(f)), options(std::move(options)) {}
359 
LinalgTilingPattern(StringRef opName,MLIRContext * context,LinalgTilingOptions options,LinalgTransformationFilter f,PatternBenefit benefit)360 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
361     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
362     LinalgTransformationFilter f, PatternBenefit benefit)
363     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
364       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
365 
366 FailureOr<TiledLinalgOp>
returningMatchAndRewrite(LinalgOp op,PatternRewriter & rewriter) const367 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
368     LinalgOp op, PatternRewriter &rewriter) const {
369   if (failed(filter.checkAndNotify(rewriter, op)))
370     return failure();
371 
372   FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
373   if (failed(res))
374     return failure();
375 
376   // Clear filter to stop recursive pattern application.
377   // This must be done here to properly propagate to peeling branches.
378   filter.replaceLinalgTransformationFilter(rewriter, res->op);
379 
380   // Peel the loops of the TiledLinalgOp.
381   peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
382 
383   if (res->tensorResults.empty())
384     rewriter.eraseOp(op);
385   else
386     rewriter.replaceOp(op, res->tensorResults);
387 
388   return res;
389 }
390 
391 /// Linalg padding pattern.
LinalgPaddingPattern(MLIRContext * context,LinalgPaddingOptions options,LinalgTransformationFilter f,PatternBenefit benefit)392 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
393     MLIRContext *context, LinalgPaddingOptions options,
394     LinalgTransformationFilter f, PatternBenefit benefit)
395     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
396       filter(std::move(f)), options(std::move(options)) {}
397 
LinalgPaddingPattern(StringRef opName,MLIRContext * context,LinalgPaddingOptions options,LinalgTransformationFilter f,PatternBenefit benefit)398 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
399     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
400     LinalgTransformationFilter f, PatternBenefit benefit)
401     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
402       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
403 
404 FailureOr<LinalgOp>
returningMatchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const405 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
406     LinalgOp linalgOp, PatternRewriter &rewriter) const {
407   if (!linalgOp.hasTensorSemantics())
408     return failure();
409   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
410     return failure();
411 
412   // Pad the operation.
413   LinalgOp paddedOp;
414   FailureOr<SmallVector<Value>> newResults =
415       rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
416                         options.paddingValues, options.packPaddings, paddedOp);
417   if (failed(newResults))
418     return failure();
419 
420   // Hoist the padding.
421   for (const auto &en : enumerate(options.hoistPaddings)) {
422     if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
423       break;
424     OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
425     auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
426     if (!padOp || en.value() == 0)
427       continue;
428 
429     // Fail hoisting if the operand shape is not fully static.
430     if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic))
431       return failure();
432 
433     tensor::PadOp hoistedOp;
434     SmallVector<GenericOp> transposeOps;
435     SmallVector<int64_t> transposeVector =
436         en.index() < options.transposePaddings.size()
437             ? options.transposePaddings[en.index()]
438             : SmallVector<int64_t>{};
439 
440     FailureOr<Value> newResult = hoistPaddingOnTensors(
441         padOp, en.value(), transposeVector, hoistedOp, transposeOps);
442     if (failed(newResult))
443       continue;
444     rewriter.replaceOp(padOp, *newResult);
445 
446     // Do not apply hoist padding to the newly introduced transpose operations.
447     for (GenericOp transposeOp : transposeOps)
448       filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
449   }
450 
451   // Replace the original operation to pad.
452   rewriter.replaceOp(linalgOp, *newResults);
453   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
454 
455   return paddedOp;
456 }
457 
458 /// Linalg tile and fuse tensor ops pattern.
459 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
LinalgTileAndFuseTensorOpsPattern(MLIRContext * context,LinalgTilingAndFusionOptions options,LinalgTransformationFilter f,PatternBenefit benefit)460     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
461                                       LinalgTilingAndFusionOptions options,
462                                       LinalgTransformationFilter f,
463                                       PatternBenefit benefit)
464     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
465       filter(std::move(f)), options(std::move(options)) {}
466 
467 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
LinalgTileAndFuseTensorOpsPattern(StringRef opName,MLIRContext * context,LinalgTilingAndFusionOptions options,LinalgTransformationFilter f,PatternBenefit benefit)468     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
469                                       LinalgTilingAndFusionOptions options,
470                                       LinalgTransformationFilter f,
471                                       PatternBenefit benefit)
472     : RewritePattern(opName, benefit, context), filter(std::move(f)),
473       options(std::move(options)) {}
474 
475 FailureOr<mlir::linalg::TileLoopNest>
returningMatchAndRewrite(Operation * op,PatternRewriter & rewriter) const476 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
477     Operation *op, PatternRewriter &rewriter) const {
478   LinalgOp rootOp = dyn_cast<LinalgOp>(op);
479   if (!rootOp)
480     return failure();
481   if (failed(filter.checkAndNotify(rewriter, op)))
482     return failure();
483 
484   // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
485   if (options.tileSizes.size() < rootOp.getNumLoops())
486     return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
487 
488   // Check `tileInterchange` contains no entries or as many as `tileSizes`.
489   if (!options.tileInterchange.empty() &&
490       options.tileInterchange.size() != options.tileSizes.size())
491     return rewriter.notifyMatchFailure(
492         op, "expect the number of tile sizes and interchange dims to match");
493 
494   // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
495   SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
496                                      options.tileSizes.begin() +
497                                          rootOp.getNumLoops());
498   SmallVector<int64_t> rootInterchange =
499       options.tileInterchange.empty()
500           ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
501           : SmallVector<int64_t>(options.tileInterchange.begin(),
502                                  options.tileInterchange.begin() +
503                                      rootOp.getNumLoops());
504 
505   // Check `rootTileSizes` contains non-zero tile sizes.
506   if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
507     return rewriter.notifyMatchFailure(
508         op, "expect at least one non-zero tile size");
509 
510   // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
511   // It has to be a permutation since the tiling cannot tile the same loop
512   // dimension multiple times.
513   if (!isPermutation(rootInterchange))
514     return rewriter.notifyMatchFailure(
515         op, "expect the tile interchange permutes the root loops");
516 
517   // Tile `rootOp` and fuse its producers.
518   FailureOr<TileLoopNest> tileLoopNest =
519       tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
520                                    rootInterchange, options.tileDistribution);
521   if (failed(tileLoopNest))
522     return rewriter.notifyMatchFailure(
523         op, "tileConsumerAndFuseProducers failed unexpectedly");
524 
525   // Replace all uses of the tiled loop operation.
526   rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
527 
528   // Apply the filter if specified.
529   for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
530     filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
531   return tileLoopNest;
532 }
533 
534 /// Linalg generic interchange pattern.
GenericOpInterchangePattern(MLIRContext * context,ArrayRef<unsigned> interchangeVector,LinalgTransformationFilter f,PatternBenefit benefit)535 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
536     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
537     LinalgTransformationFilter f, PatternBenefit benefit)
538     : OpRewritePattern(context, benefit), filter(std::move(f)),
539       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
540 
541 FailureOr<GenericOp>
returningMatchAndRewrite(GenericOp genericOp,PatternRewriter & rewriter) const542 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
543     GenericOp genericOp, PatternRewriter &rewriter) const {
544   if (failed(filter.checkAndNotify(rewriter, genericOp)))
545     return failure();
546 
547   FailureOr<GenericOp> transformedOp =
548       interchangeGenericOp(rewriter, genericOp, interchangeVector);
549   if (failed(transformedOp))
550     return failure();
551 
552   // New filter if specified.
553   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
554   return transformedOp;
555 }
556 
557 /// Linalg generalization pattern.
LinalgGeneralizationPattern(MLIRContext * context,LinalgTransformationFilter f,PatternBenefit benefit)558 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
559     MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
560     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
561       filter(std::move(f)) {}
562 
LinalgGeneralizationPattern(StringRef opName,MLIRContext * context,LinalgTransformationFilter f,PatternBenefit benefit)563 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
564     StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
565     PatternBenefit benefit)
566     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
567       filter(f.addOpNameFilter(opName)) {}
568 
569 FailureOr<GenericOp>
returningMatchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const570 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
571     LinalgOp linalgOp, PatternRewriter &rewriter) const {
572   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
573     return failure();
574   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
575   if (failed(genericOp))
576     return failure();
577   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
578   return genericOp;
579 }
580 
LinalgPeelingPattern(MLIRContext * context,LinalgTransformationFilter f,LinalgPeelOptions options,PatternBenefit benefit)581 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
582     MLIRContext *context, LinalgTransformationFilter f,
583     LinalgPeelOptions options, PatternBenefit benefit)
584     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
585       filter(std::move(f)), options(std::move(options)) {}
586 
LinalgPeelingPattern(StringRef opName,MLIRContext * context,LinalgPeelOptions options,LinalgTransformationFilter f,PatternBenefit benefit)587 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
588     StringRef opName, MLIRContext *context, LinalgPeelOptions options,
589     LinalgTransformationFilter f, PatternBenefit benefit)
590     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
591       filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
592 
matchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const593 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
594     LinalgOp linalgOp, PatternRewriter &rewriter) const {
595   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
596     return failure();
597 
598   // Increase marker counter even if peeling doesn't happen for this op.
599   filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
600 
601   if (!options.loopsToPeelComputationFunction)
602     return failure();
603 
604   SmallVector<scf::ForOp, 4> loopsToPeel;
605   options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
606   peelLoops(rewriter, loopsToPeel);
607   return success();
608 }
609 
LinalgVectorizationPattern(MLIRContext * context,LinalgTransformationFilter f,LinalgVectorizationOptions options,PatternBenefit benefit)610 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
611     MLIRContext *context, LinalgTransformationFilter f,
612     LinalgVectorizationOptions options, PatternBenefit benefit)
613     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
614       filter(std::move(f)) {}
615 
LinalgVectorizationPattern(StringRef opName,MLIRContext * context,LinalgVectorizationOptions options,LinalgTransformationFilter f,PatternBenefit benefit)616 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
617     StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
618     LinalgTransformationFilter f, PatternBenefit benefit)
619     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
620       filter(f.addOpNameFilter(opName)) {}
621 
matchAndRewrite(LinalgOp linalgOp,PatternRewriter & rewriter) const622 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
623     LinalgOp linalgOp, PatternRewriter &rewriter) const {
624   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
625     return failure();
626   return vectorize(rewriter, linalgOp);
627 }
628 
matchAndRewrite(memref::CopyOp copyOp,PatternRewriter & rewriter) const629 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
630     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
631   return vectorizeCopy(rewriter, copyOp);
632 }
633 
applyStagedPatterns(Operation * op,ArrayRef<FrozenRewritePatternSet> stage1Patterns,const FrozenRewritePatternSet & stage2Patterns,function_ref<LogicalResult (Operation *)> stage3Lambda)634 LogicalResult mlir::linalg::applyStagedPatterns(
635     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
636     const FrozenRewritePatternSet &stage2Patterns,
637     function_ref<LogicalResult(Operation *)> stage3Lambda) {
638   unsigned iteration = 0;
639   (void)iteration;
640   for (const auto &patterns : stage1Patterns) {
641     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
642                       << *op);
643     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
644       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
645       return failure();
646     }
647     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
648                       << *op);
649     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
650       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
651       return failure();
652     }
653     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
654                       << *op);
655     if (stage3Lambda) {
656       if (failed(stage3Lambda(op)))
657         return failure();
658       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
659                         << *op);
660     }
661   }
662   return success();
663 }
664 
getNParallelLoopsAttrs(unsigned nParallelLoops)665 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
666   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
667 }
668 
669 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
670 /// initialize with pad_val) and GenericOp (to copy contents).
671 LogicalResult
matchAndRewrite(tensor::PadOp padOp,PatternRewriter & rewriter) const672 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
673                                             PatternRewriter &rewriter) const {
674 
675   auto inputShapedType = padOp.getSource().getType().cast<ShapedType>();
676   auto resultShapedType = padOp.getResult().getType().cast<ShapedType>();
677 
678   // Bail on non-static shapes.
679   if (!inputShapedType.hasStaticShape())
680     return failure();
681   if (!resultShapedType.hasStaticShape())
682     return failure();
683 
684   // Only support padding with a constant for now, i.e. either:
685   //   1. A BBarg from a different block.
686   //   2. A value defined outside of the current block.
687   Block &block = padOp.getRegion().front();
688   auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
689   Value padValue = yieldOp.getValue();
690   Operation *definingOp = padValue.getDefiningOp();
691   if (definingOp && definingOp->getBlock() == &block)
692     return failure();
693   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
694     return failure();
695 
696   // Create tensor with the padded shape
697   Location loc = padOp.getLoc();
698   SmallVector<Value> indices(resultShapedType.getRank(),
699                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
700   Value initTensor = rewriter.create<InitTensorOp>(
701       loc, resultShapedType.getShape(), resultShapedType.getElementType());
702 
703   // Initialize tensor with the pad value
704   Value tmpTensor = rewriter
705                         .create<linalg::FillOp>(loc, ValueRange{padValue},
706                                                 ValueRange{initTensor})
707                         .result();
708 
709   // Copy original contents into new tensor
710   // Uses linalg.generic, but could be done with tensor.insert_slice
711   SmallVector<AffineExpr, 4> outputExprs;
712   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
713     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
714                           padOp.getStaticLow()[i].cast<IntegerAttr>().getInt());
715   }
716 
717   SmallVector<AffineMap, 2> transferMaps = {
718       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
719       AffineMap::get(resultShapedType.getRank(),
720                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
721 
722   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
723       padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps,
724       getNParallelLoopsAttrs(resultShapedType.getRank()),
725       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
726         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
727       });
728 
729   return success();
730 }
731 
732 /// Filling `dest` using FillOp constant padding value if possible.
733 /// Otherwise, generate a tensor::GenerateOp.
createFillOrGenerateOp(PatternRewriter & rewriter,tensor::PadOp padOp,Value dest,const SmallVector<Value> & dynSizes) const734 Value GeneralizePadOpPattern::createFillOrGenerateOp(
735     PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
736     const SmallVector<Value> &dynSizes) const {
737   auto padValue = padOp.getConstantPaddingValue();
738   if (padValue)
739     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
740 
741   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
742   auto generateOp = rewriter.create<tensor::GenerateOp>(
743       padOp.getLoc(), padOp.getResultType(), dynSizes);
744   // Copy region to new op.
745   BlockAndValueMapping bvm;
746   padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
747   return generateOp;
748 }
749 
750 LogicalResult
matchAndRewrite(tensor::PadOp padOp,PatternRewriter & rewriter) const751 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
752                                         PatternRewriter &rewriter) const {
753   // Given an OpFoldResult, return an index-typed value.
754   auto getIdxValue = [&](OpFoldResult ofr) {
755     if (auto val = ofr.dyn_cast<Value>())
756       return val;
757     return rewriter
758         .create<arith::ConstantIndexOp>(
759             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
760         .getResult();
761   };
762 
763   auto resultType = padOp.getResultType();
764   // Compute size of InitTensorOp. Any combination of static/dynamic is
765   // supported.
766   SmallVector<Value> dynSizes;
767   SmallVector<int64_t> staticSizes;
768   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
769     if (resultType.isDynamicDim(dim)) {
770       auto srcSize = rewriter.createOrFold<tensor::DimOp>(
771           padOp.getLoc(), padOp.getSource(), dim);
772       // Add low and high padding value.
773       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
774           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
775       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
776           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
777       dynSizes.push_back(plusHigh);
778     }
779     staticSizes.push_back(resultType.getDimSize(dim));
780   }
781 
782   // Init tensor and fill it with padding.
783   Value init = rewriter.create<InitTensorOp>(
784       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
785   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
786 
787   // Try optimize the copy of source.
788   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
789     return success();
790 
791   // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
792   // for copying the PadOp source.
793   auto sourceType = padOp.getSourceType();
794   // Compute size of source of tensor::PadOp.
795   SmallVector<OpFoldResult> srcSizes;
796   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
797     if (sourceType.isDynamicDim(dim)) {
798       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
799           padOp.getLoc(), padOp.getSource(), dim));
800     } else {
801       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
802     }
803   }
804   // Strides of InsertSliceOp are all 1.
805   SmallVector<OpFoldResult> strides(sourceType.getRank(),
806                                     rewriter.getIndexAttr(1));
807   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
808       padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
809       strides);
810 
811   return success();
812 }
813 
matchAndRewrite(tensor::ExtractSliceOp sliceOp,PatternRewriter & rewriter) const814 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
815     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
816   if (!sliceOp.hasUnitStride())
817     return failure();
818 
819   auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
820   if (!padOp)
821     return failure();
822 
823   bool zeroSliceGuard = true;
824   if (controlFn) {
825     if (Optional<bool> control = controlFn(sliceOp))
826       zeroSliceGuard = *control;
827     else
828       return failure();
829   }
830 
831   Operation *tiledPadOp =
832       tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
833                                sliceOp.getMixedSizes(), zeroSliceGuard);
834   // All shapes are static and the data source is actually used. Rewrite into
835   // pad(extract_slice(x)).
836   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
837   return success();
838 }
839 
840 // The following are patterns for downscaling convolution ops with size-1
841 // window dimensions.
842 //
843 // Note that we'd eventually want to write such transformations in a generic
844 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
845 // and then turning back to named ops. But for now it's fine to have a few
846 // patterns matching special ops to get started.
847 
848 FailureOr<Conv1DNwcWcfOp>
returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,PatternRewriter & rewriter) const849 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
850     linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
851   if (failed(filter.checkAndNotify(rewriter, convOp)))
852     return failure();
853   if (convOp.hasBufferSemantics())
854     return failure(); // To be implemented.
855 
856   Value input = convOp.inputs().front();
857   Value kernel = convOp.inputs().back();
858   Value output = convOp.outputs().front();
859 
860   auto inputType = input.getType().dyn_cast<RankedTensorType>();
861   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
862   auto outputType = output.getType().dyn_cast<RankedTensorType>();
863 
864   auto kernelShape = kernelType.getShape();
865   auto outputShape = outputType.getShape();
866 
867   // Only handle the case where at least one of the window dimensions is
868   // of size 1. Other cases can rely on tiling to reduce to such cases.
869   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
870   int64_t ohSize = outputShape[1], owSize = outputShape[2];
871   bool removeH = (khSize == 1 && ohSize == 1);
872   bool removeW = (kwSize == 1 && owSize == 1);
873   if (!removeH && !removeW)
874     return failure();
875 
876   // Get new shapes and types for all operands by removing the size-1
877   // dimension.
878   using RTTBuilder = RankedTensorType::Builder;
879   RankedTensorType newInputType =
880       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
881   RankedTensorType newKernelType =
882       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
883   RankedTensorType newOutputType =
884       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
885 
886   // Rank-reduce operands.
887   Location loc = convOp.getLoc();
888   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
889       rewriter, loc, input, newInputType);
890   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
891       rewriter, loc, kernel, newKernelType);
892   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
893       rewriter, loc, output, newOutputType);
894 
895   // Rank-reduce strides and dilations too.
896   // TODO: dropDim 1-liner helper.
897   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
898   strides.erase(strides.begin() + (removeH ? 0 : 1));
899   auto stridesAttr = rewriter.getI64VectorAttr(strides);
900 
901   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
902   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
903   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
904 
905   auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
906       loc, newOutputType, ValueRange{newInput, newKernel},
907       ValueRange{newOutput}, stridesAttr, dilationsAttr);
908 
909   // Insert back.
910   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
911       rewriter, loc, conv1DOp.getResult(0), output);
912   rewriter.replaceOp(convOp, inserted);
913 
914   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
915   return conv1DOp;
916 }
917 
918 FailureOr<DepthwiseConv1DNwcWcOp>
returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,PatternRewriter & rewriter) const919 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
920     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
921   if (failed(filter.checkAndNotify(rewriter, convOp)))
922     return failure();
923   if (convOp.hasBufferSemantics())
924     return failure(); // To be implemented.
925 
926   Value input = convOp.inputs().front();
927   Value kernel = convOp.inputs().back();
928   Value output = convOp.outputs().front();
929 
930   auto inputType = input.getType().dyn_cast<RankedTensorType>();
931   auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
932   auto outputType = output.getType().dyn_cast<RankedTensorType>();
933 
934   auto kernelShape = kernelType.getShape();
935   auto outputShape = outputType.getShape();
936 
937   // Only handle the case where at least one of the window dimensions is
938   // of size 1. Other cases can rely on tiling to reduce to such cases.
939   int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
940   int64_t ohSize = outputShape[1], owSize = outputShape[2];
941   bool removeH = (khSize == 1 && ohSize == 1);
942   bool removeW = (kwSize == 1 && owSize == 1);
943   if (!removeH && !removeW)
944     return failure();
945 
946   // Get new shapes and types for all operands by removing the size-1
947   // dimension.
948   using RTTBuilder = RankedTensorType::Builder;
949   RankedTensorType newInputType =
950       RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
951   RankedTensorType newKernelType =
952       RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
953   RankedTensorType newOutputType =
954       RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
955 
956   // Rank-reduce operands.
957   Location loc = convOp.getLoc();
958   Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
959       rewriter, loc, input, newInputType);
960   Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
961       rewriter, loc, kernel, newKernelType);
962   Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
963       rewriter, loc, output, newOutputType);
964 
965   // Rank-reduce strides and dilations too.
966   // TODO: dropDim 1-liner helper.
967   auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
968   strides.erase(strides.begin() + (removeH ? 0 : 1));
969   auto stridesAttr = rewriter.getI64VectorAttr(strides);
970 
971   auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
972   dilations.erase(dilations.begin() + (removeH ? 0 : 1));
973   auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
974 
975   auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
976       loc, newOutputType, ValueRange{newInput, newKernel},
977       ValueRange{newOutput}, stridesAttr, dilationsAttr);
978 
979   // Insert back.
980   Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
981       rewriter, loc, conv1DOp.getResult(0), output);
982   rewriter.replaceOp(convOp, inserted);
983 
984   filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
985   return conv1DOp;
986 }
987 
populateDecomposeConvolutionPatterns(RewritePatternSet & patterns,const LinalgTransformationFilter & filter,PatternBenefit benefit)988 void linalg::populateDecomposeConvolutionPatterns(
989     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
990     PatternBenefit benefit) {
991   patterns.add<DownscaleSizeOneWindowed2DConvolution,
992                DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
993                                                   benefit);
994 }
995