1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko 
13c25b20c0SAlex Zinenko #include "PassDetail.h"
14c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Passes.h"
15c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h"
16c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Transforms.h"
17c25b20c0SAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h"
18c25b20c0SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
19c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h"
20c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h"
21c25b20c0SAlex Zinenko 
22c25b20c0SAlex Zinenko using namespace mlir;
23c25b20c0SAlex Zinenko using namespace mlir::scf;
24c25b20c0SAlex Zinenko 
25c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps.
26c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) {
27c25b20c0SAlex Zinenko   auto walkResult =
28c25b20c0SAlex Zinenko       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29c25b20c0SAlex Zinenko   return walkResult.wasInterrupted();
30c25b20c0SAlex Zinenko }
31c25b20c0SAlex Zinenko 
32c25b20c0SAlex Zinenko /// Verify equal iteration spaces.
33c25b20c0SAlex Zinenko static bool equalIterationSpaces(ParallelOp firstPloop,
34c25b20c0SAlex Zinenko                                  ParallelOp secondPloop) {
35c25b20c0SAlex Zinenko   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36c25b20c0SAlex Zinenko     return false;
37c25b20c0SAlex Zinenko 
38c25b20c0SAlex Zinenko   auto matchOperands = [&](const OperandRange &lhs,
39c25b20c0SAlex Zinenko                            const OperandRange &rhs) -> bool {
40c25b20c0SAlex Zinenko     // TODO: Extend this to support aliases and equal constants.
41c25b20c0SAlex Zinenko     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42c25b20c0SAlex Zinenko   };
43c25b20c0SAlex Zinenko   return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
44c25b20c0SAlex Zinenko          matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
45c25b20c0SAlex Zinenko          matchOperands(firstPloop.step(), secondPloop.step());
46c25b20c0SAlex Zinenko }
47c25b20c0SAlex Zinenko 
48c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns
49c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second
50c25b20c0SAlex Zinenko /// loop reads.
51c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex(
52c25b20c0SAlex Zinenko     ParallelOp firstPloop, ParallelOp secondPloop,
53c25b20c0SAlex Zinenko     const BlockAndValueMapping &firstToSecondPloopIndices) {
54c25b20c0SAlex Zinenko   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
55c25b20c0SAlex Zinenko   firstPloop.getBody()->walk([&](StoreOp store) {
56c25b20c0SAlex Zinenko     bufferStores[store.getMemRef()].push_back(store.indices());
57c25b20c0SAlex Zinenko   });
58c25b20c0SAlex Zinenko   auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) {
59c25b20c0SAlex Zinenko     // Stop if the memref is defined in secondPloop body. Careful alias analysis
60c25b20c0SAlex Zinenko     // is needed.
61c25b20c0SAlex Zinenko     auto *memrefDef = load.getMemRef().getDefiningOp();
62c25b20c0SAlex Zinenko     if (memrefDef && memrefDef->getBlock() == load.getOperation()->getBlock())
63c25b20c0SAlex Zinenko       return WalkResult::interrupt();
64c25b20c0SAlex Zinenko 
65c25b20c0SAlex Zinenko     auto write = bufferStores.find(load.getMemRef());
66c25b20c0SAlex Zinenko     if (write == bufferStores.end())
67c25b20c0SAlex Zinenko       return WalkResult::advance();
68c25b20c0SAlex Zinenko 
69c25b20c0SAlex Zinenko     // Allow only single write access per buffer.
70c25b20c0SAlex Zinenko     if (write->second.size() != 1)
71c25b20c0SAlex Zinenko       return WalkResult::interrupt();
72c25b20c0SAlex Zinenko 
73c25b20c0SAlex Zinenko     // Check that the load indices of secondPloop coincide with store indices of
74c25b20c0SAlex Zinenko     // firstPloop for the same memrefs.
75c25b20c0SAlex Zinenko     auto storeIndices = write->second.front();
76c25b20c0SAlex Zinenko     auto loadIndices = load.indices();
77c25b20c0SAlex Zinenko     if (storeIndices.size() != loadIndices.size())
78c25b20c0SAlex Zinenko       return WalkResult::interrupt();
79c25b20c0SAlex Zinenko     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
80c25b20c0SAlex Zinenko       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
81c25b20c0SAlex Zinenko           loadIndices[i])
82c25b20c0SAlex Zinenko         return WalkResult::interrupt();
83c25b20c0SAlex Zinenko     }
84c25b20c0SAlex Zinenko     return WalkResult::advance();
85c25b20c0SAlex Zinenko   });
86c25b20c0SAlex Zinenko   return !walkResult.wasInterrupted();
87c25b20c0SAlex Zinenko }
88c25b20c0SAlex Zinenko 
89c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and
90c25b20c0SAlex Zinenko /// write patterns.
91c25b20c0SAlex Zinenko static LogicalResult
92c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
93c25b20c0SAlex Zinenko                    const BlockAndValueMapping &firstToSecondPloopIndices) {
94c25b20c0SAlex Zinenko   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
95c25b20c0SAlex Zinenko                                             firstToSecondPloopIndices))
96c25b20c0SAlex Zinenko     return failure();
97c25b20c0SAlex Zinenko 
98c25b20c0SAlex Zinenko   BlockAndValueMapping secondToFirstPloopIndices;
99c25b20c0SAlex Zinenko   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
100c25b20c0SAlex Zinenko                                 firstPloop.getBody()->getArguments());
101c25b20c0SAlex Zinenko   return success(haveNoReadsAfterWriteExceptSameIndex(
102c25b20c0SAlex Zinenko       secondPloop, firstPloop, secondToFirstPloopIndices));
103c25b20c0SAlex Zinenko }
104c25b20c0SAlex Zinenko 
105c25b20c0SAlex Zinenko static bool
106c25b20c0SAlex Zinenko isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
107c25b20c0SAlex Zinenko               const BlockAndValueMapping &firstToSecondPloopIndices) {
108c25b20c0SAlex Zinenko   return !hasNestedParallelOp(firstPloop) &&
109c25b20c0SAlex Zinenko          !hasNestedParallelOp(secondPloop) &&
110c25b20c0SAlex Zinenko          equalIterationSpaces(firstPloop, secondPloop) &&
111c25b20c0SAlex Zinenko          succeeded(verifyDependencies(firstPloop, secondPloop,
112c25b20c0SAlex Zinenko                                       firstToSecondPloopIndices));
113c25b20c0SAlex Zinenko }
114c25b20c0SAlex Zinenko 
115c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body.
116c25b20c0SAlex Zinenko static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
117c25b20c0SAlex Zinenko                         OpBuilder b) {
118c25b20c0SAlex Zinenko   BlockAndValueMapping firstToSecondPloopIndices;
119c25b20c0SAlex Zinenko   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
120c25b20c0SAlex Zinenko                                 secondPloop.getBody()->getArguments());
121c25b20c0SAlex Zinenko 
122c25b20c0SAlex Zinenko   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
123c25b20c0SAlex Zinenko     return;
124c25b20c0SAlex Zinenko 
125c25b20c0SAlex Zinenko   b.setInsertionPointToStart(secondPloop.getBody());
126c25b20c0SAlex Zinenko   for (auto &op : firstPloop.getBody()->without_terminator())
127c25b20c0SAlex Zinenko     b.clone(op, firstToSecondPloopIndices);
128c25b20c0SAlex Zinenko   firstPloop.erase();
129c25b20c0SAlex Zinenko }
130c25b20c0SAlex Zinenko 
131c25b20c0SAlex Zinenko void mlir::scf::naivelyFuseParallelOps(Region &region) {
132c25b20c0SAlex Zinenko   OpBuilder b(region);
133c25b20c0SAlex Zinenko   // Consider every single block and attempt to fuse adjacent loops.
134c25b20c0SAlex Zinenko   for (auto &block : region) {
135c25b20c0SAlex Zinenko     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
136c25b20c0SAlex Zinenko     // Not using `walk()` to traverse only top-level parallel loops and also
137c25b20c0SAlex Zinenko     // make sure that there are no side-effecting ops between the parallel
138c25b20c0SAlex Zinenko     // loops.
139c25b20c0SAlex Zinenko     bool noSideEffects = true;
140c25b20c0SAlex Zinenko     for (auto &op : block) {
141c25b20c0SAlex Zinenko       if (auto ploop = dyn_cast<ParallelOp>(op)) {
142c25b20c0SAlex Zinenko         if (noSideEffects) {
143c25b20c0SAlex Zinenko           ploopChains.back().push_back(ploop);
144c25b20c0SAlex Zinenko         } else {
145c25b20c0SAlex Zinenko           ploopChains.push_back({ploop});
146c25b20c0SAlex Zinenko           noSideEffects = true;
147c25b20c0SAlex Zinenko         }
148c25b20c0SAlex Zinenko         continue;
149c25b20c0SAlex Zinenko       }
150c25b20c0SAlex Zinenko       // TODO: Handle region side effects properly.
151c25b20c0SAlex Zinenko       noSideEffects &=
152c25b20c0SAlex Zinenko           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
153c25b20c0SAlex Zinenko     }
154c25b20c0SAlex Zinenko     for (ArrayRef<ParallelOp> ploops : ploopChains) {
155c25b20c0SAlex Zinenko       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
156c25b20c0SAlex Zinenko         fuseIfLegal(ploops[i], ploops[i + 1], b);
157c25b20c0SAlex Zinenko     }
158c25b20c0SAlex Zinenko   }
159c25b20c0SAlex Zinenko }
160c25b20c0SAlex Zinenko 
161c25b20c0SAlex Zinenko namespace {
162c25b20c0SAlex Zinenko struct ParallelLoopFusion
163*4bcd08ebSStephan Herhut     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
164c25b20c0SAlex Zinenko   void runOnOperation() override {
165c25b20c0SAlex Zinenko     getOperation()->walk([&](Operation *child) {
166c25b20c0SAlex Zinenko       for (Region &region : child->getRegions())
167c25b20c0SAlex Zinenko         naivelyFuseParallelOps(region);
168c25b20c0SAlex Zinenko     });
169c25b20c0SAlex Zinenko   }
170c25b20c0SAlex Zinenko };
171c25b20c0SAlex Zinenko } // namespace
172c25b20c0SAlex Zinenko 
173c25b20c0SAlex Zinenko std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
174c25b20c0SAlex Zinenko   return std::make_unique<ParallelLoopFusion>();
175c25b20c0SAlex Zinenko }
176