1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko 
13c25b20c0SAlex Zinenko #include "PassDetail.h"
14*e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
15c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Passes.h"
16c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/SCF.h"
17c25b20c0SAlex Zinenko #include "mlir/Dialect/SCF/Transforms.h"
18c25b20c0SAlex Zinenko #include "mlir/Dialect/StandardOps/IR/Ops.h"
19c25b20c0SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
20c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h"
21c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h"
22c25b20c0SAlex Zinenko 
23c25b20c0SAlex Zinenko using namespace mlir;
24c25b20c0SAlex Zinenko using namespace mlir::scf;
25c25b20c0SAlex Zinenko 
26c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps.
27c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) {
28c25b20c0SAlex Zinenko   auto walkResult =
29c25b20c0SAlex Zinenko       ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
30c25b20c0SAlex Zinenko   return walkResult.wasInterrupted();
31c25b20c0SAlex Zinenko }
32c25b20c0SAlex Zinenko 
33c25b20c0SAlex Zinenko /// Verify equal iteration spaces.
34c25b20c0SAlex Zinenko static bool equalIterationSpaces(ParallelOp firstPloop,
35c25b20c0SAlex Zinenko                                  ParallelOp secondPloop) {
36c25b20c0SAlex Zinenko   if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
37c25b20c0SAlex Zinenko     return false;
38c25b20c0SAlex Zinenko 
39c25b20c0SAlex Zinenko   auto matchOperands = [&](const OperandRange &lhs,
40c25b20c0SAlex Zinenko                            const OperandRange &rhs) -> bool {
41c25b20c0SAlex Zinenko     // TODO: Extend this to support aliases and equal constants.
42c25b20c0SAlex Zinenko     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
43c25b20c0SAlex Zinenko   };
44c25b20c0SAlex Zinenko   return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) &&
45c25b20c0SAlex Zinenko          matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) &&
46c25b20c0SAlex Zinenko          matchOperands(firstPloop.step(), secondPloop.step());
47c25b20c0SAlex Zinenko }
48c25b20c0SAlex Zinenko 
49c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns
50c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second
51c25b20c0SAlex Zinenko /// loop reads.
52c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex(
53c25b20c0SAlex Zinenko     ParallelOp firstPloop, ParallelOp secondPloop,
54c25b20c0SAlex Zinenko     const BlockAndValueMapping &firstToSecondPloopIndices) {
55c25b20c0SAlex Zinenko   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
56*e2310704SJulian Gross   firstPloop.getBody()->walk([&](memref::StoreOp store) {
57c25b20c0SAlex Zinenko     bufferStores[store.getMemRef()].push_back(store.indices());
58c25b20c0SAlex Zinenko   });
59*e2310704SJulian Gross   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
60c25b20c0SAlex Zinenko     // Stop if the memref is defined in secondPloop body. Careful alias analysis
61c25b20c0SAlex Zinenko     // is needed.
62c25b20c0SAlex Zinenko     auto *memrefDef = load.getMemRef().getDefiningOp();
63c4a04059SChristian Sigg     if (memrefDef && memrefDef->getBlock() == load->getBlock())
64c25b20c0SAlex Zinenko       return WalkResult::interrupt();
65c25b20c0SAlex Zinenko 
66c25b20c0SAlex Zinenko     auto write = bufferStores.find(load.getMemRef());
67c25b20c0SAlex Zinenko     if (write == bufferStores.end())
68c25b20c0SAlex Zinenko       return WalkResult::advance();
69c25b20c0SAlex Zinenko 
70c25b20c0SAlex Zinenko     // Allow only single write access per buffer.
71c25b20c0SAlex Zinenko     if (write->second.size() != 1)
72c25b20c0SAlex Zinenko       return WalkResult::interrupt();
73c25b20c0SAlex Zinenko 
74c25b20c0SAlex Zinenko     // Check that the load indices of secondPloop coincide with store indices of
75c25b20c0SAlex Zinenko     // firstPloop for the same memrefs.
76c25b20c0SAlex Zinenko     auto storeIndices = write->second.front();
77c25b20c0SAlex Zinenko     auto loadIndices = load.indices();
78c25b20c0SAlex Zinenko     if (storeIndices.size() != loadIndices.size())
79c25b20c0SAlex Zinenko       return WalkResult::interrupt();
80c25b20c0SAlex Zinenko     for (int i = 0, e = storeIndices.size(); i < e; ++i) {
81c25b20c0SAlex Zinenko       if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
82c25b20c0SAlex Zinenko           loadIndices[i])
83c25b20c0SAlex Zinenko         return WalkResult::interrupt();
84c25b20c0SAlex Zinenko     }
85c25b20c0SAlex Zinenko     return WalkResult::advance();
86c25b20c0SAlex Zinenko   });
87c25b20c0SAlex Zinenko   return !walkResult.wasInterrupted();
88c25b20c0SAlex Zinenko }
89c25b20c0SAlex Zinenko 
90c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and
91c25b20c0SAlex Zinenko /// write patterns.
92c25b20c0SAlex Zinenko static LogicalResult
93c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
94c25b20c0SAlex Zinenko                    const BlockAndValueMapping &firstToSecondPloopIndices) {
95c25b20c0SAlex Zinenko   if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
96c25b20c0SAlex Zinenko                                             firstToSecondPloopIndices))
97c25b20c0SAlex Zinenko     return failure();
98c25b20c0SAlex Zinenko 
99c25b20c0SAlex Zinenko   BlockAndValueMapping secondToFirstPloopIndices;
100c25b20c0SAlex Zinenko   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
101c25b20c0SAlex Zinenko                                 firstPloop.getBody()->getArguments());
102c25b20c0SAlex Zinenko   return success(haveNoReadsAfterWriteExceptSameIndex(
103c25b20c0SAlex Zinenko       secondPloop, firstPloop, secondToFirstPloopIndices));
104c25b20c0SAlex Zinenko }
105c25b20c0SAlex Zinenko 
106c25b20c0SAlex Zinenko static bool
107c25b20c0SAlex Zinenko isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
108c25b20c0SAlex Zinenko               const BlockAndValueMapping &firstToSecondPloopIndices) {
109c25b20c0SAlex Zinenko   return !hasNestedParallelOp(firstPloop) &&
110c25b20c0SAlex Zinenko          !hasNestedParallelOp(secondPloop) &&
111c25b20c0SAlex Zinenko          equalIterationSpaces(firstPloop, secondPloop) &&
112c25b20c0SAlex Zinenko          succeeded(verifyDependencies(firstPloop, secondPloop,
113c25b20c0SAlex Zinenko                                       firstToSecondPloopIndices));
114c25b20c0SAlex Zinenko }
115c25b20c0SAlex Zinenko 
116c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body.
117c25b20c0SAlex Zinenko static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
118c25b20c0SAlex Zinenko                         OpBuilder b) {
119c25b20c0SAlex Zinenko   BlockAndValueMapping firstToSecondPloopIndices;
120c25b20c0SAlex Zinenko   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
121c25b20c0SAlex Zinenko                                 secondPloop.getBody()->getArguments());
122c25b20c0SAlex Zinenko 
123c25b20c0SAlex Zinenko   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
124c25b20c0SAlex Zinenko     return;
125c25b20c0SAlex Zinenko 
126c25b20c0SAlex Zinenko   b.setInsertionPointToStart(secondPloop.getBody());
127c25b20c0SAlex Zinenko   for (auto &op : firstPloop.getBody()->without_terminator())
128c25b20c0SAlex Zinenko     b.clone(op, firstToSecondPloopIndices);
129c25b20c0SAlex Zinenko   firstPloop.erase();
130c25b20c0SAlex Zinenko }
131c25b20c0SAlex Zinenko 
132c25b20c0SAlex Zinenko void mlir::scf::naivelyFuseParallelOps(Region &region) {
133c25b20c0SAlex Zinenko   OpBuilder b(region);
134c25b20c0SAlex Zinenko   // Consider every single block and attempt to fuse adjacent loops.
135c25b20c0SAlex Zinenko   for (auto &block : region) {
136c25b20c0SAlex Zinenko     SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
137c25b20c0SAlex Zinenko     // Not using `walk()` to traverse only top-level parallel loops and also
138c25b20c0SAlex Zinenko     // make sure that there are no side-effecting ops between the parallel
139c25b20c0SAlex Zinenko     // loops.
140c25b20c0SAlex Zinenko     bool noSideEffects = true;
141c25b20c0SAlex Zinenko     for (auto &op : block) {
142c25b20c0SAlex Zinenko       if (auto ploop = dyn_cast<ParallelOp>(op)) {
143c25b20c0SAlex Zinenko         if (noSideEffects) {
144c25b20c0SAlex Zinenko           ploopChains.back().push_back(ploop);
145c25b20c0SAlex Zinenko         } else {
146c25b20c0SAlex Zinenko           ploopChains.push_back({ploop});
147c25b20c0SAlex Zinenko           noSideEffects = true;
148c25b20c0SAlex Zinenko         }
149c25b20c0SAlex Zinenko         continue;
150c25b20c0SAlex Zinenko       }
151c25b20c0SAlex Zinenko       // TODO: Handle region side effects properly.
152c25b20c0SAlex Zinenko       noSideEffects &=
153c25b20c0SAlex Zinenko           MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
154c25b20c0SAlex Zinenko     }
155c25b20c0SAlex Zinenko     for (ArrayRef<ParallelOp> ploops : ploopChains) {
156c25b20c0SAlex Zinenko       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
157c25b20c0SAlex Zinenko         fuseIfLegal(ploops[i], ploops[i + 1], b);
158c25b20c0SAlex Zinenko     }
159c25b20c0SAlex Zinenko   }
160c25b20c0SAlex Zinenko }
161c25b20c0SAlex Zinenko 
162c25b20c0SAlex Zinenko namespace {
163c25b20c0SAlex Zinenko struct ParallelLoopFusion
1644bcd08ebSStephan Herhut     : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
165c25b20c0SAlex Zinenko   void runOnOperation() override {
166c25b20c0SAlex Zinenko     getOperation()->walk([&](Operation *child) {
167c25b20c0SAlex Zinenko       for (Region &region : child->getRegions())
168c25b20c0SAlex Zinenko         naivelyFuseParallelOps(region);
169c25b20c0SAlex Zinenko     });
170c25b20c0SAlex Zinenko   }
171c25b20c0SAlex Zinenko };
172c25b20c0SAlex Zinenko } // namespace
173c25b20c0SAlex Zinenko 
174c25b20c0SAlex Zinenko std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
175c25b20c0SAlex Zinenko   return std::make_unique<ParallelLoopFusion>();
176c25b20c0SAlex Zinenko }
177