1*47f75930SValentin Clement //===-- ArrayValueCopy.cpp ------------------------------------------------===// 2*47f75930SValentin Clement // 3*47f75930SValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*47f75930SValentin Clement // See https://llvm.org/LICENSE.txt for license information. 5*47f75930SValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*47f75930SValentin Clement // 7*47f75930SValentin Clement //===----------------------------------------------------------------------===// 8*47f75930SValentin Clement 9*47f75930SValentin Clement #include "PassDetail.h" 10*47f75930SValentin Clement #include "flang/Optimizer/Builder/BoxValue.h" 11*47f75930SValentin Clement #include "flang/Optimizer/Builder/FIRBuilder.h" 12*47f75930SValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h" 13*47f75930SValentin Clement #include "flang/Optimizer/Support/FIRContext.h" 14*47f75930SValentin Clement #include "flang/Optimizer/Transforms/Factory.h" 15*47f75930SValentin Clement #include "flang/Optimizer/Transforms/Passes.h" 16*47f75930SValentin Clement #include "mlir/Dialect/SCF/SCF.h" 17*47f75930SValentin Clement #include "mlir/Transforms/DialectConversion.h" 18*47f75930SValentin Clement #include "llvm/Support/Debug.h" 19*47f75930SValentin Clement 20*47f75930SValentin Clement #define DEBUG_TYPE "flang-array-value-copy" 21*47f75930SValentin Clement 22*47f75930SValentin Clement using namespace fir; 23*47f75930SValentin Clement 24*47f75930SValentin Clement using OperationUseMapT = llvm::DenseMap<mlir::Operation *, mlir::Operation *>; 25*47f75930SValentin Clement 26*47f75930SValentin Clement namespace { 27*47f75930SValentin Clement 28*47f75930SValentin Clement /// Array copy analysis. 29*47f75930SValentin Clement /// Perform an interference analysis between array values. 30*47f75930SValentin Clement /// 31*47f75930SValentin Clement /// Lowering will generate a sequence of the following form. 32*47f75930SValentin Clement /// ```mlir 33*47f75930SValentin Clement /// %a_1 = fir.array_load %array_1(%shape) : ... 34*47f75930SValentin Clement /// ... 35*47f75930SValentin Clement /// %a_j = fir.array_load %array_j(%shape) : ... 36*47f75930SValentin Clement /// ... 37*47f75930SValentin Clement /// %a_n = fir.array_load %array_n(%shape) : ... 38*47f75930SValentin Clement /// ... 39*47f75930SValentin Clement /// %v_i = fir.array_fetch %a_i, ... 40*47f75930SValentin Clement /// %a_j1 = fir.array_update %a_j, ... 41*47f75930SValentin Clement /// ... 42*47f75930SValentin Clement /// fir.array_merge_store %a_j, %a_jn to %array_j : ... 43*47f75930SValentin Clement /// ``` 44*47f75930SValentin Clement /// 45*47f75930SValentin Clement /// The analysis is to determine if there are any conflicts. A conflict is when 46*47f75930SValentin Clement /// one the following cases occurs. 47*47f75930SValentin Clement /// 48*47f75930SValentin Clement /// 1. There is an `array_update` to an array value, a_j, such that a_j was 49*47f75930SValentin Clement /// loaded from the same array memory reference (array_j) but with a different 50*47f75930SValentin Clement /// shape as the other array values a_i, where i != j. [Possible overlapping 51*47f75930SValentin Clement /// arrays.] 52*47f75930SValentin Clement /// 53*47f75930SValentin Clement /// 2. There is either an array_fetch or array_update of a_j with a different 54*47f75930SValentin Clement /// set of index values. [Possible loop-carried dependence.] 55*47f75930SValentin Clement /// 56*47f75930SValentin Clement /// If none of the array values overlap in storage and the accesses are not 57*47f75930SValentin Clement /// loop-carried, then the arrays are conflict-free and no copies are required. 58*47f75930SValentin Clement class ArrayCopyAnalysis { 59*47f75930SValentin Clement public: 60*47f75930SValentin Clement using ConflictSetT = llvm::SmallPtrSet<mlir::Operation *, 16>; 61*47f75930SValentin Clement using UseSetT = llvm::SmallPtrSet<mlir::OpOperand *, 8>; 62*47f75930SValentin Clement using LoadMapSetsT = 63*47f75930SValentin Clement llvm::DenseMap<mlir::Operation *, SmallVector<Operation *>>; 64*47f75930SValentin Clement 65*47f75930SValentin Clement ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } 66*47f75930SValentin Clement 67*47f75930SValentin Clement mlir::Operation *getOperation() const { return operation; } 68*47f75930SValentin Clement 69*47f75930SValentin Clement /// Return true iff the `array_merge_store` has potential conflicts. 70*47f75930SValentin Clement bool hasPotentialConflict(mlir::Operation *op) const { 71*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 72*47f75930SValentin Clement << "looking for a conflict on " << *op 73*47f75930SValentin Clement << " and the set has a total of " << conflicts.size() << '\n'); 74*47f75930SValentin Clement return conflicts.contains(op); 75*47f75930SValentin Clement } 76*47f75930SValentin Clement 77*47f75930SValentin Clement /// Return the use map. The use map maps array fetch and update operations 78*47f75930SValentin Clement /// back to the array load that is the original source of the array value. 79*47f75930SValentin Clement const OperationUseMapT &getUseMap() const { return useMap; } 80*47f75930SValentin Clement 81*47f75930SValentin Clement /// Find all the array operations that access the array value that is loaded 82*47f75930SValentin Clement /// by the array load operation, `load`. 83*47f75930SValentin Clement const llvm::SmallVector<mlir::Operation *> &arrayAccesses(ArrayLoadOp load); 84*47f75930SValentin Clement 85*47f75930SValentin Clement private: 86*47f75930SValentin Clement void construct(mlir::Operation *topLevelOp); 87*47f75930SValentin Clement 88*47f75930SValentin Clement mlir::Operation *operation; // operation that analysis ran upon 89*47f75930SValentin Clement ConflictSetT conflicts; // set of conflicts (loads and merge stores) 90*47f75930SValentin Clement OperationUseMapT useMap; 91*47f75930SValentin Clement LoadMapSetsT loadMapSets; 92*47f75930SValentin Clement }; 93*47f75930SValentin Clement } // namespace 94*47f75930SValentin Clement 95*47f75930SValentin Clement namespace { 96*47f75930SValentin Clement /// Helper class to collect all array operations that produced an array value. 97*47f75930SValentin Clement class ReachCollector { 98*47f75930SValentin Clement private: 99*47f75930SValentin Clement // If provided, the `loopRegion` is the body of a loop that produces the array 100*47f75930SValentin Clement // of interest. 101*47f75930SValentin Clement ReachCollector(llvm::SmallVectorImpl<mlir::Operation *> &reach, 102*47f75930SValentin Clement mlir::Region *loopRegion) 103*47f75930SValentin Clement : reach{reach}, loopRegion{loopRegion} {} 104*47f75930SValentin Clement 105*47f75930SValentin Clement void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { 106*47f75930SValentin Clement llvm::errs() << "COLLECT " << *op << "\n"; 107*47f75930SValentin Clement if (range.empty()) { 108*47f75930SValentin Clement collectArrayAccessFrom(op, mlir::Value{}); 109*47f75930SValentin Clement return; 110*47f75930SValentin Clement } 111*47f75930SValentin Clement for (mlir::Value v : range) 112*47f75930SValentin Clement collectArrayAccessFrom(v); 113*47f75930SValentin Clement } 114*47f75930SValentin Clement 115*47f75930SValentin Clement // TODO: Replace recursive algorithm on def-use chain with an iterative one 116*47f75930SValentin Clement // with an explicit stack. 117*47f75930SValentin Clement void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { 118*47f75930SValentin Clement // `val` is defined by an Op, process the defining Op. 119*47f75930SValentin Clement // If `val` is defined by a region containing Op, we want to drill down 120*47f75930SValentin Clement // and through that Op's region(s). 121*47f75930SValentin Clement llvm::errs() << "COLLECT " << *op << "\n"; 122*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); 123*47f75930SValentin Clement auto popFn = [&](auto rop) { 124*47f75930SValentin Clement assert(val && "op must have a result value"); 125*47f75930SValentin Clement auto resNum = val.cast<mlir::OpResult>().getResultNumber(); 126*47f75930SValentin Clement llvm::SmallVector<mlir::Value> results; 127*47f75930SValentin Clement rop.resultToSourceOps(results, resNum); 128*47f75930SValentin Clement for (auto u : results) 129*47f75930SValentin Clement collectArrayAccessFrom(u); 130*47f75930SValentin Clement }; 131*47f75930SValentin Clement if (auto rop = mlir::dyn_cast<fir::DoLoopOp>(op)) { 132*47f75930SValentin Clement popFn(rop); 133*47f75930SValentin Clement return; 134*47f75930SValentin Clement } 135*47f75930SValentin Clement if (auto rop = mlir::dyn_cast<fir::IfOp>(op)) { 136*47f75930SValentin Clement popFn(rop); 137*47f75930SValentin Clement return; 138*47f75930SValentin Clement } 139*47f75930SValentin Clement if (auto mergeStore = mlir::dyn_cast<ArrayMergeStoreOp>(op)) { 140*47f75930SValentin Clement if (opIsInsideLoops(mergeStore)) 141*47f75930SValentin Clement collectArrayAccessFrom(mergeStore.sequence()); 142*47f75930SValentin Clement return; 143*47f75930SValentin Clement } 144*47f75930SValentin Clement 145*47f75930SValentin Clement if (mlir::isa<AllocaOp, AllocMemOp>(op)) { 146*47f75930SValentin Clement // Look for any stores inside the loops, and collect an array operation 147*47f75930SValentin Clement // that produced the value being stored to it. 148*47f75930SValentin Clement for (mlir::Operation *user : op->getUsers()) 149*47f75930SValentin Clement if (auto store = mlir::dyn_cast<fir::StoreOp>(user)) 150*47f75930SValentin Clement if (opIsInsideLoops(store)) 151*47f75930SValentin Clement collectArrayAccessFrom(store.value()); 152*47f75930SValentin Clement return; 153*47f75930SValentin Clement } 154*47f75930SValentin Clement 155*47f75930SValentin Clement // Otherwise, Op does not contain a region so just chase its operands. 156*47f75930SValentin Clement if (mlir::isa<ArrayLoadOp, ArrayUpdateOp, ArrayModifyOp, ArrayFetchOp>( 157*47f75930SValentin Clement op)) { 158*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); 159*47f75930SValentin Clement reach.emplace_back(op); 160*47f75930SValentin Clement } 161*47f75930SValentin Clement // Array modify assignment is performed on the result. So the analysis 162*47f75930SValentin Clement // must look at the what is done with the result. 163*47f75930SValentin Clement if (mlir::isa<ArrayModifyOp>(op)) 164*47f75930SValentin Clement for (mlir::Operation *user : op->getResult(0).getUsers()) 165*47f75930SValentin Clement followUsers(user); 166*47f75930SValentin Clement 167*47f75930SValentin Clement for (auto u : op->getOperands()) 168*47f75930SValentin Clement collectArrayAccessFrom(u); 169*47f75930SValentin Clement } 170*47f75930SValentin Clement 171*47f75930SValentin Clement void collectArrayAccessFrom(mlir::BlockArgument ba) { 172*47f75930SValentin Clement auto *parent = ba.getOwner()->getParentOp(); 173*47f75930SValentin Clement // If inside an Op holding a region, the block argument corresponds to an 174*47f75930SValentin Clement // argument passed to the containing Op. 175*47f75930SValentin Clement auto popFn = [&](auto rop) { 176*47f75930SValentin Clement collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); 177*47f75930SValentin Clement }; 178*47f75930SValentin Clement if (auto rop = mlir::dyn_cast<DoLoopOp>(parent)) { 179*47f75930SValentin Clement popFn(rop); 180*47f75930SValentin Clement return; 181*47f75930SValentin Clement } 182*47f75930SValentin Clement if (auto rop = mlir::dyn_cast<IterWhileOp>(parent)) { 183*47f75930SValentin Clement popFn(rop); 184*47f75930SValentin Clement return; 185*47f75930SValentin Clement } 186*47f75930SValentin Clement // Otherwise, a block argument is provided via the pred blocks. 187*47f75930SValentin Clement for (auto *pred : ba.getOwner()->getPredecessors()) { 188*47f75930SValentin Clement auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); 189*47f75930SValentin Clement collectArrayAccessFrom(u); 190*47f75930SValentin Clement } 191*47f75930SValentin Clement } 192*47f75930SValentin Clement 193*47f75930SValentin Clement // Recursively trace operands to find all array operations relating to the 194*47f75930SValentin Clement // values merged. 195*47f75930SValentin Clement void collectArrayAccessFrom(mlir::Value val) { 196*47f75930SValentin Clement if (!val || visited.contains(val)) 197*47f75930SValentin Clement return; 198*47f75930SValentin Clement visited.insert(val); 199*47f75930SValentin Clement 200*47f75930SValentin Clement // Process a block argument. 201*47f75930SValentin Clement if (auto ba = val.dyn_cast<mlir::BlockArgument>()) { 202*47f75930SValentin Clement collectArrayAccessFrom(ba); 203*47f75930SValentin Clement return; 204*47f75930SValentin Clement } 205*47f75930SValentin Clement 206*47f75930SValentin Clement // Process an Op. 207*47f75930SValentin Clement if (auto *op = val.getDefiningOp()) { 208*47f75930SValentin Clement collectArrayAccessFrom(op, val); 209*47f75930SValentin Clement return; 210*47f75930SValentin Clement } 211*47f75930SValentin Clement 212*47f75930SValentin Clement fir::emitFatalError(val.getLoc(), "unhandled value"); 213*47f75930SValentin Clement } 214*47f75930SValentin Clement 215*47f75930SValentin Clement /// Is \op inside the loop nest region ? 216*47f75930SValentin Clement bool opIsInsideLoops(mlir::Operation *op) const { 217*47f75930SValentin Clement return loopRegion && loopRegion->isAncestor(op->getParentRegion()); 218*47f75930SValentin Clement } 219*47f75930SValentin Clement 220*47f75930SValentin Clement /// Recursively trace the use of an operation results, calling 221*47f75930SValentin Clement /// collectArrayAccessFrom on the direct and indirect user operands. 222*47f75930SValentin Clement /// TODO: Replace recursive algorithm on def-use chain with an iterative one 223*47f75930SValentin Clement /// with an explicit stack. 224*47f75930SValentin Clement void followUsers(mlir::Operation *op) { 225*47f75930SValentin Clement for (auto userOperand : op->getOperands()) 226*47f75930SValentin Clement collectArrayAccessFrom(userOperand); 227*47f75930SValentin Clement // Go through potential converts/coordinate_op. 228*47f75930SValentin Clement for (mlir::Operation *indirectUser : op->getUsers()) 229*47f75930SValentin Clement followUsers(indirectUser); 230*47f75930SValentin Clement } 231*47f75930SValentin Clement 232*47f75930SValentin Clement llvm::SmallVectorImpl<mlir::Operation *> &reach; 233*47f75930SValentin Clement llvm::SmallPtrSet<mlir::Value, 16> visited; 234*47f75930SValentin Clement /// Region of the loops nest that produced the array value. 235*47f75930SValentin Clement mlir::Region *loopRegion; 236*47f75930SValentin Clement 237*47f75930SValentin Clement public: 238*47f75930SValentin Clement /// Return all ops that produce the array value that is stored into the 239*47f75930SValentin Clement /// `array_merge_store`. 240*47f75930SValentin Clement static void reachingValues(llvm::SmallVectorImpl<mlir::Operation *> &reach, 241*47f75930SValentin Clement mlir::Value seq) { 242*47f75930SValentin Clement reach.clear(); 243*47f75930SValentin Clement mlir::Region *loopRegion = nullptr; 244*47f75930SValentin Clement // Only `DoLoopOp` is tested here since array operations are currently only 245*47f75930SValentin Clement // associated with this kind of loop. 246*47f75930SValentin Clement if (auto doLoop = 247*47f75930SValentin Clement mlir::dyn_cast_or_null<fir::DoLoopOp>(seq.getDefiningOp())) 248*47f75930SValentin Clement loopRegion = &doLoop->getRegion(0); 249*47f75930SValentin Clement ReachCollector collector(reach, loopRegion); 250*47f75930SValentin Clement collector.collectArrayAccessFrom(seq); 251*47f75930SValentin Clement } 252*47f75930SValentin Clement }; 253*47f75930SValentin Clement } // namespace 254*47f75930SValentin Clement 255*47f75930SValentin Clement /// Find all the array operations that access the array value that is loaded by 256*47f75930SValentin Clement /// the array load operation, `load`. 257*47f75930SValentin Clement const llvm::SmallVector<mlir::Operation *> & 258*47f75930SValentin Clement ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { 259*47f75930SValentin Clement auto lmIter = loadMapSets.find(load); 260*47f75930SValentin Clement if (lmIter != loadMapSets.end()) 261*47f75930SValentin Clement return lmIter->getSecond(); 262*47f75930SValentin Clement 263*47f75930SValentin Clement llvm::SmallVector<mlir::Operation *> accesses; 264*47f75930SValentin Clement UseSetT visited; 265*47f75930SValentin Clement llvm::SmallVector<mlir::OpOperand *> queue; // uses of ArrayLoad[orig] 266*47f75930SValentin Clement 267*47f75930SValentin Clement auto appendToQueue = [&](mlir::Value val) { 268*47f75930SValentin Clement for (mlir::OpOperand &use : val.getUses()) 269*47f75930SValentin Clement if (!visited.count(&use)) { 270*47f75930SValentin Clement visited.insert(&use); 271*47f75930SValentin Clement queue.push_back(&use); 272*47f75930SValentin Clement } 273*47f75930SValentin Clement }; 274*47f75930SValentin Clement 275*47f75930SValentin Clement // Build the set of uses of `original`. 276*47f75930SValentin Clement // let USES = { uses of original fir.load } 277*47f75930SValentin Clement appendToQueue(load); 278*47f75930SValentin Clement 279*47f75930SValentin Clement // Process the worklist until done. 280*47f75930SValentin Clement while (!queue.empty()) { 281*47f75930SValentin Clement mlir::OpOperand *operand = queue.pop_back_val(); 282*47f75930SValentin Clement mlir::Operation *owner = operand->getOwner(); 283*47f75930SValentin Clement 284*47f75930SValentin Clement auto structuredLoop = [&](auto ro) { 285*47f75930SValentin Clement if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { 286*47f75930SValentin Clement int64_t arg = blockArg.getArgNumber(); 287*47f75930SValentin Clement mlir::Value output = ro.getResult(ro.finalValue() ? arg : arg - 1); 288*47f75930SValentin Clement appendToQueue(output); 289*47f75930SValentin Clement appendToQueue(blockArg); 290*47f75930SValentin Clement } 291*47f75930SValentin Clement }; 292*47f75930SValentin Clement // TODO: this need to be updated to use the control-flow interface. 293*47f75930SValentin Clement auto branchOp = [&](mlir::Block *dest, OperandRange operands) { 294*47f75930SValentin Clement if (operands.empty()) 295*47f75930SValentin Clement return; 296*47f75930SValentin Clement 297*47f75930SValentin Clement // Check if this operand is within the range. 298*47f75930SValentin Clement unsigned operandIndex = operand->getOperandNumber(); 299*47f75930SValentin Clement unsigned operandsStart = operands.getBeginOperandIndex(); 300*47f75930SValentin Clement if (operandIndex < operandsStart || 301*47f75930SValentin Clement operandIndex >= (operandsStart + operands.size())) 302*47f75930SValentin Clement return; 303*47f75930SValentin Clement 304*47f75930SValentin Clement // Index the successor. 305*47f75930SValentin Clement unsigned argIndex = operandIndex - operandsStart; 306*47f75930SValentin Clement appendToQueue(dest->getArgument(argIndex)); 307*47f75930SValentin Clement }; 308*47f75930SValentin Clement // Thread uses into structured loop bodies and return value uses. 309*47f75930SValentin Clement if (auto ro = mlir::dyn_cast<DoLoopOp>(owner)) { 310*47f75930SValentin Clement structuredLoop(ro); 311*47f75930SValentin Clement } else if (auto ro = mlir::dyn_cast<IterWhileOp>(owner)) { 312*47f75930SValentin Clement structuredLoop(ro); 313*47f75930SValentin Clement } else if (auto rs = mlir::dyn_cast<ResultOp>(owner)) { 314*47f75930SValentin Clement // Thread any uses of fir.if that return the marked array value. 315*47f75930SValentin Clement if (auto ifOp = rs->getParentOfType<fir::IfOp>()) 316*47f75930SValentin Clement appendToQueue(ifOp.getResult(operand->getOperandNumber())); 317*47f75930SValentin Clement } else if (mlir::isa<ArrayFetchOp>(owner)) { 318*47f75930SValentin Clement // Keep track of array value fetches. 319*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 320*47f75930SValentin Clement << "add fetch {" << *owner << "} to array value set\n"); 321*47f75930SValentin Clement accesses.push_back(owner); 322*47f75930SValentin Clement } else if (auto update = mlir::dyn_cast<ArrayUpdateOp>(owner)) { 323*47f75930SValentin Clement // Keep track of array value updates and thread the return value uses. 324*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 325*47f75930SValentin Clement << "add update {" << *owner << "} to array value set\n"); 326*47f75930SValentin Clement accesses.push_back(owner); 327*47f75930SValentin Clement appendToQueue(update.getResult()); 328*47f75930SValentin Clement } else if (auto update = mlir::dyn_cast<ArrayModifyOp>(owner)) { 329*47f75930SValentin Clement // Keep track of array value modification and thread the return value 330*47f75930SValentin Clement // uses. 331*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 332*47f75930SValentin Clement << "add modify {" << *owner << "} to array value set\n"); 333*47f75930SValentin Clement accesses.push_back(owner); 334*47f75930SValentin Clement appendToQueue(update.getResult(1)); 335*47f75930SValentin Clement } else if (auto br = mlir::dyn_cast<mlir::BranchOp>(owner)) { 336*47f75930SValentin Clement branchOp(br.getDest(), br.destOperands()); 337*47f75930SValentin Clement } else if (auto br = mlir::dyn_cast<mlir::CondBranchOp>(owner)) { 338*47f75930SValentin Clement branchOp(br.getTrueDest(), br.getTrueOperands()); 339*47f75930SValentin Clement branchOp(br.getFalseDest(), br.getFalseOperands()); 340*47f75930SValentin Clement } else if (mlir::isa<ArrayMergeStoreOp>(owner)) { 341*47f75930SValentin Clement // do nothing 342*47f75930SValentin Clement } else { 343*47f75930SValentin Clement llvm::report_fatal_error("array value reached unexpected op"); 344*47f75930SValentin Clement } 345*47f75930SValentin Clement } 346*47f75930SValentin Clement return loadMapSets.insert({load, accesses}).first->getSecond(); 347*47f75930SValentin Clement } 348*47f75930SValentin Clement 349*47f75930SValentin Clement /// Is there a conflict between the array value that was updated and to be 350*47f75930SValentin Clement /// stored to `st` and the set of arrays loaded (`reach`) and used to compute 351*47f75930SValentin Clement /// the updated value? 352*47f75930SValentin Clement static bool conflictOnLoad(llvm::ArrayRef<mlir::Operation *> reach, 353*47f75930SValentin Clement ArrayMergeStoreOp st) { 354*47f75930SValentin Clement mlir::Value load; 355*47f75930SValentin Clement mlir::Value addr = st.memref(); 356*47f75930SValentin Clement auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); 357*47f75930SValentin Clement for (auto *op : reach) { 358*47f75930SValentin Clement auto ld = mlir::dyn_cast<ArrayLoadOp>(op); 359*47f75930SValentin Clement if (!ld) 360*47f75930SValentin Clement continue; 361*47f75930SValentin Clement mlir::Type ldTy = ld.memref().getType(); 362*47f75930SValentin Clement if (auto boxTy = ldTy.dyn_cast<fir::BoxType>()) 363*47f75930SValentin Clement ldTy = boxTy.getEleTy(); 364*47f75930SValentin Clement if (ldTy.isa<fir::PointerType>() && stEleTy == dyn_cast_ptrEleTy(ldTy)) 365*47f75930SValentin Clement return true; 366*47f75930SValentin Clement if (ld.memref() == addr) { 367*47f75930SValentin Clement if (ld.getResult() != st.original()) 368*47f75930SValentin Clement return true; 369*47f75930SValentin Clement if (load) 370*47f75930SValentin Clement return true; 371*47f75930SValentin Clement load = ld; 372*47f75930SValentin Clement } 373*47f75930SValentin Clement } 374*47f75930SValentin Clement return false; 375*47f75930SValentin Clement } 376*47f75930SValentin Clement 377*47f75930SValentin Clement /// Check if there is any potential conflict in the chained update operations 378*47f75930SValentin Clement /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the 379*47f75930SValentin Clement /// array. A potential conflict is detected if two operations work on the same 380*47f75930SValentin Clement /// indices. 381*47f75930SValentin Clement static bool conflictOnMerge(llvm::ArrayRef<mlir::Operation *> accesses) { 382*47f75930SValentin Clement if (accesses.size() < 2) 383*47f75930SValentin Clement return false; 384*47f75930SValentin Clement llvm::SmallVector<mlir::Value> indices; 385*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() 386*47f75930SValentin Clement << " accesses on the list\n"); 387*47f75930SValentin Clement for (auto *op : accesses) { 388*47f75930SValentin Clement assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(op)) && 389*47f75930SValentin Clement "unexpected operation in analysis"); 390*47f75930SValentin Clement llvm::SmallVector<mlir::Value> compareVector; 391*47f75930SValentin Clement if (auto u = mlir::dyn_cast<ArrayUpdateOp>(op)) { 392*47f75930SValentin Clement if (indices.empty()) { 393*47f75930SValentin Clement indices = u.indices(); 394*47f75930SValentin Clement continue; 395*47f75930SValentin Clement } 396*47f75930SValentin Clement compareVector = u.indices(); 397*47f75930SValentin Clement } else if (auto f = mlir::dyn_cast<ArrayModifyOp>(op)) { 398*47f75930SValentin Clement if (indices.empty()) { 399*47f75930SValentin Clement indices = f.indices(); 400*47f75930SValentin Clement continue; 401*47f75930SValentin Clement } 402*47f75930SValentin Clement compareVector = f.indices(); 403*47f75930SValentin Clement } else if (auto f = mlir::dyn_cast<ArrayFetchOp>(op)) { 404*47f75930SValentin Clement if (indices.empty()) { 405*47f75930SValentin Clement indices = f.indices(); 406*47f75930SValentin Clement continue; 407*47f75930SValentin Clement } 408*47f75930SValentin Clement compareVector = f.indices(); 409*47f75930SValentin Clement } 410*47f75930SValentin Clement if (compareVector != indices) 411*47f75930SValentin Clement return true; 412*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); 413*47f75930SValentin Clement } 414*47f75930SValentin Clement return false; 415*47f75930SValentin Clement } 416*47f75930SValentin Clement 417*47f75930SValentin Clement // Are either of types of conflicts present? 418*47f75930SValentin Clement inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach, 419*47f75930SValentin Clement llvm::ArrayRef<mlir::Operation *> accesses, 420*47f75930SValentin Clement ArrayMergeStoreOp st) { 421*47f75930SValentin Clement return conflictOnLoad(reach, st) || conflictOnMerge(accesses); 422*47f75930SValentin Clement } 423*47f75930SValentin Clement 424*47f75930SValentin Clement /// Constructor of the array copy analysis. 425*47f75930SValentin Clement /// This performs the analysis and saves the intermediate results. 426*47f75930SValentin Clement void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { 427*47f75930SValentin Clement topLevelOp->walk([&](Operation *op) { 428*47f75930SValentin Clement if (auto st = mlir::dyn_cast<fir::ArrayMergeStoreOp>(op)) { 429*47f75930SValentin Clement llvm::SmallVector<Operation *> values; 430*47f75930SValentin Clement ReachCollector::reachingValues(values, st.sequence()); 431*47f75930SValentin Clement const llvm::SmallVector<Operation *> &accesses = 432*47f75930SValentin Clement arrayAccesses(mlir::cast<ArrayLoadOp>(st.original().getDefiningOp())); 433*47f75930SValentin Clement if (conflictDetected(values, accesses, st)) { 434*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 435*47f75930SValentin Clement << "CONFLICT: copies required for " << st << '\n' 436*47f75930SValentin Clement << " adding conflicts on: " << op << " and " 437*47f75930SValentin Clement << st.original() << '\n'); 438*47f75930SValentin Clement conflicts.insert(op); 439*47f75930SValentin Clement conflicts.insert(st.original().getDefiningOp()); 440*47f75930SValentin Clement } 441*47f75930SValentin Clement auto *ld = st.original().getDefiningOp(); 442*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 443*47f75930SValentin Clement << "map: adding {" << *ld << " -> " << st << "}\n"); 444*47f75930SValentin Clement useMap.insert({ld, op}); 445*47f75930SValentin Clement } else if (auto load = mlir::dyn_cast<ArrayLoadOp>(op)) { 446*47f75930SValentin Clement const llvm::SmallVector<mlir::Operation *> &accesses = 447*47f75930SValentin Clement arrayAccesses(load); 448*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "process load: " << load 449*47f75930SValentin Clement << ", accesses: " << accesses.size() << '\n'); 450*47f75930SValentin Clement for (auto *acc : accesses) { 451*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); 452*47f75930SValentin Clement assert((mlir::isa<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(acc))); 453*47f75930SValentin Clement if (!useMap.insert({acc, op}).second) { 454*47f75930SValentin Clement mlir::emitError( 455*47f75930SValentin Clement load.getLoc(), 456*47f75930SValentin Clement "The parallel semantics of multiple array_merge_stores per " 457*47f75930SValentin Clement "array_load are not supported."); 458*47f75930SValentin Clement return; 459*47f75930SValentin Clement } 460*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() 461*47f75930SValentin Clement << "map: adding {" << *acc << "} -> {" << load << "}\n"); 462*47f75930SValentin Clement } 463*47f75930SValentin Clement } 464*47f75930SValentin Clement }); 465*47f75930SValentin Clement } 466*47f75930SValentin Clement 467*47f75930SValentin Clement namespace { 468*47f75930SValentin Clement class ArrayLoadConversion : public mlir::OpRewritePattern<ArrayLoadOp> { 469*47f75930SValentin Clement public: 470*47f75930SValentin Clement using OpRewritePattern::OpRewritePattern; 471*47f75930SValentin Clement 472*47f75930SValentin Clement mlir::LogicalResult 473*47f75930SValentin Clement matchAndRewrite(ArrayLoadOp load, 474*47f75930SValentin Clement mlir::PatternRewriter &rewriter) const override { 475*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); 476*47f75930SValentin Clement rewriter.replaceOpWithNewOp<UndefOp>(load, load.getType()); 477*47f75930SValentin Clement return mlir::success(); 478*47f75930SValentin Clement } 479*47f75930SValentin Clement }; 480*47f75930SValentin Clement 481*47f75930SValentin Clement class ArrayMergeStoreConversion 482*47f75930SValentin Clement : public mlir::OpRewritePattern<ArrayMergeStoreOp> { 483*47f75930SValentin Clement public: 484*47f75930SValentin Clement using OpRewritePattern::OpRewritePattern; 485*47f75930SValentin Clement 486*47f75930SValentin Clement mlir::LogicalResult 487*47f75930SValentin Clement matchAndRewrite(ArrayMergeStoreOp store, 488*47f75930SValentin Clement mlir::PatternRewriter &rewriter) const override { 489*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); 490*47f75930SValentin Clement rewriter.eraseOp(store); 491*47f75930SValentin Clement return mlir::success(); 492*47f75930SValentin Clement } 493*47f75930SValentin Clement }; 494*47f75930SValentin Clement } // namespace 495*47f75930SValentin Clement 496*47f75930SValentin Clement static mlir::Type getEleTy(mlir::Type ty) { 497*47f75930SValentin Clement if (auto t = dyn_cast_ptrEleTy(ty)) 498*47f75930SValentin Clement ty = t; 499*47f75930SValentin Clement if (auto t = ty.dyn_cast<SequenceType>()) 500*47f75930SValentin Clement ty = t.getEleTy(); 501*47f75930SValentin Clement // FIXME: keep ptr/heap/ref information. 502*47f75930SValentin Clement return ReferenceType::get(ty); 503*47f75930SValentin Clement } 504*47f75930SValentin Clement 505*47f75930SValentin Clement // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. 506*47f75930SValentin Clement // TODO: getExtents on op should return a ValueRange instead of a vector. 507*47f75930SValentin Clement static void getExtents(llvm::SmallVectorImpl<mlir::Value> &result, 508*47f75930SValentin Clement mlir::Value shape) { 509*47f75930SValentin Clement auto *shapeOp = shape.getDefiningOp(); 510*47f75930SValentin Clement if (auto s = mlir::dyn_cast<fir::ShapeOp>(shapeOp)) { 511*47f75930SValentin Clement auto e = s.getExtents(); 512*47f75930SValentin Clement result.insert(result.end(), e.begin(), e.end()); 513*47f75930SValentin Clement return; 514*47f75930SValentin Clement } 515*47f75930SValentin Clement if (auto s = mlir::dyn_cast<fir::ShapeShiftOp>(shapeOp)) { 516*47f75930SValentin Clement auto e = s.getExtents(); 517*47f75930SValentin Clement result.insert(result.end(), e.begin(), e.end()); 518*47f75930SValentin Clement return; 519*47f75930SValentin Clement } 520*47f75930SValentin Clement llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); 521*47f75930SValentin Clement } 522*47f75930SValentin Clement 523*47f75930SValentin Clement // Place the extents of the array loaded by an ArrayLoadOp into the result 524*47f75930SValentin Clement // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If 525*47f75930SValentin Clement // the ArrayLoadOp is loading a fir.box, code will be generated to read the 526*47f75930SValentin Clement // extents from the fir.box, and a the retunred ShapeOp is built with the read 527*47f75930SValentin Clement // extents. 528*47f75930SValentin Clement // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp 529*47f75930SValentin Clement // argument of the ArrayLoadOp that is returned. 530*47f75930SValentin Clement static mlir::Value 531*47f75930SValentin Clement getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, 532*47f75930SValentin Clement fir::ArrayLoadOp loadOp, 533*47f75930SValentin Clement llvm::SmallVectorImpl<mlir::Value> &result) { 534*47f75930SValentin Clement assert(result.empty()); 535*47f75930SValentin Clement if (auto boxTy = loadOp.memref().getType().dyn_cast<fir::BoxType>()) { 536*47f75930SValentin Clement auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) 537*47f75930SValentin Clement .cast<fir::SequenceType>() 538*47f75930SValentin Clement .getDimension(); 539*47f75930SValentin Clement auto idxTy = rewriter.getIndexType(); 540*47f75930SValentin Clement for (decltype(rank) dim = 0; dim < rank; ++dim) { 541*47f75930SValentin Clement auto dimVal = rewriter.create<arith::ConstantIndexOp>(loc, dim); 542*47f75930SValentin Clement auto dimInfo = rewriter.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, 543*47f75930SValentin Clement loadOp.memref(), dimVal); 544*47f75930SValentin Clement result.emplace_back(dimInfo.getResult(1)); 545*47f75930SValentin Clement } 546*47f75930SValentin Clement auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); 547*47f75930SValentin Clement return rewriter.create<fir::ShapeOp>(loc, shapeType, result); 548*47f75930SValentin Clement } 549*47f75930SValentin Clement getExtents(result, loadOp.shape()); 550*47f75930SValentin Clement return loadOp.shape(); 551*47f75930SValentin Clement } 552*47f75930SValentin Clement 553*47f75930SValentin Clement static mlir::Type toRefType(mlir::Type ty) { 554*47f75930SValentin Clement if (fir::isa_ref_type(ty)) 555*47f75930SValentin Clement return ty; 556*47f75930SValentin Clement return fir::ReferenceType::get(ty); 557*47f75930SValentin Clement } 558*47f75930SValentin Clement 559*47f75930SValentin Clement static mlir::Value 560*47f75930SValentin Clement genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, 561*47f75930SValentin Clement mlir::Type resTy, mlir::Value alloc, mlir::Value shape, 562*47f75930SValentin Clement mlir::Value slice, mlir::ValueRange indices, 563*47f75930SValentin Clement mlir::ValueRange typeparams, bool skipOrig = false) { 564*47f75930SValentin Clement llvm::SmallVector<mlir::Value> originated; 565*47f75930SValentin Clement if (skipOrig) 566*47f75930SValentin Clement originated.assign(indices.begin(), indices.end()); 567*47f75930SValentin Clement else 568*47f75930SValentin Clement originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), 569*47f75930SValentin Clement shape, indices); 570*47f75930SValentin Clement auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); 571*47f75930SValentin Clement assert(seqTy && seqTy.isa<fir::SequenceType>()); 572*47f75930SValentin Clement const auto dimension = seqTy.cast<fir::SequenceType>().getDimension(); 573*47f75930SValentin Clement mlir::Value result = rewriter.create<fir::ArrayCoorOp>( 574*47f75930SValentin Clement loc, eleTy, alloc, shape, slice, 575*47f75930SValentin Clement llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension), 576*47f75930SValentin Clement typeparams); 577*47f75930SValentin Clement if (dimension < originated.size()) 578*47f75930SValentin Clement result = rewriter.create<fir::CoordinateOp>( 579*47f75930SValentin Clement loc, resTy, result, 580*47f75930SValentin Clement llvm::ArrayRef<mlir::Value>{originated}.drop_front(dimension)); 581*47f75930SValentin Clement return result; 582*47f75930SValentin Clement } 583*47f75930SValentin Clement 584*47f75930SValentin Clement namespace { 585*47f75930SValentin Clement /// Conversion of fir.array_update and fir.array_modify Ops. 586*47f75930SValentin Clement /// If there is a conflict for the update, then we need to perform a 587*47f75930SValentin Clement /// copy-in/copy-out to preserve the original values of the array. If there is 588*47f75930SValentin Clement /// no conflict, then it is save to eschew making any copies. 589*47f75930SValentin Clement template <typename ArrayOp> 590*47f75930SValentin Clement class ArrayUpdateConversionBase : public mlir::OpRewritePattern<ArrayOp> { 591*47f75930SValentin Clement public: 592*47f75930SValentin Clement explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, 593*47f75930SValentin Clement const ArrayCopyAnalysis &a, 594*47f75930SValentin Clement const OperationUseMapT &m) 595*47f75930SValentin Clement : mlir::OpRewritePattern<ArrayOp>{ctx}, analysis{a}, useMap{m} {} 596*47f75930SValentin Clement 597*47f75930SValentin Clement void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, 598*47f75930SValentin Clement mlir::Value dst, mlir::Value src, mlir::Value shapeOp, 599*47f75930SValentin Clement mlir::Type arrTy) const { 600*47f75930SValentin Clement auto insPt = rewriter.saveInsertionPoint(); 601*47f75930SValentin Clement llvm::SmallVector<mlir::Value> indices; 602*47f75930SValentin Clement llvm::SmallVector<mlir::Value> extents; 603*47f75930SValentin Clement getExtents(extents, shapeOp); 604*47f75930SValentin Clement // Build loop nest from column to row. 605*47f75930SValentin Clement for (auto sh : llvm::reverse(extents)) { 606*47f75930SValentin Clement auto idxTy = rewriter.getIndexType(); 607*47f75930SValentin Clement auto ubi = rewriter.create<fir::ConvertOp>(loc, idxTy, sh); 608*47f75930SValentin Clement auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 609*47f75930SValentin Clement auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 610*47f75930SValentin Clement auto ub = rewriter.create<arith::SubIOp>(loc, idxTy, ubi, one); 611*47f75930SValentin Clement auto loop = rewriter.create<fir::DoLoopOp>(loc, zero, ub, one); 612*47f75930SValentin Clement rewriter.setInsertionPointToStart(loop.getBody()); 613*47f75930SValentin Clement indices.push_back(loop.getInductionVar()); 614*47f75930SValentin Clement } 615*47f75930SValentin Clement // Reverse the indices so they are in column-major order. 616*47f75930SValentin Clement std::reverse(indices.begin(), indices.end()); 617*47f75930SValentin Clement auto ty = getEleTy(arrTy); 618*47f75930SValentin Clement auto fromAddr = rewriter.create<fir::ArrayCoorOp>( 619*47f75930SValentin Clement loc, ty, src, shapeOp, mlir::Value{}, 620*47f75930SValentin Clement fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, 621*47f75930SValentin Clement indices), 622*47f75930SValentin Clement mlir::ValueRange{}); 623*47f75930SValentin Clement auto load = rewriter.create<fir::LoadOp>(loc, fromAddr); 624*47f75930SValentin Clement auto toAddr = rewriter.create<fir::ArrayCoorOp>( 625*47f75930SValentin Clement loc, ty, dst, shapeOp, mlir::Value{}, 626*47f75930SValentin Clement fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, 627*47f75930SValentin Clement indices), 628*47f75930SValentin Clement mlir::ValueRange{}); 629*47f75930SValentin Clement rewriter.create<fir::StoreOp>(loc, load, toAddr); 630*47f75930SValentin Clement rewriter.restoreInsertionPoint(insPt); 631*47f75930SValentin Clement } 632*47f75930SValentin Clement 633*47f75930SValentin Clement /// Copy the RHS element into the LHS and insert copy-in/copy-out between a 634*47f75930SValentin Clement /// temp and the LHS if the analysis found potential overlaps between the RHS 635*47f75930SValentin Clement /// and LHS arrays. The element copy generator must be provided through \p 636*47f75930SValentin Clement /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. 637*47f75930SValentin Clement /// Returns the address of the LHS element inside the loop and the LHS 638*47f75930SValentin Clement /// ArrayLoad result. 639*47f75930SValentin Clement std::pair<mlir::Value, mlir::Value> 640*47f75930SValentin Clement materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, 641*47f75930SValentin Clement ArrayOp update, 642*47f75930SValentin Clement llvm::function_ref<void(mlir::Value)> assignElement, 643*47f75930SValentin Clement mlir::Type lhsEltRefType) const { 644*47f75930SValentin Clement auto *op = update.getOperation(); 645*47f75930SValentin Clement mlir::Operation *loadOp = useMap.lookup(op); 646*47f75930SValentin Clement auto load = mlir::cast<ArrayLoadOp>(loadOp); 647*47f75930SValentin Clement LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); 648*47f75930SValentin Clement if (analysis.hasPotentialConflict(loadOp)) { 649*47f75930SValentin Clement // If there is a conflict between the arrays, then we copy the lhs array 650*47f75930SValentin Clement // to a temporary, update the temporary, and copy the temporary back to 651*47f75930SValentin Clement // the lhs array. This yields Fortran's copy-in copy-out array semantics. 652*47f75930SValentin Clement LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); 653*47f75930SValentin Clement rewriter.setInsertionPoint(loadOp); 654*47f75930SValentin Clement // Copy in. 655*47f75930SValentin Clement llvm::SmallVector<mlir::Value> extents; 656*47f75930SValentin Clement mlir::Value shapeOp = 657*47f75930SValentin Clement getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); 658*47f75930SValentin Clement auto allocmem = rewriter.create<AllocMemOp>( 659*47f75930SValentin Clement loc, dyn_cast_ptrOrBoxEleTy(load.memref().getType()), 660*47f75930SValentin Clement load.typeparams(), extents); 661*47f75930SValentin Clement genArrayCopy(load.getLoc(), rewriter, allocmem, load.memref(), shapeOp, 662*47f75930SValentin Clement load.getType()); 663*47f75930SValentin Clement rewriter.setInsertionPoint(op); 664*47f75930SValentin Clement mlir::Value coor = genCoorOp( 665*47f75930SValentin Clement rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, 666*47f75930SValentin Clement shapeOp, load.slice(), update.indices(), load.typeparams(), 667*47f75930SValentin Clement update->hasAttr(fir::factory::attrFortranArrayOffsets())); 668*47f75930SValentin Clement assignElement(coor); 669*47f75930SValentin Clement mlir::Operation *storeOp = useMap.lookup(loadOp); 670*47f75930SValentin Clement auto store = mlir::cast<ArrayMergeStoreOp>(storeOp); 671*47f75930SValentin Clement rewriter.setInsertionPoint(storeOp); 672*47f75930SValentin Clement // Copy out. 673*47f75930SValentin Clement genArrayCopy(store.getLoc(), rewriter, store.memref(), allocmem, shapeOp, 674*47f75930SValentin Clement load.getType()); 675*47f75930SValentin Clement rewriter.create<FreeMemOp>(loc, allocmem); 676*47f75930SValentin Clement return {coor, load.getResult()}; 677*47f75930SValentin Clement } 678*47f75930SValentin Clement // Otherwise, when there is no conflict (a possible loop-carried 679*47f75930SValentin Clement // dependence), the lhs array can be updated in place. 680*47f75930SValentin Clement LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); 681*47f75930SValentin Clement rewriter.setInsertionPoint(op); 682*47f75930SValentin Clement auto coorTy = getEleTy(load.getType()); 683*47f75930SValentin Clement mlir::Value coor = genCoorOp( 684*47f75930SValentin Clement rewriter, loc, coorTy, lhsEltRefType, load.memref(), load.shape(), 685*47f75930SValentin Clement load.slice(), update.indices(), load.typeparams(), 686*47f75930SValentin Clement update->hasAttr(fir::factory::attrFortranArrayOffsets())); 687*47f75930SValentin Clement assignElement(coor); 688*47f75930SValentin Clement return {coor, load.getResult()}; 689*47f75930SValentin Clement } 690*47f75930SValentin Clement 691*47f75930SValentin Clement private: 692*47f75930SValentin Clement const ArrayCopyAnalysis &analysis; 693*47f75930SValentin Clement const OperationUseMapT &useMap; 694*47f75930SValentin Clement }; 695*47f75930SValentin Clement 696*47f75930SValentin Clement class ArrayUpdateConversion : public ArrayUpdateConversionBase<ArrayUpdateOp> { 697*47f75930SValentin Clement public: 698*47f75930SValentin Clement explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, 699*47f75930SValentin Clement const ArrayCopyAnalysis &a, 700*47f75930SValentin Clement const OperationUseMapT &m) 701*47f75930SValentin Clement : ArrayUpdateConversionBase{ctx, a, m} {} 702*47f75930SValentin Clement 703*47f75930SValentin Clement mlir::LogicalResult 704*47f75930SValentin Clement matchAndRewrite(ArrayUpdateOp update, 705*47f75930SValentin Clement mlir::PatternRewriter &rewriter) const override { 706*47f75930SValentin Clement auto loc = update.getLoc(); 707*47f75930SValentin Clement auto assignElement = [&](mlir::Value coor) { 708*47f75930SValentin Clement rewriter.create<fir::StoreOp>(loc, update.merge(), coor); 709*47f75930SValentin Clement }; 710*47f75930SValentin Clement auto lhsEltRefType = toRefType(update.merge().getType()); 711*47f75930SValentin Clement auto [_, lhsLoadResult] = materializeAssignment( 712*47f75930SValentin Clement loc, rewriter, update, assignElement, lhsEltRefType); 713*47f75930SValentin Clement update.replaceAllUsesWith(lhsLoadResult); 714*47f75930SValentin Clement rewriter.replaceOp(update, lhsLoadResult); 715*47f75930SValentin Clement return mlir::success(); 716*47f75930SValentin Clement } 717*47f75930SValentin Clement }; 718*47f75930SValentin Clement 719*47f75930SValentin Clement class ArrayModifyConversion : public ArrayUpdateConversionBase<ArrayModifyOp> { 720*47f75930SValentin Clement public: 721*47f75930SValentin Clement explicit ArrayModifyConversion(mlir::MLIRContext *ctx, 722*47f75930SValentin Clement const ArrayCopyAnalysis &a, 723*47f75930SValentin Clement const OperationUseMapT &m) 724*47f75930SValentin Clement : ArrayUpdateConversionBase{ctx, a, m} {} 725*47f75930SValentin Clement 726*47f75930SValentin Clement mlir::LogicalResult 727*47f75930SValentin Clement matchAndRewrite(ArrayModifyOp modify, 728*47f75930SValentin Clement mlir::PatternRewriter &rewriter) const override { 729*47f75930SValentin Clement auto loc = modify.getLoc(); 730*47f75930SValentin Clement auto assignElement = [](mlir::Value) { 731*47f75930SValentin Clement // Assignment already materialized by lowering using lhs element address. 732*47f75930SValentin Clement }; 733*47f75930SValentin Clement auto lhsEltRefType = modify.getResult(0).getType(); 734*47f75930SValentin Clement auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( 735*47f75930SValentin Clement loc, rewriter, modify, assignElement, lhsEltRefType); 736*47f75930SValentin Clement modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 737*47f75930SValentin Clement rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); 738*47f75930SValentin Clement return mlir::success(); 739*47f75930SValentin Clement } 740*47f75930SValentin Clement }; 741*47f75930SValentin Clement 742*47f75930SValentin Clement class ArrayFetchConversion : public mlir::OpRewritePattern<ArrayFetchOp> { 743*47f75930SValentin Clement public: 744*47f75930SValentin Clement explicit ArrayFetchConversion(mlir::MLIRContext *ctx, 745*47f75930SValentin Clement const OperationUseMapT &m) 746*47f75930SValentin Clement : OpRewritePattern{ctx}, useMap{m} {} 747*47f75930SValentin Clement 748*47f75930SValentin Clement mlir::LogicalResult 749*47f75930SValentin Clement matchAndRewrite(ArrayFetchOp fetch, 750*47f75930SValentin Clement mlir::PatternRewriter &rewriter) const override { 751*47f75930SValentin Clement auto *op = fetch.getOperation(); 752*47f75930SValentin Clement rewriter.setInsertionPoint(op); 753*47f75930SValentin Clement auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op)); 754*47f75930SValentin Clement auto loc = fetch.getLoc(); 755*47f75930SValentin Clement mlir::Value coor = 756*47f75930SValentin Clement genCoorOp(rewriter, loc, getEleTy(load.getType()), 757*47f75930SValentin Clement toRefType(fetch.getType()), load.memref(), load.shape(), 758*47f75930SValentin Clement load.slice(), fetch.indices(), load.typeparams(), 759*47f75930SValentin Clement fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); 760*47f75930SValentin Clement rewriter.replaceOpWithNewOp<fir::LoadOp>(fetch, coor); 761*47f75930SValentin Clement return mlir::success(); 762*47f75930SValentin Clement } 763*47f75930SValentin Clement 764*47f75930SValentin Clement private: 765*47f75930SValentin Clement const OperationUseMapT &useMap; 766*47f75930SValentin Clement }; 767*47f75930SValentin Clement } // namespace 768*47f75930SValentin Clement 769*47f75930SValentin Clement namespace { 770*47f75930SValentin Clement class ArrayValueCopyConverter 771*47f75930SValentin Clement : public ArrayValueCopyBase<ArrayValueCopyConverter> { 772*47f75930SValentin Clement public: 773*47f75930SValentin Clement void runOnFunction() override { 774*47f75930SValentin Clement auto func = getFunction(); 775*47f75930SValentin Clement LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" 776*47f75930SValentin Clement << func.getName() << "'\n"); 777*47f75930SValentin Clement auto *context = &getContext(); 778*47f75930SValentin Clement 779*47f75930SValentin Clement // Perform the conflict analysis. 780*47f75930SValentin Clement auto &analysis = getAnalysis<ArrayCopyAnalysis>(); 781*47f75930SValentin Clement const auto &useMap = analysis.getUseMap(); 782*47f75930SValentin Clement 783*47f75930SValentin Clement // Phase 1 is performing a rewrite on the array accesses. Once all the 784*47f75930SValentin Clement // array accesses are rewritten we can go on phase 2. 785*47f75930SValentin Clement // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in 786*47f75930SValentin Clement // /copy-out refers the Fortran copy-in/copy-out semantics on statements. 787*47f75930SValentin Clement mlir::OwningRewritePatternList patterns1(context); 788*47f75930SValentin Clement patterns1.insert<ArrayFetchConversion>(context, useMap); 789*47f75930SValentin Clement patterns1.insert<ArrayUpdateConversion>(context, analysis, useMap); 790*47f75930SValentin Clement patterns1.insert<ArrayModifyConversion>(context, analysis, useMap); 791*47f75930SValentin Clement mlir::ConversionTarget target(*context); 792*47f75930SValentin Clement target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 793*47f75930SValentin Clement mlir::arith::ArithmeticDialect, 794*47f75930SValentin Clement mlir::StandardOpsDialect>(); 795*47f75930SValentin Clement target.addIllegalOp<ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp>(); 796*47f75930SValentin Clement // Rewrite the array fetch and array update ops. 797*47f75930SValentin Clement if (mlir::failed( 798*47f75930SValentin Clement mlir::applyPartialConversion(func, target, std::move(patterns1)))) { 799*47f75930SValentin Clement mlir::emitError(mlir::UnknownLoc::get(context), 800*47f75930SValentin Clement "failure in array-value-copy pass, phase 1"); 801*47f75930SValentin Clement signalPassFailure(); 802*47f75930SValentin Clement } 803*47f75930SValentin Clement 804*47f75930SValentin Clement mlir::OwningRewritePatternList patterns2(context); 805*47f75930SValentin Clement patterns2.insert<ArrayLoadConversion>(context); 806*47f75930SValentin Clement patterns2.insert<ArrayMergeStoreConversion>(context); 807*47f75930SValentin Clement target.addIllegalOp<ArrayLoadOp, ArrayMergeStoreOp>(); 808*47f75930SValentin Clement if (mlir::failed( 809*47f75930SValentin Clement mlir::applyPartialConversion(func, target, std::move(patterns2)))) { 810*47f75930SValentin Clement mlir::emitError(mlir::UnknownLoc::get(context), 811*47f75930SValentin Clement "failure in array-value-copy pass, phase 2"); 812*47f75930SValentin Clement signalPassFailure(); 813*47f75930SValentin Clement } 814*47f75930SValentin Clement } 815*47f75930SValentin Clement }; 816*47f75930SValentin Clement } // namespace 817*47f75930SValentin Clement 818*47f75930SValentin Clement std::unique_ptr<mlir::Pass> fir::createArrayValueCopyPass() { 819*47f75930SValentin Clement return std::make_unique<ArrayValueCopyConverter>(); 820*47f75930SValentin Clement } 821