1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 #include "llvm/Support/CommandLine.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::scf;
30 
31 #define DEBUG_TYPE "linalg-tiling"
32 
33 static bool isZero(Value v) {
34   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
35     return cst.getValue() == 0;
36   return false;
37 }
38 
39 using LoopIndexToRangeIndexMap = DenseMap<int, int>;
40 
41 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
42 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
43 // one entry per surrounding loop. It uses zero as the convention that a
44 // particular loop is not tiled. This convention simplifies implementations by
45 // avoiding affine map manipulations.
46 // The returned ranges correspond to the loop ranges, in the proper order, that
47 // are tiled and for which new loops will be created. Also the function returns
48 // a map from loop indices of the LinalgOp to the corresponding non-empty range
49 // indices of newly created loops.
50 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
51 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
52                     ValueRange allShapeSizes, ValueRange allTileSizes) {
53   assert(allTileSizes.size() == map.getNumResults());
54   // Apply `map` to get shape sizes in loop order.
55   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
56   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
57 
58   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61     if (isZero(tileSizes[idx - zerosCount])) {
62       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
63       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
64       ++zerosCount;
65       continue;
66     }
67     loopIndexToRangeIndex[idx] = idx - zerosCount;
68   }
69 
70   // Create a new range with the applied tile sizes.
71   SmallVector<Range, 4> res;
72   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
73     res.push_back(Range{b.create<ConstantIndexOp>(loc, 0), shapeSizes[idx],
74                         tileSizes[idx]});
75   return std::make_tuple(res, loopIndexToRangeIndex);
76 }
77 
78 // All indices returned by IndexOp should be invariant with respect to tiling.
79 // Therefore, if an operation is tiled, we have to transform the indices
80 // accordingly, i.e. offset them by the values of the corresponding induction
81 // variables that are captured implicitly in the body of the op.
82 //
83 // Example. `linalg.generic` before tiling:
84 //
85 // #id_2d = (i, j) -> (i, j)
86 // #pointwise_2d_trait = {
87 //   indexing_maps = [#id_2d, #id_2d],
88 //   iterator_types = ["parallel", "parallel"]
89 // }
90 // linalg.generic #pointwise_2d_trait %operand, %result {
91 //   ^bb0(%operand_in: f32, %result_in: f32):
92 //     %i = linalg.index 0 : index
93 //     %j = linalg.index 1 : index
94 //     <some operations that use %i, %j>
95 // }: memref<50x100xf32>, memref<50x100xf32>
96 //
97 // After tiling pass with tiles sizes 10 and 25:
98 //
99 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
100 //
101 // %c1 = constant 1 : index
102 // %c0 = constant 0 : index
103 // %c25 = constant 25 : index
104 // %c10 = constant 10 : index
105 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
106 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
107 // scf.for %k = %c0 to operand_dim_0 step %c10 {
108 //   scf.for %l = %c0 to operand_dim_1 step %c25 {
109 //     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
110 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
111 //     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
112 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
113 //     linalg.generic pointwise_2d_trait %4, %5 {
114 //     ^bb0(%operand_in: f32, %result_in: f32):
115 //       %i = linalg.index 0 : index
116 //       %j = linalg.index 1 : index
117 //       // Indices `k` and `l` are implicitly captured in the body.
118 //       %transformed_i = addi %i, %k : index // index `i` is offset by %k
119 //       %transformed_j = addi %j, %l : index // index `j` is offset by %l
120 //       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
121 //       <some operations that use %transformed_i, %transformed_j>
122 //     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
123 //   }
124 // }
125 //
126 // TODO: Investigate whether mixing implicit and explicit indices
127 // does not lead to losing information.
128 static void
129 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
130                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
131   // Skip operations that have no region attached.
132   if (op->getNumRegions() == 0)
133     return;
134   assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
135          "expected linalg operation to have one block.");
136   Block &block = op->getRegion(0).front();
137 
138   for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) {
139     auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
140     if (rangeIndex == loopIndexToRangeIndex.end())
141       continue;
142     // Offset the index by the value of the corresponding induction variable and
143     // replace all uses of the previous value.
144     OpBuilder::InsertionGuard g(b);
145     b.setInsertionPointAfter(indexOp);
146     AffineExpr index, iv;
147     bindDims(b.getContext(), index, iv);
148     AffineApplyOp applyOp = b.create<AffineApplyOp>(
149         indexOp.getLoc(), index + iv,
150         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
151     indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
152   }
153 }
154 
155 template <typename LoopTy>
156 static Optional<TiledLinalgOp>
157 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
158                  const LinalgTilingOptions &options) {
159   auto nLoops = op.getNumLoops();
160   // Initial tile sizes may be too big, only take the first nLoops.
161   tileSizes = tileSizes.take_front(nLoops);
162 
163   if (llvm::all_of(tileSizes, isZero))
164     return llvm::None;
165 
166   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
167     // For conv op only support tiling along batch dimension (which is the first
168     // loop).
169     if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero))
170       return llvm::None;
171   }
172 
173   // 1. Build the tiled loop ranges.
174   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
175   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
176   if (!shapeSizesToLoopsMap)
177     return llvm::None;
178 
179   SmallVector<Range, 4> loopRanges;
180   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
181   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
182       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
183 
184   SmallVector<Attribute, 4> iteratorTypes;
185   for (auto attr :
186        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
187     if (loopIndexToRangeIndex.count(attr.index()))
188       iteratorTypes.push_back(attr.value());
189   }
190   // If interchangeVector is empty, use the identity. Build the permutation map
191   // otherwise.
192   auto invPermutationMap =
193       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
194   if (!options.interchangeVector.empty()) {
195     // Based on the pruned iterations (due to zero tile size), recompute the
196     // interchange vector.
197     SmallVector<unsigned, 4> interchangeVector;
198     interchangeVector.reserve(options.interchangeVector.size());
199     for (auto pos : options.interchangeVector) {
200       auto it = loopIndexToRangeIndex.find(pos);
201       if (it == loopIndexToRangeIndex.end())
202         continue;
203       interchangeVector.push_back(it->second);
204     }
205     // Interchange vector is guaranteed to be a permutation,
206     // `inversePermutation` must succeed.
207     invPermutationMap = inversePermutation(
208         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
209     assert(invPermutationMap);
210     applyPermutationToVector(loopRanges, interchangeVector);
211     applyPermutationToVector(iteratorTypes, interchangeVector);
212   }
213 
214   // 2. Create the tiled loops.
215   LinalgOp res = op;
216   SmallVector<Value, 4> ivs, tensorResults;
217   auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
218                                   ValueRange localIvs,
219                                   ValueRange iterArgs) -> scf::ValueVector {
220     ivs.assign(localIvs.begin(), localIvs.end());
221 
222     // When an `interchangeVector` is present, it has been applied to the
223     // loop ranges and the iterator types. Apply its inverse to the
224     // resulting loop `ivs` to match the op definition.
225     SmallVector<Value, 4> interchangedIvs;
226     if (!options.interchangeVector.empty())
227       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
228     else
229       interchangedIvs.assign(ivs.begin(), ivs.end());
230 
231     assert(op.getOutputTensorOperands().size() == iterArgs.size() &&
232            "num output tensors must match number of loop iter arguments");
233 
234     SmallVector<Value> operands = op.getInputOperands();
235     SmallVector<Value> outputBuffers = op.getOutputBufferOperands();
236     // TODO: thanks to simplifying assumption we do not need to worry about
237     // order of output buffers and tensors: there is only ever one kind.
238     assert(outputBuffers.empty() || iterArgs.empty());
239     operands.append(outputBuffers.begin(), outputBuffers.end());
240     operands.append(iterArgs.begin(), iterArgs.end());
241     auto sizeBounds =
242         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
243     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
244         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
245 
246     // TODO: use an interface/adaptor to avoid leaking position in
247     // `tiledOperands`.
248     SmallVector<Type, 4> resultTensorTypes;
249     for (OpOperand *opOperand : op.getOutputTensorOperands())
250       resultTensorTypes.push_back(
251           tiledOperands[opOperand->getOperandNumber()].getType());
252 
253     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
254 
255     // Insert a insert_slice for each output tensor.
256     unsigned resultIdx = 0;
257     for (OpOperand *opOperand : op.getOutputTensorOperands()) {
258       // TODO: use an interface/adaptor to avoid leaking position in
259       // `tiledOperands`.
260       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
261       if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
262         tensorResults.push_back(b.create<tensor::InsertSliceOp>(
263             loc, sliceOp.source().getType(), res->getResult(resultIdx),
264             sliceOp.source(), sliceOp.offsets(), sliceOp.sizes(),
265             sliceOp.strides(), sliceOp.static_offsets(), sliceOp.static_sizes(),
266             sliceOp.static_strides()));
267       } else {
268         tensorResults.push_back(res->getResult(resultIdx));
269       }
270       ++resultIdx;
271     }
272     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
273   };
274   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
275                                  tiledLoopBodyBuilder, options.distribution,
276                                  options.distributionTypes);
277 
278   // 3. Transform IndexOp results w.r.t. the tiling.
279   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
280 
281   // 4. Gather the newly created loops and return them with the new op.
282   SmallVector<Operation *, 8> loops;
283   loops.reserve(ivs.size());
284   for (auto iv : ivs) {
285     if (iv.isa<BlockArgument>()) {
286       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
287       assert(loops.back() && "no owner found for induction variable!");
288     } else {
289       // TODO: Instead of doing this, try to recover the ops used instead of the
290       // loop.
291       loops.push_back(nullptr);
292     }
293   }
294 
295   // 5. Get the tensor results from the outermost loop if available. Otherwise
296   // use the previously captured `tensorResults`.
297   Operation *outermostLoop = nullptr;
298   for (Operation *loop : loops)
299     if ((outermostLoop = loop))
300       break;
301 
302   return TiledLinalgOp{
303       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
304 }
305 
306 template <typename LoopTy>
307 Optional<TiledLinalgOp> static tileLinalgOpImpl(
308     OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
309   OpBuilder::InsertionGuard g(b);
310   b.setInsertionPoint(op);
311 
312   if (!options.tileSizeComputationFunction)
313     return llvm::None;
314 
315   // Enforce the convention that "tiling by zero" skips tiling a particular
316   // dimension. This convention is significantly simpler to handle instead of
317   // adjusting affine maps to account for missing dimensions.
318   auto nLoops = op.getNumLoops();
319   SmallVector<Value, 4> tileSizeVector =
320       options.tileSizeComputationFunction(b, op);
321   if (tileSizeVector.size() < nLoops) {
322     auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
323     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
324   }
325 
326   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
327 }
328 
329 Optional<TiledLinalgOp>
330 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
331                            const LinalgTilingOptions &options) {
332   switch (options.loopType) {
333   case LinalgTilingLoopType::Loops:
334     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
335   case LinalgTilingLoopType::ParallelLoops:
336     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
337   case LinalgTilingLoopType::TiledLoops:
338     return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
339   default:;
340   }
341   return llvm::None;
342 }
343 
344 namespace {
345 /// Helper classes for type list expansion.
346 template <typename... OpTypes>
347 class CanonicalizationPatternList;
348 
349 template <>
350 class CanonicalizationPatternList<> {
351 public:
352   static void insert(RewritePatternSet &patterns) {}
353 };
354 
355 template <typename OpTy, typename... OpTypes>
356 class CanonicalizationPatternList<OpTy, OpTypes...> {
357 public:
358   static void insert(RewritePatternSet &patterns) {
359     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
360     CanonicalizationPatternList<OpTypes...>::insert(patterns);
361   }
362 };
363 
364 /// Helper classes for type list expansion.
365 template <typename... OpTypes>
366 class RewritePatternList;
367 
368 template <>
369 class RewritePatternList<> {
370 public:
371   static void insert(RewritePatternSet &patterns,
372                      const LinalgTilingOptions &options) {}
373 };
374 
375 template <typename OpTy, typename... OpTypes>
376 class RewritePatternList<OpTy, OpTypes...> {
377 public:
378   static void insert(RewritePatternSet &patterns,
379                      const LinalgTilingOptions &options) {
380     auto *ctx = patterns.getContext();
381     patterns.add<LinalgTilingPattern<OpTy>>(
382         ctx, options,
383         LinalgTransformationFilter(ArrayRef<Identifier>{},
384                                    Identifier::get("tiled", ctx)));
385     RewritePatternList<OpTypes...>::insert(patterns, options);
386   }
387 };
388 } // namespace
389 
390 RewritePatternSet
391 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
392   RewritePatternSet patterns(ctx);
393   populateLinalgTilingCanonicalizationPatterns(patterns);
394   return patterns;
395 }
396 
397 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
398     RewritePatternSet &patterns) {
399   auto *ctx = patterns.getContext();
400   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
401   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
402   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
403   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
404   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
405   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
406   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
407   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
408   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
409   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
410   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
411   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
412   CanonicalizationPatternList<
413 #define GET_OP_LIST
414 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
415       >::insert(patterns);
416 }
417 
418 /// Populate the given list with patterns that apply Linalg tiling.
419 static void insertTilingPatterns(RewritePatternSet &patterns,
420                                  const LinalgTilingOptions &options) {
421   RewritePatternList<GenericOp,
422 #define GET_OP_LIST
423 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
424                      >::insert(patterns, options);
425 }
426 
427 static void
428 applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
429                           ArrayRef<int64_t> tileSizes,
430                           ArrayRef<StringRef> distributionTypes = {}) {
431   auto options = LinalgTilingOptions()
432                      .setTileSizes(tileSizes)
433                      .setLoopType(loopType)
434                      .setDistributionTypes(distributionTypes);
435   MLIRContext *ctx = funcOp.getContext();
436   RewritePatternSet patterns(ctx);
437   insertTilingPatterns(patterns, options);
438   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
439   (void)applyPatternsAndFoldGreedily(
440       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
441   // Drop the marker.
442   funcOp.walk([](LinalgOp op) {
443     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
444   });
445 }
446 
447 namespace {
448 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
449   LinalgTilingPass() = default;
450   LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
451 
452   void runOnFunction() override {
453     applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(),
454                               tileSizes);
455   }
456 };
457 
458 struct LinalgTilingToParallelLoopsPass
459     : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> {
460   LinalgTilingToParallelLoopsPass() = default;
461   LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) {
462     tileSizes = sizes;
463   }
464 
465   void runOnFunction() override {
466     applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops,
467                               getFunction(), tileSizes);
468   }
469 };
470 
471 struct LinalgTilingToTiledLoopsPass
472     : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> {
473   LinalgTilingToTiledLoopsPass() = default;
474   LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes,
475                                ArrayRef<StringRef> types) {
476     tileSizes = sizes;
477     distributionTypes = llvm::to_vector<2>(
478         llvm::map_range(types, [](StringRef ref) { return ref.str(); }));
479   }
480 
481   void runOnFunction() override {
482     applyTilingToLoopPatterns(
483         LinalgTilingLoopType::TiledLoops, getFunction(), tileSizes,
484         llvm::to_vector<2>(
485             llvm::map_range(distributionTypes,
486                             [](std::string &str) { return StringRef(str); })));
487   }
488 };
489 
490 } // namespace
491 
492 std::unique_ptr<OperationPass<FuncOp>>
493 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
494   return std::make_unique<LinalgTilingPass>(tileSizes);
495 }
496 
497 std::unique_ptr<OperationPass<FuncOp>>
498 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
499   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
500 }
501 
502 std::unique_ptr<OperationPass<FuncOp>>
503 mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes,
504                                         ArrayRef<StringRef> distributionTypes) {
505   return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes,
506                                                         distributionTypes);
507 }
508