1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/SCF/EDSC/Builders.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineExprVisitor.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 
28 #include "llvm/Support/CommandLine.h"
29 
30 using namespace mlir;
31 using namespace mlir::edsc;
32 using namespace mlir::edsc::intrinsics;
33 using namespace mlir::linalg;
34 using namespace mlir::scf;
35 
36 using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
37 
38 #define DEBUG_TYPE "linalg-tiling"
39 
40 static bool isZero(Value v) {
41   return isa_and_nonnull<ConstantIndexOp>(v.getDefiningOp()) &&
42          cast<ConstantIndexOp>(v.getDefiningOp()).getValue() == 0;
43 }
44 
45 using LoopIndexToRangeIndexMap = DenseMap<int, int>;
46 
47 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
48 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
49 // one entry per surrounding loop. It uses zero as the convention that a
50 // particular loop is not tiled. This convention simplifies implementations by
51 // avoiding affine map manipulations.
52 // The returned ranges correspond to the loop ranges, in the proper order, that
53 // are tiled and for which new loops will be created. Also the function returns
54 // a map from loop indices of the LinalgOp to the corresponding non-empty range
55 // indices of newly created loops.
56 static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
57 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
58                     ArrayRef<Value> allViewSizes, ArrayRef<Value> allTileSizes,
59                     OperationFolder *folder) {
60   assert(allTileSizes.size() == map.getNumResults());
61   // Apply `map` to get view sizes in loop order.
62   auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
63   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
64 
65   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
66   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
67   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
68     if (isZero(tileSizes[idx - zerosCount])) {
69       viewSizes.erase(viewSizes.begin() + idx - zerosCount);
70       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
71       ++zerosCount;
72       continue;
73     }
74     loopIndexToRangeIndex[idx] = idx - zerosCount;
75   }
76 
77   // Create a new range with the applied tile sizes.
78   SmallVector<SubViewOp::Range, 4> res;
79   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
80     res.push_back(SubViewOp::Range{folded_std_constant_index(folder, 0),
81                                    viewSizes[idx], tileSizes[idx]});
82   }
83   return std::make_tuple(res, loopIndexToRangeIndex);
84 }
85 
86 namespace {
87 
88 // Helper visitor to determine whether an AffineExpr is tiled.
89 // This is achieved by traversing every AffineDimExpr with position `pos` and
90 // checking whether the corresponding `tileSizes[pos]` is non-zero.
91 // This also enforces only positive coefficients occur in multiplications.
92 //
93 // Example:
94 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
95 //
96 struct TileCheck : public AffineExprVisitor<TileCheck> {
97   TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {}
98 
99   void visitDimExpr(AffineDimExpr expr) {
100     isTiled |= !isZero(tileSizes[expr.getPosition()]);
101   }
102   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
103     visit(expr.getLHS());
104     visit(expr.getRHS());
105     if (expr.getKind() == mlir::AffineExprKind::Mul)
106       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
107              "nonpositive multiplying coefficient");
108   }
109   bool isTiled;
110   ArrayRef<Value> tileSizes;
111 };
112 
113 } // namespace
114 
115 // IndexedGenericOp explicitly uses induction variables in the loop body. The
116 // values of the indices that are used in the loop body for any given access of
117 // input/output memref before `subview` op was applied should be invariant with
118 // respect to tiling.
119 //
120 // Therefore, if the operation is tiled, we have to transform the indices
121 // accordingly, i.e. offset them by the values of the corresponding induction
122 // variables that are captured implicitly in the body of the op.
123 //
124 // Example. `linalg.indexed_generic` before tiling:
125 //
126 // #id_2d = (i, j) -> (i, j)
127 // #pointwise_2d_trait = {
128 //   indexing_maps = [#id_2d, #id_2d],
129 //   iterator_types = ["parallel", "parallel"],
130 //   n_views = [1, 1]
131 // }
132 // linalg.indexed_generic #pointwise_2d_trait %operand, %result {
133 //   ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
134 //     <some operations that use %i, %j>
135 // }: memref<50x100xf32>, memref<50x100xf32>
136 //
137 // After tiling pass with tiles sizes 10 and 25:
138 //
139 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
140 //
141 // %c1 = constant 1 : index
142 // %c0 = constant 0 : index
143 // %c25 = constant 25 : index
144 // %c10 = constant 10 : index
145 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
146 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
147 // scf.for %k = %c0 to operand_dim_0 step %c10 {
148 //   scf.for %l = %c0 to operand_dim_1 step %c25 {
149 //     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
150 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
151 //     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
152 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
153 //     linalg.indexed_generic pointwise_2d_trait %4, %5 {
154 //     ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
155 //       // Indices `k` and `l` are implicitly captured in the body.
156 //       %transformed_i = addi %i, %k : index // index `i` is offset by %k
157 //       %transformed_j = addi %j, %l : index // index `j` is offset by %l
158 //       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
159 //       <some operations that use %transformed_i, %transformed_j>
160 //     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
161 //   }
162 // }
163 //
164 // TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
165 // does not lead to losing information.
166 static void transformIndexedGenericOpIndices(
167     OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
168     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
169   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
170   auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
171   if (!indexedGenericOp)
172     return;
173 
174   // `linalg.indexed_generic` comes in two flavours. One has a region with a
175   // single block that defines the loop body. The other has a `fun` attribute
176   // that refers to an existing function symbol. The `fun` function call will be
177   // inserted in the loop body in that case.
178   //
179   // TODO(pifon): Add support for `linalg.indexed_generic` with `fun` attribute.
180   auto &region = indexedGenericOp.region();
181   if (region.empty()) {
182     indexedGenericOp.emitOpError("expected a region");
183     return;
184   }
185   auto &block = region.getBlocks().front();
186 
187   OpBuilder::InsertionGuard g(b);
188   b.setInsertionPointToStart(&block);
189   for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) {
190     auto rangeIndex = loopIndexToRangeIndex.find(i);
191     if (rangeIndex == loopIndexToRangeIndex.end())
192       continue;
193     Value oldIndex = block.getArgument(i);
194     // Offset the index argument `i` by the value of the corresponding induction
195     // variable and replace all uses of the previous value.
196     Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
197                                       ivs[rangeIndex->second]);
198     for (auto &use : oldIndex.getUses()) {
199       if (use.getOwner() == newIndex.getDefiningOp())
200         continue;
201       use.set(newIndex);
202     }
203   }
204 }
205 
206 static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
207   if (!expr)
208     return false;
209   TileCheck t(tileSizes);
210   t.visit(expr);
211   return t.isTiled;
212 }
213 
214 // Checks whether the view with index `viewIndex` within `linalgOp` varies with
215 // respect to a non-zero `tileSize`.
216 static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
217   if (!map)
218     return false;
219   for (unsigned r = 0; r < map.getNumResults(); ++r)
220     if (isTiled(map.getResult(r), tileSizes))
221       return true;
222   return false;
223 }
224 
225 static SmallVector<Value, 4>
226 makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
227                ArrayRef<Value> ivs, ArrayRef<Value> tileSizes,
228                ArrayRef<Value> viewSizes, OperationFolder *folder) {
229   assert(linalgOp.hasBufferSemantics() &&
230          "expected linalg op with buffer semantics");
231   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
232                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
233                            [](Value v) { return !isZero(v); })) &&
234          "expected as many ivs as non-zero sizes");
235 
236   using namespace edsc::op;
237 
238   // Construct (potentially temporary) mins and maxes on which to apply maps
239   // that define tile subviews.
240   SmallVector<Value, 8> lbs, subViewSizes;
241   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
242     bool isTiled = !isZero(tileSizes[idx]);
243     lbs.push_back(isTiled ? ivs[idxIvs++]
244                           : (Value)folded_std_constant_index(folder, 0));
245     subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]);
246   }
247 
248   auto *op = linalgOp.getOperation();
249 
250   SmallVector<Value, 4> res;
251   res.reserve(op->getNumOperands());
252   auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin();
253   for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
254        ++viewIndex) {
255     Value view = *(viewIteratorBegin + viewIndex);
256     auto viewType = view.getType().cast<MemRefType>();
257     unsigned rank = viewType.getRank();
258     auto mapAttr = linalgOp.indexing_maps()[viewIndex];
259     auto map = mapAttr.cast<AffineMapAttr>().getValue();
260     // If the view is not tiled, we can use it as is.
261     if (!isTiled(map, tileSizes)) {
262       res.push_back(view);
263       continue;
264     }
265 
266     // Construct a new subview for the tile.
267     SmallVector<Value, 4> offsets, sizes, strides;
268     offsets.reserve(rank);
269     sizes.reserve(rank);
270     strides.reserve(rank);
271     for (unsigned r = 0; r < rank; ++r) {
272       if (!isTiled(map.getSubMap({r}), tileSizes)) {
273         offsets.push_back(folded_std_constant_index(folder, 0));
274         sizes.push_back(std_dim(view, r));
275         strides.push_back(folded_std_constant_index(folder, 1));
276         continue;
277       }
278 
279       // Tiling creates a new slice at the proper index, the slice step is 1
280       // (i.e. the slice view does not subsample, stepping occurs in the loop).
281       auto m = map.getSubMap({r});
282       auto offset = applyMapToValues(b, loc, m, lbs, folder).front();
283       offsets.push_back(offset);
284       auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front();
285 
286       // The size of the subview should be trimmed to avoid out-of-bounds
287       // accesses, unless we statically know the subview size divides the view
288       // size evenly.
289       int64_t viewSize = viewType.getDimSize(r);
290       auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
291       if (ShapedType::isDynamic(viewSize) || !sizeCst ||
292           (viewSize % sizeCst.getValue()) != 0) {
293         // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
294         auto minMap = AffineMap::get(
295             /*dimCount=*/3, /*symbolCount=*/0,
296             {getAffineDimExpr(/*position=*/0, b.getContext()),
297              getAffineDimExpr(/*position=*/1, b.getContext()) -
298                  getAffineDimExpr(/*position=*/2, b.getContext())},
299             b.getContext());
300         auto d = folded_std_dim(folder, view, r);
301         size = folded_affine_min(folder, b.getIndexType(), minMap,
302                                  ValueRange{size, d, offset});
303       }
304 
305       sizes.push_back(size);
306       strides.push_back(folded_std_constant_index(folder, 1));
307     }
308 
309     res.push_back(b.create<SubViewOp>(loc, view, offsets, sizes, strides));
310   }
311 
312   // Traverse the mins/maxes and erase those that don't have uses left.
313   // This is a special type of folding that we only apply when `folder` is
314   // defined.
315   if (folder)
316     for (auto v : llvm::concat<Value>(lbs, subViewSizes))
317       if (v.use_empty())
318         v.getDefiningOp()->erase();
319 
320   return res;
321 }
322 
323 template <typename LoopTy>
324 Optional<TiledLinalgOp> static tileLinalgOpImpl(
325     OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
326     ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
327   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
328   // 1. Enforce the convention that "tiling by zero" skips tiling a particular
329   // dimension. This convention is significantly simpler to handle instead of
330   // adjusting affine maps to account for missing dimensions.
331   assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
332                  op.getNumWindowLoops() ==
333              tileSizes.size() &&
334          "expected matching number of tile sizes and loops");
335 
336   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
337     // For conv op only support tiling along batch dimension (which is the first
338     // loop).
339     if (convOp.padding() &&
340         !llvm::all_of(tileSizes.drop_front(),
341                       [](Value val) { return isZero(val); }))
342       return llvm::None;
343   }
344 
345   // If interchangeVector is empty, use the identity. Build the permutation map
346   // otherwise.
347   auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
348       tileSizes.size(), ScopedContext::getContext());
349   if (!interchangeVector.empty())
350     invPermutationMap = inversePermutation(AffineMap::getPermutationMap(
351         interchangeVector, ScopedContext::getContext()));
352   if (!invPermutationMap)
353     return llvm::None;
354 
355   OpBuilder::InsertionGuard g(b);
356   b.setInsertionPoint(op);
357   ScopedContext scope(b, op.getLoc());
358   // 2. Build the tiled loop ranges.
359   auto viewSizes = getViewSizes(b, op);
360   // The flattened loopToOperandRangesMaps is expected to be an invertible
361   // permutation map (asserted in the inverse calculation).
362   auto mapsRange = op.indexing_maps().getAsRange<AffineMapAttr>();
363   auto maps = llvm::to_vector<8>(
364       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
365   auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
366   if (!viewSizesToLoopsMap)
367     return llvm::None;
368 
369   SmallVector<SubViewOp::Range, 4> loopRanges;
370   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
371   std::tie(loopRanges, loopIndexToRangeIndex) =
372       makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
373                           viewSizes, tileSizes, folder);
374   if (!interchangeVector.empty())
375     applyPermutationToVector(loopRanges, interchangeVector);
376 
377   // 3. Create the tiled loops.
378   LinalgOp res = op;
379   SmallVector<Value, 4> ivs(loopRanges.size());
380   // Convert SubViewOp::Range to linalg_range.
381   SmallVector<Value, 4> linalgRanges;
382   for (auto &range : loopRanges) {
383     linalgRanges.push_back(
384         linalg_range(range.offset, range.size, range.stride));
385   }
386   GenericLoopNestRangeBuilder<LoopTy>(ivs, linalgRanges)([&] {
387     auto &b = ScopedContext::getBuilderRef();
388     auto loc = ScopedContext::getLocation();
389     SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
390 
391     // If we have to apply a permutation to the tiled loop nest, we have to
392     // reorder the induction variables This permutation is the right one
393     // assuming that loopRanges have previously been permuted by
394     // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
395     // that one: (d0,d1,d2)->(d2,d0,d1)
396     if (!interchangeVector.empty())
397       ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
398 
399     auto views =
400         makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
401     auto operands = getAssumedNonViewOperands(op);
402     views.append(operands.begin(), operands.end());
403     res = op.clone(b, loc, views);
404   });
405 
406   // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
407   transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
408 
409   // 5. Gather the newly created loops and return them with the new op.
410   SmallVector<Operation *, 8> loops;
411   loops.reserve(ivs.size());
412   for (auto iv : ivs) {
413     loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
414     assert(loops.back() && "no owner found for induction variable!");
415   }
416 
417   return TiledLinalgOp{res, loops};
418 }
419 
420 template <typename LoopTy>
421 static Optional<TiledLinalgOp>
422 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
423                  ArrayRef<unsigned> interchangeVector,
424                  OperationFolder *folder) {
425   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
426   if (tileSizes.empty())
427     return llvm::None;
428 
429   // The following uses the convention that "tiling by zero" skips tiling a
430   // particular dimension. This convention is significantly simpler to handle
431   // instead of adjusting affine maps to account for missing dimensions.
432   auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
433                 op.getNumWindowLoops();
434   tileSizes = tileSizes.take_front(nLoops);
435   // If only 0 tilings are left, then return.
436   if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
437     return llvm::None;
438 
439   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
440     // For conv op only support tiling along batch dimension (which is the first
441     // loop).
442     if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(),
443                                           [](int64_t val) { return val == 0; }))
444       return llvm::None;
445   }
446 
447   // Create a builder for tile size constants.
448   OpBuilder::InsertionGuard g(b);
449   b.setInsertionPoint(op);
450   ScopedContext scope(b, op.getLoc());
451 
452   // Materialize concrete tile size values to pass the generic tiling function.
453   SmallVector<Value, 8> tileSizeValues;
454   tileSizeValues.reserve(tileSizes.size());
455   for (auto ts : tileSizes)
456     tileSizeValues.push_back(folded_std_constant_index(folder, ts));
457   // Pad tile sizes with zero values to enforce our convention.
458   if (tileSizeValues.size() < nLoops) {
459     for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
460       tileSizeValues.push_back(folded_std_constant_index(folder, 0));
461   }
462 
463   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeValues, interchangeVector,
464                                   folder);
465 }
466 
467 Optional<TiledLinalgOp>
468 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
469                            ArrayRef<unsigned> interchangeVector,
470                            OperationFolder *folder) {
471   return tileLinalgOpImpl<scf::ForOp>(b, op, tileSizes, interchangeVector,
472                                       folder);
473 }
474 
475 Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
476     OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
477     ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
478   return tileLinalgOpImpl<scf::ParallelOp>(b, op, tileSizes, interchangeVector,
479                                            folder);
480 }
481 
482 Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
483     OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
484     ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
485   return tileLinalgOpImpl<scf::ForOp>(b, op, tileSizes, interchangeVector,
486                                       folder);
487 }
488 
489 Optional<TiledLinalgOp> mlir::linalg::tileLinalgOpToParallelLoops(
490     OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
491     ArrayRef<unsigned> interchangeVector, OperationFolder *folder) {
492   return tileLinalgOpImpl<scf::ParallelOp>(b, op, tileSizes, interchangeVector,
493                                            folder);
494 }
495 
496 template <typename LoopTy>
497 static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
498   OpBuilder b(f);
499   OperationFolder folder(f.getContext());
500   f.walk([tileSizes, &b, &folder](LinalgOp op) {
501     if (!op.hasBufferSemantics())
502       return;
503     auto opLoopsPair = tileLinalgOpImpl<LoopTy>(
504         b, op, tileSizes, /*interchangeVector=*/{}, &folder);
505     // If tiling occurred successfully, erase old op.
506     if (opLoopsPair)
507       op.erase();
508   });
509   f.walk([](LinalgOp op) {
510     if (isOpTriviallyDead(op))
511       op.erase();
512   });
513 }
514 
515 namespace {
516 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
517   LinalgTilingPass() = default;
518   LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
519 
520   void runOnFunction() override {
521     tileLinalgOps<scf::ForOp>(getFunction(), tileSizes);
522   }
523 };
524 
525 struct LinalgTilingToParallelLoopsPass
526     : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> {
527   LinalgTilingToParallelLoopsPass() = default;
528   LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) {
529     tileSizes = sizes;
530   }
531 
532   void runOnFunction() override {
533     tileLinalgOps<scf::ParallelOp>(getFunction(), tileSizes);
534   }
535 };
536 
537 } // namespace
538 
539 std::unique_ptr<OperationPass<FuncOp>>
540 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
541   return std::make_unique<LinalgTilingPass>(tileSizes);
542 }
543 
544 std::unique_ptr<OperationPass<FuncOp>>
545 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
546   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
547 }
548