//===-- ArrayValueCopy.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Factory.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Support/FIRContext.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "flang-array-value-copy" using namespace fir; using namespace mlir; using OperationUseMapT = llvm::DenseMap; namespace { /// Array copy analysis. /// Perform an interference analysis between array values. /// /// Lowering will generate a sequence of the following form. /// ```mlir /// %a_1 = fir.array_load %array_1(%shape) : ... /// ... /// %a_j = fir.array_load %array_j(%shape) : ... /// ... /// %a_n = fir.array_load %array_n(%shape) : ... /// ... /// %v_i = fir.array_fetch %a_i, ... /// %a_j1 = fir.array_update %a_j, ... /// ... /// fir.array_merge_store %a_j, %a_jn to %array_j : ... /// ``` /// /// The analysis is to determine if there are any conflicts. A conflict is when /// one the following cases occurs. /// /// 1. There is an `array_update` to an array value, a_j, such that a_j was /// loaded from the same array memory reference (array_j) but with a different /// shape as the other array values a_i, where i != j. [Possible overlapping /// arrays.] /// /// 2. There is either an array_fetch or array_update of a_j with a different /// set of index values. [Possible loop-carried dependence.] /// /// If none of the array values overlap in storage and the accesses are not /// loop-carried, then the arrays are conflict-free and no copies are required. class ArrayCopyAnalysis { public: using ConflictSetT = llvm::SmallPtrSet; using UseSetT = llvm::SmallPtrSet; using LoadMapSetsT = llvm::DenseMap>; ArrayCopyAnalysis(mlir::Operation *op) : operation{op} { construct(op); } mlir::Operation *getOperation() const { return operation; } /// Return true iff the `array_merge_store` has potential conflicts. bool hasPotentialConflict(mlir::Operation *op) const { LLVM_DEBUG(llvm::dbgs() << "looking for a conflict on " << *op << " and the set has a total of " << conflicts.size() << '\n'); return conflicts.contains(op); } /// Return the use map. The use map maps array fetch and update operations /// back to the array load that is the original source of the array value. const OperationUseMapT &getUseMap() const { return useMap; } /// Find all the array operations that access the array value that is loaded /// by the array load operation, `load`. const llvm::SmallVector &arrayAccesses(ArrayLoadOp load); private: void construct(mlir::Operation *topLevelOp); mlir::Operation *operation; // operation that analysis ran upon ConflictSetT conflicts; // set of conflicts (loads and merge stores) OperationUseMapT useMap; LoadMapSetsT loadMapSets; }; } // namespace namespace { /// Helper class to collect all array operations that produced an array value. class ReachCollector { private: // If provided, the `loopRegion` is the body of a loop that produces the array // of interest. ReachCollector(llvm::SmallVectorImpl &reach, mlir::Region *loopRegion) : reach{reach}, loopRegion{loopRegion} {} void collectArrayAccessFrom(mlir::Operation *op, mlir::ValueRange range) { llvm::errs() << "COLLECT " << *op << "\n"; if (range.empty()) { collectArrayAccessFrom(op, mlir::Value{}); return; } for (mlir::Value v : range) collectArrayAccessFrom(v); } // TODO: Replace recursive algorithm on def-use chain with an iterative one // with an explicit stack. void collectArrayAccessFrom(mlir::Operation *op, mlir::Value val) { // `val` is defined by an Op, process the defining Op. // If `val` is defined by a region containing Op, we want to drill down // and through that Op's region(s). llvm::errs() << "COLLECT " << *op << "\n"; LLVM_DEBUG(llvm::dbgs() << "popset: " << *op << '\n'); auto popFn = [&](auto rop) { assert(val && "op must have a result value"); auto resNum = val.cast().getResultNumber(); llvm::SmallVector results; rop.resultToSourceOps(results, resNum); for (auto u : results) collectArrayAccessFrom(u); }; if (auto rop = mlir::dyn_cast(op)) { popFn(rop); return; } if (auto rop = mlir::dyn_cast(op)) { popFn(rop); return; } if (auto mergeStore = mlir::dyn_cast(op)) { if (opIsInsideLoops(mergeStore)) collectArrayAccessFrom(mergeStore.getSequence()); return; } if (mlir::isa(op)) { // Look for any stores inside the loops, and collect an array operation // that produced the value being stored to it. for (mlir::Operation *user : op->getUsers()) if (auto store = mlir::dyn_cast(user)) if (opIsInsideLoops(store)) collectArrayAccessFrom(store.getValue()); return; } // Otherwise, Op does not contain a region so just chase its operands. if (mlir::isa( op)) { LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n"); reach.emplace_back(op); } // Array modify assignment is performed on the result. So the analysis // must look at the what is done with the result. if (mlir::isa(op)) for (mlir::Operation *user : op->getResult(0).getUsers()) followUsers(user); for (auto u : op->getOperands()) collectArrayAccessFrom(u); } void collectArrayAccessFrom(mlir::BlockArgument ba) { auto *parent = ba.getOwner()->getParentOp(); // If inside an Op holding a region, the block argument corresponds to an // argument passed to the containing Op. auto popFn = [&](auto rop) { collectArrayAccessFrom(rop.blockArgToSourceOp(ba.getArgNumber())); }; if (auto rop = mlir::dyn_cast(parent)) { popFn(rop); return; } if (auto rop = mlir::dyn_cast(parent)) { popFn(rop); return; } // Otherwise, a block argument is provided via the pred blocks. for (auto *pred : ba.getOwner()->getPredecessors()) { auto u = pred->getTerminator()->getOperand(ba.getArgNumber()); collectArrayAccessFrom(u); } } // Recursively trace operands to find all array operations relating to the // values merged. void collectArrayAccessFrom(mlir::Value val) { if (!val || visited.contains(val)) return; visited.insert(val); // Process a block argument. if (auto ba = val.dyn_cast()) { collectArrayAccessFrom(ba); return; } // Process an Op. if (auto *op = val.getDefiningOp()) { collectArrayAccessFrom(op, val); return; } fir::emitFatalError(val.getLoc(), "unhandled value"); } /// Is \op inside the loop nest region ? bool opIsInsideLoops(mlir::Operation *op) const { return loopRegion && loopRegion->isAncestor(op->getParentRegion()); } /// Recursively trace the use of an operation results, calling /// collectArrayAccessFrom on the direct and indirect user operands. /// TODO: Replace recursive algorithm on def-use chain with an iterative one /// with an explicit stack. void followUsers(mlir::Operation *op) { for (auto userOperand : op->getOperands()) collectArrayAccessFrom(userOperand); // Go through potential converts/coordinate_op. for (mlir::Operation *indirectUser : op->getUsers()) followUsers(indirectUser); } llvm::SmallVectorImpl &reach; llvm::SmallPtrSet visited; /// Region of the loops nest that produced the array value. mlir::Region *loopRegion; public: /// Return all ops that produce the array value that is stored into the /// `array_merge_store`. static void reachingValues(llvm::SmallVectorImpl &reach, mlir::Value seq) { reach.clear(); mlir::Region *loopRegion = nullptr; // Only `DoLoopOp` is tested here since array operations are currently only // associated with this kind of loop. if (auto doLoop = mlir::dyn_cast_or_null(seq.getDefiningOp())) loopRegion = &doLoop->getRegion(0); ReachCollector collector(reach, loopRegion); collector.collectArrayAccessFrom(seq); } }; } // namespace /// Find all the array operations that access the array value that is loaded by /// the array load operation, `load`. const llvm::SmallVector & ArrayCopyAnalysis::arrayAccesses(ArrayLoadOp load) { auto lmIter = loadMapSets.find(load); if (lmIter != loadMapSets.end()) return lmIter->getSecond(); llvm::SmallVector accesses; UseSetT visited; llvm::SmallVector queue; // uses of ArrayLoad[orig] auto appendToQueue = [&](mlir::Value val) { for (mlir::OpOperand &use : val.getUses()) if (!visited.count(&use)) { visited.insert(&use); queue.push_back(&use); } }; // Build the set of uses of `original`. // let USES = { uses of original fir.load } appendToQueue(load); // Process the worklist until done. while (!queue.empty()) { mlir::OpOperand *operand = queue.pop_back_val(); mlir::Operation *owner = operand->getOwner(); auto structuredLoop = [&](auto ro) { if (auto blockArg = ro.iterArgToBlockArg(operand->get())) { int64_t arg = blockArg.getArgNumber(); mlir::Value output = ro.getResult(ro.getFinalValue() ? arg : arg - 1); appendToQueue(output); appendToQueue(blockArg); } }; // TODO: this need to be updated to use the control-flow interface. auto branchOp = [&](mlir::Block *dest, OperandRange operands) { if (operands.empty()) return; // Check if this operand is within the range. unsigned operandIndex = operand->getOperandNumber(); unsigned operandsStart = operands.getBeginOperandIndex(); if (operandIndex < operandsStart || operandIndex >= (operandsStart + operands.size())) return; // Index the successor. unsigned argIndex = operandIndex - operandsStart; appendToQueue(dest->getArgument(argIndex)); }; // Thread uses into structured loop bodies and return value uses. if (auto ro = mlir::dyn_cast(owner)) { structuredLoop(ro); } else if (auto ro = mlir::dyn_cast(owner)) { structuredLoop(ro); } else if (auto rs = mlir::dyn_cast(owner)) { // Thread any uses of fir.if that return the marked array value. if (auto ifOp = rs->getParentOfType()) appendToQueue(ifOp.getResult(operand->getOperandNumber())); } else if (mlir::isa(owner)) { // Keep track of array value fetches. LLVM_DEBUG(llvm::dbgs() << "add fetch {" << *owner << "} to array value set\n"); accesses.push_back(owner); } else if (auto update = mlir::dyn_cast(owner)) { // Keep track of array value updates and thread the return value uses. LLVM_DEBUG(llvm::dbgs() << "add update {" << *owner << "} to array value set\n"); accesses.push_back(owner); appendToQueue(update.getResult()); } else if (auto update = mlir::dyn_cast(owner)) { // Keep track of array value modification and thread the return value // uses. LLVM_DEBUG(llvm::dbgs() << "add modify {" << *owner << "} to array value set\n"); accesses.push_back(owner); appendToQueue(update.getResult(1)); } else if (auto br = mlir::dyn_cast(owner)) { branchOp(br.getDest(), br.getDestOperands()); } else if (auto br = mlir::dyn_cast(owner)) { branchOp(br.getTrueDest(), br.getTrueOperands()); branchOp(br.getFalseDest(), br.getFalseOperands()); } else if (mlir::isa(owner)) { // do nothing } else { llvm::report_fatal_error("array value reached unexpected op"); } } return loadMapSets.insert({load, accesses}).first->getSecond(); } /// Is there a conflict between the array value that was updated and to be /// stored to `st` and the set of arrays loaded (`reach`) and used to compute /// the updated value? static bool conflictOnLoad(llvm::ArrayRef reach, ArrayMergeStoreOp st) { mlir::Value load; mlir::Value addr = st.getMemref(); auto stEleTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType()); for (auto *op : reach) { auto ld = mlir::dyn_cast(op); if (!ld) continue; mlir::Type ldTy = ld.getMemref().getType(); if (auto boxTy = ldTy.dyn_cast()) ldTy = boxTy.getEleTy(); if (ldTy.isa() && stEleTy == dyn_cast_ptrEleTy(ldTy)) return true; if (ld.getMemref() == addr) { if (ld.getResult() != st.getOriginal()) return true; if (load) return true; load = ld; } } return false; } /// Check if there is any potential conflict in the chained update operations /// (ArrayFetchOp, ArrayUpdateOp, ArrayModifyOp) while merging back to the /// array. A potential conflict is detected if two operations work on the same /// indices. static bool conflictOnMerge(llvm::ArrayRef accesses) { if (accesses.size() < 2) return false; llvm::SmallVector indices; LLVM_DEBUG(llvm::dbgs() << "check merge conflict on with " << accesses.size() << " accesses on the list\n"); for (auto *op : accesses) { assert((mlir::isa(op)) && "unexpected operation in analysis"); llvm::SmallVector compareVector; if (auto u = mlir::dyn_cast(op)) { if (indices.empty()) { indices = u.getIndices(); continue; } compareVector = u.getIndices(); } else if (auto f = mlir::dyn_cast(op)) { if (indices.empty()) { indices = f.getIndices(); continue; } compareVector = f.getIndices(); } else if (auto f = mlir::dyn_cast(op)) { if (indices.empty()) { indices = f.getIndices(); continue; } compareVector = f.getIndices(); } if (compareVector != indices) return true; LLVM_DEBUG(llvm::dbgs() << "vectors compare equal\n"); } return false; } // Are either of types of conflicts present? inline bool conflictDetected(llvm::ArrayRef reach, llvm::ArrayRef accesses, ArrayMergeStoreOp st) { return conflictOnLoad(reach, st) || conflictOnMerge(accesses); } /// Constructor of the array copy analysis. /// This performs the analysis and saves the intermediate results. void ArrayCopyAnalysis::construct(mlir::Operation *topLevelOp) { topLevelOp->walk([&](Operation *op) { if (auto st = mlir::dyn_cast(op)) { llvm::SmallVector values; ReachCollector::reachingValues(values, st.getSequence()); const llvm::SmallVector &accesses = arrayAccesses( mlir::cast(st.getOriginal().getDefiningOp())); if (conflictDetected(values, accesses, st)) { LLVM_DEBUG(llvm::dbgs() << "CONFLICT: copies required for " << st << '\n' << " adding conflicts on: " << op << " and " << st.getOriginal() << '\n'); conflicts.insert(op); conflicts.insert(st.getOriginal().getDefiningOp()); } auto *ld = st.getOriginal().getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << "map: adding {" << *ld << " -> " << st << "}\n"); useMap.insert({ld, op}); } else if (auto load = mlir::dyn_cast(op)) { const llvm::SmallVector &accesses = arrayAccesses(load); LLVM_DEBUG(llvm::dbgs() << "process load: " << load << ", accesses: " << accesses.size() << '\n'); for (auto *acc : accesses) { LLVM_DEBUG(llvm::dbgs() << " access: " << *acc << '\n'); assert((mlir::isa(acc))); if (!useMap.insert({acc, op}).second) { mlir::emitError( load.getLoc(), "The parallel semantics of multiple array_merge_stores per " "array_load are not supported."); return; } LLVM_DEBUG(llvm::dbgs() << "map: adding {" << *acc << "} -> {" << load << "}\n"); } } }); } namespace { class ArrayLoadConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(ArrayLoadOp load, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "replace load " << load << " with undef.\n"); rewriter.replaceOpWithNewOp(load, load.getType()); return mlir::success(); } }; class ArrayMergeStoreConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(ArrayMergeStoreOp store, mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "marking store " << store << " as dead.\n"); rewriter.eraseOp(store); return mlir::success(); } }; } // namespace static mlir::Type getEleTy(mlir::Type ty) { if (auto t = dyn_cast_ptrEleTy(ty)) ty = t; if (auto t = ty.dyn_cast()) ty = t.getEleTy(); // FIXME: keep ptr/heap/ref information. return ReferenceType::get(ty); } // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. // TODO: getExtents on op should return a ValueRange instead of a vector. static void getExtents(llvm::SmallVectorImpl &result, mlir::Value shape) { auto *shapeOp = shape.getDefiningOp(); if (auto s = mlir::dyn_cast(shapeOp)) { auto e = s.getExtents(); result.insert(result.end(), e.begin(), e.end()); return; } if (auto s = mlir::dyn_cast(shapeOp)) { auto e = s.getExtents(); result.insert(result.end(), e.begin(), e.end()); return; } llvm::report_fatal_error("not a fir.shape/fir.shape_shift op"); } // Place the extents of the array loaded by an ArrayLoadOp into the result // vector and return a ShapeOp/ShapeShiftOp with the corresponding extents. If // the ArrayLoadOp is loading a fir.box, code will be generated to read the // extents from the fir.box, and a the retunred ShapeOp is built with the read // extents. // Otherwise, the extents will be extracted from the ShapeOp/ShapeShiftOp // argument of the ArrayLoadOp that is returned. static mlir::Value getOrReadExtentsAndShapeOp(mlir::Location loc, mlir::PatternRewriter &rewriter, fir::ArrayLoadOp loadOp, llvm::SmallVectorImpl &result) { assert(result.empty()); if (auto boxTy = loadOp.getMemref().getType().dyn_cast()) { auto rank = fir::dyn_cast_ptrOrBoxEleTy(boxTy) .cast() .getDimension(); auto idxTy = rewriter.getIndexType(); for (decltype(rank) dim = 0; dim < rank; ++dim) { auto dimVal = rewriter.create(loc, dim); auto dimInfo = rewriter.create( loc, idxTy, idxTy, idxTy, loadOp.getMemref(), dimVal); result.emplace_back(dimInfo.getResult(1)); } auto shapeType = fir::ShapeType::get(rewriter.getContext(), rank); return rewriter.create(loc, shapeType, result); } getExtents(result, loadOp.getShape()); return loadOp.getShape(); } static mlir::Type toRefType(mlir::Type ty) { if (fir::isa_ref_type(ty)) return ty; return fir::ReferenceType::get(ty); } static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy, mlir::Type resTy, mlir::Value alloc, mlir::Value shape, mlir::Value slice, mlir::ValueRange indices, mlir::ValueRange typeparams, bool skipOrig = false) { llvm::SmallVector originated; if (skipOrig) originated.assign(indices.begin(), indices.end()); else originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(), shape, indices); auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType()); assert(seqTy && seqTy.isa()); const auto dimension = seqTy.cast().getDimension(); mlir::Value result = rewriter.create( loc, eleTy, alloc, shape, slice, llvm::ArrayRef{originated}.take_front(dimension), typeparams); if (dimension < originated.size()) result = rewriter.create( loc, resTy, result, llvm::ArrayRef{originated}.drop_front(dimension)); return result; } namespace { /// Conversion of fir.array_update and fir.array_modify Ops. /// If there is a conflict for the update, then we need to perform a /// copy-in/copy-out to preserve the original values of the array. If there is /// no conflict, then it is save to eschew making any copies. template class ArrayUpdateConversionBase : public mlir::OpRewritePattern { public: explicit ArrayUpdateConversionBase(mlir::MLIRContext *ctx, const ArrayCopyAnalysis &a, const OperationUseMapT &m) : mlir::OpRewritePattern{ctx}, analysis{a}, useMap{m} {} void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value dst, mlir::Value src, mlir::Value shapeOp, mlir::Type arrTy) const { auto insPt = rewriter.saveInsertionPoint(); llvm::SmallVector indices; llvm::SmallVector extents; getExtents(extents, shapeOp); // Build loop nest from column to row. for (auto sh : llvm::reverse(extents)) { auto idxTy = rewriter.getIndexType(); auto ubi = rewriter.create(loc, idxTy, sh); auto zero = rewriter.create(loc, 0); auto one = rewriter.create(loc, 1); auto ub = rewriter.create(loc, idxTy, ubi, one); auto loop = rewriter.create(loc, zero, ub, one); rewriter.setInsertionPointToStart(loop.getBody()); indices.push_back(loop.getInductionVar()); } // Reverse the indices so they are in column-major order. std::reverse(indices.begin(), indices.end()); auto ty = getEleTy(arrTy); auto fromAddr = rewriter.create( loc, ty, src, shapeOp, mlir::Value{}, fir::factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices), mlir::ValueRange{}); auto load = rewriter.create(loc, fromAddr); auto toAddr = rewriter.create( loc, ty, dst, shapeOp, mlir::Value{}, fir::factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices), mlir::ValueRange{}); rewriter.create(loc, load, toAddr); rewriter.restoreInsertionPoint(insPt); } /// Copy the RHS element into the LHS and insert copy-in/copy-out between a /// temp and the LHS if the analysis found potential overlaps between the RHS /// and LHS arrays. The element copy generator must be provided through \p /// assignElement. \p update must be the ArrayUpdateOp or the ArrayModifyOp. /// Returns the address of the LHS element inside the loop and the LHS /// ArrayLoad result. std::pair materializeAssignment(mlir::Location loc, mlir::PatternRewriter &rewriter, ArrayOp update, llvm::function_ref assignElement, mlir::Type lhsEltRefType) const { auto *op = update.getOperation(); mlir::Operation *loadOp = useMap.lookup(op); auto load = mlir::cast(loadOp); LLVM_DEBUG(llvm::outs() << "does " << load << " have a conflict?\n"); if (analysis.hasPotentialConflict(loadOp)) { // If there is a conflict between the arrays, then we copy the lhs array // to a temporary, update the temporary, and copy the temporary back to // the lhs array. This yields Fortran's copy-in copy-out array semantics. LLVM_DEBUG(llvm::outs() << "Yes, conflict was found\n"); rewriter.setInsertionPoint(loadOp); // Copy in. llvm::SmallVector extents; mlir::Value shapeOp = getOrReadExtentsAndShapeOp(loc, rewriter, load, extents); auto allocmem = rewriter.create( loc, dyn_cast_ptrOrBoxEleTy(load.getMemref().getType()), load.getTypeparams(), extents); genArrayCopy(load.getLoc(), rewriter, allocmem, load.getMemref(), shapeOp, load.getType()); rewriter.setInsertionPoint(op); mlir::Value coor = genCoorOp( rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem, shapeOp, load.getSlice(), update.getIndices(), load.getTypeparams(), update->hasAttr(fir::factory::attrFortranArrayOffsets())); assignElement(coor); mlir::Operation *storeOp = useMap.lookup(loadOp); auto store = mlir::cast(storeOp); rewriter.setInsertionPoint(storeOp); // Copy out. genArrayCopy(store.getLoc(), rewriter, store.getMemref(), allocmem, shapeOp, load.getType()); rewriter.create(loc, allocmem); return {coor, load.getResult()}; } // Otherwise, when there is no conflict (a possible loop-carried // dependence), the lhs array can be updated in place. LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n"); rewriter.setInsertionPoint(op); auto coorTy = getEleTy(load.getType()); mlir::Value coor = genCoorOp( rewriter, loc, coorTy, lhsEltRefType, load.getMemref(), load.getShape(), load.getSlice(), update.getIndices(), load.getTypeparams(), update->hasAttr(fir::factory::attrFortranArrayOffsets())); assignElement(coor); return {coor, load.getResult()}; } private: const ArrayCopyAnalysis &analysis; const OperationUseMapT &useMap; }; class ArrayUpdateConversion : public ArrayUpdateConversionBase { public: explicit ArrayUpdateConversion(mlir::MLIRContext *ctx, const ArrayCopyAnalysis &a, const OperationUseMapT &m) : ArrayUpdateConversionBase{ctx, a, m} {} mlir::LogicalResult matchAndRewrite(ArrayUpdateOp update, mlir::PatternRewriter &rewriter) const override { auto loc = update.getLoc(); auto assignElement = [&](mlir::Value coor) { rewriter.create(loc, update.getMerge(), coor); }; auto lhsEltRefType = toRefType(update.getMerge().getType()); auto [_, lhsLoadResult] = materializeAssignment( loc, rewriter, update, assignElement, lhsEltRefType); update.replaceAllUsesWith(lhsLoadResult); rewriter.replaceOp(update, lhsLoadResult); return mlir::success(); } }; class ArrayModifyConversion : public ArrayUpdateConversionBase { public: explicit ArrayModifyConversion(mlir::MLIRContext *ctx, const ArrayCopyAnalysis &a, const OperationUseMapT &m) : ArrayUpdateConversionBase{ctx, a, m} {} mlir::LogicalResult matchAndRewrite(ArrayModifyOp modify, mlir::PatternRewriter &rewriter) const override { auto loc = modify.getLoc(); auto assignElement = [](mlir::Value) { // Assignment already materialized by lowering using lhs element address. }; auto lhsEltRefType = modify.getResult(0).getType(); auto [lhsEltCoor, lhsLoadResult] = materializeAssignment( loc, rewriter, modify, assignElement, lhsEltRefType); modify.replaceAllUsesWith(mlir::ValueRange{lhsEltCoor, lhsLoadResult}); rewriter.replaceOp(modify, mlir::ValueRange{lhsEltCoor, lhsLoadResult}); return mlir::success(); } }; class ArrayFetchConversion : public mlir::OpRewritePattern { public: explicit ArrayFetchConversion(mlir::MLIRContext *ctx, const OperationUseMapT &m) : OpRewritePattern{ctx}, useMap{m} {} mlir::LogicalResult matchAndRewrite(ArrayFetchOp fetch, mlir::PatternRewriter &rewriter) const override { auto *op = fetch.getOperation(); rewriter.setInsertionPoint(op); auto load = mlir::cast(useMap.lookup(op)); auto loc = fetch.getLoc(); mlir::Value coor = genCoorOp(rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()), load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(), load.getTypeparams(), fetch->hasAttr(fir::factory::attrFortranArrayOffsets())); rewriter.replaceOpWithNewOp(fetch, coor); return mlir::success(); } private: const OperationUseMapT &useMap; }; } // namespace namespace { class ArrayValueCopyConverter : public ArrayValueCopyBase { public: void runOnOperation() override { auto func = getOperation(); LLVM_DEBUG(llvm::dbgs() << "\n\narray-value-copy pass on function '" << func.getName() << "'\n"); auto *context = &getContext(); // Perform the conflict analysis. auto &analysis = getAnalysis(); const auto &useMap = analysis.getUseMap(); // Phase 1 is performing a rewrite on the array accesses. Once all the // array accesses are rewritten we can go on phase 2. // Phase 2 gets rid of the useless copy-in/copyout operations. The copy-in // /copy-out refers the Fortran copy-in/copy-out semantics on statements. mlir::RewritePatternSet patterns1(context); patterns1.insert(context, useMap); patterns1.insert(context, analysis, useMap); patterns1.insert(context, analysis, useMap); mlir::ConversionTarget target(*context); target.addLegalDialect< FIROpsDialect, mlir::scf::SCFDialect, mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect, mlir::func::FuncDialect>(); target.addIllegalOp(); // Rewrite the array fetch and array update ops. if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns1)))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in array-value-copy pass, phase 1"); signalPassFailure(); } mlir::RewritePatternSet patterns2(context); patterns2.insert(context); patterns2.insert(context); target.addIllegalOp(); if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns2)))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in array-value-copy pass, phase 2"); signalPassFailure(); } } }; } // namespace std::unique_ptr fir::createArrayValueCopyPass() { return std::make_unique(); }