1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/SCF/Transforms.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/StaticValueUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/VectorOps.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <type_traits>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
58     Optional<Identifier> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 rewriter.getStringAttr(replacement.getValue().strref()));
105   else
106     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
107                                    rewriter.getContext()));
108 }
109 
110 LinalgTilingOptions &
111 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
112   assert(!tileSizeComputationFunction && "tile sizes already set");
113   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
114   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
115     OpBuilder::InsertionGuard guard(b);
116     b.setInsertionPointToStart(
117         &op->getParentOfType<FuncOp>().getBody().front());
118     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
119       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
120       return v;
121     }));
122   };
123   return *this;
124 }
125 
126 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
127   assert(!tileSizeComputationFunction && "tile sizes already set");
128   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
129     SmallVector<Value, 4> tileSizes;
130     auto linalgOp = dyn_cast<LinalgOp>(op);
131     if (!linalgOp)
132       return tileSizes;
133     Location loc = linalgOp.getLoc();
134     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
135     AffineMap map = linalgOp.getShapesToLoopsMap();
136     if (!map)
137       return tileSizes;
138     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
139     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
140     // size 0).
141     for (Value shapeSize : shapeSizes)
142       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
143                               ? b.create<arith::ConstantIndexOp>(loc, 0)
144                               : b.create<arith::ConstantIndexOp>(loc, 1));
145     return tileSizes;
146   };
147   return *this;
148 }
149 
150 /// Helper function that tries to pad `opOperand`. Exit early and return success
151 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
152 /// pad the operand even if it already has a static shape. Set `result` to the
153 /// result of the created PadTensorOp or return failure if the operand cannot be
154 /// padded to a static shape.
155 static LogicalResult padOperandToSmallestStaticBoundingBox(
156     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
157     const PaddingValueComputationFunction &paddingFunc,
158     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
159   // Can't pad scalars.
160   if (opToPad.getShape(opOperand).empty())
161     return success();
162   // Can't pad if no padding value is known.
163   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
164   if (failed(paddingValue))
165     return success();
166   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
167   // Not a slice op, cannot construct a static bounding box.
168   if (!sliceOp)
169     return failure();
170   SmallVector<int64_t> staticSizes;
171   staticSizes.reserve(opToPad.getRank(opOperand));
172   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
173   for (auto size : shapedOp.getMixedSizes()) {
174     auto indexAttr = size.is<Attribute>()
175                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
176                          : linalg::getSmallestBoundingIndex(size.get<Value>());
177     // SmallestBoundingIndex must exist for all sizes.
178     // For now return an error if we can't find it.
179     if (!indexAttr) {
180       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
181       return failure();
182     }
183     staticSizes.push_back(indexAttr.getInt());
184   }
185   auto staticTensorType = RankedTensorType::get(
186       staticSizes, getElementTypeOrSelf(opOperand->get()));
187   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
188   result = linalg::PadTensorOp::createPadHighOp(
189       staticTensorType, opOperand->get(), paddingValue.getValue(),
190       /*nofold=*/nofold, opToPad->getLoc(), b);
191   return success();
192 }
193 
194 FailureOr<SmallVector<Value>>
195 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
196                           const PaddingValueComputationFunction &paddingFunc,
197                           const PaddingNoFoldComputationFunction &nofoldFunc,
198                           LinalgOp &paddedOp) {
199   Location loc = opToPad->getLoc();
200 
201   // TODO: there are cases where we may still want to pad to larger sizes.
202   assert(opToPad.hasTensorSemantics() &&
203          "expected operation to have tensor semantics");
204 
205   OpBuilder::InsertionGuard g(b);
206   // Set IP after op because we also take the dims of the original output.
207   b.setInsertionPointAfter(opToPad);
208   // Make a copy of the shaped operands and update it.
209   SmallVector<Value> newOperands;
210   newOperands.reserve(opToPad.getNumInputsAndOutputs());
211   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
212     Value paddedOperand;
213     // If padding was requested but the shape cannot be bounded statically then
214     // the pattern fails to apply.
215     if (failed(padOperandToSmallestStaticBoundingBox(
216             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
217       return failure();
218     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
219   }
220 
221   SmallVector<SmallVector<Value>> reifiedResultShapes;
222   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
223                  .reifyResultShapes(b, reifiedResultShapes)))
224     return failure();
225   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
226          "expected same number of results");
227 
228   // Clone `opToPad` to operate on the statically padded shapes.
229   auto resultTensorTypes =
230       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
231   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
232 
233   // Recover the slice out of the new static results. This keeps the original
234   // linalg op around because it uses the dims of the original results.
235   SmallVector<Value> paddedSubviewResults;
236   paddedSubviewResults.reserve(opToPad->getNumResults());
237   for (auto en : llvm::enumerate(paddedOp->getResults())) {
238     Value paddedResult = en.value();
239     int64_t resultNumber = en.index();
240     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
241     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
242     SmallVector<OpFoldResult> sizes;
243     for (Value v : reifiedResultShapes[resultNumber])
244       sizes.push_back(v);
245     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
246     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
247         loc, paddedResult, offsets, sizes, strides));
248   }
249   return paddedSubviewResults;
250 }
251 
252 /// Linalg base tiling pattern.
253 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
254     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
255     LinalgTransformationFilter filter, PatternBenefit benefit)
256     : RewritePattern(opName, benefit, context), filter(filter),
257       options(options) {}
258 
259 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
260     MLIRContext *context, LinalgTilingOptions options,
261     LinalgTransformationFilter filter, PatternBenefit benefit)
262     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
263       options(options) {}
264 
265 /// Try to peel a loop `op` and return the new result.
266 // TODO: Add support for scf.parallel and affine.for loops.
267 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
268   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
269       .Case<scf::ForOp>([&](scf::ForOp forOp) {
270         scf::ForOp partialIteration;
271         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
272                                                       partialIteration)))
273           return partialIteration->getResults();
274         assert(!partialIteration && "expected that loop was not peeled");
275         return forOp->getResults();
276       })
277       .Default([&](Operation *op) { return op->getResults(); });
278 }
279 
280 /// Try to peel a TiledLoopOp and return the new result.
281 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
282                                       TiledLoopOp tiledLoop, int64_t idx) {
283   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
284          "requested peeling of non-existing loop");
285   TiledLoopOp result;
286   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
287     return result->getResults();
288   assert(!result && "expected that loop was not peeled");
289   return tiledLoop->getResults();
290 }
291 
292 /// Peel loops after tiling.
293 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
294                       const LinalgTilingOptions &options) {
295   for (int64_t loop : options.peeledLoops) {
296     assert(loop < static_cast<int64_t>(res.loops.size()) &&
297            "requested peeling of non-existing loop");
298     SmallVector<Value, 4> loopResults;
299     Operation *loopOp = res.loops[loop];
300     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
301       assert(llvm::all_of(
302                  res.loops,
303                  [&](Operation *op) { return op == res.loops.front(); }) &&
304              "expected that all loop ops are the same TiledLoopOp");
305       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
306       assert(tiledLoopOp && "expected TiledLoopOp");
307       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
308     } else {
309       loopResults = peelLoop(rewriter, loopOp);
310     }
311 
312     // The result of the loop nest may change with peeling.
313     if (res.tensorResults.size() == loopOp->getNumResults() &&
314         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
315                    loopOp->getResults().begin()))
316       res.tensorResults = loopResults;
317   }
318 }
319 
320 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
321     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
322   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
323   if (!linalgOp)
324     return failure();
325   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
326     return failure();
327 
328   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
329 
330   if (!res)
331     return failure();
332   // Clear filter to stop recursive pattern application.
333   filter.replaceLinalgTransformationFilter(rewriter, res->op);
334 
335   // Peel loops.
336   peelLoops(rewriter, *res, options);
337 
338   // Consider padding on the fly only if the op has tensor semantics.
339   if (!options.paddingValueComputationFunction ||
340       !linalgOp.hasTensorSemantics()) {
341     result = *res;
342     return success();
343   }
344 
345   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
346   // `res.op` is rewritten in static form with padded operands.
347   LinalgOp paddedOp;
348   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
349       rewriter, res->op, options.paddingValueComputationFunction,
350       options.paddingNoFoldComputationFunction, paddedOp);
351   if (succeeded(newResults)) {
352     rewriter.replaceOp(res->op, newResults.getValue());
353     filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
354     res->op = paddedOp;
355     result = *res;
356     // Do not perform replacement of `linalgOp`, let the derived patterns
357     // do this as they see fit, from the resulting TiledLinalgOp.
358     return success();
359   }
360   // Set so RAII guard does not propagate TiledLinalgOp to `result`.
361   return failure();
362 }
363 
364 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
365   if (tiledOp.loops.empty())
366     return tiledOp.op.getOperation()->getResults();
367   return tiledOp.loops.front()->getResults();
368 }
369 
370 static ValueRange
371 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
372   if (tiledAndFusedOp.fusedLoops.empty())
373     return tiledAndFusedOp.op.getOperation()->getResults();
374   return tiledAndFusedOp.fusedLoops.front()->getResults();
375 }
376 
377 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
378     StringRef opName, MLIRContext *context,
379     const LinalgDependenceGraph &dependenceGraph,
380     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
381     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
382     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
383     : RewritePattern(opName, benefit, context, {}),
384       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
385       fusionOptions(fusionOptions), filter(filter),
386       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
387 
388 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
389     Operation *op, PatternRewriter &rewriter) const {
390   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
391   // TODO: remove hasIndexSemantics check once index ops are supported.
392   if (!linalgOp || linalgOp.hasIndexSemantics())
393     return failure();
394   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
395     return failure();
396 
397   DenseSet<Operation *> producers;
398   producers.insert(linalgOp);
399   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
400     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
401     // When looking at dependences into, indexingOp is always OpOperand. We
402     // could assert, but continue if this is not the case.
403     if (!operandNumber)
404       continue;
405     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
406       continue;
407     if (isa<LinalgOp>(dependence.getDependentOp()))
408       producers.insert(dependence.getDependentOp());
409   }
410 
411   SmallVector<LinalgOp, 1> fusionOps;
412   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
413        ++it) {
414     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
415     if (producerLinalgOp && producers.count(producerLinalgOp))
416       fusionOps.push_back(producerLinalgOp);
417   }
418   fusionOps.push_back(linalgOp);
419 
420   SmallVector<Value, 4> tileSizes =
421       tilingOptions.tileSizeComputationFunction(rewriter, op);
422   LinalgTilingOptions instanceTilingOptions = tilingOptions;
423   instanceTilingOptions.setTileSizes(tileSizes);
424   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
425       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
426   if (!tiledAndFusedOps)
427     return failure();
428 
429   // Tile the unfused loops;
430   SmallVector<Value, 4> unfusedLoopTileSizes;
431   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
432   for (auto tileSize : enumerate(tileSizes)) {
433     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
434       unfusedLoopTileSizes.push_back(zero);
435     else
436       unfusedLoopTileSizes.push_back(tileSize.value());
437   }
438   // Tile the loop only if there is a non-zero tile size.
439   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
440     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
441   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
442         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
443           return cst.value() != 0;
444         return true;
445       })) {
446     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
447     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
448     Optional<TiledLinalgOp> unfusedTiledOp =
449         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
450     if (!unfusedTiledOp)
451       return failure();
452     rewriter.replaceOp(tiledAndFusedOps->op,
453                        getTiledOpResult(unfusedTiledOp.getValue()));
454     tiledAndFusedOps->op = unfusedTiledOp->op;
455   }
456   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
457 
458   filter.replaceLinalgTransformationFilter(rewriter,
459                                            tiledAndFusedOps->op.getOperation());
460   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
461     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
462                                                     fusedOp.getOperation());
463   }
464   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
465     originalOpMarker.replaceLinalgTransformationFilter(
466         rewriter, origProducerOp.getOperation());
467   }
468   rewriter.updateRootInPlace(op, [&]() {
469     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
470   });
471   return success();
472 }
473 
474 /// Linalg padding pattern.
475 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
476     MLIRContext *context, LinalgPaddingOptions options,
477     LinalgTransformationFilter filter, PatternBenefit benefit)
478     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
479       options(options) {}
480 
481 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
482     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
483     LinalgTransformationFilter filter, PatternBenefit benefit)
484     : RewritePattern(opName, benefit, context, {}), filter(filter),
485       options(options) {}
486 
487 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
488     Operation *op, PatternRewriter &rewriter) const {
489   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
490   if (!linalgOp)
491     return failure();
492   if (!linalgOp.hasTensorSemantics())
493     return failure();
494   if (failed(filter.checkAndNotify(rewriter, op)))
495     return failure();
496 
497   // Pad the operation.
498   LinalgOp paddedOp;
499   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
500       rewriter, linalgOp, options.paddingValueComputationFunction,
501       options.paddingNoFoldComputationFunction, paddedOp);
502   if (failed(newResults))
503     return failure();
504 
505   // Compute the desired hoisting depths.
506   SmallVector<int64_t> depths;
507   if (options.paddingHoistComputationFunction) {
508     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
509       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
510   }
511 
512   // Hoist the padding.
513   for (auto en : enumerate(depths)) {
514     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
515     auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
516     if (!padTensorOp || en.value() == 0)
517       continue;
518     PadTensorOp hoistedOp;
519     FailureOr<Value> newResult =
520         hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
521     if (failed(newResult))
522       continue;
523     rewriter.replaceOp(padTensorOp, newResult.getValue());
524   }
525 
526   // Replace the original operation to pad.
527   rewriter.replaceOp(op, newResults.getValue());
528   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
529   return success();
530 }
531 
532 /// Linalg generic interchange pattern.
533 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
534     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
535     LinalgTransformationFilter filter, PatternBenefit benefit)
536     : OpRewritePattern(context, benefit), filter(filter),
537       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
538 
539 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
540     GenericOp genericOp, PatternRewriter &rewriter) const {
541   if (failed(filter.checkAndNotify(rewriter, genericOp)))
542     return failure();
543   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
544     return failure();
545 
546   // TODO: figure out how this interplays with named ops. In particular this
547   // should break the named op property.
548   rewriter.updateRootInPlace(genericOp, [&]() {
549     interchangeGenericOp(rewriter, genericOp, interchangeVector);
550     // New filter if specified.
551     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
552   });
553   return success();
554 }
555 
556 /// Linalg generalization pattern.
557 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
558     MLIRContext *context, LinalgTransformationFilter filter,
559     PatternBenefit benefit)
560     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
561 
562 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
563     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
564     PatternBenefit benefit)
565     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
566 
567 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
568     Operation *op, PatternRewriter &rewriter) const {
569   if (failed(filter.checkAndNotify(rewriter, op)))
570     return failure();
571   if (failed(generalizeNamedOpPrecondition(op)))
572     return failure();
573 
574   GenericOp genericOp = generalizeNamedOp(rewriter, op);
575   rewriter.replaceOp(op, genericOp.getResults());
576   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
577   return success();
578 }
579 
580 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
581     MLIRContext *context, LinalgTransformationFilter filter,
582     LinalgPromotionOptions options, PatternBenefit benefit)
583     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
584       options(options) {}
585 
586 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
587     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
588     LinalgTransformationFilter filter, PatternBenefit benefit)
589     : RewritePattern(opName, benefit, context, {}), filter(filter),
590       options(options) {}
591 
592 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
593     Operation *op, PatternRewriter &rewriter) const {
594   if (failed(filter.checkAndNotify(rewriter, op)))
595     return failure();
596   if (failed(promoteSubviewsPrecondition(op, options)))
597     return failure();
598 
599   // TODO: We cannot use root update here. This pattern is creating other ops,
600   // so if the promotion fails, those need to be cleaned up, which doesnt seem
601   // to be happening here. So to fail properly, we should be cloning the op and
602   // deleting the previous op. This needs more investigation.
603   rewriter.startRootUpdate(op);
604   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
605   if (!promotedOp) {
606     rewriter.cancelRootUpdate(op);
607     return op->emitError("subview promotion failed");
608   }
609   rewriter.finalizeRootUpdate(op);
610   filter.replaceLinalgTransformationFilter(rewriter, op);
611   return success();
612 }
613 
614 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
615     MLIRContext *context, LinalgTransformationFilter filter,
616     PatternBenefit benefit)
617     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
618 
619 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
620     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
621     PatternBenefit benefit)
622     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
623 
624 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
625     Operation *op, PatternRewriter &rewriter) const {
626   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
627   if (!linalgOp)
628     return failure();
629   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
630     return failure();
631   SmallVector<Value> newResults;
632   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
633     return failure();
634   if (!newResults.empty())
635     rewriter.replaceOp(op, newResults);
636   else
637     rewriter.eraseOp(op);
638   return success();
639 }
640 
641 LogicalResult mlir::linalg::applyStagedPatterns(
642     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
643     const FrozenRewritePatternSet &stage2Patterns,
644     function_ref<LogicalResult(Operation *)> stage3Lambda) {
645   unsigned iteration = 0;
646   (void)iteration;
647   for (const auto &patterns : stage1Patterns) {
648     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
649                       << *op);
650     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
651       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
652       return failure();
653     }
654     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
655                       << *op);
656     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
657       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
658       return failure();
659     }
660     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
661                       << *op);
662     if (stage3Lambda) {
663       if (failed(stage3Lambda(op)))
664         return failure();
665       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
666                         << *op);
667     }
668   }
669   return success();
670 }
671 
672 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
673   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
674 }
675 
676 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
677 /// with pad_val) and GenericOp (to copy contents).
678 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
679     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
680 
681   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
682   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
683 
684   // Bail on non-static shapes.
685   if (!inputShapedType.hasStaticShape())
686     return failure();
687   if (!resultShapedType.hasStaticShape())
688     return failure();
689 
690   // Only support padding with a constant for now, i.e. either:
691   //   1. A BBarg from a different block.
692   //   2. A value defined outside of the current block.
693   Block &block = padOp.region().front();
694   auto yieldOp = cast<YieldOp>(block.getTerminator());
695   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
696   Value padValue = yieldOp.values().front();
697   Operation *definingOp = padValue.getDefiningOp();
698   if (definingOp && definingOp->getBlock() == &block)
699     return failure();
700   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
701     return failure();
702 
703   // Create tensor with the padded shape
704   Location loc = padOp.getLoc();
705   SmallVector<Value> indices(resultShapedType.getRank(),
706                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
707   Value initTensor = rewriter.create<InitTensorOp>(
708       loc, resultShapedType.getShape(), resultShapedType.getElementType());
709 
710   // Initialize tensor with the pad value
711   Value tmpTensor =
712       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
713 
714   // Copy original contents into new tensor
715   // Uses linalg.generic, but could be done with tensor.insert_slice
716   SmallVector<AffineExpr, 4> outputExprs;
717   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
718     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
719                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
720   }
721 
722   SmallVector<AffineMap, 2> transferMaps = {
723       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
724       AffineMap::get(resultShapedType.getRank(),
725                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
726 
727   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
728       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
729       getNParallelLoopsAttrs(resultShapedType.getRank()),
730       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
731         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
732       });
733 
734   return success();
735 }
736 
737 /// Filling `dest` using FillOp constant padding value if possible.
738 /// Otherwise, generate a tensor::GenerateOp.
739 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
740     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
741     const SmallVector<Value> &dynSizes) const {
742   auto padValue = padOp.getConstantPaddingValue();
743   if (padValue)
744     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
745 
746   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
747   auto generateOp = rewriter.create<tensor::GenerateOp>(
748       padOp.getLoc(), padOp.getResultType(), dynSizes);
749   // Copy region to new op.
750   BlockAndValueMapping bvm;
751   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
752   // Rewrite linalg::YieldOp to tensor::YieldOp.
753   OpBuilder::InsertionGuard guard(rewriter);
754   auto yieldOp =
755       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
756   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
757   assert(yieldOp.values().size() == 1);
758   rewriter.setInsertionPoint(yieldOp);
759   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
760   return generateOp;
761 }
762 
763 LogicalResult
764 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
765                                               PatternRewriter &rewriter) const {
766   // Given an OpFoldResult, return an index-typed value.
767   auto getIdxValue = [&](OpFoldResult ofr) {
768     if (auto val = ofr.dyn_cast<Value>())
769       return val;
770     return rewriter
771         .create<arith::ConstantIndexOp>(
772             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
773         .getResult();
774   };
775 
776   auto resultType = padOp.getResultType();
777   // Compute size of InitTensorOp. Any combination of static/dynamic is
778   // supported.
779   SmallVector<Value> dynSizes;
780   SmallVector<int64_t> staticSizes;
781   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
782     if (resultType.isDynamicDim(dim)) {
783       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
784                                                           padOp.source(), dim);
785       // Add low and high padding value.
786       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
787           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
788       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
789           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
790       dynSizes.push_back(plusHigh);
791     }
792     staticSizes.push_back(resultType.getDimSize(dim));
793   }
794 
795   // Init tensor and fill it with padding.
796   Value init = rewriter.create<InitTensorOp>(
797       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
798   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
799 
800   // Try optimize the copy of source.
801   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
802     return success();
803 
804   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
805   // for copying the PadOp source.
806   auto sourceType = padOp.getSourceType();
807   // Compute size of source of PadTensorOp.
808   SmallVector<OpFoldResult> srcSizes;
809   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
810     if (sourceType.isDynamicDim(dim)) {
811       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
812           padOp.getLoc(), padOp.source(), dim));
813     } else {
814       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
815     }
816   }
817   // Strides of InsertSliceOp are all 1.
818   SmallVector<OpFoldResult> strides(sourceType.getRank(),
819                                     rewriter.getIndexAttr(1));
820   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
821       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
822 
823   return success();
824 }
825 
826 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
827     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
828   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
829   if (!padOp)
830     return failure();
831   // Only unit stride supported.
832   if (!sliceOp.hasUnitStride())
833     return failure();
834 
835   Operation *tiledPadOp = padOp.getTiledImplementation(
836       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
837       sliceOp.getMixedSizes());
838   // All shapes are static and the data source is actually used. Rewrite into
839   // pad_tensor(subtensor(x)).
840   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
841   return success();
842 }
843 
844 namespace {
845 // The following are patterns for downscaling convolution ops with size-1
846 // window dimensions.
847 //
848 // Note that we'd eventually want to write such transformations in a generic
849 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
850 // and then turning back to named ops. But for now it's fine to have a few
851 // patterns matching special ops to get started.
852 
853 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
854 /// convolution ops.
855 struct DownscaleSizeOneWindowed2DConvolution final
856     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
857   using OpRewritePattern::OpRewritePattern;
858 
859   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
860                                 PatternRewriter &rewriter) const override {
861     auto linalgOp = cast<linalg::LinalgOp>(*convOp);
862     if (linalgOp.hasBufferSemantics())
863       return failure(); // To be implemented
864 
865     Value input = convOp.inputs().front();
866     Value filter = convOp.inputs().back();
867     Value output = convOp.outputs().front();
868 
869     auto inputType = input.getType().dyn_cast<RankedTensorType>();
870     auto filterType = filter.getType().dyn_cast<RankedTensorType>();
871     auto outputType = output.getType().dyn_cast<RankedTensorType>();
872 
873     auto inputShape = inputType.getShape();
874     auto filterShape = filterType.getShape();
875     auto outputShape = outputType.getShape();
876 
877     // Only handle the case where at least one of the window dimensions is
878     // of size 1. Other cases can rely on tiling to reduce to such cases.
879     int64_t fhSize = filterShape[0], fwSize = filterShape[1];
880     int64_t ohSize = outputShape[1], owSize = outputShape[2];
881     if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
882       return failure();
883     bool removeH = ohSize == 1;
884 
885     // Get new shapes and types for all operands by removing the size-1
886     // dimension.
887 
888     SmallVector<int64_t, 3> newInputShape{
889         inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
890     auto newInputType = RankedTensorType::get(
891         newInputShape, inputType.getElementType(), inputType.getEncoding());
892 
893     SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
894                                            filterShape[2], filterShape[3]};
895     auto newFilterType = RankedTensorType::get(
896         newFilterShape, filterType.getElementType(), filterType.getEncoding());
897 
898     SmallVector<int64_t, 3> newOutputShape{
899         outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
900     auto newOutputType = RankedTensorType::get(
901         newOutputShape, outputType.getElementType(), outputType.getEncoding());
902 
903     SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
904     SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
905 
906     // Reshape all operands for 1-D convolution.
907     Location loc = convOp.getLoc();
908     Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
909         loc, newInputType, input, ioReshapeIndices);
910     Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
911         loc, newFilterType, filter, fReshapeIndices);
912     Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
913         loc, newOutputType, output, ioReshapeIndices);
914 
915     // We need to shrink the strides and dilations too.
916     auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
917     auto stridesAttr = rewriter.getI64VectorAttr(stride);
918     auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
919     auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
920 
921     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
922         loc, newOutputType, ValueRange{newInput, newFilter},
923         ValueRange{newOutput}, stridesAttr, dilationsAttr);
924 
925     rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
926         convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
927     return success();
928   };
929 };
930 
931 } // namespace
932 
933 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
934                                                   PatternBenefit benefit) {
935   patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
936                                                       benefit);
937 }
938