1*47f75930SValentin Clement //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2*47f75930SValentin Clement //
3*47f75930SValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*47f75930SValentin Clement // See https://llvm.org/LICENSE.txt for license information.
5*47f75930SValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*47f75930SValentin Clement //
7*47f75930SValentin Clement //===----------------------------------------------------------------------===//
8*47f75930SValentin Clement 
9*47f75930SValentin Clement #include "PassDetail.h"
10*47f75930SValentin Clement #include "flang/Optimizer/Builder/BoxValue.h"
11*47f75930SValentin Clement #include "flang/Optimizer/Builder/FIRBuilder.h"
12*47f75930SValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h"
13*47f75930SValentin Clement #include "flang/Optimizer/Support/FIRContext.h"
14*47f75930SValentin Clement #include "flang/Optimizer/Transforms/Factory.h"
15*47f75930SValentin Clement #include "flang/Optimizer/Transforms/Passes.h"
16*47f75930SValentin Clement #include "mlir/Dialect/SCF/SCF.h"
17*47f75930SValentin Clement #include "mlir/Transforms/DialectConversion.h"
18*47f75930SValentin Clement #include "llvm/Support/Debug.h"
19*47f75930SValentin Clement 
20*47f75930SValentin Clement #define DEBUG_TYPE "flang-array-value-copy"
21*47f75930SValentin Clement 
22*47f75930SValentin Clement using namespace fir;
23*47f75930SValentin Clement 
24*47f75930SValentin Clement using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
25*47f75930SValentin Clement 
26*47f75930SValentin Clement namespace {
27*47f75930SValentin Clement 
28*47f75930SValentin Clement /// Array copy analysis.
29*47f75930SValentin Clement /// Perform an interference analysis between array values.
30*47f75930SValentin Clement ///
31*47f75930SValentin Clement /// Lowering will generate a sequence of the following form.
32*47f75930SValentin Clement /// ```mlir
33*47f75930SValentin Clement ///   %a_1 = fir.array_load %array_1(%shape) : ...
34*47f75930SValentin Clement ///   ...
35*47f75930SValentin Clement ///   %a_j = fir.array_load %array_j(%shape) : ...
36*47f75930SValentin Clement ///   ...
37*47f75930SValentin Clement ///   %a_n = fir.array_load %array_n(%shape) : ...
38*47f75930SValentin Clement ///     ...
39*47f75930SValentin Clement ///     %v_i = fir.array_fetch %a_i, ...
40*47f75930SValentin Clement ///     %a_j1 = fir.array_update %a_j, ...
41*47f75930SValentin Clement ///     ...
42*47f75930SValentin Clement ///   fir.array_merge_store %a_j, %a_jn to %array_j : ...
43*47f75930SValentin Clement /// ```
44*47f75930SValentin Clement ///
45*47f75930SValentin Clement /// The analysis is to determine if there are any conflicts. A conflict is when
46*47f75930SValentin Clement /// one the following cases occurs.
47*47f75930SValentin Clement ///
48*47f75930SValentin Clement /// 1. There is an `array_update` to an array value, a_j, such that a_j was
49*47f75930SValentin Clement /// loaded from the same array memory reference (array_j) but with a different
50*47f75930SValentin Clement /// shape as the other array values a_i, where i != j. [Possible overlapping
51*47f75930SValentin Clement /// arrays.]
52*47f75930SValentin Clement ///
53*47f75930SValentin Clement /// 2. There is either an array_fetch or array_update of a_j with a different
54*47f75930SValentin Clement /// set of index values. [Possible loop-carried dependence.]
55*47f75930SValentin Clement ///
56*47f75930SValentin Clement /// If none of the array values overlap in storage and the accesses are not
57*47f75930SValentin Clement /// loop-carried, then the arrays are conflict-free and no copies are required.
58*47f75930SValentin Clement class ArrayCopyAnalysis {
59*47f75930SValentin Clement public:
60*47f75930SValentin Clement   using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
61*47f75930SValentin Clement   using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
62*47f75930SValentin Clement   using LoadMapSetsT =
63*47f75930SValentin Clement       llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>;
64*47f75930SValentin Clement 
65*47f75930SValentin Clement   ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
66*47f75930SValentin Clement 
67*47f75930SValentin Clement   mlir::Operation *getOperation() const { return operation; }
68*47f75930SValentin Clement 
69*47f75930SValentin Clement   /// Return true iff the `array_merge_store` has potential conflicts.
70*47f75930SValentin Clement   bool hasPotentialConflict(mlir::Operation *op) const {
71*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs()
72*47f75930SValentin Clement                << "looking for a conflict on " << *op
73*47f75930SValentin Clement                << " and the set has a total of " << conflicts.size() << '\n');
74*47f75930SValentin Clement     return conflicts.contains(op);
75*47f75930SValentin Clement   }
76*47f75930SValentin Clement 
77*47f75930SValentin Clement   /// Return the use map. The use map maps array fetch and update operations
78*47f75930SValentin Clement   /// back to the array load that is the original source of the array value.
79*47f75930SValentin Clement   const OperationUseMapT &getUseMap() const { return useMap; }
80*47f75930SValentin Clement 
81*47f75930SValentin Clement   /// Find all the array operations that access the array value that is loaded
82*47f75930SValentin Clement   /// by the array load operation, `load`.
83*47f75930SValentin Clement   const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load);
84*47f75930SValentin Clement 
85*47f75930SValentin Clement private:
86*47f75930SValentin Clement   void construct(mlir::Operation *topLevelOp);
87*47f75930SValentin Clement 
88*47f75930SValentin Clement   mlir::Operation *operation; // operation that analysis ran upon
89*47f75930SValentin Clement   ConflictSetT conflicts;     // set of conflicts (loads and merge stores)
90*47f75930SValentin Clement   OperationUseMapT useMap;
91*47f75930SValentin Clement   LoadMapSetsT loadMapSets;
92*47f75930SValentin Clement };
93*47f75930SValentin Clement } // namespace
94*47f75930SValentin Clement 
95*47f75930SValentin Clement namespace {
96*47f75930SValentin Clement /// Helper class to collect all array operations that produced an array value.
97*47f75930SValentin Clement class ReachCollector {
98*47f75930SValentin Clement private:
99*47f75930SValentin Clement   // If provided, the `loopRegion` is the body of a loop that produces the array
100*47f75930SValentin Clement   // of interest.
101*47f75930SValentin Clement   ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
102*47f75930SValentin Clement                  mlir::Region *loopRegion)
103*47f75930SValentin Clement       : reach{reach}, loopRegion{loopRegion} {}
104*47f75930SValentin Clement 
105*47f75930SValentin Clement   void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) {
106*47f75930SValentin Clement     llvm::errs() << "COLLECT " << *op << "\n";
107*47f75930SValentin Clement     if (range.empty()) {
108*47f75930SValentin Clement       collectArrayAccessFrom(op, mlir::Value{});
109*47f75930SValentin Clement       return;
110*47f75930SValentin Clement     }
111*47f75930SValentin Clement     for (mlir::Value v : range)
112*47f75930SValentin Clement       collectArrayAccessFrom(v);
113*47f75930SValentin Clement   }
114*47f75930SValentin Clement 
115*47f75930SValentin Clement   // TODO: Replace recursive algorithm on def-use chain with an iterative one
116*47f75930SValentin Clement   // with an explicit stack.
117*47f75930SValentin Clement   void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) {
118*47f75930SValentin Clement     // `val` is defined by an Op, process the defining Op.
119*47f75930SValentin Clement     // If `val` is defined by a region containing Op, we want to drill down
120*47f75930SValentin Clement     // and through that Op's region(s).
121*47f75930SValentin Clement     llvm::errs() << "COLLECT " << *op << "\n";
122*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
123*47f75930SValentin Clement     auto popFn = [&](auto rop) {
124*47f75930SValentin Clement       assert(val && "op must have a result value");
125*47f75930SValentin Clement       auto resNum = val.cast<mlir::OpResult>().getResultNumber();
126*47f75930SValentin Clement       llvm::SmallVector<mlir::Value> results;
127*47f75930SValentin Clement       rop.resultToSourceOps(results, resNum);
128*47f75930SValentin Clement       for (auto u : results)
129*47f75930SValentin Clement         collectArrayAccessFrom(u);
130*47f75930SValentin Clement     };
131*47f75930SValentin Clement     if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) {
132*47f75930SValentin Clement       popFn(rop);
133*47f75930SValentin Clement       return;
134*47f75930SValentin Clement     }
135*47f75930SValentin Clement     if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
136*47f75930SValentin Clement       popFn(rop);
137*47f75930SValentin Clement       return;
138*47f75930SValentin Clement     }
139*47f75930SValentin Clement     if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
140*47f75930SValentin Clement       if (opIsInsideLoops(mergeStore))
141*47f75930SValentin Clement         collectArrayAccessFrom(mergeStore.sequence());
142*47f75930SValentin Clement       return;
143*47f75930SValentin Clement     }
144*47f75930SValentin Clement 
145*47f75930SValentin Clement     if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
146*47f75930SValentin Clement       // Look for any stores inside the loops, and collect an array operation
147*47f75930SValentin Clement       // that produced the value being stored to it.
148*47f75930SValentin Clement       for (mlir::Operation *user : op->getUsers())
149*47f75930SValentin Clement         if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
150*47f75930SValentin Clement           if (opIsInsideLoops(store))
151*47f75930SValentin Clement             collectArrayAccessFrom(store.value());
152*47f75930SValentin Clement       return;
153*47f75930SValentin Clement     }
154*47f75930SValentin Clement 
155*47f75930SValentin Clement     // Otherwise, Op does not contain a region so just chase its operands.
156*47f75930SValentin Clement     if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>(
157*47f75930SValentin Clement             op)) {
158*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
159*47f75930SValentin Clement       reach.emplace_back(op);
160*47f75930SValentin Clement     }
161*47f75930SValentin Clement     // Array modify assignment is performed on the result. So the analysis
162*47f75930SValentin Clement     // must look at the what is done with the result.
163*47f75930SValentin Clement     if (mlir::isa<ArrayModifyOp>(op))
164*47f75930SValentin Clement       for (mlir::Operation *user : op->getResult(0).getUsers())
165*47f75930SValentin Clement         followUsers(user);
166*47f75930SValentin Clement 
167*47f75930SValentin Clement     for (auto u : op->getOperands())
168*47f75930SValentin Clement       collectArrayAccessFrom(u);
169*47f75930SValentin Clement   }
170*47f75930SValentin Clement 
171*47f75930SValentin Clement   void collectArrayAccessFrom(mlir::BlockArgument ba) {
172*47f75930SValentin Clement     auto *parent = ba.getOwner()->getParentOp();
173*47f75930SValentin Clement     // If inside an Op holding a region, the block argument corresponds to an
174*47f75930SValentin Clement     // argument passed to the containing Op.
175*47f75930SValentin Clement     auto popFn = [&](auto rop) {
176*47f75930SValentin Clement       collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
177*47f75930SValentin Clement     };
178*47f75930SValentin Clement     if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
179*47f75930SValentin Clement       popFn(rop);
180*47f75930SValentin Clement       return;
181*47f75930SValentin Clement     }
182*47f75930SValentin Clement     if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
183*47f75930SValentin Clement       popFn(rop);
184*47f75930SValentin Clement       return;
185*47f75930SValentin Clement     }
186*47f75930SValentin Clement     // Otherwise, a block argument is provided via the pred blocks.
187*47f75930SValentin Clement     for (auto *pred : ba.getOwner()->getPredecessors()) {
188*47f75930SValentin Clement       auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
189*47f75930SValentin Clement       collectArrayAccessFrom(u);
190*47f75930SValentin Clement     }
191*47f75930SValentin Clement   }
192*47f75930SValentin Clement 
193*47f75930SValentin Clement   // Recursively trace operands to find all array operations relating to the
194*47f75930SValentin Clement   // values merged.
195*47f75930SValentin Clement   void collectArrayAccessFrom(mlir::Value val) {
196*47f75930SValentin Clement     if (!val || visited.contains(val))
197*47f75930SValentin Clement       return;
198*47f75930SValentin Clement     visited.insert(val);
199*47f75930SValentin Clement 
200*47f75930SValentin Clement     // Process a block argument.
201*47f75930SValentin Clement     if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
202*47f75930SValentin Clement       collectArrayAccessFrom(ba);
203*47f75930SValentin Clement       return;
204*47f75930SValentin Clement     }
205*47f75930SValentin Clement 
206*47f75930SValentin Clement     // Process an Op.
207*47f75930SValentin Clement     if (auto *op = val.getDefiningOp()) {
208*47f75930SValentin Clement       collectArrayAccessFrom(op, val);
209*47f75930SValentin Clement       return;
210*47f75930SValentin Clement     }
211*47f75930SValentin Clement 
212*47f75930SValentin Clement     fir::emitFatalError(val.getLoc(), "unhandled value");
213*47f75930SValentin Clement   }
214*47f75930SValentin Clement 
215*47f75930SValentin Clement   /// Is \op inside the loop nest region ?
216*47f75930SValentin Clement   bool opIsInsideLoops(mlir::Operation *op) const {
217*47f75930SValentin Clement     return loopRegion && loopRegion->isAncestor(op->getParentRegion());
218*47f75930SValentin Clement   }
219*47f75930SValentin Clement 
220*47f75930SValentin Clement   /// Recursively trace the use of an operation results, calling
221*47f75930SValentin Clement   /// collectArrayAccessFrom on the direct and indirect user operands.
222*47f75930SValentin Clement   /// TODO: Replace recursive algorithm on def-use chain with an iterative one
223*47f75930SValentin Clement   /// with an explicit stack.
224*47f75930SValentin Clement   void followUsers(mlir::Operation *op) {
225*47f75930SValentin Clement     for (auto userOperand : op->getOperands())
226*47f75930SValentin Clement       collectArrayAccessFrom(userOperand);
227*47f75930SValentin Clement     // Go through potential converts/coordinate_op.
228*47f75930SValentin Clement     for (mlir::Operation *indirectUser : op->getUsers())
229*47f75930SValentin Clement       followUsers(indirectUser);
230*47f75930SValentin Clement   }
231*47f75930SValentin Clement 
232*47f75930SValentin Clement   llvm::SmallVectorImpl<mlir::Operation *> &reach;
233*47f75930SValentin Clement   llvm::SmallPtrSet<mlir::Value, 16> visited;
234*47f75930SValentin Clement   /// Region of the loops nest that produced the array value.
235*47f75930SValentin Clement   mlir::Region *loopRegion;
236*47f75930SValentin Clement 
237*47f75930SValentin Clement public:
238*47f75930SValentin Clement   /// Return all ops that produce the array value that is stored into the
239*47f75930SValentin Clement   /// `array_merge_store`.
240*47f75930SValentin Clement   static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
241*47f75930SValentin Clement                              mlir::Value seq) {
242*47f75930SValentin Clement     reach.clear();
243*47f75930SValentin Clement     mlir::Region *loopRegion = nullptr;
244*47f75930SValentin Clement     // Only `DoLoopOp` is tested here since array operations are currently only
245*47f75930SValentin Clement     // associated with this kind of loop.
246*47f75930SValentin Clement     if (auto doLoop =
247*47f75930SValentin Clement             mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp()))
248*47f75930SValentin Clement       loopRegion = &doLoop->getRegion(0);
249*47f75930SValentin Clement     ReachCollector collector(reach, loopRegion);
250*47f75930SValentin Clement     collector.collectArrayAccessFrom(seq);
251*47f75930SValentin Clement   }
252*47f75930SValentin Clement };
253*47f75930SValentin Clement } // namespace
254*47f75930SValentin Clement 
255*47f75930SValentin Clement /// Find all the array operations that access the array value that is loaded by
256*47f75930SValentin Clement /// the array load operation, `load`.
257*47f75930SValentin Clement const llvm::SmallVector<mlir::Operation *> &
258*47f75930SValentin Clement ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) {
259*47f75930SValentin Clement   auto lmIter = loadMapSets.find(load);
260*47f75930SValentin Clement   if (lmIter != loadMapSets.end())
261*47f75930SValentin Clement     return lmIter->getSecond();
262*47f75930SValentin Clement 
263*47f75930SValentin Clement   llvm::SmallVector<mlir::Operation *> accesses;
264*47f75930SValentin Clement   UseSetT visited;
265*47f75930SValentin Clement   llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
266*47f75930SValentin Clement 
267*47f75930SValentin Clement   auto appendToQueue = [&](mlir::Value val) {
268*47f75930SValentin Clement     for (mlir::OpOperand &use : val.getUses())
269*47f75930SValentin Clement       if (!visited.count(&use)) {
270*47f75930SValentin Clement         visited.insert(&use);
271*47f75930SValentin Clement         queue.push_back(&use);
272*47f75930SValentin Clement       }
273*47f75930SValentin Clement   };
274*47f75930SValentin Clement 
275*47f75930SValentin Clement   // Build the set of uses of `original`.
276*47f75930SValentin Clement   // let USES = { uses of original fir.load }
277*47f75930SValentin Clement   appendToQueue(load);
278*47f75930SValentin Clement 
279*47f75930SValentin Clement   // Process the worklist until done.
280*47f75930SValentin Clement   while (!queue.empty()) {
281*47f75930SValentin Clement     mlir::OpOperand *operand = queue.pop_back_val();
282*47f75930SValentin Clement     mlir::Operation *owner = operand->getOwner();
283*47f75930SValentin Clement 
284*47f75930SValentin Clement     auto structuredLoop = [&](auto ro) {
285*47f75930SValentin Clement       if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
286*47f75930SValentin Clement         int64_t arg = blockArg.getArgNumber();
287*47f75930SValentin Clement         mlir::Value output = ro.getResult(ro.finalValue() ? arg : arg - 1);
288*47f75930SValentin Clement         appendToQueue(output);
289*47f75930SValentin Clement         appendToQueue(blockArg);
290*47f75930SValentin Clement       }
291*47f75930SValentin Clement     };
292*47f75930SValentin Clement     // TODO: this need to be updated to use the control-flow interface.
293*47f75930SValentin Clement     auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
294*47f75930SValentin Clement       if (operands.empty())
295*47f75930SValentin Clement         return;
296*47f75930SValentin Clement 
297*47f75930SValentin Clement       // Check if this operand is within the range.
298*47f75930SValentin Clement       unsigned operandIndex = operand->getOperandNumber();
299*47f75930SValentin Clement       unsigned operandsStart = operands.getBeginOperandIndex();
300*47f75930SValentin Clement       if (operandIndex < operandsStart ||
301*47f75930SValentin Clement           operandIndex >= (operandsStart + operands.size()))
302*47f75930SValentin Clement         return;
303*47f75930SValentin Clement 
304*47f75930SValentin Clement       // Index the successor.
305*47f75930SValentin Clement       unsigned argIndex = operandIndex - operandsStart;
306*47f75930SValentin Clement       appendToQueue(dest->getArgument(argIndex));
307*47f75930SValentin Clement     };
308*47f75930SValentin Clement     // Thread uses into structured loop bodies and return value uses.
309*47f75930SValentin Clement     if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
310*47f75930SValentin Clement       structuredLoop(ro);
311*47f75930SValentin Clement     } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
312*47f75930SValentin Clement       structuredLoop(ro);
313*47f75930SValentin Clement     } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
314*47f75930SValentin Clement       // Thread any uses of fir.if that return the marked array value.
315*47f75930SValentin Clement       if (auto ifOp = rs->getParentOfType<fir::IfOp>())
316*47f75930SValentin Clement         appendToQueue(ifOp.getResult(operand->getOperandNumber()));
317*47f75930SValentin Clement     } else if (mlir::isa<ArrayFetchOp>(owner)) {
318*47f75930SValentin Clement       // Keep track of array value fetches.
319*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs()
320*47f75930SValentin Clement                  << "add fetch {" << *owner << "} to array value set\n");
321*47f75930SValentin Clement       accesses.push_back(owner);
322*47f75930SValentin Clement     } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
323*47f75930SValentin Clement       // Keep track of array value updates and thread the return value uses.
324*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs()
325*47f75930SValentin Clement                  << "add update {" << *owner << "} to array value set\n");
326*47f75930SValentin Clement       accesses.push_back(owner);
327*47f75930SValentin Clement       appendToQueue(update.getResult());
328*47f75930SValentin Clement     } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
329*47f75930SValentin Clement       // Keep track of array value modification and thread the return value
330*47f75930SValentin Clement       // uses.
331*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs()
332*47f75930SValentin Clement                  << "add modify {" << *owner << "} to array value set\n");
333*47f75930SValentin Clement       accesses.push_back(owner);
334*47f75930SValentin Clement       appendToQueue(update.getResult(1));
335*47f75930SValentin Clement     } else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) {
336*47f75930SValentin Clement       branchOp(br.getDest(), br.destOperands());
337*47f75930SValentin Clement     } else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) {
338*47f75930SValentin Clement       branchOp(br.getTrueDest(), br.getTrueOperands());
339*47f75930SValentin Clement       branchOp(br.getFalseDest(), br.getFalseOperands());
340*47f75930SValentin Clement     } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
341*47f75930SValentin Clement       // do nothing
342*47f75930SValentin Clement     } else {
343*47f75930SValentin Clement       llvm::report_fatal_error("array value reached unexpected op");
344*47f75930SValentin Clement     }
345*47f75930SValentin Clement   }
346*47f75930SValentin Clement   return loadMapSets.insert({load, accesses}).first->getSecond();
347*47f75930SValentin Clement }
348*47f75930SValentin Clement 
349*47f75930SValentin Clement /// Is there a conflict between the array value that was updated and to be
350*47f75930SValentin Clement /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
351*47f75930SValentin Clement /// the updated value?
352*47f75930SValentin Clement static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
353*47f75930SValentin Clement                            ArrayMergeStoreOp st) {
354*47f75930SValentin Clement   mlir::Value load;
355*47f75930SValentin Clement   mlir::Value addr = st.memref();
356*47f75930SValentin Clement   auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType());
357*47f75930SValentin Clement   for (auto *op : reach) {
358*47f75930SValentin Clement     auto ld = mlir::dyn_cast<ArrayLoadOp>(op);
359*47f75930SValentin Clement     if (!ld)
360*47f75930SValentin Clement       continue;
361*47f75930SValentin Clement     mlir::Type ldTy = ld.memref().getType();
362*47f75930SValentin Clement     if (auto boxTy = ldTy.dyn_cast<fir::BoxType>())
363*47f75930SValentin Clement       ldTy = boxTy.getEleTy();
364*47f75930SValentin Clement     if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy))
365*47f75930SValentin Clement       return true;
366*47f75930SValentin Clement     if (ld.memref() == addr) {
367*47f75930SValentin Clement       if (ld.getResult() != st.original())
368*47f75930SValentin Clement         return true;
369*47f75930SValentin Clement       if (load)
370*47f75930SValentin Clement         return true;
371*47f75930SValentin Clement       load = ld;
372*47f75930SValentin Clement     }
373*47f75930SValentin Clement   }
374*47f75930SValentin Clement   return false;
375*47f75930SValentin Clement }
376*47f75930SValentin Clement 
377*47f75930SValentin Clement /// Check if there is any potential conflict in the chained update operations
378*47f75930SValentin Clement /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the
379*47f75930SValentin Clement /// array. A potential conflict is detected if two operations work on the same
380*47f75930SValentin Clement /// indices.
381*47f75930SValentin Clement static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) {
382*47f75930SValentin Clement   if (accesses.size() < 2)
383*47f75930SValentin Clement     return false;
384*47f75930SValentin Clement   llvm::SmallVector<mlir::Value> indices;
385*47f75930SValentin Clement   LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size()
386*47f75930SValentin Clement                           << " accesses on the list\n");
387*47f75930SValentin Clement   for (auto *op : accesses) {
388*47f75930SValentin Clement     assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) &&
389*47f75930SValentin Clement            "unexpected operation in analysis");
390*47f75930SValentin Clement     llvm::SmallVector<mlir::Value> compareVector;
391*47f75930SValentin Clement     if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
392*47f75930SValentin Clement       if (indices.empty()) {
393*47f75930SValentin Clement         indices = u.indices();
394*47f75930SValentin Clement         continue;
395*47f75930SValentin Clement       }
396*47f75930SValentin Clement       compareVector = u.indices();
397*47f75930SValentin Clement     } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
398*47f75930SValentin Clement       if (indices.empty()) {
399*47f75930SValentin Clement         indices = f.indices();
400*47f75930SValentin Clement         continue;
401*47f75930SValentin Clement       }
402*47f75930SValentin Clement       compareVector = f.indices();
403*47f75930SValentin Clement     } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
404*47f75930SValentin Clement       if (indices.empty()) {
405*47f75930SValentin Clement         indices = f.indices();
406*47f75930SValentin Clement         continue;
407*47f75930SValentin Clement       }
408*47f75930SValentin Clement       compareVector = f.indices();
409*47f75930SValentin Clement     }
410*47f75930SValentin Clement     if (compareVector != indices)
411*47f75930SValentin Clement       return true;
412*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
413*47f75930SValentin Clement   }
414*47f75930SValentin Clement   return false;
415*47f75930SValentin Clement }
416*47f75930SValentin Clement 
417*47f75930SValentin Clement // Are either of types of conflicts present?
418*47f75930SValentin Clement inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
419*47f75930SValentin Clement                              llvm::ArrayRef<mlir::Operation *> accesses,
420*47f75930SValentin Clement                              ArrayMergeStoreOp st) {
421*47f75930SValentin Clement   return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
422*47f75930SValentin Clement }
423*47f75930SValentin Clement 
424*47f75930SValentin Clement /// Constructor of the array copy analysis.
425*47f75930SValentin Clement /// This performs the analysis and saves the intermediate results.
426*47f75930SValentin Clement void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
427*47f75930SValentin Clement   topLevelOp->walk([&](Operation *op) {
428*47f75930SValentin Clement     if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
429*47f75930SValentin Clement       llvm::SmallVector<Operation *> values;
430*47f75930SValentin Clement       ReachCollector::reachingValues(values, st.sequence());
431*47f75930SValentin Clement       const llvm::SmallVector<Operation *> &accesses =
432*47f75930SValentin Clement           arrayAccesses(mlir::cast<ArrayLoadOp>(st.original().getDefiningOp()));
433*47f75930SValentin Clement       if (conflictDetected(values, accesses, st)) {
434*47f75930SValentin Clement         LLVM_DEBUG(llvm::dbgs()
435*47f75930SValentin Clement                    << "CONFLICT: copies required for " << st << '\n'
436*47f75930SValentin Clement                    << "   adding conflicts on: " << op << " and "
437*47f75930SValentin Clement                    << st.original() << '\n');
438*47f75930SValentin Clement         conflicts.insert(op);
439*47f75930SValentin Clement         conflicts.insert(st.original().getDefiningOp());
440*47f75930SValentin Clement       }
441*47f75930SValentin Clement       auto *ld = st.original().getDefiningOp();
442*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs()
443*47f75930SValentin Clement                  << "map: adding {" << *ld << " -> " << st << "}\n");
444*47f75930SValentin Clement       useMap.insert({ld, op});
445*47f75930SValentin Clement     } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
446*47f75930SValentin Clement       const llvm::SmallVector<mlir::Operation *> &accesses =
447*47f75930SValentin Clement           arrayAccesses(load);
448*47f75930SValentin Clement       LLVM_DEBUG(llvm::dbgs() << "process load: " << load
449*47f75930SValentin Clement                               << ", accesses: " << accesses.size() << '\n');
450*47f75930SValentin Clement       for (auto *acc : accesses) {
451*47f75930SValentin Clement         LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n');
452*47f75930SValentin Clement         assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc)));
453*47f75930SValentin Clement         if (!useMap.insert({acc, op}).second) {
454*47f75930SValentin Clement           mlir::emitError(
455*47f75930SValentin Clement               load.getLoc(),
456*47f75930SValentin Clement               "The parallel semantics of multiple array_merge_stores per "
457*47f75930SValentin Clement               "array_load are not supported.");
458*47f75930SValentin Clement           return;
459*47f75930SValentin Clement         }
460*47f75930SValentin Clement         LLVM_DEBUG(llvm::dbgs()
461*47f75930SValentin Clement                    << "map: adding {" << *acc << "} -> {" << load << "}\n");
462*47f75930SValentin Clement       }
463*47f75930SValentin Clement     }
464*47f75930SValentin Clement   });
465*47f75930SValentin Clement }
466*47f75930SValentin Clement 
467*47f75930SValentin Clement namespace {
468*47f75930SValentin Clement class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
469*47f75930SValentin Clement public:
470*47f75930SValentin Clement   using OpRewritePattern::OpRewritePattern;
471*47f75930SValentin Clement 
472*47f75930SValentin Clement   mlir::LogicalResult
473*47f75930SValentin Clement   matchAndRewrite(ArrayLoadOp load,
474*47f75930SValentin Clement                   mlir::PatternRewriter &rewriter) const override {
475*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
476*47f75930SValentin Clement     rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
477*47f75930SValentin Clement     return mlir::success();
478*47f75930SValentin Clement   }
479*47f75930SValentin Clement };
480*47f75930SValentin Clement 
481*47f75930SValentin Clement class ArrayMergeStoreConversion
482*47f75930SValentin Clement     : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
483*47f75930SValentin Clement public:
484*47f75930SValentin Clement   using OpRewritePattern::OpRewritePattern;
485*47f75930SValentin Clement 
486*47f75930SValentin Clement   mlir::LogicalResult
487*47f75930SValentin Clement   matchAndRewrite(ArrayMergeStoreOp store,
488*47f75930SValentin Clement                   mlir::PatternRewriter &rewriter) const override {
489*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
490*47f75930SValentin Clement     rewriter.eraseOp(store);
491*47f75930SValentin Clement     return mlir::success();
492*47f75930SValentin Clement   }
493*47f75930SValentin Clement };
494*47f75930SValentin Clement } // namespace
495*47f75930SValentin Clement 
496*47f75930SValentin Clement static mlir::Type getEleTy(mlir::Type ty) {
497*47f75930SValentin Clement   if (auto t = dyn_cast_ptrEleTy(ty))
498*47f75930SValentin Clement     ty = t;
499*47f75930SValentin Clement   if (auto t = ty.dyn_cast<SequenceType>())
500*47f75930SValentin Clement     ty = t.getEleTy();
501*47f75930SValentin Clement   // FIXME: keep ptr/heap/ref information.
502*47f75930SValentin Clement   return ReferenceType::get(ty);
503*47f75930SValentin Clement }
504*47f75930SValentin Clement 
505*47f75930SValentin Clement // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
506*47f75930SValentin Clement // TODO: getExtents on op should return a ValueRange instead of a vector.
507*47f75930SValentin Clement static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result,
508*47f75930SValentin Clement                        mlir::Value shape) {
509*47f75930SValentin Clement   auto *shapeOp = shape.getDefiningOp();
510*47f75930SValentin Clement   if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) {
511*47f75930SValentin Clement     auto e = s.getExtents();
512*47f75930SValentin Clement     result.insert(result.end(), e.begin(), e.end());
513*47f75930SValentin Clement     return;
514*47f75930SValentin Clement   }
515*47f75930SValentin Clement   if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) {
516*47f75930SValentin Clement     auto e = s.getExtents();
517*47f75930SValentin Clement     result.insert(result.end(), e.begin(), e.end());
518*47f75930SValentin Clement     return;
519*47f75930SValentin Clement   }
520*47f75930SValentin Clement   llvm::report_fatal_error("not a fir.shape/fir.shape_shift op");
521*47f75930SValentin Clement }
522*47f75930SValentin Clement 
523*47f75930SValentin Clement // Place the extents of the array loaded by an ArrayLoadOp into the result
524*47f75930SValentin Clement // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If
525*47f75930SValentin Clement // the ArrayLoadOp is loading a fir.box, code will be generated to read the
526*47f75930SValentin Clement // extents from the fir.box, and a the retunred ShapeOp is built with the read
527*47f75930SValentin Clement // extents.
528*47f75930SValentin Clement // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp
529*47f75930SValentin Clement // argument of the ArrayLoadOp that is returned.
530*47f75930SValentin Clement static mlir::Value
531*47f75930SValentin Clement getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter,
532*47f75930SValentin Clement                            fir::ArrayLoadOp loadOp,
533*47f75930SValentin Clement                            llvm::SmallVectorImpl<mlir::Value> &result) {
534*47f75930SValentin Clement   assert(result.empty());
535*47f75930SValentin Clement   if (auto boxTy = loadOp.memref().getType().dyn_cast<fir::BoxType>()) {
536*47f75930SValentin Clement     auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy)
537*47f75930SValentin Clement                     .cast<fir::SequenceType>()
538*47f75930SValentin Clement                     .getDimension();
539*47f75930SValentin Clement     auto idxTy = rewriter.getIndexType();
540*47f75930SValentin Clement     for (decltype(rank) dim = 0; dim < rank; ++dim) {
541*47f75930SValentin Clement       auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim);
542*47f75930SValentin Clement       auto dimInfo = rewriter.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
543*47f75930SValentin Clement                                                      loadOp.memref(), dimVal);
544*47f75930SValentin Clement       result.emplace_back(dimInfo.getResult(1));
545*47f75930SValentin Clement     }
546*47f75930SValentin Clement     auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank);
547*47f75930SValentin Clement     return rewriter.create<fir::ShapeOp>(loc, shapeType, result);
548*47f75930SValentin Clement   }
549*47f75930SValentin Clement   getExtents(result, loadOp.shape());
550*47f75930SValentin Clement   return loadOp.shape();
551*47f75930SValentin Clement }
552*47f75930SValentin Clement 
553*47f75930SValentin Clement static mlir::Type toRefType(mlir::Type ty) {
554*47f75930SValentin Clement   if (fir::isa_ref_type(ty))
555*47f75930SValentin Clement     return ty;
556*47f75930SValentin Clement   return fir::ReferenceType::get(ty);
557*47f75930SValentin Clement }
558*47f75930SValentin Clement 
559*47f75930SValentin Clement static mlir::Value
560*47f75930SValentin Clement genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
561*47f75930SValentin Clement           mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
562*47f75930SValentin Clement           mlir::Value slice, mlir::ValueRange indices,
563*47f75930SValentin Clement           mlir::ValueRange typeparams, bool skipOrig = false) {
564*47f75930SValentin Clement   llvm::SmallVector<mlir::Value> originated;
565*47f75930SValentin Clement   if (skipOrig)
566*47f75930SValentin Clement     originated.assign(indices.begin(), indices.end());
567*47f75930SValentin Clement   else
568*47f75930SValentin Clement     originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
569*47f75930SValentin Clement                                                 shape, indices);
570*47f75930SValentin Clement   auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
571*47f75930SValentin Clement   assert(seqTy && seqTy.isa<fir::SequenceType>());
572*47f75930SValentin Clement   const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
573*47f75930SValentin Clement   mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
574*47f75930SValentin Clement       loc, eleTy, alloc, shape, slice,
575*47f75930SValentin Clement       llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
576*47f75930SValentin Clement       typeparams);
577*47f75930SValentin Clement   if (dimension < originated.size())
578*47f75930SValentin Clement     result = rewriter.create<fir::CoordinateOp>(
579*47f75930SValentin Clement         loc, resTy, result,
580*47f75930SValentin Clement         llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
581*47f75930SValentin Clement   return result;
582*47f75930SValentin Clement }
583*47f75930SValentin Clement 
584*47f75930SValentin Clement namespace {
585*47f75930SValentin Clement /// Conversion of fir.array_update and fir.array_modify Ops.
586*47f75930SValentin Clement /// If there is a conflict for the update, then we need to perform a
587*47f75930SValentin Clement /// copy-in/copy-out to preserve the original values of the array. If there is
588*47f75930SValentin Clement /// no conflict, then it is save to eschew making any copies.
589*47f75930SValentin Clement template <typename ArrayOp>
590*47f75930SValentin Clement class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
591*47f75930SValentin Clement public:
592*47f75930SValentin Clement   explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
593*47f75930SValentin Clement                                      const ArrayCopyAnalysis &a,
594*47f75930SValentin Clement                                      const OperationUseMapT &m)
595*47f75930SValentin Clement       : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
596*47f75930SValentin Clement 
597*47f75930SValentin Clement   void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
598*47f75930SValentin Clement                     mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
599*47f75930SValentin Clement                     mlir::Type arrTy) const {
600*47f75930SValentin Clement     auto insPt = rewriter.saveInsertionPoint();
601*47f75930SValentin Clement     llvm::SmallVector<mlir::Value> indices;
602*47f75930SValentin Clement     llvm::SmallVector<mlir::Value> extents;
603*47f75930SValentin Clement     getExtents(extents, shapeOp);
604*47f75930SValentin Clement     // Build loop nest from column to row.
605*47f75930SValentin Clement     for (auto sh : llvm::reverse(extents)) {
606*47f75930SValentin Clement       auto idxTy = rewriter.getIndexType();
607*47f75930SValentin Clement       auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh);
608*47f75930SValentin Clement       auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
609*47f75930SValentin Clement       auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
610*47f75930SValentin Clement       auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one);
611*47f75930SValentin Clement       auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one);
612*47f75930SValentin Clement       rewriter.setInsertionPointToStart(loop.getBody());
613*47f75930SValentin Clement       indices.push_back(loop.getInductionVar());
614*47f75930SValentin Clement     }
615*47f75930SValentin Clement     // Reverse the indices so they are in column-major order.
616*47f75930SValentin Clement     std::reverse(indices.begin(), indices.end());
617*47f75930SValentin Clement     auto ty = getEleTy(arrTy);
618*47f75930SValentin Clement     auto fromAddr = rewriter.create<fir::ArrayCoorOp>(
619*47f75930SValentin Clement         loc, ty, src, shapeOp, mlir::Value{},
620*47f75930SValentin Clement         fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp,
621*47f75930SValentin Clement                                        indices),
622*47f75930SValentin Clement         mlir::ValueRange{});
623*47f75930SValentin Clement     auto load = rewriter.create<fir::LoadOp>(loc, fromAddr);
624*47f75930SValentin Clement     auto toAddr = rewriter.create<fir::ArrayCoorOp>(
625*47f75930SValentin Clement         loc, ty, dst, shapeOp, mlir::Value{},
626*47f75930SValentin Clement         fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp,
627*47f75930SValentin Clement                                        indices),
628*47f75930SValentin Clement         mlir::ValueRange{});
629*47f75930SValentin Clement     rewriter.create<fir::StoreOp>(loc, load, toAddr);
630*47f75930SValentin Clement     rewriter.restoreInsertionPoint(insPt);
631*47f75930SValentin Clement   }
632*47f75930SValentin Clement 
633*47f75930SValentin Clement   /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
634*47f75930SValentin Clement   /// temp and the LHS if the analysis found potential overlaps between the RHS
635*47f75930SValentin Clement   /// and LHS arrays. The element copy generator must be provided through \p
636*47f75930SValentin Clement   /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
637*47f75930SValentin Clement   /// Returns the address of the LHS element inside the loop and the LHS
638*47f75930SValentin Clement   /// ArrayLoad result.
639*47f75930SValentin Clement   std::pair<mlir::Value, mlir::Value>
640*47f75930SValentin Clement   materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
641*47f75930SValentin Clement                         ArrayOp update,
642*47f75930SValentin Clement                         llvm::function_ref<void(mlir::Value)> assignElement,
643*47f75930SValentin Clement                         mlir::Type lhsEltRefType) const {
644*47f75930SValentin Clement     auto *op = update.getOperation();
645*47f75930SValentin Clement     mlir::Operation *loadOp = useMap.lookup(op);
646*47f75930SValentin Clement     auto load = mlir::cast<ArrayLoadOp>(loadOp);
647*47f75930SValentin Clement     LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
648*47f75930SValentin Clement     if (analysis.hasPotentialConflict(loadOp)) {
649*47f75930SValentin Clement       // If there is a conflict between the arrays, then we copy the lhs array
650*47f75930SValentin Clement       // to a temporary, update the temporary, and copy the temporary back to
651*47f75930SValentin Clement       // the lhs array. This yields Fortran's copy-in copy-out array semantics.
652*47f75930SValentin Clement       LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
653*47f75930SValentin Clement       rewriter.setInsertionPoint(loadOp);
654*47f75930SValentin Clement       // Copy in.
655*47f75930SValentin Clement       llvm::SmallVector<mlir::Value> extents;
656*47f75930SValentin Clement       mlir::Value shapeOp =
657*47f75930SValentin Clement           getOrReadExtentsAndShapeOp(loc, rewriter, load, extents);
658*47f75930SValentin Clement       auto allocmem = rewriter.create<AllocMemOp>(
659*47f75930SValentin Clement           loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()),
660*47f75930SValentin Clement           load.typeparams(), extents);
661*47f75930SValentin Clement       genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp,
662*47f75930SValentin Clement                    load.getType());
663*47f75930SValentin Clement       rewriter.setInsertionPoint(op);
664*47f75930SValentin Clement       mlir::Value coor = genCoorOp(
665*47f75930SValentin Clement           rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
666*47f75930SValentin Clement           shapeOp, load.slice(), update.indices(), load.typeparams(),
667*47f75930SValentin Clement           update->hasAttr(fir::factory::attrFortranArrayOffsets()));
668*47f75930SValentin Clement       assignElement(coor);
669*47f75930SValentin Clement       mlir::Operation *storeOp = useMap.lookup(loadOp);
670*47f75930SValentin Clement       auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
671*47f75930SValentin Clement       rewriter.setInsertionPoint(storeOp);
672*47f75930SValentin Clement       // Copy out.
673*47f75930SValentin Clement       genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp,
674*47f75930SValentin Clement                    load.getType());
675*47f75930SValentin Clement       rewriter.create<FreeMemOp>(loc, allocmem);
676*47f75930SValentin Clement       return {coor, load.getResult()};
677*47f75930SValentin Clement     }
678*47f75930SValentin Clement     // Otherwise, when there is no conflict (a possible loop-carried
679*47f75930SValentin Clement     // dependence), the lhs array can be updated in place.
680*47f75930SValentin Clement     LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
681*47f75930SValentin Clement     rewriter.setInsertionPoint(op);
682*47f75930SValentin Clement     auto coorTy = getEleTy(load.getType());
683*47f75930SValentin Clement     mlir::Value coor = genCoorOp(
684*47f75930SValentin Clement         rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(),
685*47f75930SValentin Clement         load.slice(), update.indices(), load.typeparams(),
686*47f75930SValentin Clement         update->hasAttr(fir::factory::attrFortranArrayOffsets()));
687*47f75930SValentin Clement     assignElement(coor);
688*47f75930SValentin Clement     return {coor, load.getResult()};
689*47f75930SValentin Clement   }
690*47f75930SValentin Clement 
691*47f75930SValentin Clement private:
692*47f75930SValentin Clement   const ArrayCopyAnalysis &analysis;
693*47f75930SValentin Clement   const OperationUseMapT &useMap;
694*47f75930SValentin Clement };
695*47f75930SValentin Clement 
696*47f75930SValentin Clement class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
697*47f75930SValentin Clement public:
698*47f75930SValentin Clement   explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
699*47f75930SValentin Clement                                  const ArrayCopyAnalysis &a,
700*47f75930SValentin Clement                                  const OperationUseMapT &m)
701*47f75930SValentin Clement       : ArrayUpdateConversionBase{ctx, a, m} {}
702*47f75930SValentin Clement 
703*47f75930SValentin Clement   mlir::LogicalResult
704*47f75930SValentin Clement   matchAndRewrite(ArrayUpdateOp update,
705*47f75930SValentin Clement                   mlir::PatternRewriter &rewriter) const override {
706*47f75930SValentin Clement     auto loc = update.getLoc();
707*47f75930SValentin Clement     auto assignElement = [&](mlir::Value coor) {
708*47f75930SValentin Clement       rewriter.create<fir::StoreOp>(loc, update.merge(), coor);
709*47f75930SValentin Clement     };
710*47f75930SValentin Clement     auto lhsEltRefType = toRefType(update.merge().getType());
711*47f75930SValentin Clement     auto [_, lhsLoadResult] = materializeAssignment(
712*47f75930SValentin Clement         loc, rewriter, update, assignElement, lhsEltRefType);
713*47f75930SValentin Clement     update.replaceAllUsesWith(lhsLoadResult);
714*47f75930SValentin Clement     rewriter.replaceOp(update, lhsLoadResult);
715*47f75930SValentin Clement     return mlir::success();
716*47f75930SValentin Clement   }
717*47f75930SValentin Clement };
718*47f75930SValentin Clement 
719*47f75930SValentin Clement class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
720*47f75930SValentin Clement public:
721*47f75930SValentin Clement   explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
722*47f75930SValentin Clement                                  const ArrayCopyAnalysis &a,
723*47f75930SValentin Clement                                  const OperationUseMapT &m)
724*47f75930SValentin Clement       : ArrayUpdateConversionBase{ctx, a, m} {}
725*47f75930SValentin Clement 
726*47f75930SValentin Clement   mlir::LogicalResult
727*47f75930SValentin Clement   matchAndRewrite(ArrayModifyOp modify,
728*47f75930SValentin Clement                   mlir::PatternRewriter &rewriter) const override {
729*47f75930SValentin Clement     auto loc = modify.getLoc();
730*47f75930SValentin Clement     auto assignElement = [](mlir::Value) {
731*47f75930SValentin Clement       // Assignment already materialized by lowering using lhs element address.
732*47f75930SValentin Clement     };
733*47f75930SValentin Clement     auto lhsEltRefType = modify.getResult(0).getType();
734*47f75930SValentin Clement     auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
735*47f75930SValentin Clement         loc, rewriter, modify, assignElement, lhsEltRefType);
736*47f75930SValentin Clement     modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
737*47f75930SValentin Clement     rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
738*47f75930SValentin Clement     return mlir::success();
739*47f75930SValentin Clement   }
740*47f75930SValentin Clement };
741*47f75930SValentin Clement 
742*47f75930SValentin Clement class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
743*47f75930SValentin Clement public:
744*47f75930SValentin Clement   explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
745*47f75930SValentin Clement                                 const OperationUseMapT &m)
746*47f75930SValentin Clement       : OpRewritePattern{ctx}, useMap{m} {}
747*47f75930SValentin Clement 
748*47f75930SValentin Clement   mlir::LogicalResult
749*47f75930SValentin Clement   matchAndRewrite(ArrayFetchOp fetch,
750*47f75930SValentin Clement                   mlir::PatternRewriter &rewriter) const override {
751*47f75930SValentin Clement     auto *op = fetch.getOperation();
752*47f75930SValentin Clement     rewriter.setInsertionPoint(op);
753*47f75930SValentin Clement     auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
754*47f75930SValentin Clement     auto loc = fetch.getLoc();
755*47f75930SValentin Clement     mlir::Value coor =
756*47f75930SValentin Clement         genCoorOp(rewriter, loc, getEleTy(load.getType()),
757*47f75930SValentin Clement                   toRefType(fetch.getType()), load.memref(), load.shape(),
758*47f75930SValentin Clement                   load.slice(), fetch.indices(), load.typeparams(),
759*47f75930SValentin Clement                   fetch->hasAttr(fir::factory::attrFortranArrayOffsets()));
760*47f75930SValentin Clement     rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
761*47f75930SValentin Clement     return mlir::success();
762*47f75930SValentin Clement   }
763*47f75930SValentin Clement 
764*47f75930SValentin Clement private:
765*47f75930SValentin Clement   const OperationUseMapT &useMap;
766*47f75930SValentin Clement };
767*47f75930SValentin Clement } // namespace
768*47f75930SValentin Clement 
769*47f75930SValentin Clement namespace {
770*47f75930SValentin Clement class ArrayValueCopyConverter
771*47f75930SValentin Clement     : public ArrayValueCopyBase<ArrayValueCopyConverter> {
772*47f75930SValentin Clement public:
773*47f75930SValentin Clement   void runOnFunction() override {
774*47f75930SValentin Clement     auto func = getFunction();
775*47f75930SValentin Clement     LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
776*47f75930SValentin Clement                             << func.getName() << "'\n");
777*47f75930SValentin Clement     auto *context = &getContext();
778*47f75930SValentin Clement 
779*47f75930SValentin Clement     // Perform the conflict analysis.
780*47f75930SValentin Clement     auto &analysis = getAnalysis<ArrayCopyAnalysis>();
781*47f75930SValentin Clement     const auto &useMap = analysis.getUseMap();
782*47f75930SValentin Clement 
783*47f75930SValentin Clement     // Phase 1 is performing a rewrite on the array accesses. Once all the
784*47f75930SValentin Clement     // array accesses are rewritten we can go on phase 2.
785*47f75930SValentin Clement     // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in
786*47f75930SValentin Clement     // /copy-out refers the Fortran copy-in/copy-out semantics on statements.
787*47f75930SValentin Clement     mlir::OwningRewritePatternList patterns1(context);
788*47f75930SValentin Clement     patterns1.insert<ArrayFetchConversion>(context, useMap);
789*47f75930SValentin Clement     patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
790*47f75930SValentin Clement     patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
791*47f75930SValentin Clement     mlir::ConversionTarget target(*context);
792*47f75930SValentin Clement     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
793*47f75930SValentin Clement                            mlir::arith::ArithmeticDialect,
794*47f75930SValentin Clement                            mlir::StandardOpsDialect>();
795*47f75930SValentin Clement     target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>();
796*47f75930SValentin Clement     // Rewrite the array fetch and array update ops.
797*47f75930SValentin Clement     if (mlir::failed(
798*47f75930SValentin Clement             mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
799*47f75930SValentin Clement       mlir::emitError(mlir::UnknownLoc::get(context),
800*47f75930SValentin Clement                       "failure in array-value-copy pass, phase 1");
801*47f75930SValentin Clement       signalPassFailure();
802*47f75930SValentin Clement     }
803*47f75930SValentin Clement 
804*47f75930SValentin Clement     mlir::OwningRewritePatternList patterns2(context);
805*47f75930SValentin Clement     patterns2.insert<ArrayLoadConversion>(context);
806*47f75930SValentin Clement     patterns2.insert<ArrayMergeStoreConversion>(context);
807*47f75930SValentin Clement     target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
808*47f75930SValentin Clement     if (mlir::failed(
809*47f75930SValentin Clement             mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
810*47f75930SValentin Clement       mlir::emitError(mlir::UnknownLoc::get(context),
811*47f75930SValentin Clement                       "failure in array-value-copy pass, phase 2");
812*47f75930SValentin Clement       signalPassFailure();
813*47f75930SValentin Clement     }
814*47f75930SValentin Clement   }
815*47f75930SValentin Clement };
816*47f75930SValentin Clement } // namespace
817*47f75930SValentin Clement 
818*47f75930SValentin Clement std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
819*47f75930SValentin Clement   return std::make_unique<ArrayValueCopyConverter>();
820*47f75930SValentin Clement }
821