1 //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Transforms/Passes.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21
22 #define DEBUG_TYPE "fir-memref-dataflow-opt"
23
24 using namespace mlir;
25
26 namespace {
27
28 template <typename OpT>
getSpecificUsers(mlir::Value v)29 static std::vector<OpT> getSpecificUsers(mlir::Value v) {
30 std::vector<OpT> ops;
31 for (mlir::Operation *user : v.getUsers())
32 if (auto op = dyn_cast<OpT>(user))
33 ops.push_back(op);
34 return ops;
35 }
36
37 /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
38 /// and AffineWrite interface
39 template <typename ReadOp, typename WriteOp>
40 class LoadStoreForwarding {
41 public:
LoadStoreForwarding(mlir::DominanceInfo * di)42 LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
43
44 // FIXME: This algorithm has a bug. It ignores escaping references between a
45 // store and a load.
findStoreToForward(ReadOp loadOp,std::vector<WriteOp> && storeOps)46 llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp,
47 std::vector<WriteOp> &&storeOps) {
48 llvm::SmallVector<WriteOp> candidateSet;
49
50 for (auto storeOp : storeOps)
51 if (domInfo->dominates(storeOp, loadOp))
52 candidateSet.push_back(storeOp);
53
54 if (candidateSet.empty())
55 return {};
56
57 llvm::Optional<WriteOp> nearestStore;
58 for (auto candidate : candidateSet) {
59 auto nearerThan = [&](WriteOp otherStore) {
60 if (candidate == otherStore)
61 return false;
62 bool rv = domInfo->properlyDominates(candidate, otherStore);
63 if (rv) {
64 LLVM_DEBUG(llvm::dbgs()
65 << "candidate " << candidate << " is not the nearest to "
66 << loadOp << " because " << otherStore << " is closer\n");
67 }
68 return rv;
69 };
70 if (!llvm::any_of(candidateSet, nearerThan)) {
71 nearestStore = mlir::cast<WriteOp>(candidate);
72 break;
73 }
74 }
75 if (!nearestStore) {
76 LLVM_DEBUG(
77 llvm::dbgs()
78 << "load " << loadOp << " has " << candidateSet.size()
79 << " store candidates, but this algorithm can't find a best.\n");
80 }
81 return nearestStore;
82 }
83
findReadForWrite(WriteOp storeOp,std::vector<ReadOp> && loadOps)84 llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp,
85 std::vector<ReadOp> &&loadOps) {
86 for (auto &loadOp : loadOps) {
87 if (domInfo->dominates(storeOp, loadOp))
88 return loadOp;
89 }
90 return {};
91 }
92
93 private:
94 mlir::DominanceInfo *domInfo;
95 };
96
97 class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> {
98 public:
runOnOperation()99 void runOnOperation() override {
100 mlir::func::FuncOp f = getOperation();
101
102 auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
103 LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
104 f.walk([&](fir::LoadOp loadOp) {
105 auto maybeStore = lsf.findStoreToForward(
106 loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref()));
107 if (maybeStore) {
108 auto storeOp = *maybeStore;
109 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
110 << " erasing load " << loadOp
111 << " with value from " << storeOp << '\n');
112 loadOp.getResult().replaceAllUsesWith(storeOp.getValue());
113 loadOp.erase();
114 }
115 });
116 f.walk([&](fir::AllocaOp alloca) {
117 for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
118 if (!lsf.findReadForWrite(
119 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) {
120 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
121 << " erasing store " << storeOp << '\n');
122 storeOp.erase();
123 }
124 }
125 });
126 }
127 };
128 } // namespace
129
createMemDataFlowOptPass()130 std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
131 return std::make_unique<MemDataFlowOpt>();
132 }
133