13d092e31SEric Schweitz //===- MemoryAllocation.cpp -----------------------------------------------===//
23d092e31SEric Schweitz //
33d092e31SEric Schweitz // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43d092e31SEric Schweitz // See https://llvm.org/LICENSE.txt for license information.
53d092e31SEric Schweitz // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63d092e31SEric Schweitz //
73d092e31SEric Schweitz //===----------------------------------------------------------------------===//
83d092e31SEric Schweitz 
93d092e31SEric Schweitz #include "PassDetail.h"
103d092e31SEric Schweitz #include "flang/Optimizer/Dialect/FIRDialect.h"
113d092e31SEric Schweitz #include "flang/Optimizer/Dialect/FIROps.h"
123d092e31SEric Schweitz #include "flang/Optimizer/Dialect/FIRType.h"
133d092e31SEric Schweitz #include "flang/Optimizer/Transforms/Passes.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
153d092e31SEric Schweitz #include "mlir/IR/Diagnostics.h"
163d092e31SEric Schweitz #include "mlir/Pass/Pass.h"
173d092e31SEric Schweitz #include "mlir/Transforms/DialectConversion.h"
183d092e31SEric Schweitz #include "mlir/Transforms/Passes.h"
193d092e31SEric Schweitz #include "llvm/ADT/TypeSwitch.h"
203d092e31SEric Schweitz 
213d092e31SEric Schweitz #define DEBUG_TYPE "flang-memory-allocation-opt"
223d092e31SEric Schweitz 
233d092e31SEric Schweitz // Number of elements in an array does not determine where it is allocated.
244d53f88dSValentin Clement static constexpr std::size_t unlimitedArraySize = ~static_cast<std::size_t>(0);
253d092e31SEric Schweitz 
263d092e31SEric Schweitz namespace {
273d092e31SEric Schweitz struct MemoryAllocationOptions {
283d092e31SEric Schweitz   // Always move dynamic array allocations to the heap. This may result in more
293d092e31SEric Schweitz   // heap fragmentation, so may impact performance negatively.
303d092e31SEric Schweitz   bool dynamicArrayOnHeap = false;
313d092e31SEric Schweitz 
323d092e31SEric Schweitz   // Number of elements in array threshold for moving to heap. In environments
333d092e31SEric Schweitz   // with limited stack size, moving large arrays to the heap can avoid running
343d092e31SEric Schweitz   // out of stack space.
354d53f88dSValentin Clement   std::size_t maxStackArraySize = unlimitedArraySize;
363d092e31SEric Schweitz };
373d092e31SEric Schweitz 
383d092e31SEric Schweitz class ReturnAnalysis {
393d092e31SEric Schweitz public:
405e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReturnAnalysis)
415e50dd04SRiver Riddle 
ReturnAnalysis(mlir::Operation * op)423d092e31SEric Schweitz   ReturnAnalysis(mlir::Operation *op) {
43*58ceae95SRiver Riddle     if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op))
443d092e31SEric Schweitz       for (mlir::Block &block : func)
453d092e31SEric Schweitz         for (mlir::Operation &i : block)
4623aa5a74SRiver Riddle           if (mlir::isa<mlir::func::ReturnOp>(i)) {
473d092e31SEric Schweitz             returnMap[op].push_back(&i);
483d092e31SEric Schweitz             break;
493d092e31SEric Schweitz           }
503d092e31SEric Schweitz   }
513d092e31SEric Schweitz 
getReturns(mlir::Operation * func) const523d092e31SEric Schweitz   llvm::SmallVector<mlir::Operation *> getReturns(mlir::Operation *func) const {
533d092e31SEric Schweitz     auto iter = returnMap.find(func);
543d092e31SEric Schweitz     if (iter != returnMap.end())
553d092e31SEric Schweitz       return iter->second;
563d092e31SEric Schweitz     return {};
573d092e31SEric Schweitz   }
583d092e31SEric Schweitz 
593d092e31SEric Schweitz private:
603d092e31SEric Schweitz   llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
613d092e31SEric Schweitz       returnMap;
623d092e31SEric Schweitz };
633d092e31SEric Schweitz } // namespace
643d092e31SEric Schweitz 
653d092e31SEric Schweitz /// Return `true` if this allocation is to remain on the stack (`fir.alloca`).
663d092e31SEric Schweitz /// Otherwise the allocation should be moved to the heap (`fir.allocmem`).
keepStackAllocation(fir::AllocaOp alloca,mlir::Block * entry,const MemoryAllocationOptions & options)673d092e31SEric Schweitz static inline bool keepStackAllocation(fir::AllocaOp alloca, mlir::Block *entry,
683d092e31SEric Schweitz                                        const MemoryAllocationOptions &options) {
693d092e31SEric Schweitz   // Limitation: only arrays allocated on the stack in the entry block are
703d092e31SEric Schweitz   // considered for now.
713d092e31SEric Schweitz   // TODO: Generalize the algorithm and placement of the freemem nodes.
723d092e31SEric Schweitz   if (alloca->getBlock() != entry)
733d092e31SEric Schweitz     return true;
743d092e31SEric Schweitz   if (auto seqTy = alloca.getInType().dyn_cast<fir::SequenceType>()) {
753d092e31SEric Schweitz     if (fir::hasDynamicSize(seqTy)) {
763d092e31SEric Schweitz       // Move all arrays with runtime determined size to the heap.
773d092e31SEric Schweitz       if (options.dynamicArrayOnHeap)
783d092e31SEric Schweitz         return false;
793d092e31SEric Schweitz     } else {
803d092e31SEric Schweitz       std::int64_t numberOfElements = 1;
813d092e31SEric Schweitz       for (std::int64_t i : seqTy.getShape()) {
823d092e31SEric Schweitz         numberOfElements *= i;
833d092e31SEric Schweitz         // If the count is suspicious, then don't change anything here.
843d092e31SEric Schweitz         if (numberOfElements <= 0)
853d092e31SEric Schweitz           return true;
863d092e31SEric Schweitz       }
873d092e31SEric Schweitz       // If the number of elements exceeds the threshold, move the allocation to
883d092e31SEric Schweitz       // the heap.
893d092e31SEric Schweitz       if (static_cast<std::size_t>(numberOfElements) >
903d092e31SEric Schweitz           options.maxStackArraySize) {
913d092e31SEric Schweitz         LLVM_DEBUG(llvm::dbgs()
923d092e31SEric Schweitz                    << "memory allocation opt: found " << alloca << '\n');
933d092e31SEric Schweitz         return false;
943d092e31SEric Schweitz       }
953d092e31SEric Schweitz     }
963d092e31SEric Schweitz   }
973d092e31SEric Schweitz   return true;
983d092e31SEric Schweitz }
993d092e31SEric Schweitz 
1003d092e31SEric Schweitz namespace {
1013d092e31SEric Schweitz class AllocaOpConversion : public mlir::OpRewritePattern<fir::AllocaOp> {
1023d092e31SEric Schweitz public:
1033d092e31SEric Schweitz   using OpRewritePattern::OpRewritePattern;
1043d092e31SEric Schweitz 
AllocaOpConversion(mlir::MLIRContext * ctx,llvm::ArrayRef<mlir::Operation * > rets)1053d092e31SEric Schweitz   AllocaOpConversion(mlir::MLIRContext *ctx,
1063d092e31SEric Schweitz                      llvm::ArrayRef<mlir::Operation *> rets)
1073d092e31SEric Schweitz       : OpRewritePattern(ctx), returnOps(rets) {}
1083d092e31SEric Schweitz 
1093d092e31SEric Schweitz   mlir::LogicalResult
matchAndRewrite(fir::AllocaOp alloca,mlir::PatternRewriter & rewriter) const1103d092e31SEric Schweitz   matchAndRewrite(fir::AllocaOp alloca,
1113d092e31SEric Schweitz                   mlir::PatternRewriter &rewriter) const override {
1123d092e31SEric Schweitz     auto loc = alloca.getLoc();
1133d092e31SEric Schweitz     mlir::Type varTy = alloca.getInType();
1143d092e31SEric Schweitz     auto unpackName =
1153d092e31SEric Schweitz         [](llvm::Optional<llvm::StringRef> opt) -> llvm::StringRef {
1163d092e31SEric Schweitz       if (opt)
1173d092e31SEric Schweitz         return *opt;
1183d092e31SEric Schweitz       return {};
1193d092e31SEric Schweitz     };
120149ad3d5SShraiysh Vaishay     auto uniqName = unpackName(alloca.getUniqName());
121149ad3d5SShraiysh Vaishay     auto bindcName = unpackName(alloca.getBindcName());
1223d092e31SEric Schweitz     auto heap = rewriter.create<fir::AllocMemOp>(
123149ad3d5SShraiysh Vaishay         loc, varTy, uniqName, bindcName, alloca.getTypeparams(),
124149ad3d5SShraiysh Vaishay         alloca.getShape());
1253d092e31SEric Schweitz     auto insPt = rewriter.saveInsertionPoint();
1263d092e31SEric Schweitz     for (mlir::Operation *retOp : returnOps) {
1273d092e31SEric Schweitz       rewriter.setInsertionPoint(retOp);
1283d092e31SEric Schweitz       [[maybe_unused]] auto free = rewriter.create<fir::FreeMemOp>(loc, heap);
1293d092e31SEric Schweitz       LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: add free " << free
1303d092e31SEric Schweitz                               << " for " << heap << '\n');
1313d092e31SEric Schweitz     }
1323d092e31SEric Schweitz     rewriter.restoreInsertionPoint(insPt);
1333d092e31SEric Schweitz     rewriter.replaceOpWithNewOp<fir::ConvertOp>(
1343d092e31SEric Schweitz         alloca, fir::ReferenceType::get(varTy), heap);
1353d092e31SEric Schweitz     LLVM_DEBUG(llvm::dbgs() << "memory allocation opt: replaced " << alloca
1363d092e31SEric Schweitz                             << " with " << heap << '\n');
1373d092e31SEric Schweitz     return mlir::success();
1383d092e31SEric Schweitz   }
1393d092e31SEric Schweitz 
1403d092e31SEric Schweitz private:
1413d092e31SEric Schweitz   llvm::ArrayRef<mlir::Operation *> returnOps;
1423d092e31SEric Schweitz };
1433d092e31SEric Schweitz 
1443d092e31SEric Schweitz /// This pass can reclassify memory allocations (fir.alloca, fir.allocmem) based
1453d092e31SEric Schweitz /// on heuristics and settings. The intention is to allow better performance and
1463d092e31SEric Schweitz /// workarounds for conditions such as environments with limited stack space.
1473d092e31SEric Schweitz ///
1483d092e31SEric Schweitz /// Currently, implements two conversions from stack to heap allocation.
1493d092e31SEric Schweitz ///   1. If a stack allocation is an array larger than some threshold value
1503d092e31SEric Schweitz ///      make it a heap allocation.
1513d092e31SEric Schweitz ///   2. If a stack allocation is an array with a runtime evaluated size make
1523d092e31SEric Schweitz ///      it a heap allocation.
1533d092e31SEric Schweitz class MemoryAllocationOpt
1543d092e31SEric Schweitz     : public fir::MemoryAllocationOptBase<MemoryAllocationOpt> {
1553d092e31SEric Schweitz public:
MemoryAllocationOpt()1564d53f88dSValentin Clement   MemoryAllocationOpt() {
1574d53f88dSValentin Clement     // Set options with default values. (See Passes.td.) Note that the
1584d53f88dSValentin Clement     // command-line options, e.g. dynamicArrayOnHeap,  are not set yet.
1594d53f88dSValentin Clement     options = {dynamicArrayOnHeap, maxStackArraySize};
1604d53f88dSValentin Clement   }
1614d53f88dSValentin Clement 
MemoryAllocationOpt(bool dynOnHeap,std::size_t maxStackSize)1624d53f88dSValentin Clement   MemoryAllocationOpt(bool dynOnHeap, std::size_t maxStackSize) {
1634d53f88dSValentin Clement     // Set options with default values. (See Passes.td.)
1644d53f88dSValentin Clement     options = {dynOnHeap, maxStackSize};
1654d53f88dSValentin Clement   }
1664d53f88dSValentin Clement 
1674d53f88dSValentin Clement   /// Override `options` if command-line options have been set.
useCommandLineOptions()1684d53f88dSValentin Clement   inline void useCommandLineOptions() {
1694d53f88dSValentin Clement     if (dynamicArrayOnHeap)
1704d53f88dSValentin Clement       options.dynamicArrayOnHeap = dynamicArrayOnHeap;
1714d53f88dSValentin Clement     if (maxStackArraySize != unlimitedArraySize)
1724d53f88dSValentin Clement       options.maxStackArraySize = maxStackArraySize;
1734d53f88dSValentin Clement   }
1744d53f88dSValentin Clement 
runOnOperation()1753d092e31SEric Schweitz   void runOnOperation() override {
1763d092e31SEric Schweitz     auto *context = &getContext();
1773d092e31SEric Schweitz     auto func = getOperation();
1789f85c198SRiver Riddle     mlir::RewritePatternSet patterns(context);
1793d092e31SEric Schweitz     mlir::ConversionTarget target(*context);
1804d53f88dSValentin Clement 
1814d53f88dSValentin Clement     useCommandLineOptions();
1824d53f88dSValentin Clement     LLVM_DEBUG(llvm::dbgs()
1834d53f88dSValentin Clement                << "dynamic arrays on heap: " << options.dynamicArrayOnHeap
1844d53f88dSValentin Clement                << "\nmaximum number of elements of array on stack: "
1854d53f88dSValentin Clement                << options.maxStackArraySize << '\n');
1863d092e31SEric Schweitz 
1873d092e31SEric Schweitz     // If func is a declaration, skip it.
1883d092e31SEric Schweitz     if (func.empty())
1893d092e31SEric Schweitz       return;
1903d092e31SEric Schweitz 
1913d092e31SEric Schweitz     const auto &analysis = getAnalysis<ReturnAnalysis>();
1923d092e31SEric Schweitz 
1933d092e31SEric Schweitz     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
19423aa5a74SRiver Riddle                            mlir::func::FuncDialect>();
1953d092e31SEric Schweitz     target.addDynamicallyLegalOp<fir::AllocaOp>([&](fir::AllocaOp alloca) {
1963d092e31SEric Schweitz       return keepStackAllocation(alloca, &func.front(), options);
1973d092e31SEric Schweitz     });
1983d092e31SEric Schweitz 
1993d092e31SEric Schweitz     patterns.insert<AllocaOpConversion>(context, analysis.getReturns(func));
2003d092e31SEric Schweitz     if (mlir::failed(
2013d092e31SEric Schweitz             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
2023d092e31SEric Schweitz       mlir::emitError(func.getLoc(),
2033d092e31SEric Schweitz                       "error in memory allocation optimization\n");
2043d092e31SEric Schweitz       signalPassFailure();
2053d092e31SEric Schweitz     }
2063d092e31SEric Schweitz   }
2074d53f88dSValentin Clement 
2084d53f88dSValentin Clement private:
2094d53f88dSValentin Clement   MemoryAllocationOptions options;
2103d092e31SEric Schweitz };
2113d092e31SEric Schweitz } // namespace
2123d092e31SEric Schweitz 
createMemoryAllocationPass()2133d092e31SEric Schweitz std::unique_ptr<mlir::Pass> fir::createMemoryAllocationPass() {
2143d092e31SEric Schweitz   return std::make_unique<MemoryAllocationOpt>();
2153d092e31SEric Schweitz }
2164d53f88dSValentin Clement 
2174d53f88dSValentin Clement std::unique_ptr<mlir::Pass>
createMemoryAllocationPass(bool dynOnHeap,std::size_t maxStackSize)2184d53f88dSValentin Clement fir::createMemoryAllocationPass(bool dynOnHeap, std::size_t maxStackSize) {
2194d53f88dSValentin Clement   return std::make_unique<MemoryAllocationOpt>(dynOnHeap, maxStackSize);
2204d53f88dSValentin Clement }
221