1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include "llvm/Support/CommandLine.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::scf;
30 
31 #define DEBUG_TYPE "linalg-tiling"
32 
33 static bool isZero(Value v) {
34   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
35     return cst.getValue() == 0;
36   return false;
37 }
38 
39 using LoopIndexToRangeIndexMap = DenseMap<int, int>;
40 
41 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
42 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
43 // one entry per surrounding loop. It uses zero as the convention that a
44 // particular loop is not tiled. This convention simplifies implementations by
45 // avoiding affine map manipulations.
46 // The returned ranges correspond to the loop ranges, in the proper order, that
47 // are tiled and for which new loops will be created. Also the function returns
48 // a map from loop indices of the LinalgOp to the corresponding non-empty range
49 // indices of newly created loops.
50 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
51 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
52                     ValueRange allShapeSizes, ValueRange allTileSizes) {
53   assert(allTileSizes.size() == map.getNumResults());
54   // Apply `map` to get shape sizes in loop order.
55   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
56   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
57 
58   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61     if (isZero(tileSizes[idx - zerosCount])) {
62       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
63       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
64       ++zerosCount;
65       continue;
66     }
67     loopIndexToRangeIndex[idx] = idx - zerosCount;
68   }
69 
70   // Create a new range with the applied tile sizes.
71   SmallVector<Range, 4> res;
72   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
73     res.push_back(Range{b.create<ConstantIndexOp>(loc, 0), shapeSizes[idx],
74                         tileSizes[idx]});
75   return std::make_tuple(res, loopIndexToRangeIndex);
76 }
77 
78 // All indices returned by IndexOp should be invariant with respect to tiling.
79 // Therefore, if an operation is tiled, we have to transform the indices
80 // accordingly, i.e. offset them by the values of the corresponding induction
81 // variables that are captured implicitly in the body of the op.
82 //
83 // Example. `linalg.generic` before tiling:
84 //
85 // #id_2d = (i, j) -> (i, j)
86 // #pointwise_2d_trait = {
87 //   indexing_maps = [#id_2d, #id_2d],
88 //   iterator_types = ["parallel", "parallel"]
89 // }
90 // linalg.generic #pointwise_2d_trait %operand, %result {
91 //   ^bb0(%operand_in: f32, %result_in: f32):
92 //     %i = linalg.index 0 : index
93 //     %j = linalg.index 1 : index
94 //     <some operations that use %i, %j>
95 // }: memref<50x100xf32>, memref<50x100xf32>
96 //
97 // After tiling pass with tiles sizes 10 and 25:
98 //
99 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
100 //
101 // %c1 = constant 1 : index
102 // %c0 = constant 0 : index
103 // %c25 = constant 25 : index
104 // %c10 = constant 10 : index
105 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
106 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
107 // scf.for %k = %c0 to operand_dim_0 step %c10 {
108 //   scf.for %l = %c0 to operand_dim_1 step %c25 {
109 //     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
110 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
111 //     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
112 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
113 //     linalg.generic pointwise_2d_trait %4, %5 {
114 //     ^bb0(%operand_in: f32, %result_in: f32):
115 //       %i = linalg.index 0 : index
116 //       %j = linalg.index 1 : index
117 //       // Indices `k` and `l` are implicitly captured in the body.
118 //       %transformed_i = addi %i, %k : index // index `i` is offset by %k
119 //       %transformed_j = addi %j, %l : index // index `j` is offset by %l
120 //       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
121 //       <some operations that use %transformed_i, %transformed_j>
122 //     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
123 //   }
124 // }
125 //
126 // TODO: Investigate whether mixing implicit and explicit indices
127 // does not lead to losing information.
128 static void
129 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
130                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
131   // Skip operations that have no region attached.
132   if (op->getNumRegions() == 0)
133     return;
134   assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
135          "expected linalg operation to have one block.");
136   Block &block = op->getRegion(0).front();
137 
138   for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) {
139     auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
140     if (rangeIndex == loopIndexToRangeIndex.end())
141       continue;
142     // Offset the index by the value of the corresponding induction variable and
143     // replace all uses of the previous value.
144     OpBuilder::InsertionGuard g(b);
145     b.setInsertionPointAfter(indexOp);
146     AffineExpr index, iv;
147     bindDims(b.getContext(), index, iv);
148     AffineApplyOp applyOp = b.create<AffineApplyOp>(
149         indexOp.getLoc(), index + iv,
150         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
151     indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
152   }
153 }
154 
155 template <typename LoopTy>
156 static Optional<TiledLinalgOp>
157 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
158                  const LinalgTilingOptions &options) {
159   auto nLoops = op.getNumLoops();
160   // Initial tile sizes may be too big, only take the first nLoops.
161   tileSizes = tileSizes.take_front(nLoops);
162 
163   if (llvm::all_of(tileSizes, isZero))
164     return llvm::None;
165 
166   // Canonicalize indexed generic operations before tiling.
167   if (isa<IndexedGenericOp>(op))
168     return llvm::None;
169 
170   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
171     // For conv op only support tiling along batch dimension (which is the first
172     // loop).
173     if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero))
174       return llvm::None;
175   }
176 
177   // 1. Build the tiled loop ranges.
178   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
179   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
180   if (!shapeSizesToLoopsMap)
181     return llvm::None;
182 
183   SmallVector<Range, 4> loopRanges;
184   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
185   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
186       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
187 
188   SmallVector<Attribute, 4> iteratorTypes;
189   for (auto attr :
190        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
191     if (loopIndexToRangeIndex.count(attr.index()))
192       iteratorTypes.push_back(attr.value());
193   }
194   // If interchangeVector is empty, use the identity. Build the permutation map
195   // otherwise.
196   auto invPermutationMap =
197       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
198   if (!options.interchangeVector.empty()) {
199     // Based on the pruned iterations (due to zero tile size), recompute the
200     // interchange vector.
201     SmallVector<unsigned, 4> interchangeVector;
202     interchangeVector.reserve(options.interchangeVector.size());
203     for (auto pos : options.interchangeVector) {
204       auto it = loopIndexToRangeIndex.find(pos);
205       if (it == loopIndexToRangeIndex.end())
206         continue;
207       interchangeVector.push_back(it->second);
208     }
209     // Interchange vector is guaranteed to be a permutation,
210     // `inversePermutation` must succeed.
211     invPermutationMap = inversePermutation(
212         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
213     assert(invPermutationMap);
214     applyPermutationToVector(loopRanges, interchangeVector);
215     applyPermutationToVector(iteratorTypes, interchangeVector);
216   }
217 
218   // 2. Create the tiled loops.
219   LinalgOp res = op;
220   SmallVector<Value, 4> ivs, tensorResults;
221   auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
222                                   ValueRange localIvs,
223                                   ValueRange iterArgs) -> scf::ValueVector {
224     ivs.assign(localIvs.begin(), localIvs.end());
225 
226     // When an `interchangeVector` is present, it has been applied to the
227     // loop ranges and the iterator types. Apply its inverse to the
228     // resulting loop `ivs` to match the op definition.
229     SmallVector<Value, 4> interchangedIvs;
230     if (!options.interchangeVector.empty())
231       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
232     else
233       interchangedIvs.assign(ivs.begin(), ivs.end());
234 
235     assert(op.getOutputTensorOperands().size() == iterArgs.size() &&
236            "num output tensors must match number of loop iter arguments");
237 
238     SmallVector<Value> operands = op.getInputOperands();
239     SmallVector<Value> outputBuffers = op.getOutputBufferOperands();
240     // TODO: thanks to simplifying assumption we do not need to worry about
241     // order of output buffers and tensors: there is only ever one kind.
242     assert(outputBuffers.empty() || iterArgs.empty());
243     operands.append(outputBuffers.begin(), outputBuffers.end());
244     operands.append(iterArgs.begin(), iterArgs.end());
245     auto sizeBounds =
246         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
247     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
248         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
249     auto nonShapedOperands = op.getAssumedNonShapedOperands();
250     tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
251 
252     // TODO: use an interface/adaptor to avoid leaking position in
253     // `tiledOperands`.
254     SmallVector<Type, 4> resultTensorTypes;
255     for (OpOperand *opOperand : op.getOutputTensorOperands())
256       resultTensorTypes.push_back(
257           tiledOperands[opOperand->getOperandNumber()].getType());
258 
259     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
260 
261     // Insert a subtensor_insert for each output tensor.
262     unsigned resultIdx = 0;
263     for (OpOperand *opOperand : op.getOutputTensorOperands()) {
264       // TODO: use an interface/adaptor to avoid leaking position in
265       // `tiledOperands`.
266       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
267       if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) {
268         tensorResults.push_back(b.create<SubTensorInsertOp>(
269             loc, subtensor.source().getType(), res->getResult(resultIdx),
270             subtensor.source(), subtensor.offsets(), subtensor.sizes(),
271             subtensor.strides(), subtensor.static_offsets(),
272             subtensor.static_sizes(), subtensor.static_strides()));
273       } else {
274         tensorResults.push_back(res->getResult(resultIdx));
275       }
276       ++resultIdx;
277     }
278     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
279   };
280   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
281                                  tiledLoopBodyBuilder, options.distribution,
282                                  options.distributionTypes);
283 
284   // 3. Transform IndexOp results w.r.t. the tiling.
285   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
286 
287   // 4. Gather the newly created loops and return them with the new op.
288   SmallVector<Operation *, 8> loops;
289   loops.reserve(ivs.size());
290   for (auto iv : ivs) {
291     if (iv.isa<BlockArgument>()) {
292       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
293       assert(loops.back() && "no owner found for induction variable!");
294     } else {
295       // TODO: Instead of doing this, try to recover the ops used instead of the
296       // loop.
297       loops.push_back(nullptr);
298     }
299   }
300 
301   // 5. Get the tensor results from the outermost loop if available. Otherwise
302   // use the previously captured `tensorResults`.
303   Operation *outermostLoop = nullptr;
304   for (Operation *loop : loops)
305     if ((outermostLoop = loop))
306       break;
307 
308   return TiledLinalgOp{
309       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
310 }
311 
312 template <typename LoopTy>
313 Optional<TiledLinalgOp> static tileLinalgOpImpl(
314     OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
315   OpBuilder::InsertionGuard g(b);
316   b.setInsertionPoint(op);
317 
318   if (!options.tileSizeComputationFunction)
319     return llvm::None;
320 
321   // Enforce the convention that "tiling by zero" skips tiling a particular
322   // dimension. This convention is significantly simpler to handle instead of
323   // adjusting affine maps to account for missing dimensions.
324   auto nLoops = op.getNumLoops();
325   SmallVector<Value, 4> tileSizeVector =
326       options.tileSizeComputationFunction(b, op);
327   if (tileSizeVector.size() < nLoops) {
328     auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
329     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
330   }
331 
332   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
333 }
334 
335 Optional<TiledLinalgOp>
336 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
337                            const LinalgTilingOptions &options) {
338   switch (options.loopType) {
339   case LinalgTilingLoopType::Loops:
340     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
341   case LinalgTilingLoopType::ParallelLoops:
342     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
343   case LinalgTilingLoopType::TiledLoops:
344     return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
345   default:;
346   }
347   return llvm::None;
348 }
349 
350 namespace {
351 /// Helper classes for type list expansion.
352 template <typename... OpTypes>
353 class CanonicalizationPatternList;
354 
355 template <>
356 class CanonicalizationPatternList<> {
357 public:
358   static void insert(RewritePatternSet &patterns) {}
359 };
360 
361 template <typename OpTy, typename... OpTypes>
362 class CanonicalizationPatternList<OpTy, OpTypes...> {
363 public:
364   static void insert(RewritePatternSet &patterns) {
365     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
366     CanonicalizationPatternList<OpTypes...>::insert(patterns);
367   }
368 };
369 
370 /// Helper classes for type list expansion.
371 template <typename... OpTypes>
372 class RewritePatternList;
373 
374 template <>
375 class RewritePatternList<> {
376 public:
377   static void insert(RewritePatternSet &patterns,
378                      const LinalgTilingOptions &options) {}
379 };
380 
381 template <typename OpTy, typename... OpTypes>
382 class RewritePatternList<OpTy, OpTypes...> {
383 public:
384   static void insert(RewritePatternSet &patterns,
385                      const LinalgTilingOptions &options) {
386     auto *ctx = patterns.getContext();
387     patterns.add<LinalgTilingPattern<OpTy>>(
388         ctx, options,
389         LinalgTransformationFilter(ArrayRef<Identifier>{},
390                                    Identifier::get("tiled", ctx)));
391     RewritePatternList<OpTypes...>::insert(patterns, options);
392   }
393 };
394 } // namespace
395 
396 RewritePatternSet
397 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
398   RewritePatternSet patterns(ctx);
399   populateLinalgTilingCanonicalizationPatterns(patterns);
400   return patterns;
401 }
402 
403 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
404     RewritePatternSet &patterns) {
405   auto *ctx = patterns.getContext();
406   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
407   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
408   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
409   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
410   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
411   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
412   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
413   SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
414   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
415   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
416   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
417   CanonicalizationPatternList<
418 #define GET_OP_LIST
419 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
420       >::insert(patterns);
421 }
422 
423 /// Populate the given list with patterns that apply Linalg tiling.
424 static void insertTilingPatterns(RewritePatternSet &patterns,
425                                  const LinalgTilingOptions &options) {
426   RewritePatternList<GenericOp,
427 #define GET_OP_LIST
428 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
429                      >::insert(patterns, options);
430 }
431 
432 static void
433 applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
434                           ArrayRef<int64_t> tileSizes,
435                           ArrayRef<StringRef> distributionTypes = {}) {
436   auto options = LinalgTilingOptions()
437                      .setTileSizes(tileSizes)
438                      .setLoopType(loopType)
439                      .setDistributionTypes(distributionTypes);
440   MLIRContext *ctx = funcOp.getContext();
441   RewritePatternSet patterns(ctx);
442   insertTilingPatterns(patterns, options);
443   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
444   (void)applyPatternsAndFoldGreedily(
445       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
446   // Drop the marker.
447   funcOp.walk([](LinalgOp op) {
448     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
449   });
450 }
451 
452 namespace {
453 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
454   LinalgTilingPass() = default;
455   LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
456 
457   void runOnFunction() override {
458     applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(),
459                               tileSizes);
460   }
461 };
462 
463 struct LinalgTilingToParallelLoopsPass
464     : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> {
465   LinalgTilingToParallelLoopsPass() = default;
466   LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) {
467     tileSizes = sizes;
468   }
469 
470   void runOnFunction() override {
471     applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops,
472                               getFunction(), tileSizes);
473   }
474 };
475 
476 struct LinalgTilingToTiledLoopsPass
477     : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> {
478   LinalgTilingToTiledLoopsPass() = default;
479   LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes,
480                                ArrayRef<StringRef> types) {
481     tileSizes = sizes;
482     distributionTypes = llvm::to_vector<2>(
483         llvm::map_range(types, [](StringRef ref) { return ref.str(); }));
484   }
485 
486   void runOnFunction() override {
487     applyTilingToLoopPatterns(
488         LinalgTilingLoopType::TiledLoops, getFunction(), tileSizes,
489         llvm::to_vector<2>(
490             llvm::map_range(distributionTypes,
491                             [](std::string &str) { return StringRef(str); })));
492   }
493 };
494 
495 } // namespace
496 
497 std::unique_ptr<OperationPass<FuncOp>>
498 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
499   return std::make_unique<LinalgTilingPass>(tileSizes);
500 }
501 
502 std::unique_ptr<OperationPass<FuncOp>>
503 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
504   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
505 }
506 
507 std::unique_ptr<OperationPass<FuncOp>>
508 mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes,
509                                         ArrayRef<StringRef> distributionTypes) {
510   return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes,
511                                                         distributionTypes);
512 }
513