1a70aa7bbSRiver Riddle //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle 
9a70aa7bbSRiver Riddle #include "PassDetail.h"
10a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
11a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
12a70aa7bbSRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
14f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
15a70aa7bbSRiver Riddle #include "mlir/Transforms/Passes.h"
16a70aa7bbSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
17a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
18a70aa7bbSRiver Riddle 
19a70aa7bbSRiver Riddle #define PASS_NAME "loop-coalescing"
20a70aa7bbSRiver Riddle #define DEBUG_TYPE PASS_NAME
21a70aa7bbSRiver Riddle 
22a70aa7bbSRiver Riddle using namespace mlir;
23a70aa7bbSRiver Riddle 
24a70aa7bbSRiver Riddle namespace {
25a70aa7bbSRiver Riddle struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> {
26a70aa7bbSRiver Riddle 
27a70aa7bbSRiver Riddle   /// Walk either an scf.for or an affine.for to find a band to coalesce.
28a70aa7bbSRiver Riddle   template <typename LoopOpTy>
walkLoop__anon72a84d460111::LoopCoalescingPass29a70aa7bbSRiver Riddle   static void walkLoop(LoopOpTy op) {
30a70aa7bbSRiver Riddle     // Ignore nested loops.
31a70aa7bbSRiver Riddle     if (op->template getParentOfType<LoopOpTy>())
32a70aa7bbSRiver Riddle       return;
33a70aa7bbSRiver Riddle 
34a70aa7bbSRiver Riddle     SmallVector<LoopOpTy, 4> loops;
35a70aa7bbSRiver Riddle     getPerfectlyNestedLoops(loops, op);
36a70aa7bbSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
37a70aa7bbSRiver Riddle                << "found a perfect nest of depth " << loops.size() << '\n');
38a70aa7bbSRiver Riddle 
39a70aa7bbSRiver Riddle     // Look for a band of loops that can be coalesced, i.e. perfectly nested
40a70aa7bbSRiver Riddle     // loops with bounds defined above some loop.
41a70aa7bbSRiver Riddle     // 1. For each loop, find above which parent loop its operands are
42a70aa7bbSRiver Riddle     // defined.
43a70aa7bbSRiver Riddle     SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
44a70aa7bbSRiver Riddle     for (unsigned i = 0, e = loops.size(); i < e; ++i) {
45a70aa7bbSRiver Riddle       operandsDefinedAbove[i] = i;
46a70aa7bbSRiver Riddle       for (unsigned j = 0; j < i; ++j) {
47a70aa7bbSRiver Riddle         if (areValuesDefinedAbove(loops[i].getOperands(),
48a70aa7bbSRiver Riddle                                   loops[j].getRegion())) {
49a70aa7bbSRiver Riddle           operandsDefinedAbove[i] = j;
50a70aa7bbSRiver Riddle           break;
51a70aa7bbSRiver Riddle         }
52a70aa7bbSRiver Riddle       }
53a70aa7bbSRiver Riddle       LLVM_DEBUG(llvm::dbgs()
54a70aa7bbSRiver Riddle                  << "  bounds of loop " << i << " are known above depth "
55a70aa7bbSRiver Riddle                  << operandsDefinedAbove[i] << '\n');
56a70aa7bbSRiver Riddle     }
57a70aa7bbSRiver Riddle 
58a70aa7bbSRiver Riddle     // 2. Identify bands of loops such that the operands of all of them are
59a70aa7bbSRiver Riddle     // defined above the first loop in the band.  Traverse the nest bottom-up
60a70aa7bbSRiver Riddle     // so that modifications don't invalidate the inner loops.
61a70aa7bbSRiver Riddle     for (unsigned end = loops.size(); end > 0; --end) {
62a70aa7bbSRiver Riddle       unsigned start = 0;
63a70aa7bbSRiver Riddle       for (; start < end - 1; ++start) {
64a70aa7bbSRiver Riddle         auto maxPos =
65a70aa7bbSRiver Riddle             *std::max_element(std::next(operandsDefinedAbove.begin(), start),
66a70aa7bbSRiver Riddle                               std::next(operandsDefinedAbove.begin(), end));
67a70aa7bbSRiver Riddle         if (maxPos > start)
68a70aa7bbSRiver Riddle           continue;
69a70aa7bbSRiver Riddle 
70a70aa7bbSRiver Riddle         assert(maxPos == start &&
71a70aa7bbSRiver Riddle                "expected loop bounds to be known at the start of the band");
72a70aa7bbSRiver Riddle         LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
73a70aa7bbSRiver Riddle                                 << " to " << end << '\n');
74a70aa7bbSRiver Riddle 
75a70aa7bbSRiver Riddle         auto band =
76a70aa7bbSRiver Riddle             llvm::makeMutableArrayRef(loops.data() + start, end - start);
77a70aa7bbSRiver Riddle         (void)coalesceLoops(band);
78a70aa7bbSRiver Riddle         break;
79a70aa7bbSRiver Riddle       }
80a70aa7bbSRiver Riddle       // If a band was found and transformed, keep looking at the loops above
81a70aa7bbSRiver Riddle       // the outermost transformed loop.
82a70aa7bbSRiver Riddle       if (start != end - 1)
83a70aa7bbSRiver Riddle         end = start + 1;
84a70aa7bbSRiver Riddle     }
85a70aa7bbSRiver Riddle   }
86a70aa7bbSRiver Riddle 
runOnOperation__anon72a84d460111::LoopCoalescingPass87a70aa7bbSRiver Riddle   void runOnOperation() override {
8858ceae95SRiver Riddle     func::FuncOp func = getOperation();
89a70aa7bbSRiver Riddle     func.walk([&](Operation *op) {
90a70aa7bbSRiver Riddle       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
91a70aa7bbSRiver Riddle         walkLoop(scfForOp);
92a70aa7bbSRiver Riddle       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
93a70aa7bbSRiver Riddle         walkLoop(affineForOp);
94a70aa7bbSRiver Riddle     });
95a70aa7bbSRiver Riddle   }
96a70aa7bbSRiver Riddle };
97a70aa7bbSRiver Riddle 
98a70aa7bbSRiver Riddle } // namespace
99a70aa7bbSRiver Riddle 
createLoopCoalescingPass()10058ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopCoalescingPass() {
101a70aa7bbSRiver Riddle   return std::make_unique<LoopCoalescingPass>();
102a70aa7bbSRiver Riddle }
103