1 //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Lower/Todo.h" 11 #include "flang/Optimizer/Builder/Array.h" 12 #include "flang/Optimizer/Builder/BoxValue.h" 13 #include "flang/Optimizer/Builder/FIRBuilder.h" 14 #include "flang/Optimizer/Builder/Factory.h" 15 #include "flang/Optimizer/Builder/Runtime/Derived.h" 16 #include "flang/Optimizer/Dialect/FIRDialect.h" 17 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 18 #include "flang/Optimizer/Support/FIRContext.h" 19 #include "flang/Optimizer/Transforms/Passes.h" 20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "flang-array-value-copy" 26 27 using namespace fir; 28 using namespace mlir; 29 30 using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 31 32 namespace { 33 34 /// Array copy analysis. 35 /// Perform an interference analysis between array values. 36 /// 37 /// Lowering will generate a sequence of the following form. 38 /// ```mlir 39 /// %a_1 = fir.array_load %array_1(%shape) : ... 40 /// ... 41 /// %a_j = fir.array_load %array_j(%shape) : ... 42 /// ... 43 /// %a_n = fir.array_load %array_n(%shape) : ... 44 /// ... 45 /// %v_i = fir.array_fetch %a_i, ... 46 /// %a_j1 = fir.array_update %a_j, ... 47 /// ... 48 /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 49 /// ``` 50 /// 51 /// The analysis is to determine if there are any conflicts. A conflict is when 52 /// one the following cases occurs. 53 /// 54 /// 1. There is an `array_update` to an array value, a_j, such that a_j was 55 /// loaded from the same array memory reference (array_j) but with a different 56 /// shape as the other array values a_i, where i != j. [Possible overlapping 57 /// arrays.] 58 /// 59 /// 2. There is either an array_fetch or array_update of a_j with a different 60 /// set of index values. [Possible loop-carried dependence.] 61 /// 62 /// If none of the array values overlap in storage and the accesses are not 63 /// loop-carried, then the arrays are conflict-free and no copies are required. 64 class ArrayCopyAnalysis { 65 public: 66 using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 67 using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 68 using LoadMapSetsT = llvm::DenseMap<mlir::Operation *, UseSetT>; 69 using AmendAccessSetT = llvm::SmallPtrSet<mlir::Operation *, 4>; 70 71 ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 72 73 mlir::Operation *getOperation() const { return operation; } 74 75 /// Return true iff the `array_merge_store` has potential conflicts. 76 bool hasPotentialConflict(mlir::Operation *op) const { 77 LLVM_DEBUG(llvm::dbgs() 78 << "looking for a conflict on " << *op 79 << " and the set has a total of " << conflicts.size() << '\n'); 80 return conflicts.contains(op); 81 } 82 83 /// Return the use map. 84 /// The use map maps array access, amend, fetch and update operations back to 85 /// the array load that is the original source of the array value. 86 /// It maps an array_load to an array_merge_store, if and only if the loaded 87 /// array value has pending modifications to be merged. 88 const OperationUseMapT &getUseMap() const { return useMap; } 89 90 /// Return the set of array_access ops directly associated with array_amend 91 /// ops. 92 bool inAmendAccessSet(mlir::Operation *op) const { 93 return amendAccesses.count(op); 94 } 95 96 /// For ArrayLoad `load`, return the transitive set of all OpOperands. 97 UseSetT getLoadUseSet(mlir::Operation *load) const { 98 assert(loadMapSets.count(load) && "analysis missed an array load?"); 99 return loadMapSets.lookup(load); 100 } 101 102 void arrayMentions(llvm::SmallVectorImpl<mlir::Operation *> &mentions, 103 ArrayLoadOp load); 104 105 private: 106 void construct(mlir::Operation *topLevelOp); 107 108 mlir::Operation *operation; // operation that analysis ran upon 109 ConflictSetT conflicts; // set of conflicts (loads and merge stores) 110 OperationUseMapT useMap; 111 LoadMapSetsT loadMapSets; 112 // Set of array_access ops associated with array_amend ops. 113 AmendAccessSetT amendAccesses; 114 }; 115 } // namespace 116 117 namespace { 118 /// Helper class to collect all array operations that produced an array value. 119 class ReachCollector { 120 public: 121 ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 122 mlir::Region *loopRegion) 123 : reach{reach}, loopRegion{loopRegion} {} 124 125 void collectArrayMentionFrom(mlir::Operation *op, mlir::ValueRange range) { 126 if (range.empty()) { 127 collectArrayMentionFrom(op, mlir::Value{}); 128 return; 129 } 130 for (mlir::Value v : range) 131 collectArrayMentionFrom(v); 132 } 133 134 // Collect all the array_access ops in `block`. This recursively looks into 135 // blocks in ops with regions. 136 // FIXME: This is temporarily relying on the array_amend appearing in a 137 // do_loop Region. This phase ordering assumption can be eliminated by using 138 // dominance information to find the array_access ops or by scanning the 139 // transitive closure of the amending array_access's users and the defs that 140 // reach them. 141 void collectAccesses(llvm::SmallVector<ArrayAccessOp> &result, 142 mlir::Block *block) { 143 for (auto &op : *block) { 144 if (auto access = mlir::dyn_cast<ArrayAccessOp>(op)) { 145 LLVM_DEBUG(llvm::dbgs() << "adding access: " << access << '\n'); 146 result.push_back(access); 147 continue; 148 } 149 for (auto ®ion : op.getRegions()) 150 for (auto &bb : region.getBlocks()) 151 collectAccesses(result, &bb); 152 } 153 } 154 155 void collectArrayMentionFrom(mlir::Operation *op, mlir::Value val) { 156 // `val` is defined by an Op, process the defining Op. 157 // If `val` is defined by a region containing Op, we want to drill down 158 // and through that Op's region(s). 159 LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 160 auto popFn = [&](auto rop) { 161 assert(val && "op must have a result value"); 162 auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 163 llvm::SmallVector<mlir::Value> results; 164 rop.resultToSourceOps(results, resNum); 165 for (auto u : results) 166 collectArrayMentionFrom(u); 167 }; 168 if (auto rop = mlir::dyn_cast<DoLoopOp>(op)) { 169 popFn(rop); 170 return; 171 } 172 if (auto rop = mlir::dyn_cast<IterWhileOp>(op)) { 173 popFn(rop); 174 return; 175 } 176 if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 177 popFn(rop); 178 return; 179 } 180 if (auto box = mlir::dyn_cast<EmboxOp>(op)) { 181 for (auto *user : box.getMemref().getUsers()) 182 if (user != op) 183 collectArrayMentionFrom(user, user->getResults()); 184 return; 185 } 186 if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 187 if (opIsInsideLoops(mergeStore)) 188 collectArrayMentionFrom(mergeStore.getSequence()); 189 return; 190 } 191 192 if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 193 // Look for any stores inside the loops, and collect an array operation 194 // that produced the value being stored to it. 195 for (auto *user : op->getUsers()) 196 if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 197 if (opIsInsideLoops(store)) 198 collectArrayMentionFrom(store.getValue()); 199 return; 200 } 201 202 // Scan the uses of amend's memref 203 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) { 204 reach.push_back(op); 205 llvm::SmallVector<ArrayAccessOp> accesses; 206 collectAccesses(accesses, op->getBlock()); 207 for (auto access : accesses) 208 collectArrayMentionFrom(access.getResult()); 209 } 210 211 // Otherwise, Op does not contain a region so just chase its operands. 212 if (mlir::isa<ArrayAccessOp, ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, 213 ArrayFetchOp>(op)) { 214 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 215 reach.push_back(op); 216 } 217 218 // Include all array_access ops using an array_load. 219 if (auto arrLd = mlir::dyn_cast<ArrayLoadOp>(op)) 220 for (auto *user : arrLd.getResult().getUsers()) 221 if (mlir::isa<ArrayAccessOp>(user)) { 222 LLVM_DEBUG(llvm::dbgs() << "add " << *user << " to reachable set\n"); 223 reach.push_back(user); 224 } 225 226 // Array modify assignment is performed on the result. So the analysis must 227 // look at the what is done with the result. 228 if (mlir::isa<ArrayModifyOp>(op)) 229 for (auto *user : op->getResult(0).getUsers()) 230 followUsers(user); 231 232 if (mlir::isa<fir::CallOp>(op)) { 233 LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 234 reach.push_back(op); 235 } 236 237 for (auto u : op->getOperands()) 238 collectArrayMentionFrom(u); 239 } 240 241 void collectArrayMentionFrom(mlir::BlockArgument ba) { 242 auto *parent = ba.getOwner()->getParentOp(); 243 // If inside an Op holding a region, the block argument corresponds to an 244 // argument passed to the containing Op. 245 auto popFn = [&](auto rop) { 246 collectArrayMentionFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 247 }; 248 if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 249 popFn(rop); 250 return; 251 } 252 if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 253 popFn(rop); 254 return; 255 } 256 // Otherwise, a block argument is provided via the pred blocks. 257 for (auto *pred : ba.getOwner()->getPredecessors()) { 258 auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 259 collectArrayMentionFrom(u); 260 } 261 } 262 263 // Recursively trace operands to find all array operations relating to the 264 // values merged. 265 void collectArrayMentionFrom(mlir::Value val) { 266 if (!val || visited.contains(val)) 267 return; 268 visited.insert(val); 269 270 // Process a block argument. 271 if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 272 collectArrayMentionFrom(ba); 273 return; 274 } 275 276 // Process an Op. 277 if (auto *op = val.getDefiningOp()) { 278 collectArrayMentionFrom(op, val); 279 return; 280 } 281 282 emitFatalError(val.getLoc(), "unhandled value"); 283 } 284 285 /// Return all ops that produce the array value that is stored into the 286 /// `array_merge_store`. 287 static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 288 mlir::Value seq) { 289 reach.clear(); 290 mlir::Region *loopRegion = nullptr; 291 if (auto doLoop = mlir::dyn_cast_or_null<DoLoopOp>(seq.getDefiningOp())) 292 loopRegion = &doLoop->getRegion(0); 293 ReachCollector collector(reach, loopRegion); 294 collector.collectArrayMentionFrom(seq); 295 } 296 297 private: 298 /// Is \op inside the loop nest region ? 299 /// FIXME: replace this structural dependence with graph properties. 300 bool opIsInsideLoops(mlir::Operation *op) const { 301 auto *region = op->getParentRegion(); 302 while (region) { 303 if (region == loopRegion) 304 return true; 305 region = region->getParentRegion(); 306 } 307 return false; 308 } 309 310 /// Recursively trace the use of an operation results, calling 311 /// collectArrayMentionFrom on the direct and indirect user operands. 312 void followUsers(mlir::Operation *op) { 313 for (auto userOperand : op->getOperands()) 314 collectArrayMentionFrom(userOperand); 315 // Go through potential converts/coordinate_op. 316 for (auto indirectUser : op->getUsers()) 317 followUsers(indirectUser); 318 } 319 320 llvm::SmallVectorImpl<mlir::Operation *> &reach; 321 llvm::SmallPtrSet<mlir::Value, 16> visited; 322 /// Region of the loops nest that produced the array value. 323 mlir::Region *loopRegion; 324 }; 325 } // namespace 326 327 /// Find all the array operations that access the array value that is loaded by 328 /// the array load operation, `load`. 329 void ArrayCopyAnalysis::arrayMentions( 330 llvm::SmallVectorImpl<mlir::Operation *> &mentions, ArrayLoadOp load) { 331 mentions.clear(); 332 auto lmIter = loadMapSets.find(load); 333 if (lmIter != loadMapSets.end()) { 334 for (auto *opnd : lmIter->second) { 335 auto *owner = opnd->getOwner(); 336 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 337 ArrayModifyOp>(owner)) 338 mentions.push_back(owner); 339 } 340 return; 341 } 342 343 UseSetT visited; 344 llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 345 346 auto appendToQueue = [&](mlir::Value val) { 347 for (auto &use : val.getUses()) 348 if (!visited.count(&use)) { 349 visited.insert(&use); 350 queue.push_back(&use); 351 } 352 }; 353 354 // Build the set of uses of `original`. 355 // let USES = { uses of original fir.load } 356 appendToQueue(load); 357 358 // Process the worklist until done. 359 while (!queue.empty()) { 360 mlir::OpOperand *operand = queue.pop_back_val(); 361 mlir::Operation *owner = operand->getOwner(); 362 if (!owner) 363 continue; 364 auto structuredLoop = [&](auto ro) { 365 if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 366 int64_t arg = blockArg.getArgNumber(); 367 mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); 368 appendToQueue(output); 369 appendToQueue(blockArg); 370 } 371 }; 372 // TODO: this need to be updated to use the control-flow interface. 373 auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 374 if (operands.empty()) 375 return; 376 377 // Check if this operand is within the range. 378 unsigned operandIndex = operand->getOperandNumber(); 379 unsigned operandsStart = operands.getBeginOperandIndex(); 380 if (operandIndex < operandsStart || 381 operandIndex >= (operandsStart + operands.size())) 382 return; 383 384 // Index the successor. 385 unsigned argIndex = operandIndex - operandsStart; 386 appendToQueue(dest->getArgument(argIndex)); 387 }; 388 // Thread uses into structured loop bodies and return value uses. 389 if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 390 structuredLoop(ro); 391 } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 392 structuredLoop(ro); 393 } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 394 // Thread any uses of fir.if that return the marked array value. 395 mlir::Operation *parent = rs->getParentRegion()->getParentOp(); 396 if (auto ifOp = mlir::dyn_cast<fir::IfOp>(parent)) 397 appendToQueue(ifOp.getResult(operand->getOperandNumber())); 398 } else if (mlir::isa<ArrayFetchOp>(owner)) { 399 // Keep track of array value fetches. 400 LLVM_DEBUG(llvm::dbgs() 401 << "add fetch {" << *owner << "} to array value set\n"); 402 mentions.push_back(owner); 403 } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 404 // Keep track of array value updates and thread the return value uses. 405 LLVM_DEBUG(llvm::dbgs() 406 << "add update {" << *owner << "} to array value set\n"); 407 mentions.push_back(owner); 408 appendToQueue(update.getResult()); 409 } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 410 // Keep track of array value modification and thread the return value 411 // uses. 412 LLVM_DEBUG(llvm::dbgs() 413 << "add modify {" << *owner << "} to array value set\n"); 414 mentions.push_back(owner); 415 appendToQueue(update.getResult(1)); 416 } else if (auto mention = mlir::dyn_cast<ArrayAccessOp>(owner)) { 417 mentions.push_back(owner); 418 } else if (auto amend = mlir::dyn_cast<ArrayAmendOp>(owner)) { 419 mentions.push_back(owner); 420 appendToQueue(amend.getResult()); 421 } else if (auto br = mlir::dyn_cast<mlir::cf::BranchOp>(owner)) { 422 branchOp(br.getDest(), br.getDestOperands()); 423 } else if (auto br = mlir::dyn_cast<mlir::cf::CondBranchOp>(owner)) { 424 branchOp(br.getTrueDest(), br.getTrueOperands()); 425 branchOp(br.getFalseDest(), br.getFalseOperands()); 426 } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 427 // do nothing 428 } else { 429 llvm::report_fatal_error("array value reached unexpected op"); 430 } 431 } 432 loadMapSets.insert({load, visited}); 433 } 434 435 static bool hasPointerType(mlir::Type type) { 436 if (auto boxTy = type.dyn_cast<BoxType>()) 437 type = boxTy.getEleTy(); 438 return type.isa<fir::PointerType>(); 439 } 440 441 // This is a NF performance hack. It makes a simple test that the slices of the 442 // load, \p ld, and the merge store, \p st, are trivially mutually exclusive. 443 static bool mutuallyExclusiveSliceRange(ArrayLoadOp ld, ArrayMergeStoreOp st) { 444 // If the same array_load, then no further testing is warranted. 445 if (ld.getResult() == st.getOriginal()) 446 return false; 447 448 auto getSliceOp = [](mlir::Value val) -> SliceOp { 449 if (!val) 450 return {}; 451 auto sliceOp = mlir::dyn_cast_or_null<SliceOp>(val.getDefiningOp()); 452 if (!sliceOp) 453 return {}; 454 return sliceOp; 455 }; 456 457 auto ldSlice = getSliceOp(ld.getSlice()); 458 auto stSlice = getSliceOp(st.getSlice()); 459 if (!ldSlice || !stSlice) 460 return false; 461 462 // Resign on subobject slices. 463 if (!ldSlice.getFields().empty() || !stSlice.getFields().empty() || 464 !ldSlice.getSubstr().empty() || !stSlice.getSubstr().empty()) 465 return false; 466 467 // Crudely test that the two slices do not overlap by looking for the 468 // following general condition. If the slices look like (i:j) and (j+1:k) then 469 // these ranges do not overlap. The addend must be a constant. 470 auto ldTriples = ldSlice.getTriples(); 471 auto stTriples = stSlice.getTriples(); 472 const auto size = ldTriples.size(); 473 if (size != stTriples.size()) 474 return false; 475 476 auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) { 477 auto removeConvert = [](mlir::Value v) -> mlir::Operation * { 478 auto *op = v.getDefiningOp(); 479 while (auto conv = mlir::dyn_cast_or_null<ConvertOp>(op)) 480 op = conv.getValue().getDefiningOp(); 481 return op; 482 }; 483 484 auto isPositiveConstant = [](mlir::Value v) -> bool { 485 if (auto conOp = 486 mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp())) 487 if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>()) 488 return iattr.getInt() > 0; 489 return false; 490 }; 491 492 auto *op1 = removeConvert(v1); 493 auto *op2 = removeConvert(v2); 494 if (!op1 || !op2) 495 return false; 496 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) 497 if ((addi.getLhs().getDefiningOp() == op1 && 498 isPositiveConstant(addi.getRhs())) || 499 (addi.getRhs().getDefiningOp() == op1 && 500 isPositiveConstant(addi.getLhs()))) 501 return true; 502 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) 503 if (subi.getLhs().getDefiningOp() == op2 && 504 isPositiveConstant(subi.getRhs())) 505 return true; 506 return false; 507 }; 508 509 for (std::remove_const_t<decltype(size)> i = 0; i < size; i += 3) { 510 // If both are loop invariant, skip to the next triple. 511 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i + 1].getDefiningOp()) && 512 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i + 1].getDefiningOp())) { 513 // Unless either is a vector index, then be conservative. 514 if (mlir::isa_and_nonnull<fir::UndefOp>(ldTriples[i].getDefiningOp()) || 515 mlir::isa_and_nonnull<fir::UndefOp>(stTriples[i].getDefiningOp())) 516 return false; 517 continue; 518 } 519 // If identical, skip to the next triple. 520 if (ldTriples[i] == stTriples[i] && ldTriples[i + 1] == stTriples[i + 1] && 521 ldTriples[i + 2] == stTriples[i + 2]) 522 continue; 523 // If ubound and lbound are the same with a constant offset, skip to the 524 // next triple. 525 if (displacedByConstant(ldTriples[i + 1], stTriples[i]) || 526 displacedByConstant(stTriples[i + 1], ldTriples[i])) 527 continue; 528 return false; 529 } 530 LLVM_DEBUG(llvm::dbgs() << "detected non-overlapping slice ranges on " << ld 531 << " and " << st << ", which is not a conflict\n"); 532 return true; 533 } 534 535 /// Is there a conflict between the array value that was updated and to be 536 /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 537 /// the updated value? 538 static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 539 ArrayMergeStoreOp st) { 540 mlir::Value load; 541 mlir::Value addr = st.getMemref(); 542 const bool storeHasPointerType = hasPointerType(addr.getType()); 543 for (auto *op : reach) 544 if (auto ld = mlir::dyn_cast<ArrayLoadOp>(op)) { 545 mlir::Type ldTy = ld.getMemref().getType(); 546 if (ld.getMemref() == addr) { 547 if (mutuallyExclusiveSliceRange(ld, st)) 548 continue; 549 if (ld.getResult() != st.getOriginal()) 550 return true; 551 if (load) { 552 // TODO: extend this to allow checking if the first `load` and this 553 // `ld` are mutually exclusive accesses but not identical. 554 return true; 555 } 556 load = ld; 557 } else if ((hasPointerType(ldTy) || storeHasPointerType)) { 558 // TODO: Use target attribute to restrict this case further. 559 // TODO: Check if types can also allow ruling out some cases. For now, 560 // the fact that equivalences is using pointer attribute to enforce 561 // aliasing is preventing any attempt to do so, and in general, it may 562 // be wrong to use this if any of the types is a complex or a derived 563 // for which it is possible to create a pointer to a part with a 564 // different type than the whole, although this deserve some more 565 // investigation because existing compiler behavior seem to diverge 566 // here. 567 return true; 568 } 569 } 570 return false; 571 } 572 573 /// Is there an access vector conflict on the array being merged into? If the 574 /// access vectors diverge, then assume that there are potentially overlapping 575 /// loop-carried references. 576 static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> mentions) { 577 if (mentions.size() < 2) 578 return false; 579 llvm::SmallVector<mlir::Value> indices; 580 LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << mentions.size() 581 << " mentions on the list\n"); 582 bool valSeen = false; 583 bool refSeen = false; 584 for (auto *op : mentions) { 585 llvm::SmallVector<mlir::Value> compareVector; 586 if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 587 valSeen = true; 588 if (indices.empty()) { 589 indices = u.getIndices(); 590 continue; 591 } 592 compareVector = u.getIndices(); 593 } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 594 valSeen = true; 595 if (indices.empty()) { 596 indices = f.getIndices(); 597 continue; 598 } 599 compareVector = f.getIndices(); 600 } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 601 valSeen = true; 602 if (indices.empty()) { 603 indices = f.getIndices(); 604 continue; 605 } 606 compareVector = f.getIndices(); 607 } else if (auto f = mlir::dyn_cast<ArrayAccessOp>(op)) { 608 refSeen = true; 609 if (indices.empty()) { 610 indices = f.getIndices(); 611 continue; 612 } 613 compareVector = f.getIndices(); 614 } else if (mlir::isa<ArrayAmendOp>(op)) { 615 refSeen = true; 616 continue; 617 } else { 618 mlir::emitError(op->getLoc(), "unexpected operation in analysis"); 619 } 620 if (compareVector.size() != indices.size() || 621 llvm::any_of(llvm::zip(compareVector, indices), [&](auto pair) { 622 return std::get<0>(pair) != std::get<1>(pair); 623 })) 624 return true; 625 LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 626 } 627 return valSeen && refSeen; 628 } 629 630 /// With element-by-reference semantics, an amended array with more than once 631 /// access to the same loaded array are conservatively considered a conflict. 632 /// Note: the array copy can still be eliminated in subsequent optimizations. 633 static bool conflictOnReference(llvm::ArrayRef<mlir::Operation *> mentions) { 634 LLVM_DEBUG(llvm::dbgs() << "checking reference semantics " << mentions.size() 635 << '\n'); 636 if (mentions.size() < 3) 637 return false; 638 unsigned amendCount = 0; 639 unsigned accessCount = 0; 640 for (auto *op : mentions) { 641 if (mlir::isa<ArrayAmendOp>(op) && ++amendCount > 1) { 642 LLVM_DEBUG(llvm::dbgs() << "conflict: multiple amends of array value\n"); 643 return true; 644 } 645 if (mlir::isa<ArrayAccessOp>(op) && ++accessCount > 1) { 646 LLVM_DEBUG(llvm::dbgs() 647 << "conflict: multiple accesses of array value\n"); 648 return true; 649 } 650 if (mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) { 651 LLVM_DEBUG(llvm::dbgs() 652 << "conflict: array value has both uses by-value and uses " 653 "by-reference. conservative assumption.\n"); 654 return true; 655 } 656 } 657 return false; 658 } 659 660 static mlir::Operation * 661 amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) { 662 for (auto *op : mentions) 663 if (auto amend = mlir::dyn_cast<ArrayAmendOp>(op)) 664 return amend.getMemref().getDefiningOp(); 665 return {}; 666 } 667 668 // Are either of types of conflicts present? 669 inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 670 llvm::ArrayRef<mlir::Operation *> accesses, 671 ArrayMergeStoreOp st) { 672 return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 673 } 674 675 // Assume that any call to a function that uses host-associations will be 676 // modifying the output array. 677 static bool 678 conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) { 679 return llvm::any_of(reaches, [](mlir::Operation *op) { 680 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) 681 if (auto callee = 682 call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) { 683 auto module = op->getParentOfType<mlir::ModuleOp>(); 684 return hasHostAssociationArgument( 685 module.lookupSymbol<mlir::FuncOp>(callee)); 686 } 687 return false; 688 }); 689 } 690 691 /// Constructor of the array copy analysis. 692 /// This performs the analysis and saves the intermediate results. 693 void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 694 topLevelOp->walk([&](Operation *op) { 695 if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 696 llvm::SmallVector<mlir::Operation *> values; 697 ReachCollector::reachingValues(values, st.getSequence()); 698 bool callConflict = conservativeCallConflict(values); 699 llvm::SmallVector<mlir::Operation *> mentions; 700 arrayMentions(mentions, 701 mlir::cast<ArrayLoadOp>(st.getOriginal().getDefiningOp())); 702 bool conflict = conflictDetected(values, mentions, st); 703 bool refConflict = conflictOnReference(mentions); 704 if (callConflict || conflict || refConflict) { 705 LLVM_DEBUG(llvm::dbgs() 706 << "CONFLICT: copies required for " << st << '\n' 707 << " adding conflicts on: " << op << " and " 708 << st.getOriginal() << '\n'); 709 conflicts.insert(op); 710 conflicts.insert(st.getOriginal().getDefiningOp()); 711 if (auto *access = amendingAccess(mentions)) 712 amendAccesses.insert(access); 713 } 714 auto *ld = st.getOriginal().getDefiningOp(); 715 LLVM_DEBUG(llvm::dbgs() 716 << "map: adding {" << *ld << " -> " << st << "}\n"); 717 useMap.insert({ld, op}); 718 } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 719 llvm::SmallVector<mlir::Operation *> mentions; 720 arrayMentions(mentions, load); 721 LLVM_DEBUG(llvm::dbgs() << "process load: " << load 722 << ", mentions: " << mentions.size() << '\n'); 723 for (auto *acc : mentions) { 724 LLVM_DEBUG(llvm::dbgs() << " mention: " << *acc << '\n'); 725 if (mlir::isa<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, ArrayUpdateOp, 726 ArrayModifyOp>(acc)) { 727 if (useMap.count(acc)) { 728 mlir::emitError( 729 load.getLoc(), 730 "The parallel semantics of multiple array_merge_stores per " 731 "array_load are not supported."); 732 continue; 733 } 734 LLVM_DEBUG(llvm::dbgs() 735 << "map: adding {" << *acc << "} -> {" << load << "}\n"); 736 useMap.insert({acc, op}); 737 } 738 } 739 } 740 }); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // Conversions for converting out of array value form. 745 //===----------------------------------------------------------------------===// 746 747 namespace { 748 class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 749 public: 750 using OpRewritePattern::OpRewritePattern; 751 752 mlir::LogicalResult 753 matchAndRewrite(ArrayLoadOp load, 754 mlir::PatternRewriter &rewriter) const override { 755 LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 756 rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 757 return mlir::success(); 758 } 759 }; 760 761 class ArrayMergeStoreConversion 762 : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 763 public: 764 using OpRewritePattern::OpRewritePattern; 765 766 mlir::LogicalResult 767 matchAndRewrite(ArrayMergeStoreOp store, 768 mlir::PatternRewriter &rewriter) const override { 769 LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 770 rewriter.eraseOp(store); 771 return mlir::success(); 772 } 773 }; 774 } // namespace 775 776 static mlir::Type getEleTy(mlir::Type ty) { 777 auto eleTy = unwrapSequenceType(unwrapPassByRefType(ty)); 778 // FIXME: keep ptr/heap/ref information. 779 return ReferenceType::get(eleTy); 780 } 781 782 // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 783 static bool getAdjustedExtents(mlir::Location loc, 784 mlir::PatternRewriter &rewriter, 785 ArrayLoadOp arrLoad, 786 llvm::SmallVectorImpl<mlir::Value> &result, 787 mlir::Value shape) { 788 bool copyUsingSlice = false; 789 auto *shapeOp = shape.getDefiningOp(); 790 if (auto s = mlir::dyn_cast_or_null<ShapeOp>(shapeOp)) { 791 auto e = s.getExtents(); 792 result.insert(result.end(), e.begin(), e.end()); 793 } else if (auto s = mlir::dyn_cast_or_null<ShapeShiftOp>(shapeOp)) { 794 auto e = s.getExtents(); 795 result.insert(result.end(), e.begin(), e.end()); 796 } else { 797 emitFatalError(loc, "not a fir.shape/fir.shape_shift op"); 798 } 799 auto idxTy = rewriter.getIndexType(); 800 if (factory::isAssumedSize(result)) { 801 // Use slice information to compute the extent of the column. 802 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 803 mlir::Value size = one; 804 if (mlir::Value sliceArg = arrLoad.getSlice()) { 805 if (auto sliceOp = 806 mlir::dyn_cast_or_null<SliceOp>(sliceArg.getDefiningOp())) { 807 auto triples = sliceOp.getTriples(); 808 const std::size_t tripleSize = triples.size(); 809 auto module = arrLoad->getParentOfType<mlir::ModuleOp>(); 810 FirOpBuilder builder(rewriter, getKindMapping(module)); 811 size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3], 812 triples[tripleSize - 2], 813 triples[tripleSize - 1], idxTy); 814 copyUsingSlice = true; 815 } 816 } 817 result[result.size() - 1] = size; 818 } 819 return copyUsingSlice; 820 } 821 822 /// Place the extents of the array load, \p arrLoad, into \p result and 823 /// return a ShapeOp or ShapeShiftOp with the same extents. If \p arrLoad is 824 /// loading a `!fir.box`, code will be generated to read the extents from the 825 /// boxed value, and the retunred shape Op will be built with the extents read 826 /// from the box. Otherwise, the extents will be extracted from the ShapeOp (or 827 /// ShapeShiftOp) argument of \p arrLoad. \p copyUsingSlice will be set to true 828 /// if slicing of the output array is to be done in the copy-in/copy-out rather 829 /// than in the elemental computation step. 830 static mlir::Value getOrReadExtentsAndShapeOp( 831 mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayLoadOp arrLoad, 832 llvm::SmallVectorImpl<mlir::Value> &result, bool ©UsingSlice) { 833 assert(result.empty()); 834 if (arrLoad->hasAttr(fir::getOptionalAttrName())) 835 fir::emitFatalError( 836 loc, "shapes from array load of OPTIONAL arrays must not be used"); 837 if (auto boxTy = arrLoad.getMemref().getType().dyn_cast<BoxType>()) { 838 auto rank = 839 dyn_cast_ptrOrBoxEleTy(boxTy).cast<SequenceType>().getDimension(); 840 auto idxTy = rewriter.getIndexType(); 841 for (decltype(rank) dim = 0; dim < rank; ++dim) { 842 auto dimVal = rewriter.create<mlir::arith::ConstantIndexOp>(loc, dim); 843 auto dimInfo = rewriter.create<BoxDimsOp>(loc, idxTy, idxTy, idxTy, 844 arrLoad.getMemref(), dimVal); 845 result.emplace_back(dimInfo.getResult(1)); 846 } 847 if (!arrLoad.getShape()) { 848 auto shapeType = ShapeType::get(rewriter.getContext(), rank); 849 return rewriter.create<ShapeOp>(loc, shapeType, result); 850 } 851 auto shiftOp = arrLoad.getShape().getDefiningOp<ShiftOp>(); 852 auto shapeShiftType = ShapeShiftType::get(rewriter.getContext(), rank); 853 llvm::SmallVector<mlir::Value> shapeShiftOperands; 854 for (auto [lb, extent] : llvm::zip(shiftOp.getOrigins(), result)) { 855 shapeShiftOperands.push_back(lb); 856 shapeShiftOperands.push_back(extent); 857 } 858 return rewriter.create<ShapeShiftOp>(loc, shapeShiftType, 859 shapeShiftOperands); 860 } 861 copyUsingSlice = 862 getAdjustedExtents(loc, rewriter, arrLoad, result, arrLoad.getShape()); 863 return arrLoad.getShape(); 864 } 865 866 static mlir::Type toRefType(mlir::Type ty) { 867 if (fir::isa_ref_type(ty)) 868 return ty; 869 return fir::ReferenceType::get(ty); 870 } 871 872 static mlir::Value 873 genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 874 mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 875 mlir::Value slice, mlir::ValueRange indices, 876 mlir::ValueRange typeparams, bool skipOrig = false) { 877 llvm::SmallVector<mlir::Value> originated; 878 if (skipOrig) 879 originated.assign(indices.begin(), indices.end()); 880 else 881 originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 882 shape, indices); 883 auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 884 assert(seqTy && seqTy.isa<fir::SequenceType>()); 885 const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 886 mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 887 loc, eleTy, alloc, shape, slice, 888 llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 889 typeparams); 890 if (dimension < originated.size()) 891 result = rewriter.create<fir::CoordinateOp>( 892 loc, resTy, result, 893 llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 894 return result; 895 } 896 897 static mlir::Value getCharacterLen(mlir::Location loc, FirOpBuilder &builder, 898 ArrayLoadOp load, CharacterType charTy) { 899 auto charLenTy = builder.getCharacterLengthType(); 900 if (charTy.hasDynamicLen()) { 901 if (load.getMemref().getType().isa<BoxType>()) { 902 // The loaded array is an emboxed value. Get the CHARACTER length from 903 // the box value. 904 auto eleSzInBytes = 905 builder.create<BoxEleSizeOp>(loc, charLenTy, load.getMemref()); 906 auto kindSize = 907 builder.getKindMap().getCharacterBitsize(charTy.getFKind()); 908 auto kindByteSize = 909 builder.createIntegerConstant(loc, charLenTy, kindSize / 8); 910 return builder.create<mlir::arith::DivSIOp>(loc, eleSzInBytes, 911 kindByteSize); 912 } 913 // The loaded array is a (set of) unboxed values. If the CHARACTER's 914 // length is not a constant, it must be provided as a type parameter to 915 // the array_load. 916 auto typeparams = load.getTypeparams(); 917 assert(typeparams.size() > 0 && "expected type parameters on array_load"); 918 return typeparams.back(); 919 } 920 // The typical case: the length of the CHARACTER is a compile-time 921 // constant that is encoded in the type information. 922 return builder.createIntegerConstant(loc, charLenTy, charTy.getLen()); 923 } 924 /// Generate a shallow array copy. This is used for both copy-in and copy-out. 925 template <bool CopyIn> 926 void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 927 mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 928 mlir::Value sliceOp, ArrayLoadOp arrLoad) { 929 auto insPt = rewriter.saveInsertionPoint(); 930 llvm::SmallVector<mlir::Value> indices; 931 llvm::SmallVector<mlir::Value> extents; 932 bool copyUsingSlice = 933 getAdjustedExtents(loc, rewriter, arrLoad, extents, shapeOp); 934 auto idxTy = rewriter.getIndexType(); 935 // Build loop nest from column to row. 936 for (auto sh : llvm::reverse(extents)) { 937 auto ubi = rewriter.create<ConvertOp>(loc, idxTy, sh); 938 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 939 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 940 auto ub = rewriter.create<mlir::arith::SubIOp>(loc, idxTy, ubi, one); 941 auto loop = rewriter.create<DoLoopOp>(loc, zero, ub, one); 942 rewriter.setInsertionPointToStart(loop.getBody()); 943 indices.push_back(loop.getInductionVar()); 944 } 945 // Reverse the indices so they are in column-major order. 946 std::reverse(indices.begin(), indices.end()); 947 auto typeparams = arrLoad.getTypeparams(); 948 auto fromAddr = rewriter.create<ArrayCoorOp>( 949 loc, getEleTy(src.getType()), src, shapeOp, 950 CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 951 factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices), 952 typeparams); 953 auto toAddr = rewriter.create<ArrayCoorOp>( 954 loc, getEleTy(dst.getType()), dst, shapeOp, 955 !CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, 956 factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices), 957 typeparams); 958 auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType())); 959 auto module = toAddr->getParentOfType<mlir::ModuleOp>(); 960 FirOpBuilder builder(rewriter, getKindMapping(module)); 961 // Copy from (to) object to (from) temp copy of same object. 962 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 963 auto len = getCharacterLen(loc, builder, arrLoad, charTy); 964 CharBoxValue toChar(toAddr, len); 965 CharBoxValue fromChar(fromAddr, len); 966 factory::genScalarAssignment(builder, loc, toChar, fromChar); 967 } else { 968 if (hasDynamicSize(eleTy)) 969 TODO(loc, "copy element of dynamic size"); 970 factory::genScalarAssignment(builder, loc, toAddr, fromAddr); 971 } 972 rewriter.restoreInsertionPoint(insPt); 973 } 974 975 /// The array load may be either a boxed or unboxed value. If the value is 976 /// boxed, we read the type parameters from the boxed value. 977 static llvm::SmallVector<mlir::Value> 978 genArrayLoadTypeParameters(mlir::Location loc, mlir::PatternRewriter &rewriter, 979 ArrayLoadOp load) { 980 if (load.getTypeparams().empty()) { 981 auto eleTy = 982 unwrapSequenceType(unwrapPassByRefType(load.getMemref().getType())); 983 if (hasDynamicSize(eleTy)) { 984 if (auto charTy = eleTy.dyn_cast<CharacterType>()) { 985 assert(load.getMemref().getType().isa<BoxType>()); 986 auto module = load->getParentOfType<mlir::ModuleOp>(); 987 FirOpBuilder builder(rewriter, getKindMapping(module)); 988 return {getCharacterLen(loc, builder, load, charTy)}; 989 } 990 TODO(loc, "unhandled dynamic type parameters"); 991 } 992 return {}; 993 } 994 return load.getTypeparams(); 995 } 996 997 static llvm::SmallVector<mlir::Value> 998 findNonconstantExtents(mlir::Type memrefTy, 999 llvm::ArrayRef<mlir::Value> extents) { 1000 llvm::SmallVector<mlir::Value> nce; 1001 auto arrTy = unwrapPassByRefType(memrefTy); 1002 auto seqTy = arrTy.cast<SequenceType>(); 1003 for (auto [s, x] : llvm::zip(seqTy.getShape(), extents)) 1004 if (s == SequenceType::getUnknownExtent()) 1005 nce.emplace_back(x); 1006 if (extents.size() > seqTy.getShape().size()) 1007 for (auto x : extents.drop_front(seqTy.getShape().size())) 1008 nce.emplace_back(x); 1009 return nce; 1010 } 1011 1012 /// Allocate temporary storage for an ArrayLoadOp \load and initialize any 1013 /// allocatable direct components of the array elements with an unallocated 1014 /// status. Returns the temporary address as well as a callback to generate the 1015 /// temporary clean-up once it has been used. The clean-up will take care of 1016 /// deallocating all the element allocatable components that may have been 1017 /// allocated while using the temporary. 1018 static std::pair<mlir::Value, 1019 std::function<void(mlir::PatternRewriter &rewriter)>> 1020 allocateArrayTemp(mlir::Location loc, mlir::PatternRewriter &rewriter, 1021 ArrayLoadOp load, llvm::ArrayRef<mlir::Value> extents, 1022 mlir::Value shape) { 1023 mlir::Type baseType = load.getMemref().getType(); 1024 llvm::SmallVector<mlir::Value> nonconstantExtents = 1025 findNonconstantExtents(baseType, extents); 1026 llvm::SmallVector<mlir::Value> typeParams = 1027 genArrayLoadTypeParameters(loc, rewriter, load); 1028 mlir::Value allocmem = rewriter.create<AllocMemOp>( 1029 loc, dyn_cast_ptrOrBoxEleTy(baseType), typeParams, nonconstantExtents); 1030 mlir::Type eleType = 1031 fir::unwrapSequenceType(fir::unwrapPassByRefType(baseType)); 1032 if (fir::isRecordWithAllocatableMember(eleType)) { 1033 // The allocatable component descriptors need to be set to a clean 1034 // deallocated status before anything is done with them. 1035 mlir::Value box = rewriter.create<fir::EmboxOp>( 1036 loc, fir::BoxType::get(baseType), allocmem, shape, 1037 /*slice=*/mlir::Value{}, typeParams); 1038 auto module = load->getParentOfType<mlir::ModuleOp>(); 1039 FirOpBuilder builder(rewriter, getKindMapping(module)); 1040 runtime::genDerivedTypeInitialize(builder, loc, box); 1041 // Any allocatable component that may have been allocated must be 1042 // deallocated during the clean-up. 1043 auto cleanup = [=](mlir::PatternRewriter &r) { 1044 FirOpBuilder builder(r, getKindMapping(module)); 1045 runtime::genDerivedTypeDestroy(builder, loc, box); 1046 r.create<FreeMemOp>(loc, allocmem); 1047 }; 1048 return {allocmem, cleanup}; 1049 } 1050 auto cleanup = [=](mlir::PatternRewriter &r) { 1051 r.create<FreeMemOp>(loc, allocmem); 1052 }; 1053 return {allocmem, cleanup}; 1054 } 1055 1056 namespace { 1057 /// Conversion of fir.array_update and fir.array_modify Ops. 1058 /// If there is a conflict for the update, then we need to perform a 1059 /// copy-in/copy-out to preserve the original values of the array. If there is 1060 /// no conflict, then it is save to eschew making any copies. 1061 template <typename ArrayOp> 1062 class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 1063 public: 1064 // TODO: Implement copy/swap semantics? 1065 explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 1066 const ArrayCopyAnalysis &a, 1067 const OperationUseMapT &m) 1068 : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 1069 1070 /// The array_access, \p access, is to be to a cloned copy due to a potential 1071 /// conflict. Uses copy-in/copy-out semantics and not copy/swap. 1072 mlir::Value referenceToClone(mlir::Location loc, 1073 mlir::PatternRewriter &rewriter, 1074 ArrayOp access) const { 1075 LLVM_DEBUG(llvm::dbgs() 1076 << "generating copy-in/copy-out loops for " << access << '\n'); 1077 auto *op = access.getOperation(); 1078 auto *loadOp = useMap.lookup(op); 1079 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1080 auto eleTy = access.getType(); 1081 rewriter.setInsertionPoint(loadOp); 1082 // Copy in. 1083 llvm::SmallVector<mlir::Value> extents; 1084 bool copyUsingSlice = false; 1085 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1086 copyUsingSlice); 1087 auto [allocmem, genTempCleanUp] = 1088 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1089 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1090 load.getMemref(), shapeOp, load.getSlice(), 1091 load); 1092 // Generate the reference for the access. 1093 rewriter.setInsertionPoint(op); 1094 auto coor = 1095 genCoorOp(rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, 1096 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1097 access.getIndices(), load.getTypeparams(), 1098 access->hasAttr(factory::attrFortranArrayOffsets())); 1099 // Copy out. 1100 auto *storeOp = useMap.lookup(loadOp); 1101 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1102 rewriter.setInsertionPoint(storeOp); 1103 // Copy out. 1104 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, store.getMemref(), 1105 allocmem, shapeOp, store.getSlice(), load); 1106 genTempCleanUp(rewriter); 1107 return coor; 1108 } 1109 1110 /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 1111 /// temp and the LHS if the analysis found potential overlaps between the RHS 1112 /// and LHS arrays. The element copy generator must be provided in \p 1113 /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 1114 /// Returns the address of the LHS element inside the loop and the LHS 1115 /// ArrayLoad result. 1116 std::pair<mlir::Value, mlir::Value> 1117 materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 1118 ArrayOp update, 1119 const std::function<void(mlir::Value)> &assignElement, 1120 mlir::Type lhsEltRefType) const { 1121 auto *op = update.getOperation(); 1122 auto *loadOp = useMap.lookup(op); 1123 auto load = mlir::cast<ArrayLoadOp>(loadOp); 1124 LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 1125 if (analysis.hasPotentialConflict(loadOp)) { 1126 // If there is a conflict between the arrays, then we copy the lhs array 1127 // to a temporary, update the temporary, and copy the temporary back to 1128 // the lhs array. This yields Fortran's copy-in copy-out array semantics. 1129 LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 1130 rewriter.setInsertionPoint(loadOp); 1131 // Copy in. 1132 llvm::SmallVector<mlir::Value> extents; 1133 bool copyUsingSlice = false; 1134 auto shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents, 1135 copyUsingSlice); 1136 auto [allocmem, genTempCleanUp] = 1137 allocateArrayTemp(loc, rewriter, load, extents, shapeOp); 1138 1139 genArrayCopy</*copyIn=*/true>(load.getLoc(), rewriter, allocmem, 1140 load.getMemref(), shapeOp, load.getSlice(), 1141 load); 1142 rewriter.setInsertionPoint(op); 1143 auto coor = genCoorOp( 1144 rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 1145 shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(), 1146 update.getIndices(), load.getTypeparams(), 1147 update->hasAttr(factory::attrFortranArrayOffsets())); 1148 assignElement(coor); 1149 auto *storeOp = useMap.lookup(loadOp); 1150 auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 1151 rewriter.setInsertionPoint(storeOp); 1152 // Copy out. 1153 genArrayCopy</*copyIn=*/false>(store.getLoc(), rewriter, 1154 store.getMemref(), allocmem, shapeOp, 1155 store.getSlice(), load); 1156 genTempCleanUp(rewriter); 1157 return {coor, load.getResult()}; 1158 } 1159 // Otherwise, when there is no conflict (a possible loop-carried 1160 // dependence), the lhs array can be updated in place. 1161 LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 1162 rewriter.setInsertionPoint(op); 1163 auto coorTy = getEleTy(load.getType()); 1164 auto coor = genCoorOp(rewriter, loc, coorTy, lhsEltRefType, 1165 load.getMemref(), load.getShape(), load.getSlice(), 1166 update.getIndices(), load.getTypeparams(), 1167 update->hasAttr(factory::attrFortranArrayOffsets())); 1168 assignElement(coor); 1169 return {coor, load.getResult()}; 1170 } 1171 1172 protected: 1173 const ArrayCopyAnalysis &analysis; 1174 const OperationUseMapT &useMap; 1175 }; 1176 1177 class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 1178 public: 1179 explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 1180 const ArrayCopyAnalysis &a, 1181 const OperationUseMapT &m) 1182 : ArrayUpdateConversionBase{ctx, a, m} {} 1183 1184 mlir::LogicalResult 1185 matchAndRewrite(ArrayUpdateOp update, 1186 mlir::PatternRewriter &rewriter) const override { 1187 auto loc = update.getLoc(); 1188 auto assignElement = [&](mlir::Value coor) { 1189 auto input = update.getMerge(); 1190 if (auto inEleTy = dyn_cast_ptrEleTy(input.getType())) { 1191 emitFatalError(loc, "array_update on references not supported"); 1192 } else { 1193 rewriter.create<fir::StoreOp>(loc, input, coor); 1194 } 1195 }; 1196 auto lhsEltRefType = toRefType(update.getMerge().getType()); 1197 auto [_, lhsLoadResult] = materializeAssignment( 1198 loc, rewriter, update, assignElement, lhsEltRefType); 1199 update.replaceAllUsesWith(lhsLoadResult); 1200 rewriter.replaceOp(update, lhsLoadResult); 1201 return mlir::success(); 1202 } 1203 }; 1204 1205 class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 1206 public: 1207 explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 1208 const ArrayCopyAnalysis &a, 1209 const OperationUseMapT &m) 1210 : ArrayUpdateConversionBase{ctx, a, m} {} 1211 1212 mlir::LogicalResult 1213 matchAndRewrite(ArrayModifyOp modify, 1214 mlir::PatternRewriter &rewriter) const override { 1215 auto loc = modify.getLoc(); 1216 auto assignElement = [](mlir::Value) { 1217 // Assignment already materialized by lowering using lhs element address. 1218 }; 1219 auto lhsEltRefType = modify.getResult(0).getType(); 1220 auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 1221 loc, rewriter, modify, assignElement, lhsEltRefType); 1222 modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1223 rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 1224 return mlir::success(); 1225 } 1226 }; 1227 1228 class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 1229 public: 1230 explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 1231 const OperationUseMapT &m) 1232 : OpRewritePattern{ctx}, useMap{m} {} 1233 1234 mlir::LogicalResult 1235 matchAndRewrite(ArrayFetchOp fetch, 1236 mlir::PatternRewriter &rewriter) const override { 1237 auto *op = fetch.getOperation(); 1238 rewriter.setInsertionPoint(op); 1239 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1240 auto loc = fetch.getLoc(); 1241 auto coor = 1242 genCoorOp(rewriter, loc, getEleTy(load.getType()), 1243 toRefType(fetch.getType()), load.getMemref(), load.getShape(), 1244 load.getSlice(), fetch.getIndices(), load.getTypeparams(), 1245 fetch->hasAttr(factory::attrFortranArrayOffsets())); 1246 if (isa_ref_type(fetch.getType())) 1247 rewriter.replaceOp(fetch, coor); 1248 else 1249 rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 1250 return mlir::success(); 1251 } 1252 1253 private: 1254 const OperationUseMapT &useMap; 1255 }; 1256 1257 /// As array_access op is like an array_fetch op, except that it does not imply 1258 /// a load op. (It operates in the reference domain.) 1259 class ArrayAccessConversion : public ArrayUpdateConversionBase<ArrayAccessOp> { 1260 public: 1261 explicit ArrayAccessConversion(mlir::MLIRContext *ctx, 1262 const ArrayCopyAnalysis &a, 1263 const OperationUseMapT &m) 1264 : ArrayUpdateConversionBase{ctx, a, m} {} 1265 1266 mlir::LogicalResult 1267 matchAndRewrite(ArrayAccessOp access, 1268 mlir::PatternRewriter &rewriter) const override { 1269 auto *op = access.getOperation(); 1270 auto loc = access.getLoc(); 1271 if (analysis.inAmendAccessSet(op)) { 1272 // This array_access is associated with an array_amend and there is a 1273 // conflict. Make a copy to store into. 1274 auto result = referenceToClone(loc, rewriter, access); 1275 access.replaceAllUsesWith(result); 1276 rewriter.replaceOp(access, result); 1277 return mlir::success(); 1278 } 1279 rewriter.setInsertionPoint(op); 1280 auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 1281 auto coor = genCoorOp(rewriter, loc, getEleTy(load.getType()), 1282 toRefType(access.getType()), load.getMemref(), 1283 load.getShape(), load.getSlice(), access.getIndices(), 1284 load.getTypeparams(), 1285 access->hasAttr(factory::attrFortranArrayOffsets())); 1286 rewriter.replaceOp(access, coor); 1287 return mlir::success(); 1288 } 1289 }; 1290 1291 /// An array_amend op is a marker to record which array access is being used to 1292 /// update an array value. After this pass runs, an array_amend has no 1293 /// semantics. We rewrite these to undefined values here to remove them while 1294 /// preserving SSA form. 1295 class ArrayAmendConversion : public mlir::OpRewritePattern<ArrayAmendOp> { 1296 public: 1297 explicit ArrayAmendConversion(mlir::MLIRContext *ctx) 1298 : OpRewritePattern{ctx} {} 1299 1300 mlir::LogicalResult 1301 matchAndRewrite(ArrayAmendOp amend, 1302 mlir::PatternRewriter &rewriter) const override { 1303 auto *op = amend.getOperation(); 1304 rewriter.setInsertionPoint(op); 1305 auto loc = amend.getLoc(); 1306 auto undef = rewriter.create<UndefOp>(loc, amend.getType()); 1307 rewriter.replaceOp(amend, undef.getResult()); 1308 return mlir::success(); 1309 } 1310 }; 1311 1312 class ArrayValueCopyConverter 1313 : public ArrayValueCopyBase<ArrayValueCopyConverter> { 1314 public: 1315 void runOnOperation() override { 1316 auto func = getOperation(); 1317 LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 1318 << func.getName() << "'\n"); 1319 auto *context = &getContext(); 1320 1321 // Perform the conflict analysis. 1322 const auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 1323 const auto &useMap = analysis.getUseMap(); 1324 1325 mlir::RewritePatternSet patterns1(context); 1326 patterns1.insert<ArrayFetchConversion>(context, useMap); 1327 patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 1328 patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 1329 patterns1.insert<ArrayAccessConversion>(context, analysis, useMap); 1330 patterns1.insert<ArrayAmendConversion>(context); 1331 mlir::ConversionTarget target(*context); 1332 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 1333 mlir::arith::ArithmeticDialect, 1334 mlir::func::FuncDialect>(); 1335 target.addIllegalOp<ArrayAccessOp, ArrayAmendOp, ArrayFetchOp, 1336 ArrayUpdateOp, ArrayModifyOp>(); 1337 // Rewrite the array fetch and array update ops. 1338 if (mlir::failed( 1339 mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 1340 mlir::emitError(mlir::UnknownLoc::get(context), 1341 "failure in array-value-copy pass, phase 1"); 1342 signalPassFailure(); 1343 } 1344 1345 mlir::RewritePatternSet patterns2(context); 1346 patterns2.insert<ArrayLoadConversion>(context); 1347 patterns2.insert<ArrayMergeStoreConversion>(context); 1348 target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 1349 if (mlir::failed( 1350 mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 1351 mlir::emitError(mlir::UnknownLoc::get(context), 1352 "failure in array-value-copy pass, phase 2"); 1353 signalPassFailure(); 1354 } 1355 } 1356 }; 1357 } // namespace 1358 1359 std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 1360 return std::make_unique<ArrayValueCopyConverter>(); 1361 } 1362