1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko 
13c25b20c0SAlex Zinenko #include "PassDetail.h"
14e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
15c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Passes.h"
16c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h"
17c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Transforms.h"
18c25b20c0SAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h"
19c25b20c0SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
20c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h"
21c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h"
22c25b20c0SAlex Zinenko 
23c25b20c0SAlex Zinenko using namespace mlir;
24c25b20c0SAlex Zinenko using namespace mlir::scf;
25c25b20c0SAlex Zinenko 
26c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps.
27c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) {
28c25b20c0SAlex Zinenko   auto walkResult =
29c25b20c0SAlex Zinenko       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
30c25b20c0SAlex Zinenko   return walkResult.wasInterrupted();
31c25b20c0SAlex Zinenko }
32c25b20c0SAlex Zinenko 
33c25b20c0SAlex Zinenko /// Verify equal iteration spaces.
34c25b20c0SAlex Zinenko static bool equalIterationSpaces(ParallelOp firstPloop,
35c25b20c0SAlex Zinenko                                  ParallelOp secondPloop) {
36c25b20c0SAlex Zinenko   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
37c25b20c0SAlex Zinenko     return false;
38c25b20c0SAlex Zinenko 
39c25b20c0SAlex Zinenko   auto matchOperands = [&](const OperandRange &lhs,
40c25b20c0SAlex Zinenko                            const OperandRange &rhs) -> bool {
41c25b20c0SAlex Zinenko     // TODO: Extend this to support aliases and equal constants.
42c25b20c0SAlex Zinenko     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
43c25b20c0SAlex Zinenko   };
44*c0342a2dSJacques Pienaar   return matchOperands(firstPloop.getLowerBound(),
45*c0342a2dSJacques Pienaar                        secondPloop.getLowerBound()) &&
46*c0342a2dSJacques Pienaar          matchOperands(firstPloop.getUpperBound(),
47*c0342a2dSJacques Pienaar                        secondPloop.getUpperBound()) &&
48*c0342a2dSJacques Pienaar          matchOperands(firstPloop.getStep(), secondPloop.getStep());
49c25b20c0SAlex Zinenko }
50c25b20c0SAlex Zinenko 
51c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns
52c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second
53c25b20c0SAlex Zinenko /// loop reads.
54c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex(
55c25b20c0SAlex Zinenko     ParallelOp firstPloop, ParallelOp secondPloop,
56c25b20c0SAlex Zinenko     const BlockAndValueMapping &firstToSecondPloopIndices) {
57c25b20c0SAlex Zinenko   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
58e2310704SJulian Gross   firstPloop.getBody()->walk([&](memref::StoreOp store) {
59c25b20c0SAlex Zinenko     bufferStores[store.getMemRef()].push_back(store.indices());
60c25b20c0SAlex Zinenko   });
61e2310704SJulian Gross   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
62c25b20c0SAlex Zinenko     // Stop if the memref is defined in secondPloop body. Careful alias analysis
63c25b20c0SAlex Zinenko     // is needed.
64c25b20c0SAlex Zinenko     auto *memrefDef = load.getMemRef().getDefiningOp();
65c4a04059SChristian Sigg     if (memrefDef && memrefDef->getBlock() == load->getBlock())
66c25b20c0SAlex Zinenko       return WalkResult::interrupt();
67c25b20c0SAlex Zinenko 
68c25b20c0SAlex Zinenko     auto write = bufferStores.find(load.getMemRef());
69c25b20c0SAlex Zinenko     if (write == bufferStores.end())
70c25b20c0SAlex Zinenko       return WalkResult::advance();
71c25b20c0SAlex Zinenko 
72c25b20c0SAlex Zinenko     // Allow only single write access per buffer.
73c25b20c0SAlex Zinenko     if (write->second.size() != 1)
74c25b20c0SAlex Zinenko       return WalkResult::interrupt();
75c25b20c0SAlex Zinenko 
76c25b20c0SAlex Zinenko     // Check that the load indices of secondPloop coincide with store indices of
77c25b20c0SAlex Zinenko     // firstPloop for the same memrefs.
78c25b20c0SAlex Zinenko     auto storeIndices = write->second.front();
79c25b20c0SAlex Zinenko     auto loadIndices = load.indices();
80c25b20c0SAlex Zinenko     if (storeIndices.size() != loadIndices.size())
81c25b20c0SAlex Zinenko       return WalkResult::interrupt();
82c25b20c0SAlex Zinenko     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
83c25b20c0SAlex Zinenko       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
84c25b20c0SAlex Zinenko           loadIndices[i])
85c25b20c0SAlex Zinenko         return WalkResult::interrupt();
86c25b20c0SAlex Zinenko     }
87c25b20c0SAlex Zinenko     return WalkResult::advance();
88c25b20c0SAlex Zinenko   });
89c25b20c0SAlex Zinenko   return !walkResult.wasInterrupted();
90c25b20c0SAlex Zinenko }
91c25b20c0SAlex Zinenko 
92c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and
93c25b20c0SAlex Zinenko /// write patterns.
94c25b20c0SAlex Zinenko static LogicalResult
95c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
96c25b20c0SAlex Zinenko                    const BlockAndValueMapping &firstToSecondPloopIndices) {
97c25b20c0SAlex Zinenko   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
98c25b20c0SAlex Zinenko                                             firstToSecondPloopIndices))
99c25b20c0SAlex Zinenko     return failure();
100c25b20c0SAlex Zinenko 
101c25b20c0SAlex Zinenko   BlockAndValueMapping secondToFirstPloopIndices;
102c25b20c0SAlex Zinenko   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
103c25b20c0SAlex Zinenko                                 firstPloop.getBody()->getArguments());
104c25b20c0SAlex Zinenko   return success(haveNoReadsAfterWriteExceptSameIndex(
105c25b20c0SAlex Zinenko       secondPloop, firstPloop, secondToFirstPloopIndices));
106c25b20c0SAlex Zinenko }
107c25b20c0SAlex Zinenko 
108c25b20c0SAlex Zinenko static bool
109c25b20c0SAlex Zinenko isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
110c25b20c0SAlex Zinenko               const BlockAndValueMapping &firstToSecondPloopIndices) {
111c25b20c0SAlex Zinenko   return !hasNestedParallelOp(firstPloop) &&
112c25b20c0SAlex Zinenko          !hasNestedParallelOp(secondPloop) &&
113c25b20c0SAlex Zinenko          equalIterationSpaces(firstPloop, secondPloop) &&
114c25b20c0SAlex Zinenko          succeeded(verifyDependencies(firstPloop, secondPloop,
115c25b20c0SAlex Zinenko                                       firstToSecondPloopIndices));
116c25b20c0SAlex Zinenko }
117c25b20c0SAlex Zinenko 
118c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body.
119c25b20c0SAlex Zinenko static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
120c25b20c0SAlex Zinenko                         OpBuilder b) {
121c25b20c0SAlex Zinenko   BlockAndValueMapping firstToSecondPloopIndices;
122c25b20c0SAlex Zinenko   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
123c25b20c0SAlex Zinenko                                 secondPloop.getBody()->getArguments());
124c25b20c0SAlex Zinenko 
125c25b20c0SAlex Zinenko   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
126c25b20c0SAlex Zinenko     return;
127c25b20c0SAlex Zinenko 
128c25b20c0SAlex Zinenko   b.setInsertionPointToStart(secondPloop.getBody());
129c25b20c0SAlex Zinenko   for (auto &op : firstPloop.getBody()->without_terminator())
130c25b20c0SAlex Zinenko     b.clone(op, firstToSecondPloopIndices);
131c25b20c0SAlex Zinenko   firstPloop.erase();
132c25b20c0SAlex Zinenko }
133c25b20c0SAlex Zinenko 
134c25b20c0SAlex Zinenko void mlir::scf::naivelyFuseParallelOps(Region &region) {
135c25b20c0SAlex Zinenko   OpBuilder b(region);
136c25b20c0SAlex Zinenko   // Consider every single block and attempt to fuse adjacent loops.
137c25b20c0SAlex Zinenko   for (auto &block : region) {
138c25b20c0SAlex Zinenko     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
139c25b20c0SAlex Zinenko     // Not using `walk()` to traverse only top-level parallel loops and also
140c25b20c0SAlex Zinenko     // make sure that there are no side-effecting ops between the parallel
141c25b20c0SAlex Zinenko     // loops.
142c25b20c0SAlex Zinenko     bool noSideEffects = true;
143c25b20c0SAlex Zinenko     for (auto &op : block) {
144c25b20c0SAlex Zinenko       if (auto ploop = dyn_cast<ParallelOp>(op)) {
145c25b20c0SAlex Zinenko         if (noSideEffects) {
146c25b20c0SAlex Zinenko           ploopChains.back().push_back(ploop);
147c25b20c0SAlex Zinenko         } else {
148c25b20c0SAlex Zinenko           ploopChains.push_back({ploop});
149c25b20c0SAlex Zinenko           noSideEffects = true;
150c25b20c0SAlex Zinenko         }
151c25b20c0SAlex Zinenko         continue;
152c25b20c0SAlex Zinenko       }
153c25b20c0SAlex Zinenko       // TODO: Handle region side effects properly.
154c25b20c0SAlex Zinenko       noSideEffects &=
155c25b20c0SAlex Zinenko           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
156c25b20c0SAlex Zinenko     }
157c25b20c0SAlex Zinenko     for (ArrayRef<ParallelOp> ploops : ploopChains) {
158c25b20c0SAlex Zinenko       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
159c25b20c0SAlex Zinenko         fuseIfLegal(ploops[i], ploops[i + 1], b);
160c25b20c0SAlex Zinenko     }
161c25b20c0SAlex Zinenko   }
162c25b20c0SAlex Zinenko }
163c25b20c0SAlex Zinenko 
164c25b20c0SAlex Zinenko namespace {
165c25b20c0SAlex Zinenko struct ParallelLoopFusion
1664bcd08ebSStephan Herhut     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
167c25b20c0SAlex Zinenko   void runOnOperation() override {
168c25b20c0SAlex Zinenko     getOperation()->walk([&](Operation *child) {
169c25b20c0SAlex Zinenko       for (Region &region : child->getRegions())
170c25b20c0SAlex Zinenko         naivelyFuseParallelOps(region);
171c25b20c0SAlex Zinenko     });
172c25b20c0SAlex Zinenko   }
173c25b20c0SAlex Zinenko };
174c25b20c0SAlex Zinenko } // namespace
175c25b20c0SAlex Zinenko 
176c25b20c0SAlex Zinenko std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
177c25b20c0SAlex Zinenko   return std::make_unique<ParallelLoopFusion>();
178c25b20c0SAlex Zinenko }
179