1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <utility>
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Utils/IndexingUtils.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 
30 #include "llvm/Support/CommandLine.h"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 using namespace mlir::scf;
35 
36 #define DEBUG_TYPE "linalg-tiling"
37 
38 static bool isZero(Value v) {
39   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
40     return cst.value() == 0;
41   return false;
42 }
43 
44 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
45 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
46                                   ValueRange allShapeSizes,
47                                   ValueRange allTileSizes) {
48   assert(allTileSizes.size() == map.getNumResults());
49   // Apply `map` to get shape sizes in loop order.
50   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
51   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
52 
53   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
54   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
55   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
56     if (isZero(tileSizes[idx - zerosCount])) {
57       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
58       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
59       ++zerosCount;
60       continue;
61     }
62     loopIndexToRangeIndex[idx] = idx - zerosCount;
63   }
64 
65   // Create a new range with the applied tile sizes.
66   SmallVector<Range, 4> res;
67   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
68     res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0),
69                         shapeSizes[idx], tileSizes[idx]});
70   return std::make_tuple(res, loopIndexToRangeIndex);
71 }
72 
73 void mlir::linalg::transformIndexOps(
74     RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
75     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
76   SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
77   for (auto &en : enumerate(allIvs)) {
78     auto rangeIndex = loopIndexToRangeIndex.find(en.index());
79     if (rangeIndex == loopIndexToRangeIndex.end())
80       continue;
81     en.value() = ivs[rangeIndex->second];
82   }
83   offsetIndices(b, op, allIvs);
84 }
85 
86 /// Asserts that the given index-typed value is strictly positive. If the value
87 /// is an attribute, asserts at compile time, otherwise emits an assertion
88 /// checked at runtime.
89 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
90                                          OpFoldResult value) {
91   if (auto attr = value.dyn_cast<Attribute>()) {
92     assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
93            "expected strictly positive tile size and divisor");
94     return;
95   }
96 
97   Value zero = b.create<arith::ConstantIndexOp>(0);
98   Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
99                                             value.get<Value>(), zero);
100   b.create<cf::AssertOp>(
101       condition,
102       b.getStringAttr("expected strictly positive tile size and divisor"));
103 }
104 
105 FailureOr<MultiSizeSpecification>
106 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
107                                     unsigned dimension, OpFoldResult targetSize,
108                                     OpFoldResult divisor, bool emitAssertions) {
109   // Bail out on dimension overflow.
110   if (dimension >= op.getNumLoops())
111     return failure();
112 
113   // The code below works only on values.
114   ImplicitLocOpBuilder b(op.getLoc(), builder);
115   if (emitAssertions) {
116     emitIsPositiveIndexAssertion(b, targetSize);
117     emitIsPositiveIndexAssertion(b, divisor);
118   }
119   Value targetSizeValue = materializeOpFoldResult(b, targetSize);
120   Value divisorValue = materializeOpFoldResult(b, divisor);
121 
122   // Find the trip count of the iteration space dimension for which the tile
123   // sizes are computed.
124   // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid
125   // littering by useless constant materialization.
126   SmallVector<Value, 4> allShapes =
127       op.createFlatListOfOperandDims(b, b.getLoc());
128   AffineMap shapesToLoops = op.getShapesToLoopsMap();
129   SmallVector<Value, 4> loopRanges =
130       applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes);
131   Value tripCount = loopRanges[dimension];
132 
133   // Compute the tile sizes and the respective numbers of tiles.
134   AffineExpr s0 = b.getAffineSymbolExpr(0);
135   AffineExpr s1 = b.getAffineSymbolExpr(1);
136   AffineExpr s2 = b.getAffineSymbolExpr(2);
137   auto apply = [&](AffineExpr expr, ValueRange values) -> Value {
138     return makeComposedAffineApply(b, b.getLoc(), expr, values);
139   };
140   Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
141   Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
142   Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
143   Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
144   Value v = apply(s0 % s1, {a, d});
145   Value u = apply(s0 - s1, {d, v});
146 
147   MultiSizeSpecification spec;
148   spec.lowTileSize = s;
149   spec.highTileSize = apply(s0 + s1, {s, divisorValue});
150   spec.lowTripCount = u;
151   spec.highTripCount = v;
152 
153   // If requested, emit the check that the tile sizes are computed correctly.
154   // For example, for iteration dimension size of 15 and the target size 8 it is
155   // impossible to find two tile sizes both divisible by 8 that fully cover the
156   // original space dimension.
157   if (emitAssertions) {
158     AffineExpr s3 = builder.getAffineSymbolExpr(3);
159     Value coveredSize =
160         apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
161                                   spec.highTileSize, spec.highTripCount});
162     Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
163                                            coveredSize, tripCount);
164     b.create<cf::AssertOp>(
165         equals, builder.getStringAttr(
166                     "could not compute dynamic multi-size tile shapes"));
167   }
168 
169   return spec;
170 }
171 
172 // Insert a tile `source` into the destination tensor `dest`. The position at
173 // which the tile is inserted (as well as size of tile) is taken from a given
174 // ExtractSliceOp `sliceOp`.
175 static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
176                                    tensor::ExtractSliceOp sliceOp, Value source,
177                                    Value dest) {
178   return b.create<tensor::InsertSliceOp>(
179       loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
180       sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
181       sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
182 }
183 
184 template <typename LoopTy>
185 static FailureOr<TiledLinalgOp>
186 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
187                  const LinalgTilingOptions &options) {
188   auto nLoops = op.getNumLoops();
189   // Initial tile sizes may be too big, only take the first nLoops.
190   tileSizes = tileSizes.take_front(nLoops);
191 
192   if (llvm::all_of(tileSizes, isZero)) {
193     TiledLinalgOp tiledOp;
194     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
195     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
196                                  tiledOp.op->result_end());
197     return tiledOp;
198   }
199 
200   // 1. Build the tiled loop ranges.
201   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
202   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
203   if (!shapeSizesToLoopsMap)
204     return failure();
205 
206   SmallVector<Range, 4> loopRanges;
207   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
208   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
209       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
210 
211   SmallVector<Attribute, 4> iteratorTypes;
212   for (const auto &attr :
213        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
214     if (loopIndexToRangeIndex.count(attr.index()))
215       iteratorTypes.push_back(attr.value());
216   }
217   // If interchangeVector is empty, use the identity. Build the permutation map
218   // otherwise.
219   auto invPermutationMap =
220       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
221   if (!options.interchangeVector.empty()) {
222     // Based on the pruned iterations (due to zero tile size), recompute the
223     // interchange vector.
224     SmallVector<unsigned, 4> interchangeVector;
225     interchangeVector.reserve(options.interchangeVector.size());
226     for (auto pos : options.interchangeVector) {
227       auto it = loopIndexToRangeIndex.find(pos);
228       if (it == loopIndexToRangeIndex.end())
229         continue;
230       interchangeVector.push_back(it->second);
231     }
232     // Interchange vector is guaranteed to be a permutation,
233     // `inversePermutation` must succeed.
234     invPermutationMap = inversePermutation(
235         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
236     assert(invPermutationMap);
237     SmallVector<int64_t> permutation(interchangeVector.begin(),
238                                      interchangeVector.end());
239     applyPermutationToVector(loopRanges, permutation);
240     applyPermutationToVector(iteratorTypes, permutation);
241   }
242 
243   // 2. Create the tiled loops.
244   LinalgOp res = op;
245   SmallVector<Value, 4> ivs, tensorResults;
246   auto tiledLoopBodyBuilder =
247       [&](OpBuilder &builder, Location loc, ValueRange localIvs,
248           ValueRange operandValuesToUse) -> scf::ValueVector {
249     ivs.assign(localIvs.begin(), localIvs.end());
250 
251     // When an `interchangeVector` is present, it has been applied to the
252     // loop ranges and the iterator types. Apply its inverse to the
253     // resulting loop `ivs` to match the op definition.
254     SmallVector<Value, 4> interchangedIvs;
255     if (!options.interchangeVector.empty())
256       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
257     else
258       interchangedIvs.assign(ivs.begin(), ivs.end());
259 
260     // Tile the `operandValuesToUse` that either match the `op` operands
261     // themselves or the tile loop arguments forwarding them.
262     assert(operandValuesToUse.size() ==
263                static_cast<size_t>(op.getNumInputsAndOutputs()) &&
264            "expect the number of operands and inputs and outputs to match");
265     SmallVector<Value> valuesToTile = operandValuesToUse;
266     auto sizeBounds =
267         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
268     SmallVector<Value, 4> tiledOperands =
269         makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
270                         sizeBounds, /*omitPartialTileCheck=*/false);
271 
272     SmallVector<Type> resultTensorTypes =
273         getTensorOutputTypes(op, tiledOperands);
274     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
275     tensorResults =
276         insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
277     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
278   };
279   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
280                                  tiledLoopBodyBuilder, options.distribution,
281                                  options.distributionTypes);
282 
283   // 3. Transform IndexOp results w.r.t. the tiling.
284   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
285 
286   // 4. Gather the newly created loops and return them with the new op.
287   SmallVector<Operation *, 8> loops;
288   loops.reserve(ivs.size());
289   for (auto iv : ivs) {
290     if (iv.isa<BlockArgument>()) {
291       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
292       assert(loops.back() && "no owner found for induction variable!");
293     } else {
294       // TODO: Instead of doing this, try to recover the ops used instead of the
295       // loop.
296       loops.push_back(nullptr);
297     }
298   }
299 
300   // 5. Get the tensor results from the outermost loop if available. Otherwise
301   // use the previously captured `tensorResults`.
302   Operation *outermostLoop = nullptr;
303   for (Operation *loop : loops)
304     if ((outermostLoop = loop))
305       break;
306 
307   return TiledLinalgOp{
308       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
309 }
310 
311 template <typename LoopTy>
312 FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
313     RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
314   OpBuilder::InsertionGuard g(b);
315   b.setInsertionPoint(op);
316 
317   if (!options.tileSizeComputationFunction)
318     return failure();
319 
320   // Enforce the convention that "tiling by zero" skips tiling a particular
321   // dimension. This convention is significantly simpler to handle instead of
322   // adjusting affine maps to account for missing dimensions.
323   auto nLoops = op.getNumLoops();
324   SmallVector<Value, 4> tileSizeVector =
325       options.tileSizeComputationFunction(b, op);
326   if (tileSizeVector.size() < nLoops) {
327     auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
328     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
329   }
330 
331   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
332 }
333 
334 FailureOr<TiledLinalgOp>
335 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
336                            const LinalgTilingOptions &options) {
337   switch (options.loopType) {
338   case LinalgTilingLoopType::Loops:
339     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
340   case LinalgTilingLoopType::ParallelLoops:
341     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
342   default:;
343   }
344   return failure();
345 }
346 
347 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp`
348 /// and `loopNest` are output parameters that return the new (tiled)
349 /// tensor::PadOp and the loop nest.
350 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
351                                tensor::PadOp &newPadOp, LoopNest &loopNest,
352                                const LinalgTilingOptions &options) {
353   Location loc = op.getLoc();
354   OpBuilder::InsertionGuard g(builder);
355   builder.setInsertionPoint(op);
356 
357   // Clone tensor::PadOp so that the existing op can be replaced more easily.
358   newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation()));
359   // Get rank and tile sizes.
360   int64_t rank = op.getResultType().getRank();
361   SmallVector<Value> tileSizes =
362       options.tileSizeComputationFunction(builder, op);
363   // Normalize untiled padding dimensions to 0.
364   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
365   tileSizes.append(rank - tileSizes.size(), zero);
366   // Compute lower and upper bounds of the loop nest.
367   TilingInterface tilingInterface =
368       dyn_cast<TilingInterface>(op.getOperation());
369   SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder);
370   SmallVector<Value> lbs, dims, allDims, steps;
371   for (int64_t i = 0; i < rank; ++i) {
372     allDims.push_back(ranges[i].size);
373     if (!isZero(tileSizes[i])) {
374       lbs.push_back(ranges[i].offset);
375       dims.push_back(ranges[i].size);
376       steps.push_back(tileSizes[i]);
377     }
378   }
379   // Generate loop nest: One loop per dimension.
380   SmallVector<Value> destOperand =
381       tilingInterface.getDestinationOperands(builder);
382   loopNest = mlir::scf::buildLoopNest(
383       builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
384       [&](OpBuilder &b, Location loc, ValueRange localIvs,
385           ValueRange iterArgs) -> scf::ValueVector {
386         // Compute offsets and sizes of ExtractSliceOp.
387         SmallVector<Value> offsets =
388             computeTileOffsets(b, loc, localIvs, tileSizes);
389         SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
390         // Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
391         // Note: The tensor::PadOp is located outside of the loop nest. It is
392         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
393         auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
394         Value tiledOutput = makeTiledShape(
395             b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims,
396             sizes, /*omitPartialTileCheck=*/false);
397         auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
398         assert(sliceOp && "expected ExtractSliceOp");
399         // Insert the tile into the output tensor.
400         // TODO: Propagate RewriterBase everywhere.
401         IRRewriter rewriter(b);
402         Value yieldValue =
403             insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
404         return scf::ValueVector({yieldValue});
405       });
406   return success();
407 }
408 
409 namespace {
410 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> {
411   PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt)
412       : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {}
413 
414   LogicalResult matchAndRewrite(tensor::PadOp op,
415                                 PatternRewriter &rewriter) const override {
416     if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
417       return failure();
418     tensor::PadOp newPadOp;
419     LoopNest loopNest;
420     if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options)))
421       return failure();
422     newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
423                       rewriter.getUnitAttr());
424     // Replace all uses of the original tensor::PadOp.
425     rewriter.replaceOp(op, loopNest.getResults()[0]);
426     return success();
427   }
428 
429   LinalgTilingOptions options;
430 };
431 } // namespace
432 
433 namespace {
434 /// Helper classes for type list expansion.
435 template <typename... OpTypes>
436 class CanonicalizationPatternList;
437 
438 template <>
439 class CanonicalizationPatternList<> {
440 public:
441   static void insert(RewritePatternSet &patterns) {}
442 };
443 
444 template <typename OpTy, typename... OpTypes>
445 class CanonicalizationPatternList<OpTy, OpTypes...> {
446 public:
447   static void insert(RewritePatternSet &patterns) {
448     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
449     CanonicalizationPatternList<OpTypes...>::insert(patterns);
450   }
451 };
452 } // namespace
453 
454 RewritePatternSet
455 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
456   RewritePatternSet patterns(ctx);
457   populateLinalgTilingCanonicalizationPatterns(patterns);
458   return patterns;
459 }
460 
461 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
462     RewritePatternSet &patterns) {
463   auto *ctx = patterns.getContext();
464   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
465   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
466   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
467   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
468   arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
469 
470   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
471   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
472 
473   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
474   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
475 
476   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
477   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
478   tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
479 
480   InitTensorOp::getCanonicalizationPatterns(patterns, ctx);
481   tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
482   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
483 
484   CanonicalizationPatternList<
485 #define GET_OP_LIST
486 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
487       >::insert(patterns);
488 }
489 
490 /// Populate the given list with patterns that apply Linalg tiling.
491 static void insertTilingPatterns(RewritePatternSet &patterns,
492                                  const LinalgTilingOptions &options) {
493   auto *ctx = patterns.getContext();
494   LinalgTransformationFilter f(ArrayRef<StringAttr>{},
495                                StringAttr::get(ctx, "tiled"));
496   TilingPatterns<GenericOp,
497 #define GET_OP_LIST
498 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
499                  >::insert(patterns, options, f);
500   patterns.add<PadOpTilingPattern>(ctx, options);
501 }
502 
503 void mlir::linalg::populatePadTensorTilingPatterns(
504     RewritePatternSet &patterns, const LinalgTilingOptions &options) {
505   auto *ctx = patterns.getContext();
506   patterns.add<PadOpTilingPattern>(ctx, options);
507 }
508 
509 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
510   MLIRContext *ctx = funcOp.getContext();
511   RewritePatternSet patterns(ctx);
512   patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
513   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
514   (void)applyPatternsAndFoldGreedily(
515       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
516 }
517 
518 namespace {
519 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
520   LinalgTilingPass() = default;
521   LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) {
522     this->tileSizes = tileSizes;
523     this->loopType = "";
524     this->loopTypeEnum = loopType;
525   }
526 
527   void runOnOperation() override {
528     func::FuncOp funcOp = getOperation();
529     LinalgTilingLoopType type =
530         llvm::StringSwitch<LinalgTilingLoopType>(loopType)
531             .Case("for", LinalgTilingLoopType::Loops)
532             .Case("affine", LinalgTilingLoopType::AffineLoops)
533             .Case("parallel", LinalgTilingLoopType::ParallelLoops)
534             .Default(loopTypeEnum);
535     auto options =
536         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type);
537     MLIRContext *ctx = funcOp.getContext();
538     RewritePatternSet patterns(ctx);
539     insertTilingPatterns(patterns, options);
540     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
541     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
542     (void)applyPatternsAndFoldGreedily(
543         funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
544     // Drop the marker.
545     funcOp.walk([](LinalgOp op) {
546       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
547     });
548 
549     // Apply swap pattern after generating loop nest and running
550     // canonicalizations.
551     applyExtractSliceOfPadTensorSwapPattern(funcOp);
552   }
553 
554   LinalgTilingLoopType loopTypeEnum;
555 };
556 
557 } // namespace
558 
559 std::unique_ptr<OperationPass<func::FuncOp>>
560 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
561                              linalg::LinalgTilingLoopType loopType) {
562   return std::make_unique<LinalgTilingPass>(tileSizes, loopType);
563 }
564