1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop tiling on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 
21 using namespace mlir;
22 using namespace mlir::scf;
23 
24 /// Tile a parallel loop of the form
25 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
26 ///                                            step (%arg4, %arg5)
27 ///
28 /// into
29 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
30 ///                                            step (%arg4*tileSize[0],
31 ///                                                  %arg5*tileSize[1])
32 ///     scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
33 ///                                          min(%arg5*tileSize[1], %arg3-%i1))
34 ///                                      step (%arg4, %arg5)
35 ///
36 /// or, when no-min-max-bounds is true, into
37 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
38 ///                                            step (%arg4*tileSize[0],
39 ///                                                  %arg5*tileSize[1])
40 ///     scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
41 ///                                          %arg5*tileSize[1])
42 ///                                      step (%arg4, %arg5)
43 ///        %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
44 ///                   (%j1 * %arg5 + %i1 < %arg3)
45 ///        scf.if (%inbound)
46 ///          ....
47 ///
48 /// where the uses of %i0 and %i1 in the loop body are replaced by
49 /// %i0 + j0 and %i1 + %j1.
50 //
51 /// The old loop is replaced with the new one.
52 std::pair<ParallelOp, ParallelOp>
53 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
54                             bool noMinMaxBounds) {
55   OpBuilder b(op);
56   auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
57   SmallVector<Value, 2> tileSizeConstants;
58   tileSizeConstants.reserve(op.upperBound().size());
59   for (size_t i = 0, end = op.upperBound().size(); i != end; ++i) {
60     if (i < tileSizes.size())
61       tileSizeConstants.push_back(
62           b.create<ConstantIndexOp>(op.getLoc(), tileSizes[i]));
63     else
64       // Just pick 1 for the remaining dimensions.
65       tileSizeConstants.push_back(b.create<ConstantIndexOp>(op.getLoc(), 1));
66   }
67 
68   // Create the outer loop with adjusted steps.
69   SmallVector<Value, 2> newSteps;
70   newSteps.reserve(op.step().size());
71   for (auto step : llvm::zip(op.step(), tileSizeConstants)) {
72     newSteps.push_back(
73         b.create<MulIOp>(op.getLoc(), std::get<0>(step), std::get<1>(step)));
74   }
75   auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.lowerBound(),
76                                         op.upperBound(), newSteps);
77   b.setInsertionPointToStart(outerLoop.getBody());
78 
79   // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
80   auto minMap = AffineMap::get(
81       /*dimCount=*/3, /*symbolCount=*/0,
82       {getAffineDimExpr(/*position=*/0, b.getContext()),
83        getAffineDimExpr(/*position=*/1, b.getContext()) -
84            getAffineDimExpr(/*position=*/2, b.getContext())},
85       b.getContext());
86 
87   // Create the inner loop with adjusted bounds.
88   SmallVector<Value, 2> newBounds;
89   newBounds.reserve(op.upperBound().size());
90   bool needInboundCheck = false;
91   for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(),
92                             outerLoop.step(), outerLoop.getInductionVars(),
93                             op.step(), tileSizeConstants)) {
94     Value lowerBound, upperBound, newStep, iv, step, tileSizeConstant;
95     std::tie(lowerBound, upperBound, newStep, iv, step, tileSizeConstant) = dim;
96     // Collect the statically known loop bounds
97     auto lowerBoundConstant =
98         dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp());
99     auto upperBoundConstant =
100         dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp());
101     auto stepConstant = dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp());
102     auto tileSize =
103         cast<ConstantIndexOp>(tileSizeConstant.getDefiningOp()).getValue();
104     // If the loop bounds and the loop step are constant and if the number of
105     // loop iterations is an integer multiple of the tile size, we use a static
106     // bound for the inner loop.
107     if (lowerBoundConstant && upperBoundConstant && stepConstant) {
108       auto numIterations = llvm::divideCeil(upperBoundConstant.getValue() -
109                                                 lowerBoundConstant.getValue(),
110                                             stepConstant.getValue());
111       if (numIterations % tileSize == 0) {
112         newBounds.push_back(newStep);
113         continue;
114       }
115     }
116 
117     // For InboundCheck mode, just use the variable outer step
118     if (noMinMaxBounds) {
119       newBounds.push_back(newStep);
120       needInboundCheck = true;
121       continue;
122     }
123 
124     // Otherwise, we dynamically compute the bound for
125     // each iteration of the outer loop.
126     newBounds.push_back(
127         b.create<AffineMinOp>(op.getLoc(), b.getIndexType(), minMap,
128                               ValueRange{newStep, upperBound, iv}));
129   }
130   auto innerLoop = b.create<ParallelOp>(
131       op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
132       op.step());
133 
134   if (noMinMaxBounds && needInboundCheck) {
135     b.setInsertionPointToStart(innerLoop.getBody());
136     // Insert in-bound check
137     Value inbound =
138         b.create<ConstantOp>(op.getLoc(), b.getIntegerType(1),
139                              b.getIntegerAttr(b.getIntegerType(1), 1));
140     for (auto dim :
141          llvm::zip(outerLoop.upperBound(), outerLoop.getInductionVars(),
142                    innerLoop.getInductionVars(), innerLoop.step())) {
143       Value outerUpperBound, outerIV, innerIV, innerStep;
144       std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim;
145       // %in_bound = %in_bound &&
146       //             (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
147       Value index = b.create<AddIOp>(
148           op.getLoc(), b.create<MulIOp>(op.getLoc(), innerIV, innerStep),
149           outerIV);
150       Value dimInbound = b.create<CmpIOp>(op.getLoc(), CmpIPredicate::ult,
151                                           index, outerUpperBound);
152       inbound = b.create<AndOp>(op.getLoc(), inbound, dimInbound);
153     }
154     auto ifInbound = b.create<IfOp>(op.getLoc(),
155                                     /*resultTypes*/ ArrayRef<Type>{}, inbound,
156                                     /*hasElseRegion*/ false);
157     ifInbound.thenRegion().takeBody(op.region());
158     Block &thenBlock = ifInbound.thenRegion().front();
159     b.setInsertionPointToStart(innerLoop.getBody());
160     for (auto ivs : llvm::enumerate(llvm::zip(innerLoop.getInductionVars(),
161                                               outerLoop.getInductionVars()))) {
162       AddIOp newIndex = b.create<AddIOp>(op.getLoc(), std::get<0>(ivs.value()),
163                                          std::get<1>(ivs.value()));
164       thenBlock.getArgument(ivs.index())
165           .replaceAllUsesExcept(newIndex, newIndex);
166     }
167     thenBlock.eraseArguments(llvm::to_vector<4>(
168         llvm::seq((unsigned)0, thenBlock.getNumArguments())));
169   } else {
170     innerLoop.region().takeBody(op.region());
171     b.setInsertionPointToStart(innerLoop.getBody());
172     for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
173                               outerLoop.getInductionVars())) {
174       Value innerIndex = std::get<0>(ivs);
175       AddIOp newIndex =
176           b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
177       innerIndex.replaceAllUsesExcept(newIndex, newIndex);
178     }
179   }
180 
181   op.erase();
182   return std::make_pair(outerLoop, innerLoop);
183 }
184 
185 namespace {
186 struct ParallelLoopTiling
187     : public SCFParallelLoopTilingBase<ParallelLoopTiling> {
188   ParallelLoopTiling() = default;
189   explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
190                               bool noMinMaxBounds = false) {
191     this->tileSizes = tileSizes;
192     this->noMinMaxBounds = noMinMaxBounds;
193   }
194 
195   void runOnFunction() override {
196     SmallVector<ParallelOp, 2> innermostPloops;
197     getInnermostParallelLoops(getFunction().getOperation(), innermostPloops);
198     for (ParallelOp ploop : innermostPloops) {
199       // FIXME: Add reduction support.
200       if (ploop.getNumReductions() == 0)
201         tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
202     }
203   }
204 };
205 } // namespace
206 
207 std::unique_ptr<Pass>
208 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
209                                    bool noMinMaxBounds) {
210   return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
211 }
212