1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/Transforms.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Utils/StaticValueUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 
36 #define DEBUG_TYPE "linalg-transforms"
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42 
43 //===----------------------------------------------------------------------===//
44 // Transformations exposed as rewrite patterns.
45 //===----------------------------------------------------------------------===//
46 // Marker used as attribute name in generated Linalg rewriting transformations.
47 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
48     "__internal_linalg_transform__";
49 
50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
51     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
52     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
53       replacement(replacement) {}
54 
55 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
56     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
57     Optional<Identifier> replacement)
58     : filters(),
59       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
60       replacement(replacement) {
61   if (f)
62     filters.push_back(f);
63 }
64 
65 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
66     PatternRewriter &rewriter, Operation *op) const {
67   if (llvm::any_of(filters,
68                    [&](const FilterFunction &f) { return failed(f(op)); }))
69     return failure();
70 
71   auto attr = op->template getAttrOfType<StringAttr>(
72       LinalgTransforms::kLinalgTransformMarker);
73 
74   if (!attr) {
75     // 1. Has no filter case and matchDisjunction is empty.
76     if (matchDisjunction.empty())
77       return success();
78 
79     // 2. Has no filter but was expecting a filter.
80     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
81       diag << " does not have any filter from list: ";
82       interleaveComma(matchDisjunction, diag);
83     });
84   }
85 
86   // 4. Match explicit filter.
87   for (auto filter : matchDisjunction)
88     if (attr.getValue() == filter)
89       return success();
90 
91   // 5. Fail to match.
92   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
93     diag << " does not have any filter from list: ";
94     interleaveComma(matchDisjunction, diag);
95   });
96 }
97 
98 void mlir::linalg::LinalgTransformationFilter::
99     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
100                                       Operation *op) const {
101   if (replacement.hasValue())
102     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
103                 rewriter.getStringAttr(replacement.getValue().strref()));
104   else
105     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
106                                    rewriter.getContext()));
107 }
108 
109 LinalgTilingOptions &
110 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
111   assert(!tileSizeComputationFunction && "tile sizes already set");
112   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
113   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
114     OpBuilder::InsertionGuard guard(b);
115     b.setInsertionPointToStart(
116         &op->getParentOfType<FuncOp>().getBody().front());
117     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
118       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
119       return v;
120     }));
121   };
122   return *this;
123 }
124 
125 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
126   assert(!tileSizeComputationFunction && "tile sizes already set");
127   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
128     SmallVector<Value, 4> tileSizes;
129     auto linalgOp = dyn_cast<LinalgOp>(op);
130     if (!linalgOp)
131       return tileSizes;
132     Location loc = linalgOp.getLoc();
133     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
134     AffineMap map = linalgOp.getShapesToLoopsMap();
135     if (!map)
136       return tileSizes;
137     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
138     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
139     // size 0).
140     for (Value shapeSize : shapeSizes)
141       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
142                               ? b.create<arith::ConstantIndexOp>(loc, 0)
143                               : b.create<arith::ConstantIndexOp>(loc, 1));
144     return tileSizes;
145   };
146   return *this;
147 }
148 
149 /// Helper function that tries to pad `opOperand`. Exit early and return success
150 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
151 /// pad the operand even if it already has a static shape. Set `result` to the
152 /// result of the created PadTensorOp or return failure if the operand cannot be
153 /// padded to a static shape.
154 static LogicalResult padOperandToSmallestStaticBoundingBox(
155     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
156     const PaddingValueComputationFunction &paddingFunc,
157     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
158   // Can't pad scalars.
159   if (opToPad.getShape(opOperand).empty())
160     return success();
161   // Can't pad if no padding value is known.
162   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
163   if (failed(paddingValue))
164     return success();
165   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
166   // Not a slice op, cannot construct a static bounding box.
167   if (!sliceOp)
168     return failure();
169   SmallVector<int64_t> staticSizes;
170   staticSizes.reserve(opToPad.getRank(opOperand));
171   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
172   for (auto size : shapedOp.getMixedSizes()) {
173     auto indexAttr = size.is<Attribute>()
174                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
175                          : linalg::getSmallestBoundingIndex(size.get<Value>());
176     // SmallestBoundingIndex must exist for all sizes.
177     // For now return an error if we can't find it.
178     if (!indexAttr) {
179       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
180       return failure();
181     }
182     staticSizes.push_back(indexAttr.getInt());
183   }
184   auto staticTensorType = RankedTensorType::get(
185       staticSizes, getElementTypeOrSelf(opOperand->get()));
186   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
187   result = linalg::PadTensorOp::createPadHighOp(
188       staticTensorType, opOperand->get(), paddingValue.getValue(),
189       /*nofold=*/nofold, opToPad->getLoc(), b);
190   return success();
191 }
192 
193 FailureOr<SmallVector<Value>>
194 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
195                           const PaddingValueComputationFunction &paddingFunc,
196                           const PaddingNoFoldComputationFunction &nofoldFunc,
197                           LinalgOp &paddedOp) {
198   Location loc = opToPad->getLoc();
199 
200   // TODO: there are cases where we may still want to pad to larger sizes.
201   assert(opToPad.hasTensorSemantics() &&
202          "expected operation to have tensor semantics");
203 
204   OpBuilder::InsertionGuard g(b);
205   // Set IP after op because we also take the dims of the original output.
206   b.setInsertionPointAfter(opToPad);
207   // Make a copy of the shaped operands and update it.
208   SmallVector<Value> newOperands;
209   newOperands.reserve(opToPad.getNumInputsAndOutputs());
210   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
211     Value paddedOperand;
212     // If padding was requested but the shape cannot be bounded statically then
213     // the pattern fails to apply.
214     if (failed(padOperandToSmallestStaticBoundingBox(
215             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
216       return failure();
217     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
218   }
219 
220   SmallVector<SmallVector<Value>> reifiedResultShapes;
221   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
222                  .reifyResultShapes(b, reifiedResultShapes)))
223     return failure();
224   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
225          "expected same number of results");
226 
227   // Clone `opToPad` to operate on the statically padded shapes.
228   auto resultTensorTypes =
229       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
230   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
231 
232   // Recover the slice out of the new static results. This keeps the original
233   // linalg op around because it uses the dims of the original results.
234   SmallVector<Value> paddedSubviewResults;
235   paddedSubviewResults.reserve(opToPad->getNumResults());
236   for (auto en : llvm::enumerate(paddedOp->getResults())) {
237     Value paddedResult = en.value();
238     int64_t resultNumber = en.index();
239     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
240     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
241     SmallVector<OpFoldResult> sizes;
242     for (Value v : reifiedResultShapes[resultNumber])
243       sizes.push_back(v);
244     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
245     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
246         loc, paddedResult, offsets, sizes, strides));
247   }
248   return paddedSubviewResults;
249 }
250 
251 /// Linalg base tiling pattern.
252 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
253     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
254     LinalgTransformationFilter filter, PatternBenefit benefit)
255     : RewritePattern(opName, benefit, context), filter(filter),
256       options(options) {}
257 
258 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
259     MLIRContext *context, LinalgTilingOptions options,
260     LinalgTransformationFilter filter, PatternBenefit benefit)
261     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
262       options(options) {}
263 
264 /// Try to peel a loop `op` and return the new result.
265 // TODO: Add support for scf.parallel and affine.for loops.
266 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
267   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
268       .Case<scf::ForOp>([&](scf::ForOp forOp) {
269         scf::ForOp partialIteration;
270         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
271                                                       partialIteration)))
272           return partialIteration->getResults();
273         assert(!partialIteration && "expected that loop was not peeled");
274         return forOp->getResults();
275       })
276       .Default([&](Operation *op) { return op->getResults(); });
277 }
278 
279 /// Try to peel a TiledLoopOp and return the new result.
280 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
281                                       TiledLoopOp tiledLoop, int64_t idx) {
282   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
283          "requested peeling of non-existing loop");
284   TiledLoopOp result;
285   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
286     return result->getResults();
287   assert(!result && "expected that loop was not peeled");
288   return tiledLoop->getResults();
289 }
290 
291 /// Peel loops after tiling.
292 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
293                       const LinalgTilingOptions &options) {
294   for (int64_t loop : options.peeledLoops) {
295     assert(loop < static_cast<int64_t>(res.loops.size()) &&
296            "requested peeling of non-existing loop");
297     SmallVector<Value, 4> loopResults;
298     Operation *loopOp = res.loops[loop];
299     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
300       assert(llvm::all_of(
301                  res.loops,
302                  [&](Operation *op) { return op == res.loops.front(); }) &&
303              "expected that all loop ops are the same TiledLoopOp");
304       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
305       assert(tiledLoopOp && "expected TiledLoopOp");
306       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
307     } else {
308       loopResults = peelLoop(rewriter, loopOp);
309     }
310 
311     // The result of the loop nest may change with peeling.
312     if (res.tensorResults.size() == loopOp->getNumResults() &&
313         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
314                    loopOp->getResults().begin()))
315       res.tensorResults = loopResults;
316   }
317 }
318 
319 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
320     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
321   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
322   if (!linalgOp)
323     return failure();
324   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
325     return failure();
326 
327   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
328 
329   if (!res)
330     return failure();
331   // Clear filter to stop recursive pattern application.
332   filter.replaceLinalgTransformationFilter(rewriter, res->op);
333 
334   // Peel loops.
335   peelLoops(rewriter, *res, options);
336 
337   // Consider padding on the fly only if the op has tensor semantics.
338   if (!options.paddingValueComputationFunction ||
339       !linalgOp.hasTensorSemantics()) {
340     result = *res;
341     return success();
342   }
343 
344   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
345   // `res.op` is rewritten in static form with padded operands.
346   LinalgOp paddedOp;
347   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
348       rewriter, res->op, options.paddingValueComputationFunction,
349       options.paddingNoFoldComputationFunction, paddedOp);
350   if (succeeded(newResults)) {
351     rewriter.replaceOp(res->op, newResults.getValue());
352     filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
353     res->op = paddedOp;
354     result = *res;
355     // Do not perform replacement of `linalgOp`, let the derived patterns
356     // do this as they see fit, from the resulting TiledLinalgOp.
357     return success();
358   }
359   // Set so RAII guard does not propagate TiledLinalgOp to `result`.
360   return failure();
361 }
362 
363 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
364   if (tiledOp.loops.empty())
365     return tiledOp.op.getOperation()->getResults();
366   return tiledOp.loops.front()->getResults();
367 }
368 
369 static ValueRange
370 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
371   if (tiledAndFusedOp.fusedLoops.empty())
372     return tiledAndFusedOp.op.getOperation()->getResults();
373   return tiledAndFusedOp.fusedLoops.front()->getResults();
374 }
375 
376 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
377     StringRef opName, MLIRContext *context,
378     const LinalgDependenceGraph &dependenceGraph,
379     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
380     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
381     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
382     : RewritePattern(opName, benefit, context, {}),
383       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
384       fusionOptions(fusionOptions), filter(filter),
385       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
386 
387 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
388     Operation *op, PatternRewriter &rewriter) const {
389   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
390   // TODO: remove hasIndexSemantics check once index ops are supported.
391   if (!linalgOp || linalgOp.hasIndexSemantics())
392     return failure();
393   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
394     return failure();
395 
396   DenseSet<Operation *> producers;
397   producers.insert(linalgOp);
398   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
399     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
400     // When looking at dependences into, indexingOp is always OpOperand. We
401     // could assert, but continue if this is not the case.
402     if (!operandNumber)
403       continue;
404     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
405       continue;
406     if (isa<LinalgOp>(dependence.getDependentOp()))
407       producers.insert(dependence.getDependentOp());
408   }
409 
410   SmallVector<LinalgOp, 1> fusionOps;
411   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
412        ++it) {
413     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
414     if (producerLinalgOp && producers.count(producerLinalgOp))
415       fusionOps.push_back(producerLinalgOp);
416   }
417   fusionOps.push_back(linalgOp);
418 
419   SmallVector<Value, 4> tileSizes =
420       tilingOptions.tileSizeComputationFunction(rewriter, op);
421   LinalgTilingOptions instanceTilingOptions = tilingOptions;
422   instanceTilingOptions.setTileSizes(tileSizes);
423   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
424       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
425   if (!tiledAndFusedOps)
426     return failure();
427 
428   // Tile the unfused loops;
429   SmallVector<Value, 4> unfusedLoopTileSizes;
430   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
431   for (auto tileSize : enumerate(tileSizes)) {
432     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
433       unfusedLoopTileSizes.push_back(zero);
434     else
435       unfusedLoopTileSizes.push_back(tileSize.value());
436   }
437   // Tile the loop only if there is a non-zero tile size.
438   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
439     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
440   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
441         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
442           return cst.value() != 0;
443         return true;
444       })) {
445     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
446     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
447     Optional<TiledLinalgOp> unfusedTiledOp =
448         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
449     if (!unfusedTiledOp)
450       return failure();
451     rewriter.replaceOp(tiledAndFusedOps->op,
452                        getTiledOpResult(unfusedTiledOp.getValue()));
453     tiledAndFusedOps->op = unfusedTiledOp->op;
454   }
455   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
456 
457   filter.replaceLinalgTransformationFilter(rewriter,
458                                            tiledAndFusedOps->op.getOperation());
459   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
460     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
461                                                     fusedOp.getOperation());
462   }
463   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
464     originalOpMarker.replaceLinalgTransformationFilter(
465         rewriter, origProducerOp.getOperation());
466   }
467   rewriter.updateRootInPlace(op, [&]() {
468     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
469   });
470   return success();
471 }
472 
473 /// Linalg generic interchange pattern.
474 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
475     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
476     LinalgTransformationFilter filter, PatternBenefit benefit)
477     : OpRewritePattern(context, benefit), filter(filter),
478       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
479 
480 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
481     GenericOp genericOp, PatternRewriter &rewriter) const {
482   if (failed(filter.checkAndNotify(rewriter, genericOp)))
483     return failure();
484   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
485     return failure();
486 
487   // TODO: figure out how this interplays with named ops. In particular this
488   // should break the named op property.
489   rewriter.updateRootInPlace(genericOp, [&]() {
490     interchangeGenericOp(rewriter, genericOp, interchangeVector);
491     // New filter if specified.
492     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
493   });
494   return success();
495 }
496 
497 /// Linalg generalization pattern.
498 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
499     MLIRContext *context, LinalgTransformationFilter filter,
500     PatternBenefit benefit)
501     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
502 
503 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
504     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
505     PatternBenefit benefit)
506     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
507 
508 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
509     Operation *op, PatternRewriter &rewriter) const {
510   if (failed(filter.checkAndNotify(rewriter, op)))
511     return failure();
512   if (failed(generalizeNamedOpPrecondition(op)))
513     return failure();
514 
515   GenericOp genericOp = generalizeNamedOp(rewriter, op);
516   rewriter.replaceOp(op, genericOp.getResults());
517   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
518   return success();
519 }
520 
521 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
522     MLIRContext *context, LinalgTransformationFilter filter,
523     LinalgPromotionOptions options, PatternBenefit benefit)
524     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
525       options(options) {}
526 
527 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
528     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
529     LinalgTransformationFilter filter, PatternBenefit benefit)
530     : RewritePattern(opName, benefit, context, {}), filter(filter),
531       options(options) {}
532 
533 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
534     Operation *op, PatternRewriter &rewriter) const {
535   if (failed(filter.checkAndNotify(rewriter, op)))
536     return failure();
537   if (failed(promoteSubviewsPrecondition(op, options)))
538     return failure();
539 
540   // TODO: We cannot use root update here. This pattern is creating other ops,
541   // so if the promotion fails, those need to be cleaned up, which doesnt seem
542   // to be happening here. So to fail properly, we should be cloning the op and
543   // deleting the previous op. This needs more investigation.
544   rewriter.startRootUpdate(op);
545   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
546   if (!promotedOp) {
547     rewriter.cancelRootUpdate(op);
548     return op->emitError("subview promotion failed");
549   }
550   rewriter.finalizeRootUpdate(op);
551   filter.replaceLinalgTransformationFilter(rewriter, op);
552   return success();
553 }
554 
555 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
556     MLIRContext *context, LinalgTransformationFilter filter,
557     PatternBenefit benefit)
558     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
559 
560 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
561     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
562     PatternBenefit benefit)
563     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
564 
565 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
566     Operation *op, PatternRewriter &rewriter) const {
567   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
568   if (!linalgOp)
569     return failure();
570   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
571     return failure();
572   SmallVector<Value> newResults;
573   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
574     return failure();
575   if (!newResults.empty())
576     rewriter.replaceOp(op, newResults);
577   else
578     rewriter.eraseOp(op);
579   return success();
580 }
581 
582 LogicalResult mlir::linalg::applyStagedPatterns(
583     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
584     const FrozenRewritePatternSet &stage2Patterns,
585     function_ref<LogicalResult(Operation *)> stage3Lambda) {
586   unsigned iteration = 0;
587   (void)iteration;
588   for (const auto &patterns : stage1Patterns) {
589     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
590                       << *op);
591     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
592       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
593       return failure();
594     }
595     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
596                       << *op);
597     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
598       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
599       return failure();
600     }
601     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
602                       << *op);
603     if (stage3Lambda) {
604       if (failed(stage3Lambda(op)))
605         return failure();
606       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
607                         << *op);
608     }
609   }
610   return success();
611 }
612 
613 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
614   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
615 }
616 
617 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
618 /// with pad_val) and GenericOp (to copy contents).
619 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
620     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
621 
622   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
623   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
624 
625   // Bail on non-static shapes.
626   if (!inputShapedType.hasStaticShape())
627     return failure();
628   if (!resultShapedType.hasStaticShape())
629     return failure();
630 
631   // Only support padding with a constant for now, i.e. either:
632   //   1. A BBarg from a different block.
633   //   2. A value defined outside of the current block.
634   Block &block = padOp.region().front();
635   auto yieldOp = cast<YieldOp>(block.getTerminator());
636   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
637   Value padValue = yieldOp.values().front();
638   Operation *definingOp = padValue.getDefiningOp();
639   if (definingOp && definingOp->getBlock() == &block)
640     return failure();
641   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
642     return failure();
643 
644   // Create tensor with the padded shape
645   Location loc = padOp.getLoc();
646   SmallVector<Value> indices(resultShapedType.getRank(),
647                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
648   Value initTensor = rewriter.create<InitTensorOp>(
649       loc, resultShapedType.getShape(), resultShapedType.getElementType());
650 
651   // Initialize tensor with the pad value
652   Value tmpTensor =
653       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
654 
655   // Copy original contents into new tensor
656   // Uses linalg.generic, but could be done with tensor.insert_slice
657   SmallVector<AffineExpr, 4> outputExprs;
658   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
659     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
660                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
661   }
662 
663   SmallVector<AffineMap, 2> transferMaps = {
664       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
665       AffineMap::get(resultShapedType.getRank(),
666                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
667 
668   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
669       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
670       getNParallelLoopsAttrs(resultShapedType.getRank()),
671       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
672         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
673       });
674 
675   return success();
676 }
677 
678 /// Filling `dest` using FillOp constant padding value if possible.
679 /// Otherwise, generate a tensor::GenerateOp.
680 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
681     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
682     const SmallVector<Value> &dynSizes) const {
683   auto padValue = padOp.getConstantPaddingValue();
684   if (padValue)
685     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
686 
687   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
688   auto generateOp = rewriter.create<tensor::GenerateOp>(
689       padOp.getLoc(), padOp.getResultType(), dynSizes);
690   // Copy region to new op.
691   BlockAndValueMapping bvm;
692   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
693   // Rewrite linalg::YieldOp to tensor::YieldOp.
694   OpBuilder::InsertionGuard guard(rewriter);
695   auto yieldOp =
696       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
697   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
698   assert(yieldOp.values().size() == 1);
699   rewriter.setInsertionPoint(yieldOp);
700   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
701   return generateOp;
702 }
703 
704 LogicalResult
705 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
706                                               PatternRewriter &rewriter) const {
707   // Given an OpFoldResult, return an index-typed value.
708   auto getIdxValue = [&](OpFoldResult ofr) {
709     if (auto val = ofr.dyn_cast<Value>())
710       return val;
711     return rewriter
712         .create<arith::ConstantIndexOp>(
713             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
714         .getResult();
715   };
716 
717   auto resultType = padOp.getResultType();
718   // Compute size of InitTensorOp. Any combination of static/dynamic is
719   // supported.
720   SmallVector<Value> dynSizes;
721   SmallVector<int64_t> staticSizes;
722   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
723     if (resultType.isDynamicDim(dim)) {
724       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
725                                                           padOp.source(), dim);
726       // Add low and high padding value.
727       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
728           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
729       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
730           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
731       dynSizes.push_back(plusHigh);
732     }
733     staticSizes.push_back(resultType.getDimSize(dim));
734   }
735 
736   // Init tensor and fill it with padding.
737   Value init = rewriter.create<InitTensorOp>(
738       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
739   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
740 
741   // Try optimize the copy of source.
742   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
743     return success();
744 
745   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
746   // for copying the PadOp source.
747   auto sourceType = padOp.getSourceType();
748   // Compute size of source of PadTensorOp.
749   SmallVector<OpFoldResult> srcSizes;
750   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
751     if (sourceType.isDynamicDim(dim)) {
752       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
753           padOp.getLoc(), padOp.source(), dim));
754     } else {
755       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
756     }
757   }
758   // Strides of InsertSliceOp are all 1.
759   SmallVector<OpFoldResult> strides(sourceType.getRank(),
760                                     rewriter.getIndexAttr(1));
761   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
762       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
763 
764   return success();
765 }
766 
767 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
768     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
769   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
770   if (!padOp)
771     return failure();
772   // Only unit stride supported.
773   if (!sliceOp.hasUnitStride())
774     return failure();
775 
776   Operation *tiledPadOp = padOp.getTiledImplementation(
777       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
778       sliceOp.getMixedSizes());
779   // All shapes are static and the data source is actually used. Rewrite into
780   // pad_tensor(subtensor(x)).
781   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
782   return success();
783 }
784