1 //===-- ArrayValueCopy.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Array.h"
11 #include "flang/Optimizer/Builder/BoxValue.h"
12 #include "flang/Optimizer/Builder/FIRBuilder.h"
13 #include "flang/Optimizer/Builder/Factory.h"
14 #include "flang/Optimizer/Builder/Runtime/Derived.h"
15 #include "flang/Optimizer/Builder/Todo.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
18 #include "flang/Optimizer/Support/FIRContext.h"
19 #include "flang/Optimizer/Transforms/Passes.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24
25 #define DEBUG_TYPE "flang-array-value-copy"
26
27 using namespace fir;
28 using namespace mlir;
29
30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>;
31
32 namespace {
33
34 /// Array copy analysis.
35 /// Perform an interference analysis between array values.
36 ///
37 /// Lowering will generate a sequence of the following form.
38 /// ```mlir
39 /// %a_1 = fir.array_load %array_1(%shape) : ...
40 /// ...
41 /// %a_j = fir.array_load %array_j(%shape) : ...
42 /// ...
43 /// %a_n = fir.array_load %array_n(%shape) : ...
44 /// ...
45 /// %v_i = fir.array_fetch %a_i, ...
46 /// %a_j1 = fir.array_update %a_j, ...
47 /// ...
48 /// fir.array_merge_store %a_j, %a_jn to %array_j : ...
49 /// ```
50 ///
51 /// The analysis is to determine if there are any conflicts. A conflict is when
52 /// one the following cases occurs.
53 ///
54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was
55 /// loaded from the same array memory reference (array_j) but with a different
56 /// shape as the other array values a_i, where i != j. [Possible overlapping
57 /// arrays.]
58 ///
59 /// 2. There is either an array_fetch or array_update of a_j with a different
60 /// set of index values. [Possible loop-carried dependence.]
61 ///
62 /// If none of the array values overlap in storage and the accesses are not
63 /// loop-carried, then the arrays are conflict-free and no copies are required.
64 class ArrayCopyAnalysis {
65 public:
66 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ArrayCopyAnalysis)
67
68 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>;
69 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>;
70 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>;
71 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>;
72
ArrayCopyAnalysis(mlir::Operation * op)73 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); }
74
getOperation() const75 mlir::Operation *getOperation() const { return operation; }
76
77 /// Return true iff the `array_merge_store` has potential conflicts.
hasPotentialConflict(mlir::Operation * op) const78 bool hasPotentialConflict(mlir::Operation *op) const {
79 LLVM_DEBUG(llvm::dbgs()
80 << "looking for a conflict on " << *op
81 << " and the set has a total of " << conflicts.size() << '\n');
82 return conflicts.contains(op);
83 }
84
85 /// Return the use map.
86 /// The use map maps array access, amend, fetch and update operations back to
87 /// the array load that is the original source of the array value.
88 /// It maps an array_load to an array_merge_store, if and only if the loaded
89 /// array value has pending modifications to be merged.
getUseMap() const90 const OperationUseMapT &getUseMap() const { return useMap; }
91
92 /// Return the set of array_access ops directly associated with array_amend
93 /// ops.
inAmendAccessSet(mlir::Operation * op) const94 bool inAmendAccessSet(mlir::Operation *op) const {
95 return amendAccesses.count(op);
96 }
97
98 /// For ArrayLoad `load`, return the transitive set of all OpOperands.
getLoadUseSet(mlir::Operation * load) const99 UseSetT getLoadUseSet(mlir::Operation *load) const {
100 assert(loadMapSets.count(load) && "analysis missed an array load?");
101 return loadMapSets.lookup(load);
102 }
103
104 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions,
105 ArrayLoadOp load);
106
107 private:
108 void construct(mlir::Operation *topLevelOp);
109
110 mlir::Operation *operation; // operation that analysis ran upon
111 ConflictSetT conflicts; // set of conflicts (loads and merge stores)
112 OperationUseMapT useMap;
113 LoadMapSetsT loadMapSets;
114 // Set of array_access ops associated with array_amend ops.
115 AmendAccessSetT amendAccesses;
116 };
117 } // namespace
118
119 namespace {
120 /// Helper class to collect all array operations that produced an array value.
121 class ReachCollector {
122 public:
ReachCollector(llvm::SmallVectorImpl<mlir::Operation * > & reach,mlir::Region * loopRegion)123 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach,
124 mlir::Region *loopRegion)
125 : reach{reach}, loopRegion{loopRegion} {}
126
collectArrayMentionFrom(mlir::Operation * op,mlir::ValueRange range)127 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) {
128 if (range.empty()) {
129 collectArrayMentionFrom(op, mlir::Value{});
130 return;
131 }
132 for (mlir::Value v : range)
133 collectArrayMentionFrom(v);
134 }
135
136 // Collect all the array_access ops in `block`. This recursively looks into
137 // blocks in ops with regions.
138 // FIXME: This is temporarily relying on the array_amend appearing in a
139 // do_loop Region. This phase ordering assumption can be eliminated by using
140 // dominance information to find the array_access ops or by scanning the
141 // transitive closure of the amending array_access's users and the defs that
142 // reach them.
collectAccesses(llvm::SmallVector<ArrayAccessOp> & result,mlir::Block * block)143 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result,
144 mlir::Block *block) {
145 for (auto &op : *block) {
146 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) {
147 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n');
148 result.push_back(access);
149 continue;
150 }
151 for (auto ®ion : op.getRegions())
152 for (auto &bb : region.getBlocks())
153 collectAccesses(result, &bb);
154 }
155 }
156
collectArrayMentionFrom(mlir::Operation * op,mlir::Value val)157 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) {
158 // `val` is defined by an Op, process the defining Op.
159 // If `val` is defined by a region containing Op, we want to drill down
160 // and through that Op's region(s).
161 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n');
162 auto popFn = [&](auto rop) {
163 assert(val && "op must have a result value");
164 auto resNum = val.cast<mlir::OpResult>().getResultNumber();
165 llvm::SmallVector<mlir::Value> results;
166 rop.resultToSourceOps(results, resNum);
167 for (auto u : results)
168 collectArrayMentionFrom(u);
169 };
170 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) {
171 popFn(rop);
172 return;
173 }
174 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) {
175 popFn(rop);
176 return;
177 }
178 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) {
179 popFn(rop);
180 return;
181 }
182 if (auto box = mlir::dyn_cast<EmboxOp>(op)) {
183 for (auto *user : box.getMemref().getUsers())
184 if (user != op)
185 collectArrayMentionFrom(user, user->getResults());
186 return;
187 }
188 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
189 if (opIsInsideLoops(mergeStore))
190 collectArrayMentionFrom(mergeStore.getSequence());
191 return;
192 }
193
194 if (mlir::isa<AllocaOp, AllocMemOp>(op)) {
195 // Look for any stores inside the loops, and collect an array operation
196 // that produced the value being stored to it.
197 for (auto *user : op->getUsers())
198 if (auto store = mlir::dyn_cast<fir::StoreOp>(user))
199 if (opIsInsideLoops(store))
200 collectArrayMentionFrom(store.getValue());
201 return;
202 }
203
204 // Scan the uses of amend's memref
205 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) {
206 reach.push_back(op);
207 llvm::SmallVector<ArrayAccessOp> accesses;
208 collectAccesses(accesses, op->getBlock());
209 for (auto access : accesses)
210 collectArrayMentionFrom(access.getResult());
211 }
212
213 // Otherwise, Op does not contain a region so just chase its operands.
214 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp,
215 ArrayFetchOp>(op)) {
216 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
217 reach.push_back(op);
218 }
219
220 // Include all array_access ops using an array_load.
221 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op))
222 for (auto *user : arrLd.getResult().getUsers())
223 if (mlir::isa<ArrayAccessOp>(user)) {
224 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n");
225 reach.push_back(user);
226 }
227
228 // Array modify assignment is performed on the result. So the analysis must
229 // look at the what is done with the result.
230 if (mlir::isa<ArrayModifyOp>(op))
231 for (auto *user : op->getResult(0).getUsers())
232 followUsers(user);
233
234 if (mlir::isa<fir::CallOp>(op)) {
235 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
236 reach.push_back(op);
237 }
238
239 for (auto u : op->getOperands())
240 collectArrayMentionFrom(u);
241 }
242
collectArrayMentionFrom(mlir::BlockArgument ba)243 void collectArrayMentionFrom(mlir::BlockArgument ba) {
244 auto *parent = ba.getOwner()->getParentOp();
245 // If inside an Op holding a region, the block argument corresponds to an
246 // argument passed to the containing Op.
247 auto popFn = [&](auto rop) {
248 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber()));
249 };
250 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) {
251 popFn(rop);
252 return;
253 }
254 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) {
255 popFn(rop);
256 return;
257 }
258 // Otherwise, a block argument is provided via the pred blocks.
259 for (auto *pred : ba.getOwner()->getPredecessors()) {
260 auto u = pred->getTerminator()->getOperand(ba.getArgNumber());
261 collectArrayMentionFrom(u);
262 }
263 }
264
265 // Recursively trace operands to find all array operations relating to the
266 // values merged.
collectArrayMentionFrom(mlir::Value val)267 void collectArrayMentionFrom(mlir::Value val) {
268 if (!val || visited.contains(val))
269 return;
270 visited.insert(val);
271
272 // Process a block argument.
273 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) {
274 collectArrayMentionFrom(ba);
275 return;
276 }
277
278 // Process an Op.
279 if (auto *op = val.getDefiningOp()) {
280 collectArrayMentionFrom(op, val);
281 return;
282 }
283
284 emitFatalError(val.getLoc(), "unhandled value");
285 }
286
287 /// Return all ops that produce the array value that is stored into the
288 /// `array_merge_store`.
reachingValues(llvm::SmallVectorImpl<mlir::Operation * > & reach,mlir::Value seq)289 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach,
290 mlir::Value seq) {
291 reach.clear();
292 mlir::Region *loopRegion = nullptr;
293 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp()))
294 loopRegion = &doLoop->getRegion(0);
295 ReachCollector collector(reach, loopRegion);
296 collector.collectArrayMentionFrom(seq);
297 }
298
299 private:
300 /// Is \op inside the loop nest region ?
301 /// FIXME: replace this structural dependence with graph properties.
opIsInsideLoops(mlir::Operation * op) const302 bool opIsInsideLoops(mlir::Operation *op) const {
303 auto *region = op->getParentRegion();
304 while (region) {
305 if (region == loopRegion)
306 return true;
307 region = region->getParentRegion();
308 }
309 return false;
310 }
311
312 /// Recursively trace the use of an operation results, calling
313 /// collectArrayMentionFrom on the direct and indirect user operands.
followUsers(mlir::Operation * op)314 void followUsers(mlir::Operation *op) {
315 for (auto userOperand : op->getOperands())
316 collectArrayMentionFrom(userOperand);
317 // Go through potential converts/coordinate_op.
318 for (auto indirectUser : op->getUsers())
319 followUsers(indirectUser);
320 }
321
322 llvm::SmallVectorImpl<mlir::Operation *> &reach;
323 llvm::SmallPtrSet<mlir::Value, 16> visited;
324 /// Region of the loops nest that produced the array value.
325 mlir::Region *loopRegion;
326 };
327 } // namespace
328
329 /// Find all the array operations that access the array value that is loaded by
330 /// the array load operation, `load`.
arrayMentions(llvm::SmallVectorImpl<mlir::Operation * > & mentions,ArrayLoadOp load)331 void ArrayCopyAnalysis::arrayMentions(
332 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) {
333 mentions.clear();
334 auto lmIter = loadMapSets.find(load);
335 if (lmIter != loadMapSets.end()) {
336 for (auto *opnd : lmIter->second) {
337 auto *owner = opnd->getOwner();
338 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
339 ArrayModifyOp>(owner))
340 mentions.push_back(owner);
341 }
342 return;
343 }
344
345 UseSetT visited;
346 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig]
347
348 auto appendToQueue = [&](mlir::Value val) {
349 for (auto &use : val.getUses())
350 if (!visited.count(&use)) {
351 visited.insert(&use);
352 queue.push_back(&use);
353 }
354 };
355
356 // Build the set of uses of `original`.
357 // let USES = { uses of original fir.load }
358 appendToQueue(load);
359
360 // Process the worklist until done.
361 while (!queue.empty()) {
362 mlir::OpOperand *operand = queue.pop_back_val();
363 mlir::Operation *owner = operand->getOwner();
364 if (!owner)
365 continue;
366 auto structuredLoop = [&](auto ro) {
367 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) {
368 int64_t arg = blockArg.getArgNumber();
369 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1);
370 appendToQueue(output);
371 appendToQueue(blockArg);
372 }
373 };
374 // TODO: this need to be updated to use the control-flow interface.
375 auto branchOp = [&](mlir::Block *dest, OperandRange operands) {
376 if (operands.empty())
377 return;
378
379 // Check if this operand is within the range.
380 unsigned operandIndex = operand->getOperandNumber();
381 unsigned operandsStart = operands.getBeginOperandIndex();
382 if (operandIndex < operandsStart ||
383 operandIndex >= (operandsStart + operands.size()))
384 return;
385
386 // Index the successor.
387 unsigned argIndex = operandIndex - operandsStart;
388 appendToQueue(dest->getArgument(argIndex));
389 };
390 // Thread uses into structured loop bodies and return value uses.
391 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) {
392 structuredLoop(ro);
393 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) {
394 structuredLoop(ro);
395 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) {
396 // Thread any uses of fir.if that return the marked array value.
397 mlir::Operation *parent = rs->getParentRegion()->getParentOp();
398 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent))
399 appendToQueue(ifOp.getResult(operand->getOperandNumber()));
400 } else if (mlir::isa<ArrayFetchOp>(owner)) {
401 // Keep track of array value fetches.
402 LLVM_DEBUG(llvm::dbgs()
403 << "add fetch {" << *owner << "} to array value set\n");
404 mentions.push_back(owner);
405 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) {
406 // Keep track of array value updates and thread the return value uses.
407 LLVM_DEBUG(llvm::dbgs()
408 << "add update {" << *owner << "} to array value set\n");
409 mentions.push_back(owner);
410 appendToQueue(update.getResult());
411 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) {
412 // Keep track of array value modification and thread the return value
413 // uses.
414 LLVM_DEBUG(llvm::dbgs()
415 << "add modify {" << *owner << "} to array value set\n");
416 mentions.push_back(owner);
417 appendToQueue(update.getResult(1));
418 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) {
419 mentions.push_back(owner);
420 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) {
421 mentions.push_back(owner);
422 appendToQueue(amend.getResult());
423 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) {
424 branchOp(br.getDest(), br.getDestOperands());
425 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) {
426 branchOp(br.getTrueDest(), br.getTrueOperands());
427 branchOp(br.getFalseDest(), br.getFalseOperands());
428 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) {
429 // do nothing
430 } else {
431 llvm::report_fatal_error("array value reached unexpected op");
432 }
433 }
434 loadMapSets.insert({load, visited});
435 }
436
hasPointerType(mlir::Type type)437 static bool hasPointerType(mlir::Type type) {
438 if (auto boxTy = type.dyn_cast<BoxType>())
439 type = boxTy.getEleTy();
440 return type.isa<fir::PointerType>();
441 }
442
443 // This is a NF performance hack. It makes a simple test that the slices of the
444 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive.
mutuallyExclusiveSliceRange(ArrayLoadOp ld,ArrayMergeStoreOp st)445 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) {
446 // If the same array_load, then no further testing is warranted.
447 if (ld.getResult() == st.getOriginal())
448 return false;
449
450 auto getSliceOp = [](mlir::Value val) -> SliceOp {
451 if (!val)
452 return {};
453 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp());
454 if (!sliceOp)
455 return {};
456 return sliceOp;
457 };
458
459 auto ldSlice = getSliceOp(ld.getSlice());
460 auto stSlice = getSliceOp(st.getSlice());
461 if (!ldSlice || !stSlice)
462 return false;
463
464 // Resign on subobject slices.
465 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() ||
466 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty())
467 return false;
468
469 // Crudely test that the two slices do not overlap by looking for the
470 // following general condition. If the slices look like (i:j) and (j+1:k) then
471 // these ranges do not overlap. The addend must be a constant.
472 auto ldTriples = ldSlice.getTriples();
473 auto stTriples = stSlice.getTriples();
474 const auto size = ldTriples.size();
475 if (size != stTriples.size())
476 return false;
477
478 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
479 auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
480 auto *op = v.getDefiningOp();
481 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op))
482 op = conv.getValue().getDefiningOp();
483 return op;
484 };
485
486 auto isPositiveConstant = [](mlir::Value v) -> bool {
487 if (auto conOp =
488 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
489 if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
490 return iattr.getInt() > 0;
491 return false;
492 };
493
494 auto *op1 = removeConvert(v1);
495 auto *op2 = removeConvert(v2);
496 if (!op1 || !op2)
497 return false;
498 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
499 if ((addi.getLhs().getDefiningOp() == op1 &&
500 isPositiveConstant(addi.getRhs())) ||
501 (addi.getRhs().getDefiningOp() == op1 &&
502 isPositiveConstant(addi.getLhs())))
503 return true;
504 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
505 if (subi.getLhs().getDefiningOp() == op2 &&
506 isPositiveConstant(subi.getRhs()))
507 return true;
508 return false;
509 };
510
511 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) {
512 // If both are loop invariant, skip to the next triple.
513 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) &&
514 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) {
515 // Unless either is a vector index, then be conservative.
516 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) ||
517 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp()))
518 return false;
519 continue;
520 }
521 // If identical, skip to the next triple.
522 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] &&
523 ldTriples[i + 2] == stTriples[i + 2])
524 continue;
525 // If ubound and lbound are the same with a constant offset, skip to the
526 // next triple.
527 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) ||
528 displacedByConstant(stTriples[i + 1], ldTriples[i]))
529 continue;
530 return false;
531 }
532 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld
533 << " and " << st << ", which is not a conflict\n");
534 return true;
535 }
536
537 /// Is there a conflict between the array value that was updated and to be
538 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute
539 /// the updated value?
conflictOnLoad(llvm::ArrayRef<mlir::Operation * > reach,ArrayMergeStoreOp st)540 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach,
541 ArrayMergeStoreOp st) {
542 mlir::Value load;
543 mlir::Value addr = st.getMemref();
544 const bool storeHasPointerType = hasPointerType(addr.getType());
545 for (auto *op : reach)
546 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) {
547 mlir::Type ldTy = ld.getMemref().getType();
548 if (ld.getMemref() == addr) {
549 if (mutuallyExclusiveSliceRange(ld, st))
550 continue;
551 if (ld.getResult() != st.getOriginal())
552 return true;
553 if (load) {
554 // TODO: extend this to allow checking if the first `load` and this
555 // `ld` are mutually exclusive accesses but not identical.
556 return true;
557 }
558 load = ld;
559 } else if ((hasPointerType(ldTy) || storeHasPointerType)) {
560 // TODO: Use target attribute to restrict this case further.
561 // TODO: Check if types can also allow ruling out some cases. For now,
562 // the fact that equivalences is using pointer attribute to enforce
563 // aliasing is preventing any attempt to do so, and in general, it may
564 // be wrong to use this if any of the types is a complex or a derived
565 // for which it is possible to create a pointer to a part with a
566 // different type than the whole, although this deserve some more
567 // investigation because existing compiler behavior seem to diverge
568 // here.
569 return true;
570 }
571 }
572 return false;
573 }
574
575 /// Is there an access vector conflict on the array being merged into? If the
576 /// access vectors diverge, then assume that there are potentially overlapping
577 /// loop-carried references.
conflictOnMerge(llvm::ArrayRef<mlir::Operation * > mentions)578 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) {
579 if (mentions.size() < 2)
580 return false;
581 llvm::SmallVector<mlir::Value> indices;
582 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size()
583 << " mentions on the list\n");
584 bool valSeen = false;
585 bool refSeen = false;
586 for (auto *op : mentions) {
587 llvm::SmallVector<mlir::Value> compareVector;
588 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) {
589 valSeen = true;
590 if (indices.empty()) {
591 indices = u.getIndices();
592 continue;
593 }
594 compareVector = u.getIndices();
595 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) {
596 valSeen = true;
597 if (indices.empty()) {
598 indices = f.getIndices();
599 continue;
600 }
601 compareVector = f.getIndices();
602 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) {
603 valSeen = true;
604 if (indices.empty()) {
605 indices = f.getIndices();
606 continue;
607 }
608 compareVector = f.getIndices();
609 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) {
610 refSeen = true;
611 if (indices.empty()) {
612 indices = f.getIndices();
613 continue;
614 }
615 compareVector = f.getIndices();
616 } else if (mlir::isa<ArrayAmendOp>(op)) {
617 refSeen = true;
618 continue;
619 } else {
620 mlir::emitError(op->getLoc(), "unexpected operation in analysis");
621 }
622 if (compareVector.size() != indices.size() ||
623 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) {
624 return std::get<0>(pair) != std::get<1>(pair);
625 }))
626 return true;
627 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n");
628 }
629 return valSeen && refSeen;
630 }
631
632 /// With element-by-reference semantics, an amended array with more than once
633 /// access to the same loaded array are conservatively considered a conflict.
634 /// Note: the array copy can still be eliminated in subsequent optimizations.
conflictOnReference(llvm::ArrayRef<mlir::Operation * > mentions)635 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) {
636 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size()
637 << '\n');
638 if (mentions.size() < 3)
639 return false;
640 unsigned amendCount = 0;
641 unsigned accessCount = 0;
642 for (auto *op : mentions) {
643 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) {
644 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n");
645 return true;
646 }
647 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) {
648 LLVM_DEBUG(llvm::dbgs()
649 << "conflict: multiple accesses of array value\n");
650 return true;
651 }
652 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) {
653 LLVM_DEBUG(llvm::dbgs()
654 << "conflict: array value has both uses by-value and uses "
655 "by-reference. conservative assumption.\n");
656 return true;
657 }
658 }
659 return false;
660 }
661
662 static mlir::Operation *
amendingAccess(llvm::ArrayRef<mlir::Operation * > mentions)663 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
664 for (auto *op : mentions)
665 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op))
666 return amend.getMemref().getDefiningOp();
667 return {};
668 }
669
670 // Are any conflicts present? The conflicts detected here are described above.
conflictDetected(llvm::ArrayRef<mlir::Operation * > reach,llvm::ArrayRef<mlir::Operation * > mentions,ArrayMergeStoreOp st)671 static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
672 llvm::ArrayRef<mlir::Operation *> mentions,
673 ArrayMergeStoreOp st) {
674 return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
675 }
676
677 // Assume that any call to a function that uses host-associations will be
678 // modifying the output array.
679 static bool
conservativeCallConflict(llvm::ArrayRef<mlir::Operation * > reaches)680 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
681 return llvm::any_of(reaches, [](mlir::Operation *op) {
682 if (auto call = mlir::dyn_cast<fir::CallOp>(op))
683 if (auto callee =
684 call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
685 auto module = op->getParentOfType<mlir::ModuleOp>();
686 return hasHostAssociationArgument(
687 module.lookupSymbol<mlir::func::FuncOp>(callee));
688 }
689 return false;
690 });
691 }
692
693 /// Constructor of the array copy analysis.
694 /// This performs the analysis and saves the intermediate results.
construct(mlir::Operation * topLevelOp)695 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) {
696 topLevelOp->walk([&](Operation *op) {
697 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) {
698 llvm::SmallVector<mlir::Operation *> values;
699 ReachCollector::reachingValues(values, st.getSequence());
700 bool callConflict = conservativeCallConflict(values);
701 llvm::SmallVector<mlir::Operation *> mentions;
702 arrayMentions(mentions,
703 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp()));
704 bool conflict = conflictDetected(values, mentions, st);
705 bool refConflict = conflictOnReference(mentions);
706 if (callConflict || conflict || refConflict) {
707 LLVM_DEBUG(llvm::dbgs()
708 << "CONFLICT: copies required for " << st << '\n'
709 << " adding conflicts on: " << *op << " and "
710 << st.getOriginal() << '\n');
711 conflicts.insert(op);
712 conflicts.insert(st.getOriginal().getDefiningOp());
713 if (auto *access = amendingAccess(mentions))
714 amendAccesses.insert(access);
715 }
716 auto *ld = st.getOriginal().getDefiningOp();
717 LLVM_DEBUG(llvm::dbgs()
718 << "map: adding {" << *ld << " -> " << st << "}\n");
719 useMap.insert({ld, op});
720 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) {
721 llvm::SmallVector<mlir::Operation *> mentions;
722 arrayMentions(mentions, load);
723 LLVM_DEBUG(llvm::dbgs() << "process load: " << load
724 << ", mentions: " << mentions.size() << '\n');
725 for (auto *acc : mentions) {
726 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n');
727 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp,
728 ArrayModifyOp>(acc)) {
729 if (useMap.count(acc)) {
730 mlir::emitError(
731 load.getLoc(),
732 "The parallel semantics of multiple array_merge_stores per "
733 "array_load are not supported.");
734 continue;
735 }
736 LLVM_DEBUG(llvm::dbgs()
737 << "map: adding {" << *acc << "} -> {" << load << "}\n");
738 useMap.insert({acc, op});
739 }
740 }
741 }
742 });
743 }
744
745 //===----------------------------------------------------------------------===//
746 // Conversions for converting out of array value form.
747 //===----------------------------------------------------------------------===//
748
749 namespace {
750 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> {
751 public:
752 using OpRewritePattern::OpRewritePattern;
753
754 mlir::LogicalResult
matchAndRewrite(ArrayLoadOp load,mlir::PatternRewriter & rewriter) const755 matchAndRewrite(ArrayLoadOp load,
756 mlir::PatternRewriter &rewriter) const override {
757 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n");
758 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType());
759 return mlir::success();
760 }
761 };
762
763 class ArrayMergeStoreConversion
764 : public mlir::OpRewritePattern<ArrayMergeStoreOp> {
765 public:
766 using OpRewritePattern::OpRewritePattern;
767
768 mlir::LogicalResult
matchAndRewrite(ArrayMergeStoreOp store,mlir::PatternRewriter & rewriter) const769 matchAndRewrite(ArrayMergeStoreOp store,
770 mlir::PatternRewriter &rewriter) const override {
771 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n");
772 rewriter.eraseOp(store);
773 return mlir::success();
774 }
775 };
776 } // namespace
777
getEleTy(mlir::Type ty)778 static mlir::Type getEleTy(mlir::Type ty) {
779 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty));
780 // FIXME: keep ptr/heap/ref information.
781 return ReferenceType::get(eleTy);
782 }
783
784 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
getAdjustedExtents(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp arrLoad,llvm::SmallVectorImpl<mlir::Value> & result,mlir::Value shape)785 static bool getAdjustedExtents(mlir::Location loc,
786 mlir::PatternRewriter &rewriter,
787 ArrayLoadOp arrLoad,
788 llvm::SmallVectorImpl<mlir::Value> &result,
789 mlir::Value shape) {
790 bool copyUsingSlice = false;
791 auto *shapeOp = shape.getDefiningOp();
792 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) {
793 auto e = s.getExtents();
794 result.insert(result.end(), e.begin(), e.end());
795 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) {
796 auto e = s.getExtents();
797 result.insert(result.end(), e.begin(), e.end());
798 } else {
799 emitFatalError(loc, "not a fir.shape/fir.shape_shift op");
800 }
801 auto idxTy = rewriter.getIndexType();
802 if (factory::isAssumedSize(result)) {
803 // Use slice information to compute the extent of the column.
804 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
805 mlir::Value size = one;
806 if (mlir::Value sliceArg = arrLoad.getSlice()) {
807 if (auto sliceOp =
808 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) {
809 auto triples = sliceOp.getTriples();
810 const std::size_t tripleSize = triples.size();
811 auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
812 fir::KindMapping kindMap = getKindMapping(module);
813 FirOpBuilder builder(rewriter, kindMap);
814 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3],
815 triples[tripleSize - 2],
816 triples[tripleSize - 1], idxTy);
817 copyUsingSlice = true;
818 }
819 }
820 result[result.size() - 1] = size;
821 }
822 return copyUsingSlice;
823 }
824
825 /// Place the extents of the array load, \p arrLoad, into \p result and
826 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is
827 /// loading a `!fir.box`, code will be generated to read the extents from the
828 /// boxed value, and the retunred shape Op will be built with the extents read
829 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or
830 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true
831 /// if slicing of the output array is to be done in the copy-in/copy-out rather
832 /// than in the elemental computation step.
getOrReadExtentsAndShapeOp(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp arrLoad,llvm::SmallVectorImpl<mlir::Value> & result,bool & copyUsingSlice)833 static mlir::Value getOrReadExtentsAndShapeOp(
834 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad,
835 llvm::SmallVectorImpl<mlir::Value> &result, bool ©UsingSlice) {
836 assert(result.empty());
837 if (arrLoad->hasAttr(fir::getOptionalAttrName()))
838 fir::emitFatalError(
839 loc, "shapes from array load of OPTIONAL arrays must not be used");
840 if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) {
841 auto rank =
842 dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension();
843 auto idxTy = rewriter.getIndexType();
844 for (decltype(rank) dim = 0; dim < rank; ++dim) {
845 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim);
846 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy,
847 arrLoad.getMemref(), dimVal);
848 result.emplace_back(dimInfo.getResult(1));
849 }
850 if (!arrLoad.getShape()) {
851 auto shapeType = ShapeType::get(rewriter.getContext(), rank);
852 return rewriter.create<ShapeOp>(loc, shapeType, result);
853 }
854 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>();
855 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank);
856 llvm::SmallVector<mlir::Value> shapeShiftOperands;
857 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) {
858 shapeShiftOperands.push_back(lb);
859 shapeShiftOperands.push_back(extent);
860 }
861 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType,
862 shapeShiftOperands);
863 }
864 copyUsingSlice =
865 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape());
866 return arrLoad.getShape();
867 }
868
toRefType(mlir::Type ty)869 static mlir::Type toRefType(mlir::Type ty) {
870 if (fir::isa_ref_type(ty))
871 return ty;
872 return fir::ReferenceType::get(ty);
873 }
874
875 static llvm::SmallVector<mlir::Value>
getTypeParamsIfRawData(mlir::Location loc,FirOpBuilder & builder,ArrayLoadOp arrLoad,mlir::Type ty)876 getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
877 ArrayLoadOp arrLoad, mlir::Type ty) {
878 if (ty.isa<BoxType>())
879 return {};
880 return fir::factory::getTypeParams(loc, builder, arrLoad);
881 }
882
genCoorOp(mlir::PatternRewriter & rewriter,mlir::Location loc,mlir::Type eleTy,mlir::Type resTy,mlir::Value alloc,mlir::Value shape,mlir::Value slice,mlir::ValueRange indices,ArrayLoadOp load,bool skipOrig=false)883 static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
884 mlir::Location loc, mlir::Type eleTy,
885 mlir::Type resTy, mlir::Value alloc,
886 mlir::Value shape, mlir::Value slice,
887 mlir::ValueRange indices, ArrayLoadOp load,
888 bool skipOrig = false) {
889 llvm::SmallVector<mlir::Value> originated;
890 if (skipOrig)
891 originated.assign(indices.begin(), indices.end());
892 else
893 originated = factory::originateIndices(loc, rewriter, alloc.getType(),
894 shape, indices);
895 auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
896 assert(seqTy && seqTy.isa<SequenceType>());
897 const auto dimension = seqTy.cast<SequenceType>().getDimension();
898 auto module = load->getParentOfType<mlir::ModuleOp>();
899 fir::KindMapping kindMap = getKindMapping(module);
900 FirOpBuilder builder(rewriter, kindMap);
901 auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
902 mlir::Value result = rewriter.create<ArrayCoorOp>(
903 loc, eleTy, alloc, shape, slice,
904 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
905 typeparams);
906 if (dimension < originated.size())
907 result = rewriter.create<fir::CoordinateOp>(
908 loc, resTy, result,
909 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension));
910 return result;
911 }
912
getCharacterLen(mlir::Location loc,FirOpBuilder & builder,ArrayLoadOp load,CharacterType charTy)913 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder,
914 ArrayLoadOp load, CharacterType charTy) {
915 auto charLenTy = builder.getCharacterLengthType();
916 if (charTy.hasDynamicLen()) {
917 if (load.getMemref().getType().isa<BoxType>()) {
918 // The loaded array is an emboxed value. Get the CHARACTER length from
919 // the box value.
920 auto eleSzInBytes =
921 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref());
922 auto kindSize =
923 builder.getKindMap().getCharacterBitsize(charTy.getFKind());
924 auto kindByteSize =
925 builder.createIntegerConstant(loc, charLenTy, kindSize / 8);
926 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes,
927 kindByteSize);
928 }
929 // The loaded array is a (set of) unboxed values. If the CHARACTER's
930 // length is not a constant, it must be provided as a type parameter to
931 // the array_load.
932 auto typeparams = load.getTypeparams();
933 assert(typeparams.size() > 0 && "expected type parameters on array_load");
934 return typeparams.back();
935 }
936 // The typical case: the length of the CHARACTER is a compile-time
937 // constant that is encoded in the type information.
938 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen());
939 }
940 /// Generate a shallow array copy. This is used for both copy-in and copy-out.
941 template <bool CopyIn>
genArrayCopy(mlir::Location loc,mlir::PatternRewriter & rewriter,mlir::Value dst,mlir::Value src,mlir::Value shapeOp,mlir::Value sliceOp,ArrayLoadOp arrLoad)942 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
943 mlir::Value dst, mlir::Value src, mlir::Value shapeOp,
944 mlir::Value sliceOp, ArrayLoadOp arrLoad) {
945 auto insPt = rewriter.saveInsertionPoint();
946 llvm::SmallVector<mlir::Value> indices;
947 llvm::SmallVector<mlir::Value> extents;
948 bool copyUsingSlice =
949 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp);
950 auto idxTy = rewriter.getIndexType();
951 // Build loop nest from column to row.
952 for (auto sh : llvm::reverse(extents)) {
953 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh);
954 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
955 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
956 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one);
957 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one);
958 rewriter.setInsertionPointToStart(loop.getBody());
959 indices.push_back(loop.getInductionVar());
960 }
961 // Reverse the indices so they are in column-major order.
962 std::reverse(indices.begin(), indices.end());
963 auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
964 fir::KindMapping kindMap = getKindMapping(module);
965 FirOpBuilder builder(rewriter, kindMap);
966 auto fromAddr = rewriter.create<ArrayCoorOp>(
967 loc, getEleTy(src.getType()), src, shapeOp,
968 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
969 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
970 getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
971 auto toAddr = rewriter.create<ArrayCoorOp>(
972 loc, getEleTy(dst.getType()), dst, shapeOp,
973 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
974 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
975 getTypeParamsIfRawData(loc, builder, arrLoad, dst.getType()));
976 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
977 // Copy from (to) object to (from) temp copy of same object.
978 if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
979 auto len = getCharacterLen(loc, builder, arrLoad, charTy);
980 CharBoxValue toChar(toAddr, len);
981 CharBoxValue fromChar(fromAddr, len);
982 factory::genScalarAssignment(builder, loc, toChar, fromChar);
983 } else {
984 if (hasDynamicSize(eleTy))
985 TODO(loc, "copy element of dynamic size");
986 factory::genScalarAssignment(builder, loc, toAddr, fromAddr);
987 }
988 rewriter.restoreInsertionPoint(insPt);
989 }
990
991 /// The array load may be either a boxed or unboxed value. If the value is
992 /// boxed, we read the type parameters from the boxed value.
993 static llvm::SmallVector<mlir::Value>
genArrayLoadTypeParameters(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp load)994 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter,
995 ArrayLoadOp load) {
996 if (load.getTypeparams().empty()) {
997 auto eleTy =
998 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType()));
999 if (hasDynamicSize(eleTy)) {
1000 if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
1001 assert(load.getMemref().getType().isa<BoxType>());
1002 auto module = load->getParentOfType<mlir::ModuleOp>();
1003 fir::KindMapping kindMap = getKindMapping(module);
1004 FirOpBuilder builder(rewriter, kindMap);
1005 return {getCharacterLen(loc, builder, load, charTy)};
1006 }
1007 TODO(loc, "unhandled dynamic type parameters");
1008 }
1009 return {};
1010 }
1011 return load.getTypeparams();
1012 }
1013
1014 static llvm::SmallVector<mlir::Value>
findNonconstantExtents(mlir::Type memrefTy,llvm::ArrayRef<mlir::Value> extents)1015 findNonconstantExtents(mlir::Type memrefTy,
1016 llvm::ArrayRef<mlir::Value> extents) {
1017 llvm::SmallVector<mlir::Value> nce;
1018 auto arrTy = unwrapPassByRefType(memrefTy);
1019 auto seqTy = arrTy.cast<SequenceType>();
1020 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents))
1021 if (s == SequenceType::getUnknownExtent())
1022 nce.emplace_back(x);
1023 if (extents.size() > seqTy.getShape().size())
1024 for (auto x : extents.drop_front(seqTy.getShape().size()))
1025 nce.emplace_back(x);
1026 return nce;
1027 }
1028
1029 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any
1030 /// allocatable direct components of the array elements with an unallocated
1031 /// status. Returns the temporary address as well as a callback to generate the
1032 /// temporary clean-up once it has been used. The clean-up will take care of
1033 /// deallocating all the element allocatable components that may have been
1034 /// allocated while using the temporary.
1035 static std::pair<mlir::Value,
1036 std::function<void(mlir::PatternRewriter &rewriter)>>
allocateArrayTemp(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayLoadOp load,llvm::ArrayRef<mlir::Value> extents,mlir::Value shape)1037 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter,
1038 ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents,
1039 mlir::Value shape) {
1040 mlir::Type baseType = load.getMemref().getType();
1041 llvm::SmallVector<mlir::Value> nonconstantExtents =
1042 findNonconstantExtents(baseType, extents);
1043 llvm::SmallVector<mlir::Value> typeParams =
1044 genArrayLoadTypeParameters(loc, rewriter, load);
1045 mlir::Value allocmem = rewriter.create<AllocMemOp>(
1046 loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents);
1047 mlir::Type eleType =
1048 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType));
1049 if (fir::isRecordWithAllocatableMember(eleType)) {
1050 // The allocatable component descriptors need to be set to a clean
1051 // deallocated status before anything is done with them.
1052 mlir::Value box = rewriter.create<fir::EmboxOp>(
1053 loc, fir::BoxType::get(baseType), allocmem, shape,
1054 /*slice=*/mlir::Value{}, typeParams);
1055 auto module = load->getParentOfType<mlir::ModuleOp>();
1056 fir::KindMapping kindMap = getKindMapping(module);
1057 FirOpBuilder builder(rewriter, kindMap);
1058 runtime::genDerivedTypeInitialize(builder, loc, box);
1059 // Any allocatable component that may have been allocated must be
1060 // deallocated during the clean-up.
1061 auto cleanup = [=](mlir::PatternRewriter &r) {
1062 fir::KindMapping kindMap = getKindMapping(module);
1063 FirOpBuilder builder(r, kindMap);
1064 runtime::genDerivedTypeDestroy(builder, loc, box);
1065 r.create<FreeMemOp>(loc, allocmem);
1066 };
1067 return {allocmem, cleanup};
1068 }
1069 auto cleanup = [=](mlir::PatternRewriter &r) {
1070 r.create<FreeMemOp>(loc, allocmem);
1071 };
1072 return {allocmem, cleanup};
1073 }
1074
1075 namespace {
1076 /// Conversion of fir.array_update and fir.array_modify Ops.
1077 /// If there is a conflict for the update, then we need to perform a
1078 /// copy-in/copy-out to preserve the original values of the array. If there is
1079 /// no conflict, then it is save to eschew making any copies.
1080 template <typename ArrayOp>
1081 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> {
1082 public:
1083 // TODO: Implement copy/swap semantics?
ArrayUpdateConversionBase(mlir::MLIRContext * ctx,const ArrayCopyAnalysis & a,const OperationUseMapT & m)1084 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx,
1085 const ArrayCopyAnalysis &a,
1086 const OperationUseMapT &m)
1087 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {}
1088
1089 /// The array_access, \p access, is to be to a cloned copy due to a potential
1090 /// conflict. Uses copy-in/copy-out semantics and not copy/swap.
referenceToClone(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayOp access) const1091 mlir::Value referenceToClone(mlir::Location loc,
1092 mlir::PatternRewriter &rewriter,
1093 ArrayOp access) const {
1094 LLVM_DEBUG(llvm::dbgs()
1095 << "generating copy-in/copy-out loops for " << access << '\n');
1096 auto *op = access.getOperation();
1097 auto *loadOp = useMap.lookup(op);
1098 auto load = mlir::cast<ArrayLoadOp>(loadOp);
1099 auto eleTy = access.getType();
1100 rewriter.setInsertionPoint(loadOp);
1101 // Copy in.
1102 llvm::SmallVector<mlir::Value> extents;
1103 bool copyUsingSlice = false;
1104 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1105 copyUsingSlice);
1106 auto [allocmem, genTempCleanUp] =
1107 allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1108 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1109 load.getMemref(), shapeOp, load.getSlice(),
1110 load);
1111 // Generate the reference for the access.
1112 rewriter.setInsertionPoint(op);
1113 auto coor = genCoorOp(
1114 rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
1115 copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
1116 load, access->hasAttr(factory::attrFortranArrayOffsets()));
1117 // Copy out.
1118 auto *storeOp = useMap.lookup(loadOp);
1119 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1120 rewriter.setInsertionPoint(storeOp);
1121 // Copy out.
1122 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(),
1123 allocmem, shapeOp, store.getSlice(), load);
1124 genTempCleanUp(rewriter);
1125 return coor;
1126 }
1127
1128 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a
1129 /// temp and the LHS if the analysis found potential overlaps between the RHS
1130 /// and LHS arrays. The element copy generator must be provided in \p
1131 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp.
1132 /// Returns the address of the LHS element inside the loop and the LHS
1133 /// ArrayLoad result.
1134 std::pair<mlir::Value, mlir::Value>
materializeAssignment(mlir::Location loc,mlir::PatternRewriter & rewriter,ArrayOp update,const std::function<void (mlir::Value)> & assignElement,mlir::Type lhsEltRefType) const1135 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter,
1136 ArrayOp update,
1137 const std::function<void(mlir::Value)> &assignElement,
1138 mlir::Type lhsEltRefType) const {
1139 auto *op = update.getOperation();
1140 auto *loadOp = useMap.lookup(op);
1141 auto load = mlir::cast<ArrayLoadOp>(loadOp);
1142 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n");
1143 if (analysis.hasPotentialConflict(loadOp)) {
1144 // If there is a conflict between the arrays, then we copy the lhs array
1145 // to a temporary, update the temporary, and copy the temporary back to
1146 // the lhs array. This yields Fortran's copy-in copy-out array semantics.
1147 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n");
1148 rewriter.setInsertionPoint(loadOp);
1149 // Copy in.
1150 llvm::SmallVector<mlir::Value> extents;
1151 bool copyUsingSlice = false;
1152 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents,
1153 copyUsingSlice);
1154 auto [allocmem, genTempCleanUp] =
1155 allocateArrayTemp(loc, rewriter, load, extents, shapeOp);
1156
1157 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem,
1158 load.getMemref(), shapeOp, load.getSlice(),
1159 load);
1160 rewriter.setInsertionPoint(op);
1161 auto coor = genCoorOp(
1162 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
1163 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
1164 update.getIndices(), load,
1165 update->hasAttr(factory::attrFortranArrayOffsets()));
1166 assignElement(coor);
1167 auto *storeOp = useMap.lookup(loadOp);
1168 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
1169 rewriter.setInsertionPoint(storeOp);
1170 // Copy out.
1171 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter,
1172 store.getMemref(), allocmem, shapeOp,
1173 store.getSlice(), load);
1174 genTempCleanUp(rewriter);
1175 return {coor, load.getResult()};
1176 }
1177 // Otherwise, when there is no conflict (a possible loop-carried
1178 // dependence), the lhs array can be updated in place.
1179 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
1180 rewriter.setInsertionPoint(op);
1181 auto coorTy = getEleTy(load.getType());
1182 auto coor =
1183 genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
1184 load.getShape(), load.getSlice(), update.getIndices(), load,
1185 update->hasAttr(factory::attrFortranArrayOffsets()));
1186 assignElement(coor);
1187 return {coor, load.getResult()};
1188 }
1189
1190 protected:
1191 const ArrayCopyAnalysis &analysis;
1192 const OperationUseMapT &useMap;
1193 };
1194
1195 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> {
1196 public:
ArrayUpdateConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysis & a,const OperationUseMapT & m)1197 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx,
1198 const ArrayCopyAnalysis &a,
1199 const OperationUseMapT &m)
1200 : ArrayUpdateConversionBase{ctx, a, m} {}
1201
1202 mlir::LogicalResult
matchAndRewrite(ArrayUpdateOp update,mlir::PatternRewriter & rewriter) const1203 matchAndRewrite(ArrayUpdateOp update,
1204 mlir::PatternRewriter &rewriter) const override {
1205 auto loc = update.getLoc();
1206 auto assignElement = [&](mlir::Value coor) {
1207 auto input = update.getMerge();
1208 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) {
1209 emitFatalError(loc, "array_update on references not supported");
1210 } else {
1211 rewriter.create<fir::StoreOp>(loc, input, coor);
1212 }
1213 };
1214 auto lhsEltRefType = toRefType(update.getMerge().getType());
1215 auto [_, lhsLoadResult] = materializeAssignment(
1216 loc, rewriter, update, assignElement, lhsEltRefType);
1217 update.replaceAllUsesWith(lhsLoadResult);
1218 rewriter.replaceOp(update, lhsLoadResult);
1219 return mlir::success();
1220 }
1221 };
1222
1223 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> {
1224 public:
ArrayModifyConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysis & a,const OperationUseMapT & m)1225 explicit ArrayModifyConversion(mlir::MLIRContext *ctx,
1226 const ArrayCopyAnalysis &a,
1227 const OperationUseMapT &m)
1228 : ArrayUpdateConversionBase{ctx, a, m} {}
1229
1230 mlir::LogicalResult
matchAndRewrite(ArrayModifyOp modify,mlir::PatternRewriter & rewriter) const1231 matchAndRewrite(ArrayModifyOp modify,
1232 mlir::PatternRewriter &rewriter) const override {
1233 auto loc = modify.getLoc();
1234 auto assignElement = [](mlir::Value) {
1235 // Assignment already materialized by lowering using lhs element address.
1236 };
1237 auto lhsEltRefType = modify.getResult(0).getType();
1238 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment(
1239 loc, rewriter, modify, assignElement, lhsEltRefType);
1240 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1241 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult});
1242 return mlir::success();
1243 }
1244 };
1245
1246 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> {
1247 public:
ArrayFetchConversion(mlir::MLIRContext * ctx,const OperationUseMapT & m)1248 explicit ArrayFetchConversion(mlir::MLIRContext *ctx,
1249 const OperationUseMapT &m)
1250 : OpRewritePattern{ctx}, useMap{m} {}
1251
1252 mlir::LogicalResult
matchAndRewrite(ArrayFetchOp fetch,mlir::PatternRewriter & rewriter) const1253 matchAndRewrite(ArrayFetchOp fetch,
1254 mlir::PatternRewriter &rewriter) const override {
1255 auto *op = fetch.getOperation();
1256 rewriter.setInsertionPoint(op);
1257 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1258 auto loc = fetch.getLoc();
1259 auto coor = genCoorOp(
1260 rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
1261 load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
1262 load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
1263 if (isa_ref_type(fetch.getType()))
1264 rewriter.replaceOp(fetch, coor);
1265 else
1266 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor);
1267 return mlir::success();
1268 }
1269
1270 private:
1271 const OperationUseMapT &useMap;
1272 };
1273
1274 /// As array_access op is like an array_fetch op, except that it does not imply
1275 /// a load op. (It operates in the reference domain.)
1276 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> {
1277 public:
ArrayAccessConversion(mlir::MLIRContext * ctx,const ArrayCopyAnalysis & a,const OperationUseMapT & m)1278 explicit ArrayAccessConversion(mlir::MLIRContext *ctx,
1279 const ArrayCopyAnalysis &a,
1280 const OperationUseMapT &m)
1281 : ArrayUpdateConversionBase{ctx, a, m} {}
1282
1283 mlir::LogicalResult
matchAndRewrite(ArrayAccessOp access,mlir::PatternRewriter & rewriter) const1284 matchAndRewrite(ArrayAccessOp access,
1285 mlir::PatternRewriter &rewriter) const override {
1286 auto *op = access.getOperation();
1287 auto loc = access.getLoc();
1288 if (analysis.inAmendAccessSet(op)) {
1289 // This array_access is associated with an array_amend and there is a
1290 // conflict. Make a copy to store into.
1291 auto result = referenceToClone(loc, rewriter, access);
1292 access.replaceAllUsesWith(result);
1293 rewriter.replaceOp(access, result);
1294 return mlir::success();
1295 }
1296 rewriter.setInsertionPoint(op);
1297 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
1298 auto coor = genCoorOp(
1299 rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
1300 load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
1301 load, access->hasAttr(factory::attrFortranArrayOffsets()));
1302 rewriter.replaceOp(access, coor);
1303 return mlir::success();
1304 }
1305 };
1306
1307 /// An array_amend op is a marker to record which array access is being used to
1308 /// update an array value. After this pass runs, an array_amend has no
1309 /// semantics. We rewrite these to undefined values here to remove them while
1310 /// preserving SSA form.
1311 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> {
1312 public:
ArrayAmendConversion(mlir::MLIRContext * ctx)1313 explicit ArrayAmendConversion(mlir::MLIRContext *ctx)
1314 : OpRewritePattern{ctx} {}
1315
1316 mlir::LogicalResult
matchAndRewrite(ArrayAmendOp amend,mlir::PatternRewriter & rewriter) const1317 matchAndRewrite(ArrayAmendOp amend,
1318 mlir::PatternRewriter &rewriter) const override {
1319 auto *op = amend.getOperation();
1320 rewriter.setInsertionPoint(op);
1321 auto loc = amend.getLoc();
1322 auto undef = rewriter.create<UndefOp>(loc, amend.getType());
1323 rewriter.replaceOp(amend, undef.getResult());
1324 return mlir::success();
1325 }
1326 };
1327
1328 class ArrayValueCopyConverter
1329 : public ArrayValueCopyBase<ArrayValueCopyConverter> {
1330 public:
runOnOperation()1331 void runOnOperation() override {
1332 auto func = getOperation();
1333 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '"
1334 << func.getName() << "'\n");
1335 auto *context = &getContext();
1336
1337 // Perform the conflict analysis.
1338 const auto &analysis = getAnalysis<ArrayCopyAnalysis>();
1339 const auto &useMap = analysis.getUseMap();
1340
1341 mlir::RewritePatternSet patterns1(context);
1342 patterns1.insert<ArrayFetchConversion>(context, useMap);
1343 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap);
1344 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap);
1345 patterns1.insert<ArrayAccessConversion>(context, analysis, useMap);
1346 patterns1.insert<ArrayAmendConversion>(context);
1347 mlir::ConversionTarget target(*context);
1348 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1349 mlir::arith::ArithmeticDialect,
1350 mlir::func::FuncDialect>();
1351 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp,
1352 ArrayUpdateOp, ArrayModifyOp>();
1353 // Rewrite the array fetch and array update ops.
1354 if (mlir::failed(
1355 mlir::applyPartialConversion(func, target, std::move(patterns1)))) {
1356 mlir::emitError(mlir::UnknownLoc::get(context),
1357 "failure in array-value-copy pass, phase 1");
1358 signalPassFailure();
1359 }
1360
1361 mlir::RewritePatternSet patterns2(context);
1362 patterns2.insert<ArrayLoadConversion>(context);
1363 patterns2.insert<ArrayMergeStoreConversion>(context);
1364 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>();
1365 if (mlir::failed(
1366 mlir::applyPartialConversion(func, target, std::move(patterns2)))) {
1367 mlir::emitError(mlir::UnknownLoc::get(context),
1368 "failure in array-value-copy pass, phase 2");
1369 signalPassFailure();
1370 }
1371 }
1372 };
1373 } // namespace
1374
createArrayValueCopyPass()1375 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() {
1376 return std::make_unique<ArrayValueCopyConverter>();
1377 }
1378