1c25b20c0SAlex Zinenko //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2c25b20c0SAlex Zinenko //
3c25b20c0SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c25b20c0SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5c25b20c0SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c25b20c0SAlex Zinenko //
7c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
8c25b20c0SAlex Zinenko //
9c25b20c0SAlex Zinenko // This file implements loop fusion on parallel loops.
10c25b20c0SAlex Zinenko //
11c25b20c0SAlex Zinenko //===----------------------------------------------------------------------===//
12c25b20c0SAlex Zinenko
13c25b20c0SAlex Zinenko #include "PassDetail.h"
14e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18c25b20c0SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
19c25b20c0SAlex Zinenko #include "mlir/IR/Builders.h"
20c25b20c0SAlex Zinenko #include "mlir/IR/OpDefinition.h"
21c25b20c0SAlex Zinenko
22c25b20c0SAlex Zinenko using namespace mlir;
23c25b20c0SAlex Zinenko using namespace mlir::scf;
24c25b20c0SAlex Zinenko
25c25b20c0SAlex Zinenko /// Verify there are no nested ParallelOps.
hasNestedParallelOp(ParallelOp ploop)26c25b20c0SAlex Zinenko static bool hasNestedParallelOp(ParallelOp ploop) {
27c25b20c0SAlex Zinenko auto walkResult =
28c25b20c0SAlex Zinenko ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
29c25b20c0SAlex Zinenko return walkResult.wasInterrupted();
30c25b20c0SAlex Zinenko }
31c25b20c0SAlex Zinenko
32c25b20c0SAlex Zinenko /// Verify equal iteration spaces.
equalIterationSpaces(ParallelOp firstPloop,ParallelOp secondPloop)33c25b20c0SAlex Zinenko static bool equalIterationSpaces(ParallelOp firstPloop,
34c25b20c0SAlex Zinenko ParallelOp secondPloop) {
35c25b20c0SAlex Zinenko if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
36c25b20c0SAlex Zinenko return false;
37c25b20c0SAlex Zinenko
38c25b20c0SAlex Zinenko auto matchOperands = [&](const OperandRange &lhs,
39c25b20c0SAlex Zinenko const OperandRange &rhs) -> bool {
40c25b20c0SAlex Zinenko // TODO: Extend this to support aliases and equal constants.
41c25b20c0SAlex Zinenko return std::equal(lhs.begin(), lhs.end(), rhs.begin());
42c25b20c0SAlex Zinenko };
43c0342a2dSJacques Pienaar return matchOperands(firstPloop.getLowerBound(),
44c0342a2dSJacques Pienaar secondPloop.getLowerBound()) &&
45c0342a2dSJacques Pienaar matchOperands(firstPloop.getUpperBound(),
46c0342a2dSJacques Pienaar secondPloop.getUpperBound()) &&
47c0342a2dSJacques Pienaar matchOperands(firstPloop.getStep(), secondPloop.getStep());
48c25b20c0SAlex Zinenko }
49c25b20c0SAlex Zinenko
50c25b20c0SAlex Zinenko /// Checks if the parallel loops have mixed access to the same buffers. Returns
51c25b20c0SAlex Zinenko /// `true` if the first parallel loop writes to the same indices that the second
52c25b20c0SAlex Zinenko /// loop reads.
haveNoReadsAfterWriteExceptSameIndex(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)53c25b20c0SAlex Zinenko static bool haveNoReadsAfterWriteExceptSameIndex(
54c25b20c0SAlex Zinenko ParallelOp firstPloop, ParallelOp secondPloop,
55c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) {
56c25b20c0SAlex Zinenko DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
57e2310704SJulian Gross firstPloop.getBody()->walk([&](memref::StoreOp store) {
58*136d746eSJacques Pienaar bufferStores[store.getMemRef()].push_back(store.getIndices());
59c25b20c0SAlex Zinenko });
60e2310704SJulian Gross auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
61c25b20c0SAlex Zinenko // Stop if the memref is defined in secondPloop body. Careful alias analysis
62c25b20c0SAlex Zinenko // is needed.
63c25b20c0SAlex Zinenko auto *memrefDef = load.getMemRef().getDefiningOp();
64c4a04059SChristian Sigg if (memrefDef && memrefDef->getBlock() == load->getBlock())
65c25b20c0SAlex Zinenko return WalkResult::interrupt();
66c25b20c0SAlex Zinenko
67c25b20c0SAlex Zinenko auto write = bufferStores.find(load.getMemRef());
68c25b20c0SAlex Zinenko if (write == bufferStores.end())
69c25b20c0SAlex Zinenko return WalkResult::advance();
70c25b20c0SAlex Zinenko
71c25b20c0SAlex Zinenko // Allow only single write access per buffer.
72c25b20c0SAlex Zinenko if (write->second.size() != 1)
73c25b20c0SAlex Zinenko return WalkResult::interrupt();
74c25b20c0SAlex Zinenko
75c25b20c0SAlex Zinenko // Check that the load indices of secondPloop coincide with store indices of
76c25b20c0SAlex Zinenko // firstPloop for the same memrefs.
77c25b20c0SAlex Zinenko auto storeIndices = write->second.front();
78*136d746eSJacques Pienaar auto loadIndices = load.getIndices();
79c25b20c0SAlex Zinenko if (storeIndices.size() != loadIndices.size())
80c25b20c0SAlex Zinenko return WalkResult::interrupt();
81c25b20c0SAlex Zinenko for (int i = 0, e = storeIndices.size(); i < e; ++i) {
82c25b20c0SAlex Zinenko if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
83c25b20c0SAlex Zinenko loadIndices[i])
84c25b20c0SAlex Zinenko return WalkResult::interrupt();
85c25b20c0SAlex Zinenko }
86c25b20c0SAlex Zinenko return WalkResult::advance();
87c25b20c0SAlex Zinenko });
88c25b20c0SAlex Zinenko return !walkResult.wasInterrupted();
89c25b20c0SAlex Zinenko }
90c25b20c0SAlex Zinenko
91c25b20c0SAlex Zinenko /// Analyzes dependencies in the most primitive way by checking simple read and
92c25b20c0SAlex Zinenko /// write patterns.
93c25b20c0SAlex Zinenko static LogicalResult
verifyDependencies(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)94c25b20c0SAlex Zinenko verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
95c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) {
96c25b20c0SAlex Zinenko if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
97c25b20c0SAlex Zinenko firstToSecondPloopIndices))
98c25b20c0SAlex Zinenko return failure();
99c25b20c0SAlex Zinenko
100c25b20c0SAlex Zinenko BlockAndValueMapping secondToFirstPloopIndices;
101c25b20c0SAlex Zinenko secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
102c25b20c0SAlex Zinenko firstPloop.getBody()->getArguments());
103c25b20c0SAlex Zinenko return success(haveNoReadsAfterWriteExceptSameIndex(
104c25b20c0SAlex Zinenko secondPloop, firstPloop, secondToFirstPloopIndices));
105c25b20c0SAlex Zinenko }
106c25b20c0SAlex Zinenko
107c25b20c0SAlex Zinenko static bool
isFusionLegal(ParallelOp firstPloop,ParallelOp secondPloop,const BlockAndValueMapping & firstToSecondPloopIndices)108c25b20c0SAlex Zinenko isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
109c25b20c0SAlex Zinenko const BlockAndValueMapping &firstToSecondPloopIndices) {
110c25b20c0SAlex Zinenko return !hasNestedParallelOp(firstPloop) &&
111c25b20c0SAlex Zinenko !hasNestedParallelOp(secondPloop) &&
112c25b20c0SAlex Zinenko equalIterationSpaces(firstPloop, secondPloop) &&
113c25b20c0SAlex Zinenko succeeded(verifyDependencies(firstPloop, secondPloop,
114c25b20c0SAlex Zinenko firstToSecondPloopIndices));
115c25b20c0SAlex Zinenko }
116c25b20c0SAlex Zinenko
117c25b20c0SAlex Zinenko /// Prepends operations of firstPloop's body into secondPloop's body.
fuseIfLegal(ParallelOp firstPloop,ParallelOp secondPloop,OpBuilder b)118c25b20c0SAlex Zinenko static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
119c25b20c0SAlex Zinenko OpBuilder b) {
120c25b20c0SAlex Zinenko BlockAndValueMapping firstToSecondPloopIndices;
121c25b20c0SAlex Zinenko firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
122c25b20c0SAlex Zinenko secondPloop.getBody()->getArguments());
123c25b20c0SAlex Zinenko
124c25b20c0SAlex Zinenko if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
125c25b20c0SAlex Zinenko return;
126c25b20c0SAlex Zinenko
127c25b20c0SAlex Zinenko b.setInsertionPointToStart(secondPloop.getBody());
128c25b20c0SAlex Zinenko for (auto &op : firstPloop.getBody()->without_terminator())
129c25b20c0SAlex Zinenko b.clone(op, firstToSecondPloopIndices);
130c25b20c0SAlex Zinenko firstPloop.erase();
131c25b20c0SAlex Zinenko }
132c25b20c0SAlex Zinenko
naivelyFuseParallelOps(Region & region)133c25b20c0SAlex Zinenko void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
134c25b20c0SAlex Zinenko OpBuilder b(region);
135c25b20c0SAlex Zinenko // Consider every single block and attempt to fuse adjacent loops.
136c25b20c0SAlex Zinenko for (auto &block : region) {
137c25b20c0SAlex Zinenko SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
138c25b20c0SAlex Zinenko // Not using `walk()` to traverse only top-level parallel loops and also
139c25b20c0SAlex Zinenko // make sure that there are no side-effecting ops between the parallel
140c25b20c0SAlex Zinenko // loops.
141c25b20c0SAlex Zinenko bool noSideEffects = true;
142c25b20c0SAlex Zinenko for (auto &op : block) {
143c25b20c0SAlex Zinenko if (auto ploop = dyn_cast<ParallelOp>(op)) {
144c25b20c0SAlex Zinenko if (noSideEffects) {
145c25b20c0SAlex Zinenko ploopChains.back().push_back(ploop);
146c25b20c0SAlex Zinenko } else {
147c25b20c0SAlex Zinenko ploopChains.push_back({ploop});
148c25b20c0SAlex Zinenko noSideEffects = true;
149c25b20c0SAlex Zinenko }
150c25b20c0SAlex Zinenko continue;
151c25b20c0SAlex Zinenko }
152c25b20c0SAlex Zinenko // TODO: Handle region side effects properly.
153c25b20c0SAlex Zinenko noSideEffects &=
154c25b20c0SAlex Zinenko MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
155c25b20c0SAlex Zinenko }
156c25b20c0SAlex Zinenko for (ArrayRef<ParallelOp> ploops : ploopChains) {
157c25b20c0SAlex Zinenko for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
158c25b20c0SAlex Zinenko fuseIfLegal(ploops[i], ploops[i + 1], b);
159c25b20c0SAlex Zinenko }
160c25b20c0SAlex Zinenko }
161c25b20c0SAlex Zinenko }
162c25b20c0SAlex Zinenko
163c25b20c0SAlex Zinenko namespace {
164c25b20c0SAlex Zinenko struct ParallelLoopFusion
1654bcd08ebSStephan Herhut : public SCFParallelLoopFusionBase<ParallelLoopFusion> {
runOnOperation__anone73b14020511::ParallelLoopFusion166c25b20c0SAlex Zinenko void runOnOperation() override {
167c25b20c0SAlex Zinenko getOperation()->walk([&](Operation *child) {
168c25b20c0SAlex Zinenko for (Region ®ion : child->getRegions())
169c25b20c0SAlex Zinenko naivelyFuseParallelOps(region);
170c25b20c0SAlex Zinenko });
171c25b20c0SAlex Zinenko }
172c25b20c0SAlex Zinenko };
173c25b20c0SAlex Zinenko } // namespace
174c25b20c0SAlex Zinenko
createParallelLoopFusionPass()175c25b20c0SAlex Zinenko std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
176c25b20c0SAlex Zinenko return std::make_unique<ParallelLoopFusion>();
177c25b20c0SAlex Zinenko }
178