1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/SCF/Transforms.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/StaticValueUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/VectorOps.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <type_traits>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
49     "__internal_linalg_transform__";
50 
51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
52     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
53     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54       replacement(replacement), matchByDefault(false) {}
55 
56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
57     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
58     Optional<Identifier> replacement)
59     : filters(),
60       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61       replacement(replacement), matchByDefault(false) {
62   if (f)
63     filters.push_back(f);
64 }
65 
66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
67     PatternRewriter &rewriter, Operation *op) const {
68   if (llvm::any_of(filters,
69                    [&](const FilterFunction &f) { return failed(f(op)); }))
70     return failure();
71 
72   auto attr = op->template getAttrOfType<StringAttr>(
73       LinalgTransforms::kLinalgTransformMarker);
74 
75   if (!attr) {
76     // 1. Has no filter case and matchDisjunction is empty.
77     if (matchDisjunction.empty() || matchByDefault)
78       return success();
79 
80     // 2. Has no filter but was expecting a filter.
81     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82       diag << " does not have any filter from list: ";
83       interleaveComma(matchDisjunction, diag);
84     });
85   }
86 
87   // 4. Match explicit filter.
88   for (auto filter : matchDisjunction)
89     if (attr.getValue() == filter)
90       return success();
91 
92   // 5. Fail to match.
93   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94     diag << " does not have any filter from list: ";
95     interleaveComma(matchDisjunction, diag);
96   });
97 }
98 
99 void mlir::linalg::LinalgTransformationFilter::
100     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
101                                       Operation *op) const {
102   if (replacement.hasValue())
103     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
104                 rewriter.getStringAttr(replacement.getValue().strref()));
105   else
106     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
107                                    rewriter.getContext()));
108 }
109 
110 LinalgTilingOptions &
111 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
112   assert(!tileSizeComputationFunction && "tile sizes already set");
113   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
114   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
115     OpBuilder::InsertionGuard guard(b);
116     b.setInsertionPointToStart(
117         &op->getParentOfType<FuncOp>().getBody().front());
118     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
119       Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
120       return v;
121     }));
122   };
123   return *this;
124 }
125 
126 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
127   assert(!tileSizeComputationFunction && "tile sizes already set");
128   tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
129     SmallVector<Value, 4> tileSizes;
130     auto linalgOp = dyn_cast<LinalgOp>(op);
131     if (!linalgOp)
132       return tileSizes;
133     Location loc = linalgOp.getLoc();
134     auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
135     AffineMap map = linalgOp.getShapesToLoopsMap();
136     if (!map)
137       return tileSizes;
138     auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
139     // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
140     // size 0).
141     for (Value shapeSize : shapeSizes)
142       tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
143                               ? b.create<arith::ConstantIndexOp>(loc, 0)
144                               : b.create<arith::ConstantIndexOp>(loc, 1));
145     return tileSizes;
146   };
147   return *this;
148 }
149 
150 /// Helper function that tries to pad `opOperand`. Exit early and return success
151 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
152 /// pad the operand even if it already has a static shape. Set `result` to the
153 /// result of the created PadTensorOp or return failure if the operand cannot be
154 /// padded to a static shape.
155 static LogicalResult padOperandToSmallestStaticBoundingBox(
156     OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
157     const PaddingValueComputationFunction &paddingFunc,
158     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
159   // Can't pad scalars.
160   if (opToPad.getShape(opOperand).empty())
161     return success();
162   // Can't pad if no padding value is known.
163   FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
164   if (failed(paddingValue))
165     return success();
166   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
167   // Not a slice op, cannot construct a static bounding box.
168   if (!sliceOp)
169     return failure();
170   SmallVector<int64_t> staticSizes;
171   staticSizes.reserve(opToPad.getRank(opOperand));
172   auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
173   for (auto size : shapedOp.getMixedSizes()) {
174     // If the size is an attribute add it directly to `staticSizes`.
175     if (size.is<Attribute>()) {
176       staticSizes.push_back(
177           size.get<Attribute>().dyn_cast<IntegerAttr>().getInt());
178       continue;
179     }
180     // Otherwise, try to compute a constant upper bound for the size value.
181     FailureOr<int64_t> upperBound =
182         getConstantUpperBoundForIndex(size.get<Value>());
183     if (failed(upperBound)) {
184       LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
185       return failure();
186     }
187     staticSizes.push_back(upperBound.getValue());
188   }
189   auto staticTensorType = RankedTensorType::get(
190       staticSizes, getElementTypeOrSelf(opOperand->get()));
191   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
192   result = linalg::PadTensorOp::createPadHighOp(
193       staticTensorType, opOperand->get(), paddingValue.getValue(),
194       /*nofold=*/nofold, opToPad->getLoc(), b);
195   return success();
196 }
197 
198 FailureOr<SmallVector<Value>>
199 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
200                           const PaddingValueComputationFunction &paddingFunc,
201                           const PaddingNoFoldComputationFunction &nofoldFunc,
202                           LinalgOp &paddedOp) {
203   Location loc = opToPad->getLoc();
204 
205   // TODO: there are cases where we may still want to pad to larger sizes.
206   assert(opToPad.hasTensorSemantics() &&
207          "expected operation to have tensor semantics");
208 
209   OpBuilder::InsertionGuard g(b);
210   // Set IP after op because we also take the dims of the original output.
211   b.setInsertionPointAfter(opToPad);
212   // Make a copy of the shaped operands and update it.
213   SmallVector<Value> newOperands;
214   newOperands.reserve(opToPad.getNumInputsAndOutputs());
215   for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
216     Value paddedOperand;
217     // If padding was requested but the shape cannot be bounded statically then
218     // the pattern fails to apply.
219     if (failed(padOperandToSmallestStaticBoundingBox(
220             b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
221       return failure();
222     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
223   }
224 
225   SmallVector<SmallVector<Value>> reifiedResultShapes;
226   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
227                  .reifyResultShapes(b, reifiedResultShapes)))
228     return failure();
229   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
230          "expected same number of results");
231 
232   // Clone `opToPad` to operate on the statically padded shapes.
233   auto resultTensorTypes =
234       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
235   paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
236 
237   // Recover the slice out of the new static results. This keeps the original
238   // linalg op around because it uses the dims of the original results.
239   SmallVector<Value> paddedSubviewResults;
240   paddedSubviewResults.reserve(opToPad->getNumResults());
241   for (auto en : llvm::enumerate(paddedOp->getResults())) {
242     Value paddedResult = en.value();
243     int64_t resultNumber = en.index();
244     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
245     SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
246     SmallVector<OpFoldResult> sizes;
247     for (Value v : reifiedResultShapes[resultNumber])
248       sizes.push_back(v);
249     SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
250     paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
251         loc, paddedResult, offsets, sizes, strides));
252   }
253   return paddedSubviewResults;
254 }
255 
256 /// Linalg base tiling pattern.
257 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
258     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
259     LinalgTransformationFilter filter, PatternBenefit benefit)
260     : RewritePattern(opName, benefit, context), filter(filter),
261       options(options) {}
262 
263 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
264     MLIRContext *context, LinalgTilingOptions options,
265     LinalgTransformationFilter filter, PatternBenefit benefit)
266     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
267       options(options) {}
268 
269 /// Try to peel a loop `op` and return the new result.
270 // TODO: Add support for scf.parallel and affine.for loops.
271 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
272   return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
273       .Case<scf::ForOp>([&](scf::ForOp forOp) {
274         scf::ForOp partialIteration;
275         if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
276                                                       partialIteration)))
277           return partialIteration->getResults();
278         assert(!partialIteration && "expected that loop was not peeled");
279         return forOp->getResults();
280       })
281       .Default([&](Operation *op) { return op->getResults(); });
282 }
283 
284 /// Try to peel a TiledLoopOp and return the new result.
285 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
286                                       TiledLoopOp tiledLoop, int64_t idx) {
287   assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
288          "requested peeling of non-existing loop");
289   TiledLoopOp result;
290   if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
291     return result->getResults();
292   assert(!result && "expected that loop was not peeled");
293   return tiledLoop->getResults();
294 }
295 
296 /// Peel loops after tiling.
297 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
298                       const LinalgTilingOptions &options) {
299   for (int64_t loop : options.peeledLoops) {
300     assert(loop < static_cast<int64_t>(res.loops.size()) &&
301            "requested peeling of non-existing loop");
302     SmallVector<Value, 4> loopResults;
303     Operation *loopOp = res.loops[loop];
304     if (options.loopType == LinalgTilingLoopType::TiledLoops) {
305       assert(llvm::all_of(
306                  res.loops,
307                  [&](Operation *op) { return op == res.loops.front(); }) &&
308              "expected that all loop ops are the same TiledLoopOp");
309       auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
310       assert(tiledLoopOp && "expected TiledLoopOp");
311       loopResults = peelLoop(rewriter, tiledLoopOp, loop);
312     } else {
313       loopResults = peelLoop(rewriter, loopOp);
314     }
315 
316     // The result of the loop nest may change with peeling.
317     if (res.tensorResults.size() == loopOp->getNumResults() &&
318         std::equal(res.tensorResults.begin(), res.tensorResults.end(),
319                    loopOp->getResults().begin()))
320       res.tensorResults = loopResults;
321   }
322 }
323 
324 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
325     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
326   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
327   if (!linalgOp)
328     return failure();
329   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
330     return failure();
331 
332   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
333 
334   if (!res)
335     return failure();
336   // Clear filter to stop recursive pattern application.
337   filter.replaceLinalgTransformationFilter(rewriter, res->op);
338 
339   // Peel loops.
340   peelLoops(rewriter, *res, options);
341 
342   result = *res;
343   return success();
344 }
345 
346 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
347   if (tiledOp.loops.empty())
348     return tiledOp.op.getOperation()->getResults();
349   return tiledOp.loops.front()->getResults();
350 }
351 
352 static ValueRange
353 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
354   if (tiledAndFusedOp.fusedLoops.empty())
355     return tiledAndFusedOp.op.getOperation()->getResults();
356   return tiledAndFusedOp.fusedLoops.front()->getResults();
357 }
358 
359 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
360     StringRef opName, MLIRContext *context,
361     const LinalgDependenceGraph &dependenceGraph,
362     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
363     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
364     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
365     : RewritePattern(opName, benefit, context, {}),
366       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
367       fusionOptions(fusionOptions), filter(filter),
368       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
369 
370 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
371     Operation *op, PatternRewriter &rewriter) const {
372   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
373   // TODO: remove hasIndexSemantics check once index ops are supported.
374   if (!linalgOp || linalgOp.hasIndexSemantics())
375     return failure();
376   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
377     return failure();
378 
379   DenseSet<Operation *> producers;
380   producers.insert(linalgOp);
381   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
382     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
383     // When looking at dependences into, indexingOp is always OpOperand. We
384     // could assert, but continue if this is not the case.
385     if (!operandNumber)
386       continue;
387     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
388       continue;
389     if (isa<LinalgOp>(dependence.getDependentOp()))
390       producers.insert(dependence.getDependentOp());
391   }
392 
393   SmallVector<LinalgOp, 1> fusionOps;
394   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
395        ++it) {
396     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
397     if (producerLinalgOp && producers.count(producerLinalgOp))
398       fusionOps.push_back(producerLinalgOp);
399   }
400   fusionOps.push_back(linalgOp);
401 
402   SmallVector<Value, 4> tileSizes =
403       tilingOptions.tileSizeComputationFunction(rewriter, op);
404   LinalgTilingOptions instanceTilingOptions = tilingOptions;
405   instanceTilingOptions.setTileSizes(tileSizes);
406   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
407       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
408   if (!tiledAndFusedOps)
409     return failure();
410 
411   // Tile the unfused loops;
412   SmallVector<Value, 4> unfusedLoopTileSizes;
413   Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
414   for (auto tileSize : enumerate(tileSizes)) {
415     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
416       unfusedLoopTileSizes.push_back(zero);
417     else
418       unfusedLoopTileSizes.push_back(tileSize.value());
419   }
420   // Tile the loop only if there is a non-zero tile size.
421   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
422     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
423   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
424         if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
425           return cst.value() != 0;
426         return true;
427       })) {
428     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
429     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
430     Optional<TiledLinalgOp> unfusedTiledOp =
431         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
432     if (!unfusedTiledOp)
433       return failure();
434     rewriter.replaceOp(tiledAndFusedOps->op,
435                        getTiledOpResult(unfusedTiledOp.getValue()));
436     tiledAndFusedOps->op = unfusedTiledOp->op;
437   }
438   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
439 
440   filter.replaceLinalgTransformationFilter(rewriter,
441                                            tiledAndFusedOps->op.getOperation());
442   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
443     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
444                                                     fusedOp.getOperation());
445   }
446   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
447     originalOpMarker.replaceLinalgTransformationFilter(
448         rewriter, origProducerOp.getOperation());
449   }
450   rewriter.updateRootInPlace(op, [&]() {
451     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
452   });
453   return success();
454 }
455 
456 /// Linalg padding pattern.
457 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
458     MLIRContext *context, LinalgPaddingOptions options,
459     LinalgTransformationFilter filter, PatternBenefit benefit)
460     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
461       options(options) {}
462 
463 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
464     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
465     LinalgTransformationFilter filter, PatternBenefit benefit)
466     : RewritePattern(opName, benefit, context, {}), filter(filter),
467       options(options) {}
468 
469 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
470     Operation *op, PatternRewriter &rewriter) const {
471   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
472   if (!linalgOp)
473     return failure();
474   if (!linalgOp.hasTensorSemantics())
475     return failure();
476   if (failed(filter.checkAndNotify(rewriter, op)))
477     return failure();
478 
479   // Pad the operation.
480   LinalgOp paddedOp;
481   FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
482       rewriter, linalgOp, options.paddingValueComputationFunction,
483       options.paddingNoFoldComputationFunction, paddedOp);
484   if (failed(newResults))
485     return failure();
486 
487   // Compute the desired hoisting depths.
488   SmallVector<int64_t> depths;
489   if (options.paddingHoistComputationFunction) {
490     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
491       depths.push_back(options.paddingHoistComputationFunction(*opOperand));
492   }
493 
494   // Hoist the padding.
495   for (auto en : enumerate(depths)) {
496     OpOperand &opOperand = paddedOp->getOpOperand(en.index());
497     auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
498     if (!padTensorOp || en.value() == 0)
499       continue;
500     PadTensorOp hoistedOp;
501     FailureOr<Value> newResult =
502         hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
503     if (failed(newResult))
504       continue;
505     rewriter.replaceOp(padTensorOp, newResult.getValue());
506   }
507 
508   // Replace the original operation to pad.
509   rewriter.replaceOp(op, newResults.getValue());
510   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
511   return success();
512 }
513 
514 /// Linalg generic interchange pattern.
515 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
516     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
517     LinalgTransformationFilter filter, PatternBenefit benefit)
518     : OpRewritePattern(context, benefit), filter(filter),
519       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
520 
521 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
522     GenericOp genericOp, PatternRewriter &rewriter) const {
523   if (failed(filter.checkAndNotify(rewriter, genericOp)))
524     return failure();
525   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
526     return failure();
527 
528   // TODO: figure out how this interplays with named ops. In particular this
529   // should break the named op property.
530   rewriter.updateRootInPlace(genericOp, [&]() {
531     interchangeGenericOp(rewriter, genericOp, interchangeVector);
532     // New filter if specified.
533     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
534   });
535   return success();
536 }
537 
538 /// Linalg generalization pattern.
539 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
540     MLIRContext *context, LinalgTransformationFilter filter,
541     PatternBenefit benefit)
542     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
543 
544 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
545     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
546     PatternBenefit benefit)
547     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
548 
549 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
550     Operation *op, PatternRewriter &rewriter) const {
551   if (failed(filter.checkAndNotify(rewriter, op)))
552     return failure();
553   if (failed(generalizeNamedOpPrecondition(op)))
554     return failure();
555 
556   GenericOp genericOp = generalizeNamedOp(rewriter, op);
557   rewriter.replaceOp(op, genericOp.getResults());
558   filter.replaceLinalgTransformationFilter(rewriter, genericOp);
559   return success();
560 }
561 
562 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
563     MLIRContext *context, LinalgTransformationFilter filter,
564     LinalgPromotionOptions options, PatternBenefit benefit)
565     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
566       options(options) {}
567 
568 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
569     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
570     LinalgTransformationFilter filter, PatternBenefit benefit)
571     : RewritePattern(opName, benefit, context, {}), filter(filter),
572       options(options) {}
573 
574 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
575     Operation *op, PatternRewriter &rewriter) const {
576   if (failed(filter.checkAndNotify(rewriter, op)))
577     return failure();
578   if (failed(promoteSubviewsPrecondition(op, options)))
579     return failure();
580 
581   // TODO: We cannot use root update here. This pattern is creating other ops,
582   // so if the promotion fails, those need to be cleaned up, which doesnt seem
583   // to be happening here. So to fail properly, we should be cloning the op and
584   // deleting the previous op. This needs more investigation.
585   rewriter.startRootUpdate(op);
586   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
587   if (!promotedOp) {
588     rewriter.cancelRootUpdate(op);
589     return op->emitError("subview promotion failed");
590   }
591   rewriter.finalizeRootUpdate(op);
592   filter.replaceLinalgTransformationFilter(rewriter, op);
593   return success();
594 }
595 
596 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
597     MLIRContext *context, LinalgTransformationFilter filter,
598     PatternBenefit benefit)
599     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
600 
601 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
602     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
603     PatternBenefit benefit)
604     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
605 
606 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
607     Operation *op, PatternRewriter &rewriter) const {
608   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
609   if (!linalgOp)
610     return failure();
611   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
612     return failure();
613   SmallVector<Value> newResults;
614   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
615     return failure();
616   if (!newResults.empty())
617     rewriter.replaceOp(op, newResults);
618   else
619     rewriter.eraseOp(op);
620   return success();
621 }
622 
623 LogicalResult mlir::linalg::applyStagedPatterns(
624     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
625     const FrozenRewritePatternSet &stage2Patterns,
626     function_ref<LogicalResult(Operation *)> stage3Lambda) {
627   unsigned iteration = 0;
628   (void)iteration;
629   for (const auto &patterns : stage1Patterns) {
630     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
631                       << *op);
632     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
633       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
634       return failure();
635     }
636     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
637                       << *op);
638     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
639       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
640       return failure();
641     }
642     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
643                       << *op);
644     if (stage3Lambda) {
645       if (failed(stage3Lambda(op)))
646         return failure();
647       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
648                         << *op);
649     }
650   }
651   return success();
652 }
653 
654 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
655   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
656 }
657 
658 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
659 /// with pad_val) and GenericOp (to copy contents).
660 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
661     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
662 
663   auto inputShapedType = padOp.source().getType().cast<ShapedType>();
664   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
665 
666   // Bail on non-static shapes.
667   if (!inputShapedType.hasStaticShape())
668     return failure();
669   if (!resultShapedType.hasStaticShape())
670     return failure();
671 
672   // Only support padding with a constant for now, i.e. either:
673   //   1. A BBarg from a different block.
674   //   2. A value defined outside of the current block.
675   Block &block = padOp.region().front();
676   auto yieldOp = cast<YieldOp>(block.getTerminator());
677   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
678   Value padValue = yieldOp.values().front();
679   Operation *definingOp = padValue.getDefiningOp();
680   if (definingOp && definingOp->getBlock() == &block)
681     return failure();
682   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
683     return failure();
684 
685   // Create tensor with the padded shape
686   Location loc = padOp.getLoc();
687   SmallVector<Value> indices(resultShapedType.getRank(),
688                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
689   Value initTensor = rewriter.create<InitTensorOp>(
690       loc, resultShapedType.getShape(), resultShapedType.getElementType());
691 
692   // Initialize tensor with the pad value
693   Value tmpTensor =
694       rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
695 
696   // Copy original contents into new tensor
697   // Uses linalg.generic, but could be done with tensor.insert_slice
698   SmallVector<AffineExpr, 4> outputExprs;
699   for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
700     outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
701                           padOp.static_low()[i].cast<IntegerAttr>().getInt());
702   }
703 
704   SmallVector<AffineMap, 2> transferMaps = {
705       rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
706       AffineMap::get(resultShapedType.getRank(),
707                      /*symbolCount=*/0, outputExprs, rewriter.getContext())};
708 
709   rewriter.replaceOpWithNewOp<linalg::GenericOp>(
710       padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
711       getNParallelLoopsAttrs(resultShapedType.getRank()),
712       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
713         nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
714       });
715 
716   return success();
717 }
718 
719 /// Filling `dest` using FillOp constant padding value if possible.
720 /// Otherwise, generate a tensor::GenerateOp.
721 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
722     PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
723     const SmallVector<Value> &dynSizes) const {
724   auto padValue = padOp.getConstantPaddingValue();
725   if (padValue)
726     return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
727 
728   // Fill could not be optimized: Lower to tensor::GenerateOp with region.
729   auto generateOp = rewriter.create<tensor::GenerateOp>(
730       padOp.getLoc(), padOp.getResultType(), dynSizes);
731   // Copy region to new op.
732   BlockAndValueMapping bvm;
733   padOp.region().cloneInto(&generateOp.getRegion(), bvm);
734   // Rewrite linalg::YieldOp to tensor::YieldOp.
735   OpBuilder::InsertionGuard guard(rewriter);
736   auto yieldOp =
737       dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
738   assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
739   assert(yieldOp.values().size() == 1);
740   rewriter.setInsertionPoint(yieldOp);
741   rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
742   return generateOp;
743 }
744 
745 LogicalResult
746 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
747                                               PatternRewriter &rewriter) const {
748   // Given an OpFoldResult, return an index-typed value.
749   auto getIdxValue = [&](OpFoldResult ofr) {
750     if (auto val = ofr.dyn_cast<Value>())
751       return val;
752     return rewriter
753         .create<arith::ConstantIndexOp>(
754             padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
755         .getResult();
756   };
757 
758   auto resultType = padOp.getResultType();
759   // Compute size of InitTensorOp. Any combination of static/dynamic is
760   // supported.
761   SmallVector<Value> dynSizes;
762   SmallVector<int64_t> staticSizes;
763   for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
764     if (resultType.isDynamicDim(dim)) {
765       auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
766                                                           padOp.source(), dim);
767       // Add low and high padding value.
768       auto plusLow = rewriter.createOrFold<arith::AddIOp>(
769           padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
770       auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
771           padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
772       dynSizes.push_back(plusHigh);
773     }
774     staticSizes.push_back(resultType.getDimSize(dim));
775   }
776 
777   // Init tensor and fill it with padding.
778   Value init = rewriter.create<InitTensorOp>(
779       padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
780   Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
781 
782   // Try optimize the copy of source.
783   if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
784     return success();
785 
786   // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
787   // for copying the PadOp source.
788   auto sourceType = padOp.getSourceType();
789   // Compute size of source of PadTensorOp.
790   SmallVector<OpFoldResult> srcSizes;
791   for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
792     if (sourceType.isDynamicDim(dim)) {
793       srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
794           padOp.getLoc(), padOp.source(), dim));
795     } else {
796       srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
797     }
798   }
799   // Strides of InsertSliceOp are all 1.
800   SmallVector<OpFoldResult> strides(sourceType.getRank(),
801                                     rewriter.getIndexAttr(1));
802   rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
803       padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
804 
805   return success();
806 }
807 
808 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
809     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
810   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
811   if (!padOp)
812     return failure();
813   // Only unit stride supported.
814   if (!sliceOp.hasUnitStride())
815     return failure();
816 
817   Operation *tiledPadOp = padOp.getTiledImplementation(
818       rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
819       sliceOp.getMixedSizes());
820   // All shapes are static and the data source is actually used. Rewrite into
821   // pad_tensor(subtensor(x)).
822   rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
823   return success();
824 }
825 
826 namespace {
827 // The following are patterns for downscaling convolution ops with size-1
828 // window dimensions.
829 //
830 // Note that we'd eventually want to write such transformations in a generic
831 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
832 // and then turning back to named ops. But for now it's fine to have a few
833 // patterns matching special ops to get started.
834 
835 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
836 /// convolution ops.
837 struct DownscaleSizeOneWindowed2DConvolution final
838     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
839   using OpRewritePattern::OpRewritePattern;
840 
841   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
842                                 PatternRewriter &rewriter) const override {
843     auto linalgOp = cast<linalg::LinalgOp>(*convOp);
844     if (linalgOp.hasBufferSemantics())
845       return failure(); // To be implemented
846 
847     Value input = convOp.inputs().front();
848     Value filter = convOp.inputs().back();
849     Value output = convOp.outputs().front();
850 
851     auto inputType = input.getType().dyn_cast<RankedTensorType>();
852     auto filterType = filter.getType().dyn_cast<RankedTensorType>();
853     auto outputType = output.getType().dyn_cast<RankedTensorType>();
854 
855     auto inputShape = inputType.getShape();
856     auto filterShape = filterType.getShape();
857     auto outputShape = outputType.getShape();
858 
859     // Only handle the case where at least one of the window dimensions is
860     // of size 1. Other cases can rely on tiling to reduce to such cases.
861     int64_t fhSize = filterShape[0], fwSize = filterShape[1];
862     int64_t ohSize = outputShape[1], owSize = outputShape[2];
863     if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
864       return failure();
865     bool removeH = ohSize == 1;
866 
867     // Get new shapes and types for all operands by removing the size-1
868     // dimension.
869 
870     SmallVector<int64_t, 3> newInputShape{
871         inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
872     auto newInputType = RankedTensorType::get(
873         newInputShape, inputType.getElementType(), inputType.getEncoding());
874 
875     SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
876                                            filterShape[2], filterShape[3]};
877     auto newFilterType = RankedTensorType::get(
878         newFilterShape, filterType.getElementType(), filterType.getEncoding());
879 
880     SmallVector<int64_t, 3> newOutputShape{
881         outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
882     auto newOutputType = RankedTensorType::get(
883         newOutputShape, outputType.getElementType(), outputType.getEncoding());
884 
885     SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
886     SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
887 
888     // Reshape all operands for 1-D convolution.
889     Location loc = convOp.getLoc();
890     Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
891         loc, newInputType, input, ioReshapeIndices);
892     Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
893         loc, newFilterType, filter, fReshapeIndices);
894     Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
895         loc, newOutputType, output, ioReshapeIndices);
896 
897     // We need to shrink the strides and dilations too.
898     auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0];
899     auto stridesAttr = rewriter.getI64VectorAttr(stride);
900     auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0];
901     auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
902 
903     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
904         loc, newOutputType, ValueRange{newInput, newFilter},
905         ValueRange{newOutput}, stridesAttr, dilationsAttr);
906 
907     rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
908         convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
909     return success();
910   };
911 };
912 
913 } // namespace
914 
915 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
916                                                   PatternBenefit benefit) {
917   patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
918                                                       benefit);
919 }
920