1 //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion on parallel loops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Passes.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/OpDefinition.h"
21
22 using namespace mlir;
23 using namespace mlir::scf;
24
25 /// Verify there are no nested ParallelOps.
hasNestedParallelOp(ParallelOp ploop)26 static bool hasNestedParallelOp(ParallelOp ploop) {
27 auto walkResult =
28 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29 return walkResult.wasInterrupted();
30 }
31
32 /// Verify equal iteration spaces.
equalIterationSpaces(ParallelOp firstPloop,ParallelOp secondPloop)33 static bool equalIterationSpaces(ParallelOp firstPloop,
34 ParallelOp secondPloop) {
35 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36 return false;
37
38 auto matchOperands = [&](const OperandRange &lhs,
39 const OperandRange &rhs) -> bool {
40 // TODO: Extend this to support aliases and equal constants.
41 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42 };
43 return matchOperands(firstPloop.getLowerBound(),
44 secondPloop.getLowerBound()) &&
45 matchOperands(firstPloop.getUpperBound(),
46 secondPloop.getUpperBound()) &&
47 matchOperands(firstPloop.getStep(), secondPloop.getStep());
48 }
49
50 /// Checks if the parallel loops have mixed access to the same buffers. Returns
51 /// `true` if the first parallel loop writes to the same indices that the second
52 /// loop reads.
haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)53 static bool haveNoReadsAfterWriteExceptSameIndex(
54 ParallelOp firstPloop, ParallelOp secondPloop,
55 const BlockAndValueMapping &firstToSecondPloopIndices) {
56 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
57 firstPloop.getBody()->walk([&](memref::StoreOp store) {
58 bufferStores[store.getMemRef()].push_back(store.getIndices());
59 });
60 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
61 // Stop if the memref is defined in secondPloop body. Careful alias analysis
62 // is needed.
63 auto *memrefDef = load.getMemRef().getDefiningOp();
64 if (memrefDef && memrefDef->getBlock() == load->getBlock())
65 return WalkResult::interrupt();
66
67 auto write = bufferStores.find(load.getMemRef());
68 if (write == bufferStores.end())
69 return WalkResult::advance();
70
71 // Allow only single write access per buffer.
72 if (write->second.size() != 1)
73 return WalkResult::interrupt();
74
75 // Check that the load indices of secondPloop coincide with store indices of
76 // firstPloop for the same memrefs.
77 auto storeIndices = write->second.front();
78 auto loadIndices = load.getIndices();
79 if (storeIndices.size() != loadIndices.size())
80 return WalkResult::interrupt();
81 for (int i = 0, e = storeIndices.size(); i < e; ++i) {
82 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
83 loadIndices[i])
84 return WalkResult::interrupt();
85 }
86 return WalkResult::advance();
87 });
88 return !walkResult.wasInterrupted();
89 }
90
91 /// Analyzes dependencies in the most primitive way by checking simple read and
92 /// write patterns.
93 static LogicalResult
verifyDependencies(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)94 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
95 const BlockAndValueMapping &firstToSecondPloopIndices) {
96 if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
97 firstToSecondPloopIndices))
98 return failure();
99
100 BlockAndValueMapping secondToFirstPloopIndices;
101 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
102 firstPloop.getBody()->getArguments());
103 return success(haveNoReadsAfterWriteExceptSameIndex(
104 secondPloop, firstPloop, secondToFirstPloopIndices));
105 }
106
107 static bool
isFusionLegal(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)108 isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
109 const BlockAndValueMapping &firstToSecondPloopIndices) {
110 return !hasNestedParallelOp(firstPloop) &&
111 !hasNestedParallelOp(secondPloop) &&
112 equalIterationSpaces(firstPloop, secondPloop) &&
113 succeeded(verifyDependencies(firstPloop, secondPloop,
114 firstToSecondPloopIndices));
115 }
116
117 /// Prepends operations of firstPloop's body into secondPloop's body.
fuseIfLegal(ParallelOp firstPloop,ParallelOp secondPloop,OpBuilder b)118 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
119 OpBuilder b) {
120 BlockAndValueMapping firstToSecondPloopIndices;
121 firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
122 secondPloop.getBody()->getArguments());
123
124 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
125 return;
126
127 b.setInsertionPointToStart(secondPloop.getBody());
128 for (auto &op : firstPloop.getBody()->without_terminator())
129 b.clone(op, firstToSecondPloopIndices);
130 firstPloop.erase();
131 }
132
naivelyFuseParallelOps(Region & region)133 void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
134 OpBuilder b(region);
135 // Consider every single block and attempt to fuse adjacent loops.
136 for (auto &block : region) {
137 SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
138 // Not using `walk()` to traverse only top-level parallel loops and also
139 // make sure that there are no side-effecting ops between the parallel
140 // loops.
141 bool noSideEffects = true;
142 for (auto &op : block) {
143 if (auto ploop = dyn_cast<ParallelOp>(op)) {
144 if (noSideEffects) {
145 ploopChains.back().push_back(ploop);
146 } else {
147 ploopChains.push_back({ploop});
148 noSideEffects = true;
149 }
150 continue;
151 }
152 // TODO: Handle region side effects properly.
153 noSideEffects &=
154 MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
155 }
156 for (ArrayRef<ParallelOp> ploops : ploopChains) {
157 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
158 fuseIfLegal(ploops[i], ploops[i + 1], b);
159 }
160 }
161 }
162
163 namespace {
164 struct ParallelLoopFusion
165 : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
runOnOperation__anone73b14020511::ParallelLoopFusion166 void runOnOperation() override {
167 getOperation()->walk([&](Operation *child) {
168 for (Region ®ion : child->getRegions())
169 naivelyFuseParallelOps(region);
170 });
171 }
172 };
173 } // namespace
174
createParallelLoopFusionPass()175 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
176 return std::make_unique<ParallelLoopFusion>();
177 }
178