1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the linalg dialect Tiling pass.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/LoopOps/LoopOps.h"
23 #include "mlir/EDSC/Helpers.h"
24 #include "mlir/IR/AffineExpr.h"
25 #include "mlir/IR/AffineExprVisitor.h"
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
29 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
30 #include "mlir/Dialect/Linalg/Passes.h"
31 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
32 #include "mlir/Dialect/Linalg/Utils/Utils.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Support/STLExtras.h"
36 #include "mlir/Transforms/FoldUtils.h"
37 
38 #include "llvm/Support/CommandLine.h"
39 
40 using namespace mlir;
41 using namespace mlir::edsc;
42 using namespace mlir::edsc::intrinsics;
43 using namespace mlir::linalg;
44 using namespace mlir::linalg::intrinsics;
45 using namespace mlir::loop;
46 
47 #define DEBUG_TYPE "linalg-tiling"
48 
49 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
50 static llvm::cl::list<unsigned>
51     clTileSizes("linalg-tile-sizes",
52                 llvm::cl::desc("Tile sizes by which to tile linalg operations"),
53                 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
54                 llvm::cl::cat(clOptionsCategory));
55 static llvm::cl::opt<bool> clPromoteFullTileViews(
56     "linalg-tile-promote-full-tile-views",
57     llvm::cl::desc("Create scoped local buffers for tiled views "),
58     llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
59 
60 static bool isZero(Value *v) {
61   return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
62          cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
63 }
64 
65 // Creates a number of ranges equal to the number of non-zero in `tileSizes`.
66 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
67 // one entry per surrounding loop. It uses zero as the convention that a
68 // particular loop is not tiled. This convention simplifies implementations by
69 // avoiding affine map manipulations.
70 // The returned ranges correspond to the loop ranges, in the proper order, that
71 // are tiled and for which new loops will be created.
72 static SmallVector<SubViewOp::Range, 4>
73 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
74                     ArrayRef<Value *> allViewSizes,
75                     ArrayRef<Value *> allTileSizes, OperationFolder &folder) {
76   assert(allTileSizes.size() == map.getNumResults());
77   // Apply `map` to get view sizes in loop order.
78   auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
79   SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
80 
81   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
82   for (int idx = tileSizes.size() - 1; idx >= 0; --idx) {
83     if (isZero(tileSizes[idx])) {
84       viewSizes.erase(viewSizes.begin() + idx);
85       tileSizes.erase(tileSizes.begin() + idx);
86     }
87   }
88 
89   // Create a new range with the applied tile sizes.
90   SmallVector<SubViewOp::Range, 4> res;
91   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
92     res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx],
93                                    tileSizes[idx]});
94   }
95   return res;
96 }
97 
98 namespace {
99 // Helper visitor to determine whether an AffineExpr is tiled.
100 // This is achieved by traversing every AffineDimExpr with position `pos` and
101 // checking whether the corresponding `tileSizes[pos]` is non-zero.
102 // This also enforces only positive coefficients occur in multiplications.
103 //
104 // Example:
105 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
106 //
107 struct TileCheck : public AffineExprVisitor<TileCheck> {
108   TileCheck(ArrayRef<Value *> tileSizes)
109       : isTiled(false), tileSizes(tileSizes) {}
110 
111   void visitDimExpr(AffineDimExpr expr) {
112     isTiled |= !isZero(tileSizes[expr.getPosition()]);
113   }
114   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
115     visit(expr.getLHS());
116     visit(expr.getRHS());
117     if (expr.getKind() == mlir::AffineExprKind::Mul)
118       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
119              "nonpositive multipliying coefficient");
120   }
121   bool isTiled;
122   ArrayRef<Value *> tileSizes;
123 };
124 } // namespace
125 
126 static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
127   if (!expr)
128     return false;
129   TileCheck t(tileSizes);
130   t.visit(expr);
131   return t.isTiled;
132 }
133 
134 // Checks whether the view with index `viewIndex` within `linalgOp` varies with
135 // respect to a non-zero `tileSize`.
136 static bool isTiled(AffineMap map, ArrayRef<Value *> tileSizes) {
137   if (!map)
138     return false;
139   for (unsigned r = 0; r < map.getNumResults(); ++r)
140     if (isTiled(map.getResult(r), tileSizes))
141       return true;
142   return false;
143 }
144 
145 static SmallVector<Value *, 4>
146 makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
147                ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes,
148                ArrayRef<Value *> viewSizes, OperationFolder &folder) {
149   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
150                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
151                            [](Value *v) { return !isZero(v); })) &&
152          "expected as many ivs as non-zero sizes");
153 
154   using edsc::intrinsics::select;
155   using edsc::op::operator+;
156   using edsc::op::operator<;
157 
158   // Construct (potentially temporary) mins and maxes on which to apply maps
159   // that define tile subviews.
160   SmallVector<Value *, 8> mins, maxes;
161   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
162     if (isZero(tileSizes[idx])) {
163       mins.push_back(constant_index(folder, 0));
164       maxes.push_back(viewSizes[idx]);
165     } else {
166       ValueHandle lb(ivs[idxIvs++]), step(tileSizes[idx]);
167       mins.push_back(lb);
168       maxes.push_back(lb + step);
169     }
170   }
171 
172   auto *op = linalgOp.getOperation();
173 
174   SmallVector<Value *, 4> res;
175   res.reserve(op->getNumOperands());
176   auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
177   for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
178        ++viewIndex) {
179     Value *view = *(viewIteratorBegin + viewIndex);
180     unsigned viewRank = view->getType().cast<ViewType>().getRank();
181     auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
182     // If the view is not tiled, we can use it as is.
183     if (!isTiled(map, tileSizes)) {
184       res.push_back(view);
185       continue;
186     }
187 
188     // Construct a new subview for the tile.
189     SmallVector<SubViewOp::Range, 4> subViewOperands;
190     subViewOperands.reserve(viewRank * 3);
191     for (unsigned r = 0; r < viewRank; ++r) {
192       if (!isTiled(map.getSubMap({r}), tileSizes)) {
193         subViewOperands.push_back(SubViewOp::Range{
194             constant_index(folder, 0), linalg::intrinsics::dim(view, r),
195             constant_index(folder, 1)});
196         continue;
197       }
198 
199       auto m = map.getSubMap({r});
200       auto *min = applyMapToValues(b, loc, m, mins, folder).front();
201       auto *max = applyMapToValues(b, loc, m, maxes, folder).front();
202       // Tiling creates a new slice at the proper index, the slice step is 1
203       // (i.e. the slice view does not subsample, stepping occurs in the loop).
204       subViewOperands.push_back(
205           SubViewOp::Range{min, max, constant_index(folder, 1)});
206     }
207     res.push_back(b.create<SubViewOp>(loc, view, subViewOperands));
208   }
209 
210   // Traverse the mins/maxes and erase those that don't have uses left.
211   mins.append(maxes.begin(), maxes.end());
212   for (auto *v : mins)
213     if (v->use_empty())
214       v->getDefiningOp()->erase();
215 
216   return res;
217 }
218 
219 static AffineMap getAffineDifferenceMap(MLIRContext *context) {
220   AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
221   return AffineMap::get(2, 0, {d0 - d1});
222 }
223 
224 static Value *allocBuffer(Type elementType, Value *size) {
225   if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
226     return buffer_alloc(
227         BufferType::get(size->getContext(), elementType, cst.getValue()));
228   return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
229 }
230 
231 // Performs promotion of a `subView` into a local buffer of the size of the
232 // *ranges* of the `subView`. This produces a buffer whose size may be bigger
233 // than the actual size of the `subView` at the boundaries.
234 // This is related to the full/partial tile problem.
235 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
236 // `partialLocalView` such that:
237 //   * `buffer` is always the size of the full tile.
238 //   * `fullLocalView` is a dense contiguous view into that buffer.
239 //   * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
240 //     that corresponds to the size of `subView` and accounting for boundary
241 //     effects.
242 // The point of the full tile buffer is that constant static tile sizes are
243 // folded and result in a buffer type with statically known size and alignment
244 // properties.
245 // To account for general boundary effects, padding must be performed on the
246 // boundary tiles. For now this is done with an unconditional `fill` op followed
247 // by a partial `copy` op.
248 static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
249                                            SubViewOp subView,
250                                            OperationFolder &folder) {
251   auto zero = constant_index(folder, 0);
252   auto one = constant_index(folder, 1);
253 
254   auto viewType = subView.getViewType();
255   auto rank = viewType.getRank();
256   Value *allocSize = one;
257   SmallVector<Value *, 8> fullRanges, partialRanges;
258   fullRanges.reserve(rank);
259   partialRanges.reserve(rank);
260   for (auto en : llvm::enumerate(subView.getRanges())) {
261     auto rank = en.index();
262     auto rangeValue = en.value();
263     Value *d =
264         isa<linalg::DimOp>(rangeValue.max->getDefiningOp())
265             ? rangeValue.max
266             : applyMapToValues(b, loc, getAffineDifferenceMap(b.getContext()),
267                                {rangeValue.max, rangeValue.min}, folder)
268                   .front();
269     allocSize = muli(folder, allocSize, d).getValue();
270     fullRanges.push_back(range(folder, zero, d, one));
271     partialRanges.push_back(
272         range(folder, zero, linalg::intrinsics::dim(subView, rank), one));
273   }
274   auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
275   auto fullLocalView = view(buffer, fullRanges);
276   auto partialLocalView = slice(fullLocalView, partialRanges);
277   return PromotionInfo{buffer, fullLocalView, partialLocalView};
278 }
279 
280 // Performs promotion of a view `v` into a local buffer of the size of the
281 // view. This produces a buffer whose size is exactky the size of `v`.
282 // Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
283 // `partialLocalView` such that:
284 //   * `buffer` is always the size of the view.
285 //   * `partialLocalView` is a dense contiguous view into that buffer.
286 //   * `fullLocalView` is equal to `partialLocalView`.
287 // The point of the full tile buffer is that constant static tile sizes are
288 // folded and result in a buffer type with statically known size and alignment
289 // properties.
290 static PromotionInfo promotePartialTileBuffer(OpBuilder &b, Location loc,
291                                               Value *v,
292                                               OperationFolder &folder) {
293   auto zero = constant_index(folder, 0);
294   auto one = constant_index(folder, 1);
295 
296   auto viewType = v->getType().cast<ViewType>();
297   auto rank = viewType.getRank();
298   Value *allocSize = one;
299   SmallVector<Value *, 8> partialRanges;
300   partialRanges.reserve(rank);
301   for (unsigned r = 0; r < rank; ++r) {
302     Value *d = linalg::intrinsics::dim(v, r);
303     allocSize = muli(folder, allocSize, d).getValue();
304     partialRanges.push_back(range(folder, zero, d, one));
305   }
306   auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
307   auto partialLocalView = view(folder, buffer, partialRanges);
308   return PromotionInfo{buffer, partialLocalView, partialLocalView};
309 }
310 
311 SmallVector<PromotionInfo, 8>
312 mlir::linalg::promoteLinalgViews(OpBuilder &b, Location loc,
313                                  ArrayRef<Value *> views,
314                                  OperationFolder &folder) {
315   if (views.empty())
316     return {};
317 
318   ScopedContext scope(b, loc);
319   SmallVector<PromotionInfo, 8> res;
320   res.reserve(views.size());
321   DenseMap<Value *, PromotionInfo> promotionInfo;
322   for (auto *v : views) {
323     PromotionInfo pi;
324     if (auto subView = dyn_cast<SubViewOp>(v->getDefiningOp()))
325       pi = promoteFullTileBuffer(b, loc, subView, folder);
326     else
327       pi = promotePartialTileBuffer(b, loc, v, folder);
328     promotionInfo.insert(std::make_pair(v, pi));
329     res.push_back(pi);
330   }
331 
332   for (auto *v : views) {
333     auto info = promotionInfo.find(v);
334     if (info == promotionInfo.end())
335       continue;
336     auto viewType = v->getType().cast<ViewType>();
337     // TODO(ntv): value to fill with should be related to the operation.
338     // For now, just use APFloat(0.0f).
339     auto t = viewType.getElementType().cast<FloatType>();
340     Value *fillVal = constant_float(folder, APFloat(0.0f), t);
341     // TODO(ntv): fill is only necessary if `promotionInfo` has a full local
342     // view that is different from the partial local view and we are on the
343     // boundary.
344     fill(info->second.fullLocalView, fillVal);
345   }
346 
347   for (auto *v : views) {
348     auto info = promotionInfo.find(v);
349     if (info == promotionInfo.end())
350       continue;
351     copy(v, info->second.partialLocalView);
352   }
353   return res;
354 }
355 
356 llvm::Optional<TiledLinalgOp>
357 mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
358                            OperationFolder &folder,
359                            ArrayRef<bool> viewsToPromote) {
360   // 1. Enforce the convention that "tiling by zero" skips tiling a particular
361   // dimension. This convention is significantly simpler to handle instead of
362   // adjusting affine maps to account for missing dimensions.
363   assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
364                  op.getNumWindowLoops() ==
365              tileSizes.size() &&
366          "expected matching number of tile sizes and loops");
367 
368   OpBuilder builder(op.getOperation());
369   ScopedContext scope(builder, op.getLoc());
370   // 2. Build the tiled loop ranges.
371   auto viewSizes = getViewSizes(op);
372   // The flattened loopToOperandRangesMaps is expected to be an invertible
373   // permutation map (asserted in the inverse calculation).
374   auto viewSizesToLoopsMap =
375       inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
376   assert(viewSizesToLoopsMap && "expected invertible map");
377   auto loopRanges =
378       makeTiledLoopRanges(scope.getBuilder(), scope.getLocation(),
379                           viewSizesToLoopsMap, viewSizes, tileSizes, folder);
380 
381   // 3. Create the tiled loops.
382   LinalgOp res = op;
383   SmallVector<IndexHandle, 4> ivs(loopRanges.size());
384   auto pivs = makeIndexHandlePointers(ivs);
385   LoopNestRangeBuilder(pivs, loopRanges)([&] {
386     auto b = ScopedContext::getBuilder();
387     auto loc = ScopedContext::getLocation();
388     SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
389     auto views =
390         makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
391 
392     // If no promotion, we are done.
393     auto promote = !viewsToPromote.empty() &&
394                    llvm::any_of(llvm::make_range(viewsToPromote.begin(),
395                                                  viewsToPromote.end()),
396                                 [](bool b) { return b; });
397     if (!promote) {
398       auto operands = getAssumedNonViewOperands(op);
399       views.append(operands.begin(), operands.end());
400       res = op.create(b, loc, views, op.getAttrs());
401       return;
402     }
403 
404     // 4. Filter the subset of views that need to be promoted.
405     SmallVector<Value *, 8> filteredViews;
406     filteredViews.reserve(views.size());
407     assert((viewsToPromote.empty() || views.size() == viewsToPromote.size()) &&
408            "expected viewsToPromote to be empty or of the same size as view");
409     for (auto it : llvm::zip(views, viewsToPromote)) {
410       if (!std::get<1>(it))
411         continue;
412       filteredViews.push_back(std::get<0>(it));
413     }
414 
415     // 5. Promote the specified views and use them in the new op.
416     auto promotedBufferAndViews =
417         promoteLinalgViews(b, loc, filteredViews, folder);
418     SmallVector<Value *, 8> opViews(views.size(), nullptr);
419     SmallVector<Value *, 8> writebackViews(views.size(), nullptr);
420     for (unsigned i = 0, promotedIdx = 0, e = opViews.size(); i < e; ++i) {
421       if (viewsToPromote[i]) {
422         opViews[i] = promotedBufferAndViews[promotedIdx].fullLocalView;
423         writebackViews[i] =
424             promotedBufferAndViews[promotedIdx].partialLocalView;
425         promotedIdx++;
426       } else {
427         opViews[i] = views[i];
428       }
429     }
430     auto operands = getAssumedNonViewOperands(op);
431     opViews.append(operands.begin(), operands.end());
432     res = op.create(b, loc, opViews, op.getAttrs());
433 
434     // 6. Emit write-back for the promoted output views: copy the partial view.
435     for (unsigned i = 0, e = writebackViews.size(); i < e; ++i) {
436       bool isOutput = res.getIndexOfOutput(opViews[i]).hasValue();
437       if (writebackViews[i] && isOutput)
438         copy(writebackViews[i], views[i]);
439     }
440 
441     // 7. Dealloc local buffers.
442     for (const auto &pi : promotedBufferAndViews)
443       buffer_dealloc(pi.buffer);
444   });
445 
446   // 8. Gather the newly created loops and return them with the new op.
447   SmallVector<ForOp, 8> loops;
448   loops.reserve(ivs.size());
449   for (auto iv : ivs)
450     loops.push_back(loop::getForInductionVarOwner(iv));
451 
452   return TiledLinalgOp{res, loops};
453 }
454 
455 llvm::Optional<TiledLinalgOp>
456 mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
457                            OperationFolder &folder,
458                            ArrayRef<bool> viewsToPromote) {
459   if (tileSizes.empty())
460     return llvm::None;
461 
462   // The following uses the convention that "tiling by zero" skips tiling a
463   // particular dimension. This convention is significantly simpler to handle
464   // instead of adjusting affine maps to account for missing dimensions.
465   auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
466                 op.getNumWindowLoops();
467   tileSizes = tileSizes.take_front(nLoops);
468   // If only 0 tilings are left, then return.
469   if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
470     return llvm::None;
471 
472   // Create a builder for tile size constants.
473   OpBuilder builder(op);
474   ScopedContext scope(builder, op.getLoc());
475 
476   // Materialize concrete tile size values to pass the generic tiling function.
477   SmallVector<Value *, 8> tileSizeValues;
478   tileSizeValues.reserve(tileSizes.size());
479   for (auto ts : tileSizes)
480     tileSizeValues.push_back(constant_index(folder, ts));
481   // Pad tile sizes with zero values to enforce our convention.
482   if (tileSizeValues.size() < nLoops) {
483     for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
484       tileSizeValues.push_back(constant_index(folder, 0));
485   }
486 
487   return tileLinalgOp(op, tileSizeValues, folder, viewsToPromote);
488 }
489 
490 static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes,
491                           bool promoteViews) {
492   OperationFolder folder;
493   f.walk<LinalgOp>([promoteViews, tileSizes, &folder](LinalgOp op) {
494     // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
495     // nothing.
496     SmallVector<bool, 8> viewsToPromote(op.getNumInputsAndOutputs(),
497                                         promoteViews);
498     auto opLoopsPair = tileLinalgOp(op, tileSizes, folder, viewsToPromote);
499     // If tiling occurred successfully, erase old op.
500     if (opLoopsPair)
501       op.erase();
502   });
503   f.walk<LinalgOp>([](LinalgOp op) {
504     if (!op.getOperation()->hasNoSideEffect())
505       return;
506     if (op.getOperation()->use_empty())
507       op.erase();
508   });
509 }
510 
511 namespace {
512 struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
513   LinalgTilingPass() = default;
514   LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews);
515 
516   void runOnFunction() {
517     tileLinalgOps(getFunction(), tileSizes, promoteViews);
518   }
519 
520   SmallVector<int64_t, 8> tileSizes;
521   bool promoteViews;
522 };
523 } // namespace
524 
525 LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews) {
526   this->tileSizes.assign(sizes.begin(), sizes.end());
527   this->promoteViews = promoteViews;
528 }
529 
530 std::unique_ptr<FunctionPassBase>
531 mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
532                                      bool promoteViews) {
533   return std::make_unique<LinalgTilingPass>(tileSizes, promoteViews);
534 }
535 
536 static PassRegistration<LinalgTilingPass>
537     pass("linalg-tile", "Tile operations in the linalg dialect", [] {
538       auto pass = std::make_unique<LinalgTilingPass>();
539       pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
540       pass->promoteViews = clPromoteFullTileViews;
541       return pass;
542     });
543