1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <utility>
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Utils/IndexingUtils.h"
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/Transforms/FoldUtils.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 
30 #include "llvm/Support/CommandLine.h"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 using namespace mlir::scf;
35 
36 #define DEBUG_TYPE "linalg-tiling"
37 
38 static bool isZero(Value v) {
39   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
40     return cst.value() == 0;
41   return false;
42 }
43 
44 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
45 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
46                                   ValueRange allShapeSizes,
47                                   ValueRange allTileSizes) {
48   assert(allTileSizes.size() == map.getNumResults());
49   // Apply `map` to get shape sizes in loop order.
50   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
51   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
52 
53   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
54   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
55   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
56     if (isZero(tileSizes[idx - zerosCount])) {
57       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
58       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
59       ++zerosCount;
60       continue;
61     }
62     loopIndexToRangeIndex[idx] = idx - zerosCount;
63   }
64 
65   // Create a new range with the applied tile sizes.
66   SmallVector<Range, 4> res;
67   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
68     res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0),
69                         shapeSizes[idx], tileSizes[idx]});
70   return std::make_tuple(res, loopIndexToRangeIndex);
71 }
72 
73 void mlir::linalg::transformIndexOps(
74     RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
75     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
76   SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
77   for (auto &en : enumerate(allIvs)) {
78     auto rangeIndex = loopIndexToRangeIndex.find(en.index());
79     if (rangeIndex == loopIndexToRangeIndex.end())
80       continue;
81     en.value() = ivs[rangeIndex->second];
82   }
83   offsetIndices(b, op, allIvs);
84 }
85 
86 /// Asserts that the given index-typed value is strictly positive. If the value
87 /// is an attribute, asserts at compile time, otherwise emits an assertion
88 /// checked at runtime.
89 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
90                                          OpFoldResult value) {
91   if (auto attr = value.dyn_cast<Attribute>()) {
92     assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
93            "expected strictly positive tile size and divisor");
94     return;
95   }
96 
97   Value zero = b.create<arith::ConstantIndexOp>(0);
98   Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
99                                             value.get<Value>(), zero);
100   b.create<cf::AssertOp>(
101       condition,
102       b.getStringAttr("expected strictly positive tile size and divisor"));
103 }
104 
105 FailureOr<MultiSizeSpecification>
106 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
107                                     unsigned dimension, OpFoldResult targetSize,
108                                     OpFoldResult divisor, bool emitAssertions) {
109   // Bail out on dimension overflow.
110   if (dimension >= op.getNumLoops())
111     return failure();
112 
113   // The code below works only on values.
114   ImplicitLocOpBuilder b(op.getLoc(), builder);
115   if (emitAssertions) {
116     emitIsPositiveIndexAssertion(b, targetSize);
117     emitIsPositiveIndexAssertion(b, divisor);
118   }
119   Value targetSizeValue = materializeOpFoldResult(b, targetSize);
120   Value divisorValue = materializeOpFoldResult(b, divisor);
121 
122   // Find the trip count of the iteration space dimension for which the tile
123   // sizes are computed.
124   // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid
125   // littering by useless constant materialization.
126   SmallVector<Value, 4> allShapes =
127       op.createFlatListOfOperandDims(b, b.getLoc());
128   AffineMap shapesToLoops = op.getShapesToLoopsMap();
129   SmallVector<Value, 4> loopRanges =
130       applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes);
131   Value tripCount = loopRanges[dimension];
132 
133   // Compute the tile sizes and the respective numbers of tiles.
134   AffineExpr s0 = b.getAffineSymbolExpr(0);
135   AffineExpr s1 = b.getAffineSymbolExpr(1);
136   AffineExpr s2 = b.getAffineSymbolExpr(2);
137   auto apply = [&](AffineExpr expr, ValueRange values) -> Value {
138     return makeComposedAffineApply(b, b.getLoc(), expr, values);
139   };
140   Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
141   Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
142   Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
143   Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
144   Value v = apply(s0 % s1, {a, d});
145   Value u = apply(s0 - s1, {d, v});
146 
147   MultiSizeSpecification spec;
148   spec.lowTileSize = s;
149   spec.highTileSize = apply(s0 + s1, {s, divisorValue});
150   spec.lowTripCount = u;
151   spec.highTripCount = v;
152 
153   // If requested, emit the check that the tile sizes are computed correctly.
154   // For example, for iteration dimension size of 15 and the target size 8 it is
155   // impossible to find two tile sizes both divisible by 8 that fully cover the
156   // original space dimension.
157   if (emitAssertions) {
158     AffineExpr s3 = builder.getAffineSymbolExpr(3);
159     Value coveredSize =
160         apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
161                                   spec.highTileSize, spec.highTripCount});
162     Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
163                                            coveredSize, tripCount);
164     b.create<cf::AssertOp>(
165         equals, builder.getStringAttr(
166                     "could not compute dynamic multi-size tile shapes"));
167   }
168 
169   return spec;
170 }
171 
172 /// Given a `subsetExtractOp`, a `source` and a `dest`, create a new
173 /// `ParallelInsertSlice` op of `source` into `dest` at the same subset location
174 /// as `subsetExtractOp`.
175 static void
176 createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
177                                      tensor::ExtractSliceOp subsetExtractOp,
178                                      Value source, Value dest) {
179   b.create<tensor::ParallelInsertSliceOp>(
180       loc, source, dest, subsetExtractOp.getMixedOffsets(),
181       subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides());
182 }
183 
184 /// Build an `affine_max` of all the `vals`.
185 static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
186   return b.createOrFold<AffineMaxOp>(
187       loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
188       vals);
189 }
190 
191 /// Build an `affine_min` of all the `vals`.
192 static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
193   return b.createOrFold<AffineMinOp>(
194       loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
195       vals);
196 }
197 
198 FailureOr<ForeachThreadTilingResult>
199 linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
200                               ArrayRef<OpFoldResult> numThreads,
201                               ArrayRef<int64_t> threadDimMapping) {
202   Location loc = op->getLoc();
203   OpBuilder::InsertionGuard g(b);
204   SmallVector<Range> loopRanges = op.getIterationDomain(b);
205   if (loopRanges.empty())
206     return op->emitOpError("expected non-empty loop ranges");
207   auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
208   if (llvm::any_of(loopRanges, hasStrideOne))
209     return op->emitOpError("only stride-1 supported atm");
210   // TODO: support `getTiledImplementation` with >1 produced tiled ops.
211   auto destOperands = op.getDestinationOperands(b);
212   if (destOperands.size() != 1)
213     return op->emitOpError("only single dest operand supported atm");
214 
215   SmallVector<OpFoldResult> nonZeroNumThreads =
216       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
217         return !isConstantIntValue(ofr, 0);
218       }));
219   SmallVector<Value> materializedNonZeroNumThreads =
220       llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
221         ImplicitLocOpBuilder ilocb(loc, b);
222         return materializeOpFoldResult(ilocb, ofr);
223       }));
224 
225   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
226   Operation *tiledOp = nullptr;
227   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
228       loc, materializedNonZeroNumThreads, threadDimMapping,
229       [&](OpBuilder &b, Location loc, ValueRange threadIds) {
230         int64_t nLoops = loopRanges.size();
231         SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
232         tiledOffsets.reserve(nLoops);
233         tiledSizes.reserve(nLoops);
234         for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
235              ++loopIdx) {
236           bool overflow = loopIdx >= numThreads.size();
237           bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
238           // Degenerate case: take the whole domain.
239           if (overflow || isZero) {
240             tiledOffsets.push_back(loopRanges[loopIdx].offset);
241             tiledSizes.push_back(loopRanges[loopIdx].size);
242             continue;
243           }
244 
245           // Tiled case: compute the offset and size.
246           AffineExpr i, j, M, N, O;
247           bindDims(b.getContext(), i, j);
248           bindSymbols(b.getContext(), M, N, O);
249           Value size = loopRanges[loopIdx].size;
250           Value offset = loopRanges[loopIdx].offset;
251           Value threadId = threadIds[threadIdIdx];
252           // TODO: more aggressive foldings.
253           // Symbolic fixed max size per thread.
254           // TODO: floor + 0/1 depending on case for better load-balancing.
255           Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
256               loc, M.ceilDiv(N),
257               ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
258           // Dynamic offset shifted by threadId * maxSizePerThread.
259           Value offsetPerThread = b.createOrFold<AffineApplyOp>(
260               loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
261           // Dynamic upper-bound depending on the threadId.
262           Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
263               loc, -i + M, ValueRange{offsetPerThread, size});
264           Value tileSizePerThread = buildMin(
265               b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
266           tiledOffsets.push_back(offsetPerThread);
267           // TODO: if tileSizePerThread <= 0 early exit.
268           tiledSizes.push_back(
269               buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
270           ++threadIdIdx;
271         }
272 
273         SmallVector<Operation *> tiledOps =
274             op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
275                                       /*tileDestOperands=*/true);
276         assert(tiledOps.size() == 1 && "expected a single produced tiled op");
277         tiledOp = tiledOps.front();
278 
279         auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
280         assert(tilingInterfaceOp &&
281                "Tiled op does not implement TilingInterface");
282 
283         auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
284 
285         // Create terminator with parallel subset insert operations.
286         auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
287         OpBuilder::InsertionGuard g(b);
288         b.setInsertionPointToStart(performConcurrentlyOp.getBody());
289         for (auto it :
290              llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
291                        destOperands)) {
292           createMatchingParallelSubsetInsertOp(
293               b, loc,
294               cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
295               std::get<1>(it), std::get<2>(it));
296         }
297       });
298   return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
299 }
300 
301 // Insert a tile `source` into the destination tensor `dest`. The position at
302 // which the tile is inserted (as well as size of tile) is taken from a given
303 // ExtractSliceOp `sliceOp`.
304 static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
305                                    tensor::ExtractSliceOp sliceOp, Value source,
306                                    Value dest) {
307   return b.create<tensor::InsertSliceOp>(
308       loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
309       sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
310       sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
311 }
312 
313 template <typename LoopTy>
314 static FailureOr<TiledLinalgOp>
315 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
316                  const LinalgTilingOptions &options) {
317   auto nLoops = op.getNumLoops();
318   // Initial tile sizes may be too big, only take the first nLoops.
319   tileSizes = tileSizes.take_front(nLoops);
320 
321   if (llvm::all_of(tileSizes, isZero)) {
322     TiledLinalgOp tiledOp;
323     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
324     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
325                                  tiledOp.op->result_end());
326     return tiledOp;
327   }
328 
329   // 1. Build the tiled loop ranges.
330   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
331   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
332   if (!shapeSizesToLoopsMap)
333     return failure();
334 
335   SmallVector<Range, 4> loopRanges;
336   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
337   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
338       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
339 
340   SmallVector<Attribute, 4> iteratorTypes;
341   for (const auto &attr :
342        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
343     if (loopIndexToRangeIndex.count(attr.index()))
344       iteratorTypes.push_back(attr.value());
345   }
346   // If interchangeVector is empty, use the identity. Build the permutation map
347   // otherwise.
348   auto invPermutationMap =
349       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
350   if (!options.interchangeVector.empty()) {
351     // Based on the pruned iterations (due to zero tile size), recompute the
352     // interchange vector.
353     SmallVector<unsigned, 4> interchangeVector;
354     interchangeVector.reserve(options.interchangeVector.size());
355     for (auto pos : options.interchangeVector) {
356       auto it = loopIndexToRangeIndex.find(pos);
357       if (it == loopIndexToRangeIndex.end())
358         continue;
359       interchangeVector.push_back(it->second);
360     }
361     // Interchange vector is guaranteed to be a permutation,
362     // `inversePermutation` must succeed.
363     invPermutationMap = inversePermutation(
364         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
365     assert(invPermutationMap);
366     SmallVector<int64_t> permutation(interchangeVector.begin(),
367                                      interchangeVector.end());
368     applyPermutationToVector(loopRanges, permutation);
369     applyPermutationToVector(iteratorTypes, permutation);
370   }
371 
372   // 2. Create the tiled loops.
373   LinalgOp res = op;
374   SmallVector<Value, 4> ivs, tensorResults;
375   auto tiledLoopBodyBuilder =
376       [&](OpBuilder &builder, Location loc, ValueRange localIvs,
377           ValueRange operandValuesToUse) -> scf::ValueVector {
378     ivs.assign(localIvs.begin(), localIvs.end());
379 
380     // When an `interchangeVector` is present, it has been applied to the
381     // loop ranges and the iterator types. Apply its inverse to the
382     // resulting loop `ivs` to match the op definition.
383     SmallVector<Value, 4> interchangedIvs;
384     if (!options.interchangeVector.empty())
385       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
386     else
387       interchangedIvs.assign(ivs.begin(), ivs.end());
388 
389     // Tile the `operandValuesToUse` that either match the `op` operands
390     // themselves or the tile loop arguments forwarding them.
391     assert(operandValuesToUse.size() ==
392                static_cast<size_t>(op.getNumInputsAndOutputs()) &&
393            "expect the number of operands and inputs and outputs to match");
394     SmallVector<Value> valuesToTile = operandValuesToUse;
395     auto sizeBounds =
396         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
397     SmallVector<Value, 4> tiledOperands =
398         makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
399                         sizeBounds, /*omitPartialTileCheck=*/false);
400 
401     SmallVector<Type> resultTensorTypes =
402         getTensorOutputTypes(op, tiledOperands);
403     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
404     tensorResults =
405         insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
406     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
407   };
408   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
409                                  tiledLoopBodyBuilder, options.distribution,
410                                  options.distributionTypes);
411 
412   // 3. Transform IndexOp results w.r.t. the tiling.
413   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
414 
415   // 4. Gather the newly created loops and return them with the new op.
416   SmallVector<Operation *, 8> loops;
417   loops.reserve(ivs.size());
418   for (auto iv : ivs) {
419     if (iv.isa<BlockArgument>()) {
420       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
421       assert(loops.back() && "no owner found for induction variable!");
422     } else {
423       // TODO: Instead of doing this, try to recover the ops used instead of the
424       // loop.
425       loops.push_back(nullptr);
426     }
427   }
428 
429   // 5. Get the tensor results from the outermost loop if available. Otherwise
430   // use the previously captured `tensorResults`.
431   Operation *outermostLoop = nullptr;
432   for (Operation *loop : loops)
433     if ((outermostLoop = loop))
434       break;
435 
436   return TiledLinalgOp{
437       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
438 }
439 
440 template <typename LoopTy>
441 FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
442     RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
443   OpBuilder::InsertionGuard g(b);
444   b.setInsertionPoint(op);
445 
446   if (!options.tileSizeComputationFunction)
447     return failure();
448 
449   // Enforce the convention that "tiling by zero" skips tiling a particular
450   // dimension. This convention is significantly simpler to handle instead of
451   // adjusting affine maps to account for missing dimensions.
452   auto nLoops = op.getNumLoops();
453   SmallVector<Value, 4> tileSizeVector =
454       options.tileSizeComputationFunction(b, op);
455   if (tileSizeVector.size() < nLoops) {
456     auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
457     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
458   }
459 
460   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
461 }
462 
463 FailureOr<TiledLinalgOp>
464 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
465                            const LinalgTilingOptions &options) {
466   switch (options.loopType) {
467   case LinalgTilingLoopType::Loops:
468     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
469   case LinalgTilingLoopType::ParallelLoops:
470     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
471   default:;
472   }
473   return failure();
474 }
475 
476 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp`
477 /// and `loopNest` are output parameters that return the new (tiled)
478 /// tensor::PadOp and the loop nest.
479 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
480                                tensor::PadOp &newPadOp, LoopNest &loopNest,
481                                const LinalgTilingOptions &options) {
482   Location loc = op.getLoc();
483   OpBuilder::InsertionGuard g(builder);
484   builder.setInsertionPoint(op);
485 
486   // Clone tensor::PadOp so that the existing op can be replaced more easily.
487   newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation()));
488   // Get rank and tile sizes.
489   int64_t rank = op.getResultType().getRank();
490   SmallVector<Value> tileSizes =
491       options.tileSizeComputationFunction(builder, op);
492   // Normalize untiled padding dimensions to 0.
493   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
494   tileSizes.append(rank - tileSizes.size(), zero);
495   // Compute lower and upper bounds of the loop nest.
496   TilingInterface tilingInterface =
497       dyn_cast<TilingInterface>(op.getOperation());
498   SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder);
499   SmallVector<Value> lbs, dims, allDims, steps;
500   for (int64_t i = 0; i < rank; ++i) {
501     allDims.push_back(ranges[i].size);
502     if (!isZero(tileSizes[i])) {
503       lbs.push_back(ranges[i].offset);
504       dims.push_back(ranges[i].size);
505       steps.push_back(tileSizes[i]);
506     }
507   }
508   // Generate loop nest: One loop per dimension.
509   SmallVector<Value> destOperand =
510       tilingInterface.getDestinationOperands(builder);
511   loopNest = mlir::scf::buildLoopNest(
512       builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
513       [&](OpBuilder &b, Location loc, ValueRange localIvs,
514           ValueRange iterArgs) -> scf::ValueVector {
515         // Compute offsets and sizes of ExtractSliceOp.
516         SmallVector<Value> offsets =
517             computeTileOffsets(b, loc, localIvs, tileSizes);
518         SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
519         // Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
520         // Note: The tensor::PadOp is located outside of the loop nest. It is
521         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
522         auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
523         Value tiledOutput = makeTiledShape(
524             b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims,
525             sizes, /*omitPartialTileCheck=*/false);
526         auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
527         assert(sliceOp && "expected ExtractSliceOp");
528         // Insert the tile into the output tensor.
529         // TODO: Propagate RewriterBase everywhere.
530         IRRewriter rewriter(b);
531         Value yieldValue =
532             insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
533         return scf::ValueVector({yieldValue});
534       });
535   return success();
536 }
537 
538 namespace {
539 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> {
540   PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt)
541       : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {}
542 
543   LogicalResult matchAndRewrite(tensor::PadOp op,
544                                 PatternRewriter &rewriter) const override {
545     if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
546       return failure();
547     tensor::PadOp newPadOp;
548     LoopNest loopNest;
549     if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options)))
550       return failure();
551     newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
552                       rewriter.getUnitAttr());
553     // Replace all uses of the original tensor::PadOp.
554     rewriter.replaceOp(op, loopNest.getResults()[0]);
555     return success();
556   }
557 
558   LinalgTilingOptions options;
559 };
560 } // namespace
561 
562 namespace {
563 /// Helper classes for type list expansion.
564 template <typename... OpTypes>
565 class CanonicalizationPatternList;
566 
567 template <>
568 class CanonicalizationPatternList<> {
569 public:
570   static void insert(RewritePatternSet &patterns) {}
571 };
572 
573 template <typename OpTy, typename... OpTypes>
574 class CanonicalizationPatternList<OpTy, OpTypes...> {
575 public:
576   static void insert(RewritePatternSet &patterns) {
577     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
578     CanonicalizationPatternList<OpTypes...>::insert(patterns);
579   }
580 };
581 } // namespace
582 
583 RewritePatternSet
584 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
585   RewritePatternSet patterns(ctx);
586   populateLinalgTilingCanonicalizationPatterns(patterns);
587   return patterns;
588 }
589 
590 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
591     RewritePatternSet &patterns) {
592   auto *ctx = patterns.getContext();
593   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
594   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
595   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
596   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
597   arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
598 
599   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
600   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
601 
602   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
603   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
604 
605   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
606   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
607   tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
608 
609   InitTensorOp::getCanonicalizationPatterns(patterns, ctx);
610   tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
611   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
612 
613   CanonicalizationPatternList<
614 #define GET_OP_LIST
615 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
616       >::insert(patterns);
617 }
618 
619 /// Populate the given list with patterns that apply Linalg tiling.
620 static void insertTilingPatterns(RewritePatternSet &patterns,
621                                  const LinalgTilingOptions &options) {
622   auto *ctx = patterns.getContext();
623   LinalgTransformationFilter f(ArrayRef<StringAttr>{},
624                                StringAttr::get(ctx, "tiled"));
625   TilingPatterns<GenericOp,
626 #define GET_OP_LIST
627 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
628                  >::insert(patterns, options, f);
629   patterns.add<PadOpTilingPattern>(ctx, options);
630 }
631 
632 void mlir::linalg::populatePadTensorTilingPatterns(
633     RewritePatternSet &patterns, const LinalgTilingOptions &options) {
634   auto *ctx = patterns.getContext();
635   patterns.add<PadOpTilingPattern>(ctx, options);
636 }
637 
638 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
639   MLIRContext *ctx = funcOp.getContext();
640   RewritePatternSet patterns(ctx);
641   patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
642   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
643   (void)applyPatternsAndFoldGreedily(
644       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
645 }
646 
647 namespace {
648 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
649   LinalgTilingPass() = default;
650   LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) {
651     this->tileSizes = tileSizes;
652     this->loopType = "";
653     this->loopTypeEnum = loopType;
654   }
655 
656   void runOnOperation() override {
657     func::FuncOp funcOp = getOperation();
658     LinalgTilingLoopType type =
659         llvm::StringSwitch<LinalgTilingLoopType>(loopType)
660             .Case("for", LinalgTilingLoopType::Loops)
661             .Case("affine", LinalgTilingLoopType::AffineLoops)
662             .Case("parallel", LinalgTilingLoopType::ParallelLoops)
663             .Default(loopTypeEnum);
664     auto options =
665         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type);
666     MLIRContext *ctx = funcOp.getContext();
667     RewritePatternSet patterns(ctx);
668     insertTilingPatterns(patterns, options);
669     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
670     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
671     (void)applyPatternsAndFoldGreedily(
672         funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
673     // Drop the marker.
674     funcOp.walk([](LinalgOp op) {
675       op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
676     });
677 
678     // Apply swap pattern after generating loop nest and running
679     // canonicalizations.
680     applyExtractSliceOfPadTensorSwapPattern(funcOp);
681   }
682 
683   LinalgTilingLoopType loopTypeEnum;
684 };
685 
686 } // namespace
687 
688 std::unique_ptr<OperationPass<func::FuncOp>>
689 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
690                              linalg::LinalgTilingLoopType loopType) {
691   return std::make_unique<LinalgTilingPass>(tileSizes, loopType);
692 }
693