1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 
29 #include "llvm/Support/CommandLine.h"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 using namespace mlir::scf;
36 
37 #define DEBUG_TYPE "linalg-tiling"
38 
39 static bool isZero(Value v) {
40   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
41     return cst.getValue() == 0;
42   return false;
43 }
44 
45 using LoopIndexToRangeIndexMap = DenseMap<int, int>;
46 
47 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
48 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
49 // one entry per surrounding loop. It uses zero as the convention that a
50 // particular loop is not tiled. This convention simplifies implementations by
51 // avoiding affine map manipulations.
52 // The returned ranges correspond to the loop ranges, in the proper order, that
53 // are tiled and for which new loops will be created. Also the function returns
54 // a map from loop indices of the LinalgOp to the corresponding non-empty range
55 // indices of newly created loops.
56 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
57 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
58                     ValueRange allShapeSizes, ValueRange allTileSizes) {
59   assert(allTileSizes.size() == map.getNumResults());
60   // Apply `map` to get shape sizes in loop order.
61   auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
62   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
63 
64   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
65   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
66   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
67     if (isZero(tileSizes[idx - zerosCount])) {
68       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
69       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
70       ++zerosCount;
71       continue;
72     }
73     loopIndexToRangeIndex[idx] = idx - zerosCount;
74   }
75 
76   // Create a new range with the applied tile sizes.
77   SmallVector<Range, 4> res;
78   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
79     res.push_back(
80         Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]});
81   return std::make_tuple(res, loopIndexToRangeIndex);
82 }
83 
84 // All indices returned by IndexOp should be invariant with respect to tiling.
85 // Therefore, if an operation is tiled, we have to transform the indices
86 // accordingly, i.e. offset them by the values of the corresponding induction
87 // variables that are captured implicitly in the body of the op.
88 //
89 // Example. `linalg.generic` before tiling:
90 //
91 // #id_2d = (i, j) -> (i, j)
92 // #pointwise_2d_trait = {
93 //   indexing_maps = [#id_2d, #id_2d],
94 //   iterator_types = ["parallel", "parallel"]
95 // }
96 // linalg.generic #pointwise_2d_trait %operand, %result {
97 //   ^bb0(%operand_in: f32, %result_in: f32):
98 //     %i = linalg.index 0 : index
99 //     %j = linalg.index 1 : index
100 //     <some operations that use %i, %j>
101 // }: memref<50x100xf32>, memref<50x100xf32>
102 //
103 // After tiling pass with tiles sizes 10 and 25:
104 //
105 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
106 //
107 // %c1 = constant 1 : index
108 // %c0 = constant 0 : index
109 // %c25 = constant 25 : index
110 // %c10 = constant 10 : index
111 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
112 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
113 // scf.for %k = %c0 to operand_dim_0 step %c10 {
114 //   scf.for %l = %c0 to operand_dim_1 step %c25 {
115 //     %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
116 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
117 //     %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
118 //       : memref<50x100xf32> to memref<?x?xf32, #strided>
119 //     linalg.generic pointwise_2d_trait %4, %5 {
120 //     ^bb0(%operand_in: f32, %result_in: f32):
121 //       %i = linalg.index 0 : index
122 //       %j = linalg.index 1 : index
123 //       // Indices `k` and `l` are implicitly captured in the body.
124 //       %transformed_i = addi %i, %k : index // index `i` is offset by %k
125 //       %transformed_j = addi %j, %l : index // index `j` is offset by %l
126 //       // Every use of %i, %j is replaced with %transformed_i, %transformed_j
127 //       <some operations that use %transformed_i, %transformed_j>
128 //     }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
129 //   }
130 // }
131 //
132 // TODO: Investigate whether mixing implicit and explicit indices
133 // does not lead to losing information.
134 static void
135 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
136                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
137   // Skip operations that have no region attached.
138   if (op->getNumRegions() == 0)
139     return;
140   assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
141          "expected linalg operation to have one block.");
142   Block &block = op->getRegion(0).front();
143 
144   for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) {
145     auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim());
146     if (rangeIndex == loopIndexToRangeIndex.end())
147       continue;
148     // Offset the index by the value of the corresponding induction variable and
149     // replace all uses of the previous value.
150     OpBuilder::InsertionGuard g(b);
151     b.setInsertionPointAfter(indexOp);
152     AffineExpr index, iv;
153     bindDims(b.getContext(), index, iv);
154     AffineApplyOp applyOp = b.create<AffineApplyOp>(
155         indexOp.getLoc(), index + iv,
156         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
157     indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
158   }
159 }
160 
161 template <typename LoopTy>
162 static Optional<TiledLinalgOp>
163 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
164                  const LinalgTilingOptions &options) {
165   auto nLoops = op.getNumLoops();
166   // Initial tile sizes may be too big, only take the first nLoops.
167   tileSizes = tileSizes.take_front(nLoops);
168 
169   if (llvm::all_of(tileSizes, isZero))
170     return llvm::None;
171 
172   // Canonicalize indexed generic operations before tiling.
173   if (isa<IndexedGenericOp>(op))
174     return llvm::None;
175 
176   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
177     // For conv op only support tiling along batch dimension (which is the first
178     // loop).
179     if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero))
180       return llvm::None;
181   }
182 
183   // 1. Build the tiled loop ranges.
184   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
185   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
186   if (!shapeSizesToLoopsMap)
187     return llvm::None;
188 
189   SmallVector<Range, 4> loopRanges;
190   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
191   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
192       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
193 
194   SmallVector<Attribute, 4> iteratorTypes;
195   for (auto attr :
196        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
197     if (loopIndexToRangeIndex.count(attr.index()))
198       iteratorTypes.push_back(attr.value());
199   }
200   // If interchangeVector is empty, use the identity. Build the permutation map
201   // otherwise.
202   auto invPermutationMap =
203       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
204   if (!options.interchangeVector.empty()) {
205     // Based on the pruned iterations (due to zero tile size), recompute the
206     // interchange vector.
207     SmallVector<unsigned, 4> interchangeVector;
208     interchangeVector.reserve(options.interchangeVector.size());
209     for (auto pos : options.interchangeVector) {
210       auto it = loopIndexToRangeIndex.find(pos);
211       if (it == loopIndexToRangeIndex.end())
212         continue;
213       interchangeVector.push_back(it->second);
214     }
215     // Interchange vector is guaranteed to be a permutation,
216     // `inversePermutation` must succeed.
217     invPermutationMap = inversePermutation(
218         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
219     assert(invPermutationMap);
220     applyPermutationToVector(loopRanges, interchangeVector);
221     applyPermutationToVector(iteratorTypes, interchangeVector);
222   }
223 
224   // 2. Create the tiled loops.
225   LinalgOp res = op;
226   SmallVector<Value, 4> ivs, tensorResults;
227   auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
228                                   ValueRange localIvs,
229                                   ValueRange iterArgs) -> scf::ValueVector {
230     ivs.assign(localIvs.begin(), localIvs.end());
231 
232     // When an `interchangeVector` is present, it has been applied to the
233     // loop ranges and the iterator types. Apply its inverse to the
234     // resulting loop `ivs` to match the op definition.
235     SmallVector<Value, 4> interchangedIvs;
236     if (!options.interchangeVector.empty())
237       interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
238     else
239       interchangedIvs.assign(ivs.begin(), ivs.end());
240 
241     assert(op.getNumOutputTensors() == iterArgs.size() &&
242            "num output tensors must match number of loop iter arguments");
243 
244     auto operands = llvm::to_vector<4>(op.getInputs());
245     SmallVector<Value, 4> outputBuffers = op.getOutputBuffers();
246     // TODO: thanks to simplifying assumption we do not need to worry about
247     // order of output buffers and tensors: there is only ever one kind.
248     assert(outputBuffers.empty() || iterArgs.empty());
249     operands.append(outputBuffers.begin(), outputBuffers.end());
250     operands.append(iterArgs.begin(), iterArgs.end());
251     auto sizeBounds =
252         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
253     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
254         b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
255     auto nonShapedOperands = op.getAssumedNonShapedOperands();
256     tiledOperands.append(nonShapedOperands.begin(), nonShapedOperands.end());
257 
258     // TODO: use an interface/adaptor to avoid leaking position in
259     // `tiledOperands`.
260     SmallVector<Type, 4> resultTensorTypes;
261     for (OpOperand *opOperand : op.getOutputTensorsOpOperands())
262       resultTensorTypes.push_back(
263           tiledOperands[opOperand->getOperandNumber()].getType());
264 
265     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
266 
267     // Insert a subtensor_insert for each output tensor.
268     unsigned resultIdx = 0;
269     for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) {
270       // TODO: use an interface/adaptor to avoid leaking position in
271       // `tiledOperands`.
272       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
273       if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) {
274         tensorResults.push_back(b.create<SubTensorInsertOp>(
275             loc, subtensor.source().getType(), res->getResult(resultIdx),
276             subtensor.source(), subtensor.offsets(), subtensor.sizes(),
277             subtensor.strides(), subtensor.static_offsets(),
278             subtensor.static_sizes(), subtensor.static_strides()));
279       } else {
280         tensorResults.push_back(res->getResult(resultIdx));
281       }
282       ++resultIdx;
283     }
284     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
285   };
286   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
287                                  tiledLoopBodyBuilder, options.distribution);
288 
289   // 3. Transform IndexOp results w.r.t. the tiling.
290   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
291 
292   // 4. Gather the newly created loops and return them with the new op.
293   SmallVector<Operation *, 8> loops;
294   loops.reserve(ivs.size());
295   for (auto iv : ivs) {
296     if (iv.isa<BlockArgument>()) {
297       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
298       assert(loops.back() && "no owner found for induction variable!");
299     } else {
300       // TODO: Instead of doing this, try to recover the ops used instead of the
301       // loop.
302       loops.push_back(nullptr);
303     }
304   }
305 
306   // 5. Get the tensor results from the outermost loop if available. Otherwise
307   // use the previously captured `tensorResults`.
308   Operation *outermostLoop = nullptr;
309   for (Operation *loop : loops)
310     if ((outermostLoop = loop))
311       break;
312 
313   return TiledLinalgOp{
314       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
315 }
316 
317 template <typename LoopTy>
318 Optional<TiledLinalgOp> static tileLinalgOpImpl(
319     OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
320   OpBuilder::InsertionGuard g(b);
321   b.setInsertionPoint(op);
322   ScopedContext scope(b, op.getLoc());
323 
324   if (!options.tileSizeComputationFunction)
325     return llvm::None;
326 
327   // Enforce the convention that "tiling by zero" skips tiling a particular
328   // dimension. This convention is significantly simpler to handle instead of
329   // adjusting affine maps to account for missing dimensions.
330   auto nLoops = op.getNumLoops();
331   SmallVector<Value, 4> tileSizeVector =
332       options.tileSizeComputationFunction(b, op);
333   if (tileSizeVector.size() < nLoops) {
334     auto zero = std_constant_index(0);
335     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
336   }
337 
338   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
339 }
340 
341 Optional<TiledLinalgOp>
342 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
343                            const LinalgTilingOptions &options) {
344   switch (options.loopType) {
345   case LinalgTilingLoopType::Loops:
346     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
347   case LinalgTilingLoopType::ParallelLoops:
348     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
349   case LinalgTilingLoopType::TiledLoops:
350     return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
351   default:;
352   }
353   return llvm::None;
354 }
355 
356 namespace {
357 /// Helper classes for type list expansion.
358 template <typename... OpTypes>
359 class CanonicalizationPatternList;
360 
361 template <>
362 class CanonicalizationPatternList<> {
363 public:
364   static void insert(RewritePatternSet &patterns) {}
365 };
366 
367 template <typename OpTy, typename... OpTypes>
368 class CanonicalizationPatternList<OpTy, OpTypes...> {
369 public:
370   static void insert(RewritePatternSet &patterns) {
371     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
372     CanonicalizationPatternList<OpTypes...>::insert(patterns);
373   }
374 };
375 
376 /// Helper classes for type list expansion.
377 template <typename... OpTypes>
378 class RewritePatternList;
379 
380 template <>
381 class RewritePatternList<> {
382 public:
383   static void insert(RewritePatternSet &patterns,
384                      const LinalgTilingOptions &options) {}
385 };
386 
387 template <typename OpTy, typename... OpTypes>
388 class RewritePatternList<OpTy, OpTypes...> {
389 public:
390   static void insert(RewritePatternSet &patterns,
391                      const LinalgTilingOptions &options) {
392     auto *ctx = patterns.getContext();
393     patterns.add<LinalgTilingPattern<OpTy>>(
394         ctx, options,
395         LinalgTransformationFilter(ArrayRef<Identifier>{},
396                                    Identifier::get("tiled", ctx)));
397     RewritePatternList<OpTypes...>::insert(patterns, options);
398   }
399 };
400 } // namespace
401 
402 RewritePatternSet
403 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
404   RewritePatternSet patterns(ctx);
405   populateLinalgTilingCanonicalizationPatterns(patterns);
406   return patterns;
407 }
408 
409 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
410     RewritePatternSet &patterns) {
411   auto *ctx = patterns.getContext();
412   AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
413   AffineForOp::getCanonicalizationPatterns(patterns, ctx);
414   AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
415   AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
416   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
417   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
418   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
419   SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
420   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
421   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
422   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
423   CanonicalizationPatternList<
424 #define GET_OP_LIST
425 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
426       >::insert(patterns);
427 }
428 
429 /// Populate the given list with patterns that apply Linalg tiling.
430 static void insertTilingPatterns(RewritePatternSet &patterns,
431                                  const LinalgTilingOptions &options) {
432   RewritePatternList<GenericOp,
433 #define GET_OP_LIST
434 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
435                      >::insert(patterns, options);
436 }
437 
438 static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
439                                       FuncOp funcOp,
440                                       ArrayRef<int64_t> tileSizes) {
441   auto options =
442       LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType);
443   MLIRContext *ctx = funcOp.getContext();
444   RewritePatternSet patterns(ctx);
445   insertTilingPatterns(patterns, options);
446   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
447   (void)applyPatternsAndFoldGreedily(
448       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
449   // Drop the marker.
450   funcOp.walk([](LinalgOp op) {
451     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
452   });
453 }
454 
455 namespace {
456 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> {
457   LinalgTilingPass() = default;
458   LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
459 
460   void runOnFunction() override {
461     applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(),
462                               tileSizes);
463   }
464 };
465 
466 struct LinalgTilingToParallelLoopsPass
467     : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> {
468   LinalgTilingToParallelLoopsPass() = default;
469   LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) {
470     tileSizes = sizes;
471   }
472 
473   void runOnFunction() override {
474     applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops,
475                               getFunction(), tileSizes);
476   }
477 };
478 
479 struct LinalgTilingToTiledLoopsPass
480     : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> {
481   LinalgTilingToTiledLoopsPass() = default;
482   LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
483 
484   void runOnFunction() override {
485     applyTilingToLoopPatterns(LinalgTilingLoopType::TiledLoops, getFunction(),
486                               tileSizes);
487   }
488 };
489 
490 } // namespace
491 
492 std::unique_ptr<OperationPass<FuncOp>>
493 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
494   return std::make_unique<LinalgTilingPass>(tileSizes);
495 }
496 
497 std::unique_ptr<OperationPass<FuncOp>>
498 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
499   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
500 }
501 
502 std::unique_ptr<OperationPass<FuncOp>>
503 mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes) {
504   return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes);
505 }
506