1 //===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/IR/Dominance.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Transforms/Passes.h" 18 #include "llvm/ADT/Optional.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/SmallVector.h" 21 22 #define DEBUG_TYPE "fir-memref-dataflow-opt" 23 24 using namespace mlir; 25 26 namespace { 27 28 template <typename OpT> 29 static std::vector<OpT> getSpecificUsers(mlir::Value v) { 30 std::vector<OpT> ops; 31 for (mlir::Operation *user : v.getUsers()) 32 if (auto op = dyn_cast<OpT>(user)) 33 ops.push_back(op); 34 return ops; 35 } 36 37 /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead 38 /// and AffineWrite interface 39 template <typename ReadOp, typename WriteOp> 40 class LoadStoreForwarding { 41 public: 42 LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {} 43 44 // FIXME: This algorithm has a bug. It ignores escaping references between a 45 // store and a load. 46 llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp, 47 std::vector<WriteOp> &&storeOps) { 48 llvm::SmallVector<WriteOp> candidateSet; 49 50 for (auto storeOp : storeOps) 51 if (domInfo->dominates(storeOp, loadOp)) 52 candidateSet.push_back(storeOp); 53 54 if (candidateSet.empty()) 55 return {}; 56 57 llvm::Optional<WriteOp> nearestStore; 58 for (auto candidate : candidateSet) { 59 auto nearerThan = [&](WriteOp otherStore) { 60 if (candidate == otherStore) 61 return false; 62 bool rv = domInfo->properlyDominates(candidate, otherStore); 63 if (rv) { 64 LLVM_DEBUG(llvm::dbgs() 65 << "candidate " << candidate << " is not the nearest to " 66 << loadOp << " because " << otherStore << " is closer\n"); 67 } 68 return rv; 69 }; 70 if (!llvm::any_of(candidateSet, nearerThan)) { 71 nearestStore = mlir::cast<WriteOp>(candidate); 72 break; 73 } 74 } 75 if (!nearestStore) { 76 LLVM_DEBUG( 77 llvm::dbgs() 78 << "load " << loadOp << " has " << candidateSet.size() 79 << " store candidates, but this algorithm can't find a best.\n"); 80 } 81 return nearestStore; 82 } 83 84 llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp, 85 std::vector<ReadOp> &&loadOps) { 86 for (auto &loadOp : loadOps) { 87 if (domInfo->dominates(storeOp, loadOp)) 88 return loadOp; 89 } 90 return {}; 91 } 92 93 private: 94 mlir::DominanceInfo *domInfo; 95 }; 96 97 class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> { 98 public: 99 void runOnOperation() override { 100 mlir::func::FuncOp f = getOperation(); 101 102 auto *domInfo = &getAnalysis<mlir::DominanceInfo>(); 103 LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo); 104 f.walk([&](fir::LoadOp loadOp) { 105 auto maybeStore = lsf.findStoreToForward( 106 loadOp, getSpecificUsers<fir::StoreOp>(loadOp.getMemref())); 107 if (maybeStore) { 108 auto storeOp = *maybeStore; 109 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() 110 << " erasing load " << loadOp 111 << " with value from " << storeOp << '\n'); 112 loadOp.getResult().replaceAllUsesWith(storeOp.getValue()); 113 loadOp.erase(); 114 } 115 }); 116 f.walk([&](fir::AllocaOp alloca) { 117 for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) { 118 if (!lsf.findReadForWrite( 119 storeOp, getSpecificUsers<fir::LoadOp>(storeOp.getMemref()))) { 120 LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() 121 << " erasing store " << storeOp << '\n'); 122 storeOp.erase(); 123 } 124 } 125 }); 126 } 127 }; 128 } // namespace 129 130 std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() { 131 return std::make_unique<MemDataFlowOpt>(); 132 } 133