1 //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "flang/Optimizer/Transforms/Passes.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Transforms/Passes.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 
22 #define DEBUG_TYPE "fir-memref-dataflow-opt"
23 
24 namespace {
25 
26 template <typename OpT>
27 static std::vector<OpT> getSpecificUsers(mlir::Value v) {
28   std::vector<OpT> ops;
29   for (mlir::Operation *user : v.getUsers())
30     if (auto op = dyn_cast<OpT>(user))
31       ops.push_back(op);
32   return ops;
33 }
34 
35 /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
36 /// and AffineWrite interface
37 template <typename ReadOp, typename WriteOp>
38 class LoadStoreForwarding {
39 public:
40   LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
41 
42   // FIXME: This algorithm has a bug. It ignores escaping references between a
43   // store and a load.
44   llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp,
45                                              std::vector<WriteOp> &&storeOps) {
46     llvm::SmallVector<WriteOp> candidateSet;
47 
48     for (auto storeOp : storeOps)
49       if (domInfo->dominates(storeOp, loadOp))
50         candidateSet.push_back(storeOp);
51 
52     if (candidateSet.empty())
53       return {};
54 
55     llvm::Optional<WriteOp> nearestStore;
56     for (auto candidate : candidateSet) {
57       auto nearerThan = [&](WriteOp otherStore) {
58         if (candidate == otherStore)
59           return false;
60         bool rv = domInfo->properlyDominates(candidate, otherStore);
61         if (rv) {
62           LLVM_DEBUG(llvm::dbgs()
63                      << "candidate " << candidate << " is not the nearest to "
64                      << loadOp << " because " << otherStore << " is closer\n");
65         }
66         return rv;
67       };
68       if (!llvm::any_of(candidateSet, nearerThan)) {
69         nearestStore = mlir::cast<WriteOp>(candidate);
70         break;
71       }
72     }
73     if (!nearestStore) {
74       LLVM_DEBUG(
75           llvm::dbgs()
76           << "load " << loadOp << " has " << candidateSet.size()
77           << " store candidates, but this algorithm can't find a best.\n");
78     }
79     return nearestStore;
80   }
81 
82   llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp,
83                                           std::vector<ReadOp> &&loadOps) {
84     for (auto &loadOp : loadOps) {
85       if (domInfo->dominates(storeOp, loadOp))
86         return loadOp;
87     }
88     return {};
89   }
90 
91 private:
92   mlir::DominanceInfo *domInfo;
93 };
94 
95 class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> {
96 public:
97   void runOnOperation() override {
98     mlir::FuncOp f = getOperation();
99 
100     auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
101     LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
102     f.walk([&](fir::LoadOp loadOp) {
103       auto maybeStore = lsf.findStoreToForward(
104           loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref()));
105       if (maybeStore) {
106         auto storeOp = maybeStore.getValue();
107         LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
108                                 << " erasing load " << loadOp
109                                 << " with value from " << storeOp << '\n');
110         loadOp.getResult().replaceAllUsesWith(storeOp.getValue());
111         loadOp.erase();
112       }
113     });
114     f.walk([&](fir::AllocaOp alloca) {
115       for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
116         if (!lsf.findReadForWrite(
117                 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) {
118           LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
119                                   << " erasing store " << storeOp << '\n');
120           storeOp.erase();
121         }
122       }
123     });
124   }
125 };
126 } // namespace
127 
128 std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
129   return std::make_unique<MemDataFlowOpt>();
130 }
131