1 //===- ParallelLoopTiling.cpp - Tiles loop.parallel ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop tiling on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Transforms/RegionUtils.h"
20 #include "llvm/Support/CommandLine.h"
21 
22 using namespace mlir;
23 using namespace mlir::scf;
24 
25 /// Tile a parallel loop of the form
26 ///   loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
27 ///                                             step (%arg4, %arg5)
28 ///
29 /// into
30 ///   loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
31 ///                                             step (%arg4*tileSize[0],
32 ///                                                   %arg5*tileSize[1])
33 ///     loop.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0)
34 ///                                           min(tileSize[1], %arg3-%j1))
35 ///                                        step (%arg4, %arg5)
36 /// The old loop is replaced with the new one.
37 void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
38   OpBuilder b(op);
39   auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
40   SmallVector<Value, 2> tileSizeConstants;
41   tileSizeConstants.reserve(op.upperBound().size());
42   for (size_t i = 0, end = op.upperBound().size(); i != end; ++i) {
43     if (i < tileSizes.size())
44       tileSizeConstants.push_back(
45           b.create<ConstantIndexOp>(op.getLoc(), tileSizes[i]));
46     else
47       // Just pick 1 for the remaining dimensions.
48       tileSizeConstants.push_back(b.create<ConstantIndexOp>(op.getLoc(), 1));
49   }
50 
51   // Create the outer loop with adjusted steps.
52   SmallVector<Value, 2> newSteps;
53   newSteps.reserve(op.step().size());
54   for (auto step : llvm::zip(op.step(), tileSizeConstants)) {
55     newSteps.push_back(
56         b.create<MulIOp>(op.getLoc(), std::get<0>(step), std::get<1>(step)));
57   }
58   auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.lowerBound(),
59                                         op.upperBound(), newSteps);
60   b.setInsertionPointToStart(outerLoop.getBody());
61 
62   // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
63   // FIXME: Instead of using min, we want to replicate the tail. This would give
64   // the inner loop constant bounds for easy vectorization.
65   auto minMap = AffineMap::get(
66       /*dimCount=*/3, /*symbolCount=*/0,
67       {getAffineDimExpr(/*position=*/0, b.getContext()),
68        getAffineDimExpr(/*position=*/1, b.getContext()) -
69            getAffineDimExpr(/*position=*/2, b.getContext())},
70       b.getContext());
71 
72   // Create the inner loop with adjusted bounds.
73   SmallVector<Value, 2> newBounds;
74   newBounds.reserve(op.upperBound().size());
75   for (auto bounds : llvm::zip(tileSizeConstants, outerLoop.upperBound(),
76                                outerLoop.getInductionVars())) {
77     newBounds.push_back(b.create<AffineMinOp>(
78         op.getLoc(), b.getIndexType(), minMap,
79         ValueRange{std::get<0>(bounds), std::get<1>(bounds),
80                    std::get<2>(bounds)}));
81   }
82   auto innerLoop = b.create<ParallelOp>(
83       op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
84       op.step());
85 
86   // Steal the body of the old parallel loop and erase it.
87   innerLoop.region().takeBody(op.region());
88   op.erase();
89 }
90 
91 /// Get a list of most nested parallel loops. Assumes that ParallelOps are only
92 /// directly nested.
93 static bool getInnermostNestedLoops(Block *block,
94                                     SmallVectorImpl<ParallelOp> &loops) {
95   bool hasInnerLoop = false;
96   for (auto parallelOp : block->getOps<ParallelOp>()) {
97     hasInnerLoop = true;
98     if (!getInnermostNestedLoops(parallelOp.getBody(), loops))
99       loops.push_back(parallelOp);
100   }
101   return hasInnerLoop;
102 }
103 
104 namespace {
105 struct ParallelLoopTiling
106     : public LoopParallelLoopTilingBase<ParallelLoopTiling> {
107   ParallelLoopTiling() = default;
108   explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes) {
109     this->tileSizes = tileSizes;
110   }
111 
112   void runOnFunction() override {
113     SmallVector<ParallelOp, 2> mostNestedParallelOps;
114     for (Block &block : getFunction()) {
115       getInnermostNestedLoops(&block, mostNestedParallelOps);
116     }
117     for (ParallelOp pLoop : mostNestedParallelOps) {
118       tileParallelLoop(pLoop, tileSizes);
119     }
120   }
121 };
122 } // namespace
123 
124 std::unique_ptr<Pass>
125 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes) {
126   return std::make_unique<ParallelLoopTiling>(tileSizes);
127 }
128