1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include "llvm/Support/CommandLine.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::scf;
30 
31 #define DEBUG_TYPE "linalg-tiling"
32 
33 static bool isZero(Value v) {
34   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
35     return cst.getValue() == 0;
36   return false;
37 }
38 
39 using LoopIndexToRangeIndexMap = DenseMap<int, int>;
40 
41 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
42 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
43 // one entry per surrounding loop. It uses zero as the convention that a
44 // particular loop is not tiled. This convention simplifies implementations by
45 // avoiding affine map manipulations.
46 // The returned ranges correspond to the loop ranges, in the proper order, that
47 // are tiled and for which new loops will be created. Also the function returns
48 // a map from loop indices of the LinalgOp to the corresponding non-empty range
49 // indices of newly created loops.
50 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
51 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
52                     ValueRange allShapeSizes, ValueRange allTileSizes) {
53   assert(allTileSizes.size() == map.getNumResults());
54   // Apply `map` to get shape sizes in loop order.
55   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
56   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
57 
58   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61     if (isZero(tileSizes[idx - zerosCount])) {
62       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
63       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
64       ++zerosCount;
65       continue;
66     }
67     loopIndexToRangeIndex[idx] = idx - zerosCount;
68   }
69 
70   // Create a new range with the applied tile sizes.
71   SmallVector<Range, 4> res;
72   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
73     res.push_back(Range{b.create<ConstantIndexOp>(loc, 0), shapeSizes[idx],
74                         tileSizes[idx]});
75   return std::make_tuple(res, loopIndexToRangeIndex);
76 }
77 
78 // All indices returned by IndexOp should be invariant with respect to tiling.
79 // Therefore, if an operation is tiled, we have to transform the indices
80 // accordingly, i.e. offset them by the values of the corresponding induction
81 // variables that are captured implicitly in the body of the op.
82 //
83 // Example. `linalg.generic` before tiling:
84 //
85 // #id_2d = (i, j) -> (i, j)
86 // #pointwise_2d_trait = {
87 //   indexing_maps = [#id_2d, #id_2d],
88 //   iterator_types = ["parallel", "parallel"]
89 // }
90 // linalg.generic #pointwise_2d_trait %operand, %result {
91 //   ^bb0(%operand_in: f32, %result_in: f32):
92 //     %i = linalg.index 0 : index
93 //     %j = linalg.index 1 : index
94 //     <some operations that use %i, %j>
95 // }: memref<50x100xf32>, memref<50x100xf32>
96 //
97 // After tiling pass with tiles sizes 10 and 25:
98 //
99 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
100 //
101 // %c1 = constant 1 : index
102 // %c0 = constant 0 : index
103 // %c25 = constant 25 : index
104 // %c10 = constant 10 : index
105 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
106 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
107 // scf.for %k = %c0 to operand_dim_0 step %c10 {
108 //   scf.for %l = %c0 to operand_dim_1 step %c25 {
109 //     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
110 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
111 //     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
112 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
113 //     linalg.generic pointwise_2d_trait %4, %5 {
114 //     ^bb0(%operand_in: f32, %result_in: f32):
115 //       %i = linalg.index 0 : index
116 //       %j = linalg.index 1 : index
117 //       // Indices `k` and `l` are implicitly captured in the body.
118 //       %transformed_i = addi %i, %k : index // index `i` is offset by %k
119 //       %transformed_j = addi %j, %l : index // index `j` is offset by %l
120 //       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
121 //       <some operations that use %transformed_i, %transformed_j>
122 //     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
123 //   }
124 // }
125 //
126 // TODO: Investigate whether mixing implicit and explicit indices
127 // does not lead to losing information.
128 static void
129 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
130                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
131   // Skip operations that have no region attached.
132   if (op->getNumRegions() == 0)
133     return;
134   assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
135          "expected linalg operation to have one block.");
136   Block &block = op->getRegion(0).front();
137 
138   for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) {
139     auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
140     if (rangeIndex == loopIndexToRangeIndex.end())
141       continue;
142     // Offset the index by the value of the corresponding induction variable and
143     // replace all uses of the previous value.
144     OpBuilder::InsertionGuard g(b);
145     b.setInsertionPointAfter(indexOp);
146     AffineExpr index, iv;
147     bindDims(b.getContext(), index, iv);
148     AffineApplyOp applyOp = b.create<AffineApplyOp>(
149         indexOp.getLoc(), index + iv,
150         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
151     indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
152   }
153 }
154 
155 template <typename LoopTy>
156 static Optional<TiledLinalgOp>
157 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
158                  const LinalgTilingOptions &options) {
159   auto nLoops = op.getNumLoops();
160   // Initial tile sizes may be too big, only take the first nLoops.
161   tileSizes = tileSizes.take_front(nLoops);
162 
163   if (llvm::all_of(tileSizes, isZero))
164     return llvm::None;
165 
166   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
167     // For conv op only support tiling along batch dimension (which is the first
168     // loop).
169     if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero))
170       return llvm::None;
171   }
172 
173   // 1. Build the tiled loop ranges.
174   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
175   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
176   if (!shapeSizesToLoopsMap)
177     return llvm::None;
178 
179   SmallVector<Range, 4> loopRanges;
180   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
181   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
182       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
183 
184   SmallVector<Attribute, 4> iteratorTypes;
185   for (auto attr :
186        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
187     if (loopIndexToRangeIndex.count(attr.index()))
188       iteratorTypes.push_back(attr.value());
189   }
190   // If interchangeVector is empty, use the identity. Build the permutation map
191   // otherwise.
192   auto invPermutationMap =
193       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
194   if (!options.interchangeVector.empty()) {
195     // Based on the pruned iterations (due to zero tile size), recompute the
196     // interchange vector.
197     SmallVector<unsigned, 4> interchangeVector;
198     interchangeVector.reserve(options.interchangeVector.size());
199     for (auto pos : options.interchangeVector) {
200       auto it = loopIndexToRangeIndex.find(pos);
201       if (it == loopIndexToRangeIndex.end())
202         continue;
203       interchangeVector.push_back(it->second);
204     }
205     // Interchange vector is guaranteed to be a permutation,
206     // `inversePermutation` must succeed.
207     invPermutationMap = inversePermutation(
208         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
209     assert(invPermutationMap);
210     applyPermutationToVector(loopRanges, interchangeVector);
211     applyPermutationToVector(iteratorTypes, interchangeVector);
212   }
213 
214   // 2. Create the tiled loops.
215   LinalgOp res = op;
216   SmallVector<Value, 4> ivs, tensorResults;
217   auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
218                                   ValueRange localIvs,
219                                   ValueRange iterArgs) -> scf::ValueVector {
220     ivs.assign(localIvs.begin(), localIvs.end());
221 
222     // When an `interchangeVector` is present, it has been applied to the
223     // loop ranges and the iterator types. Apply its inverse to the
224     // resulting loop `ivs` to match the op definition.
225     SmallVector<Value, 4> interchangedIvs;
226     if (!options.interchangeVector.empty())
227       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
228     else
229       interchangedIvs.assign(ivs.begin(), ivs.end());
230 
231     assert(op.getOutputTensorOperands().size() == iterArgs.size() &&
232            "num output tensors must match number of loop iter arguments");
233 
234     SmallVector<Value> operands = op.getInputOperands();
235     SmallVector<Value> outputBuffers = op.getOutputBufferOperands();
236     // TODO: thanks to simplifying assumption we do not need to worry about
237     // order of output buffers and tensors: there is only ever one kind.
238     assert(outputBuffers.empty() || iterArgs.empty());
239     operands.append(outputBuffers.begin(), outputBuffers.end());
240     operands.append(iterArgs.begin(), iterArgs.end());
241     auto sizeBounds =
242         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
243     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
244         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
245     auto nonShapedOperands = op.getAssumedNonShapedOperands();
246     tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
247 
248     // TODO: use an interface/adaptor to avoid leaking position in
249     // `tiledOperands`.
250     SmallVector<Type, 4> resultTensorTypes;
251     for (OpOperand *opOperand : op.getOutputTensorOperands())
252       resultTensorTypes.push_back(
253           tiledOperands[opOperand->getOperandNumber()].getType());
254 
255     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
256 
257     // Insert a insert_slice for each output tensor.
258     unsigned resultIdx = 0;
259     for (OpOperand *opOperand : op.getOutputTensorOperands()) {
260       // TODO: use an interface/adaptor to avoid leaking position in
261       // `tiledOperands`.
262       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
263       if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
264         tensorResults.push_back(b.create<tensor::InsertSliceOp>(
265             loc, sliceOp.source().getType(), res->getResult(resultIdx),
266             sliceOp.source(), sliceOp.offsets(), sliceOp.sizes(),
267             sliceOp.strides(), sliceOp.static_offsets(), sliceOp.static_sizes(),
268             sliceOp.static_strides()));
269       } else {
270         tensorResults.push_back(res->getResult(resultIdx));
271       }
272       ++resultIdx;
273     }
274     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
275   };
276   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
277                                  tiledLoopBodyBuilder, options.distribution,
278                                  options.distributionTypes);
279 
280   // 3. Transform IndexOp results w.r.t. the tiling.
281   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
282 
283   // 4. Gather the newly created loops and return them with the new op.
284   SmallVector<Operation *, 8> loops;
285   loops.reserve(ivs.size());
286   for (auto iv : ivs) {
287     if (iv.isa<BlockArgument>()) {
288       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
289       assert(loops.back() && "no owner found for induction variable!");
290     } else {
291       // TODO: Instead of doing this, try to recover the ops used instead of the
292       // loop.
293       loops.push_back(nullptr);
294     }
295   }
296 
297   // 5. Get the tensor results from the outermost loop if available. Otherwise
298   // use the previously captured `tensorResults`.
299   Operation *outermostLoop = nullptr;
300   for (Operation *loop : loops)
301     if ((outermostLoop = loop))
302       break;
303 
304   return TiledLinalgOp{
305       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
306 }
307 
308 template <typename LoopTy>
309 Optional<TiledLinalgOp> static tileLinalgOpImpl(
310     OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
311   OpBuilder::InsertionGuard g(b);
312   b.setInsertionPoint(op);
313 
314   if (!options.tileSizeComputationFunction)
315     return llvm::None;
316 
317   // Enforce the convention that "tiling by zero" skips tiling a particular
318   // dimension. This convention is significantly simpler to handle instead of
319   // adjusting affine maps to account for missing dimensions.
320   auto nLoops = op.getNumLoops();
321   SmallVector<Value, 4> tileSizeVector =
322       options.tileSizeComputationFunction(b, op);
323   if (tileSizeVector.size() < nLoops) {
324     auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
325     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
326   }
327 
328   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
329 }
330 
331 Optional<TiledLinalgOp>
332 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
333                            const LinalgTilingOptions &options) {
334   switch (options.loopType) {
335   case LinalgTilingLoopType::Loops:
336     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
337   case LinalgTilingLoopType::ParallelLoops:
338     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
339   case LinalgTilingLoopType::TiledLoops:
340     return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
341   default:;
342   }
343   return llvm::None;
344 }
345 
346 namespace {
347 /// Helper classes for type list expansion.
348 template <typename... OpTypes>
349 class CanonicalizationPatternList;
350 
351 template <>
352 class CanonicalizationPatternList<> {
353 public:
354   static void insert(RewritePatternSet &patterns) {}
355 };
356 
357 template <typename OpTy, typename... OpTypes>
358 class CanonicalizationPatternList<OpTy, OpTypes...> {
359 public:
360   static void insert(RewritePatternSet &patterns) {
361     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
362     CanonicalizationPatternList<OpTypes...>::insert(patterns);
363   }
364 };
365 
366 /// Helper classes for type list expansion.
367 template <typename... OpTypes>
368 class RewritePatternList;
369 
370 template <>
371 class RewritePatternList<> {
372 public:
373   static void insert(RewritePatternSet &patterns,
374                      const LinalgTilingOptions &options) {}
375 };
376 
377 template <typename OpTy, typename... OpTypes>
378 class RewritePatternList<OpTy, OpTypes...> {
379 public:
380   static void insert(RewritePatternSet &patterns,
381                      const LinalgTilingOptions &options) {
382     auto *ctx = patterns.getContext();
383     patterns.add<LinalgTilingPattern<OpTy>>(
384         ctx, options,
385         LinalgTransformationFilter(ArrayRef<Identifier>{},
386                                    Identifier::get("tiled", ctx)));
387     RewritePatternList<OpTypes...>::insert(patterns, options);
388   }
389 };
390 } // namespace
391 
392 RewritePatternSet
393 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
394   RewritePatternSet patterns(ctx);
395   populateLinalgTilingCanonicalizationPatterns(patterns);
396   return patterns;
397 }
398 
399 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
400     RewritePatternSet &patterns) {
401   auto *ctx = patterns.getContext();
402   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
403   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
404   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
405   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
406   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
407   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
408   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
409   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
410   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
411   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
412   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
413   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
414   CanonicalizationPatternList<
415 #define GET_OP_LIST
416 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
417       >::insert(patterns);
418 }
419 
420 /// Populate the given list with patterns that apply Linalg tiling.
421 static void insertTilingPatterns(RewritePatternSet &patterns,
422                                  const LinalgTilingOptions &options) {
423   RewritePatternList<GenericOp,
424 #define GET_OP_LIST
425 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
426                      >::insert(patterns, options);
427 }
428 
429 static void
430 applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
431                           ArrayRef<int64_t> tileSizes,
432                           ArrayRef<StringRef> distributionTypes = {}) {
433   auto options = LinalgTilingOptions()
434                      .setTileSizes(tileSizes)
435                      .setLoopType(loopType)
436                      .setDistributionTypes(distributionTypes);
437   MLIRContext *ctx = funcOp.getContext();
438   RewritePatternSet patterns(ctx);
439   insertTilingPatterns(patterns, options);
440   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
441   (void)applyPatternsAndFoldGreedily(
442       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
443   // Drop the marker.
444   funcOp.walk([](LinalgOp op) {
445     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
446   });
447 }
448 
449 namespace {
450 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
451   LinalgTilingPass() = default;
452   LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
453 
454   void runOnFunction() override {
455     applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(),
456                               tileSizes);
457   }
458 };
459 
460 struct LinalgTilingToParallelLoopsPass
461     : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> {
462   LinalgTilingToParallelLoopsPass() = default;
463   LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) {
464     tileSizes = sizes;
465   }
466 
467   void runOnFunction() override {
468     applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops,
469                               getFunction(), tileSizes);
470   }
471 };
472 
473 struct LinalgTilingToTiledLoopsPass
474     : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> {
475   LinalgTilingToTiledLoopsPass() = default;
476   LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes,
477                                ArrayRef<StringRef> types) {
478     tileSizes = sizes;
479     distributionTypes = llvm::to_vector<2>(
480         llvm::map_range(types, [](StringRef ref) { return ref.str(); }));
481   }
482 
483   void runOnFunction() override {
484     applyTilingToLoopPatterns(
485         LinalgTilingLoopType::TiledLoops, getFunction(), tileSizes,
486         llvm::to_vector<2>(
487             llvm::map_range(distributionTypes,
488                             [](std::string &str) { return StringRef(str); })));
489   }
490 };
491 
492 } // namespace
493 
494 std::unique_ptr<OperationPass<FuncOp>>
495 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
496   return std::make_unique<LinalgTilingPass>(tileSizes);
497 }
498 
499 std::unique_ptr<OperationPass<FuncOp>>
500 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
501   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
502 }
503 
504 std::unique_ptr<OperationPass<FuncOp>>
505 mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes,
506                                         ArrayRef<StringRef> distributionTypes) {
507   return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes,
508                                                         distributionTypes);
509 }
510