10e9a4a3bSRiver Riddle //===- BufferOptimizations.cpp - pre-pass optimizations for bufferization -===//
20e9a4a3bSRiver Riddle //
30e9a4a3bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40e9a4a3bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
50e9a4a3bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60e9a4a3bSRiver Riddle //
70e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
80e9a4a3bSRiver Riddle //
90e9a4a3bSRiver Riddle // This file implements logic for three optimization passes. The first two
100e9a4a3bSRiver Riddle // passes try to move alloc nodes out of blocks to reduce the number of
110e9a4a3bSRiver Riddle // allocations and copies during buffer deallocation. The third pass tries to
120e9a4a3bSRiver Riddle // convert heap-based allocations to stack-based allocations, if possible.
130e9a4a3bSRiver Riddle
140e9a4a3bSRiver Riddle #include "PassDetail.h"
150e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
160e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
170e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
180e9a4a3bSRiver Riddle #include "mlir/IR/Operation.h"
190e9a4a3bSRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.h"
200e9a4a3bSRiver Riddle #include "mlir/Pass/Pass.h"
210e9a4a3bSRiver Riddle
220e9a4a3bSRiver Riddle using namespace mlir;
230e9a4a3bSRiver Riddle using namespace mlir::bufferization;
240e9a4a3bSRiver Riddle
250e9a4a3bSRiver Riddle /// Returns true if the given operation implements a known high-level region-
260e9a4a3bSRiver Riddle /// based control-flow interface.
isKnownControlFlowInterface(Operation * op)270e9a4a3bSRiver Riddle static bool isKnownControlFlowInterface(Operation *op) {
280e9a4a3bSRiver Riddle return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
290e9a4a3bSRiver Riddle }
300e9a4a3bSRiver Riddle
310e9a4a3bSRiver Riddle /// Check if the size of the allocation is less than the given size. The
320e9a4a3bSRiver Riddle /// transformation is only applied to small buffers since large buffers could
330e9a4a3bSRiver Riddle /// exceed the stack space.
defaultIsSmallAlloc(Value alloc,unsigned maximumSizeInBytes,unsigned maxRankOfAllocatedMemRef)340e9a4a3bSRiver Riddle static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
350e9a4a3bSRiver Riddle unsigned maxRankOfAllocatedMemRef) {
360e9a4a3bSRiver Riddle auto type = alloc.getType().dyn_cast<ShapedType>();
370e9a4a3bSRiver Riddle if (!type || !alloc.getDefiningOp<memref::AllocOp>())
380e9a4a3bSRiver Riddle return false;
390e9a4a3bSRiver Riddle if (!type.hasStaticShape()) {
400e9a4a3bSRiver Riddle // Check if the dynamic shape dimension of the alloc is produced by
410e9a4a3bSRiver Riddle // `memref.rank`. If this is the case, it is likely to be small.
420e9a4a3bSRiver Riddle // Furthermore, the dimension is limited to the maximum rank of the
430e9a4a3bSRiver Riddle // allocated memref to avoid large values by multiplying several small
440e9a4a3bSRiver Riddle // values.
450e9a4a3bSRiver Riddle if (type.getRank() <= maxRankOfAllocatedMemRef) {
460e9a4a3bSRiver Riddle return llvm::all_of(alloc.getDefiningOp()->getOperands(),
470e9a4a3bSRiver Riddle [&](Value operand) {
480e9a4a3bSRiver Riddle return operand.getDefiningOp<memref::RankOp>();
490e9a4a3bSRiver Riddle });
500e9a4a3bSRiver Riddle }
510e9a4a3bSRiver Riddle return false;
520e9a4a3bSRiver Riddle }
53b70366c9SBenjamin Kramer unsigned bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp())
54b70366c9SBenjamin Kramer .getTypeSizeInBits(type.getElementType());
550e9a4a3bSRiver Riddle return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;
560e9a4a3bSRiver Riddle }
570e9a4a3bSRiver Riddle
580e9a4a3bSRiver Riddle /// Checks whether the given aliases leave the allocation scope.
590e9a4a3bSRiver Riddle static bool
leavesAllocationScope(Region * parentRegion,const BufferViewFlowAnalysis::ValueSetT & aliases)600e9a4a3bSRiver Riddle leavesAllocationScope(Region *parentRegion,
610e9a4a3bSRiver Riddle const BufferViewFlowAnalysis::ValueSetT &aliases) {
620e9a4a3bSRiver Riddle for (Value alias : aliases) {
630e9a4a3bSRiver Riddle for (auto *use : alias.getUsers()) {
640e9a4a3bSRiver Riddle // If there is at least one alias that leaves the parent region, we know
650e9a4a3bSRiver Riddle // that this alias escapes the whole region and hence the associated
660e9a4a3bSRiver Riddle // allocation leaves allocation scope.
670e9a4a3bSRiver Riddle if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion)
680e9a4a3bSRiver Riddle return true;
690e9a4a3bSRiver Riddle }
700e9a4a3bSRiver Riddle }
710e9a4a3bSRiver Riddle return false;
720e9a4a3bSRiver Riddle }
730e9a4a3bSRiver Riddle
740e9a4a3bSRiver Riddle /// Checks, if an automated allocation scope for a given alloc value exists.
hasAllocationScope(Value alloc,const BufferViewFlowAnalysis & aliasAnalysis)750e9a4a3bSRiver Riddle static bool hasAllocationScope(Value alloc,
760e9a4a3bSRiver Riddle const BufferViewFlowAnalysis &aliasAnalysis) {
770e9a4a3bSRiver Riddle Region *region = alloc.getParentRegion();
780e9a4a3bSRiver Riddle do {
790e9a4a3bSRiver Riddle if (Operation *parentOp = region->getParentOp()) {
800e9a4a3bSRiver Riddle // Check if the operation is an automatic allocation scope and whether an
810e9a4a3bSRiver Riddle // alias leaves the scope. This means, an allocation yields out of
820e9a4a3bSRiver Riddle // this scope and can not be transformed in a stack-based allocation.
830e9a4a3bSRiver Riddle if (parentOp->hasTrait<OpTrait::AutomaticAllocationScope>() &&
840e9a4a3bSRiver Riddle !leavesAllocationScope(region, aliasAnalysis.resolve(alloc)))
850e9a4a3bSRiver Riddle return true;
860e9a4a3bSRiver Riddle // Check if the operation is a known control flow interface and break the
870e9a4a3bSRiver Riddle // loop to avoid transformation in loops. Furthermore skip transformation
880e9a4a3bSRiver Riddle // if the operation does not implement a RegionBeanchOpInterface.
890e9a4a3bSRiver Riddle if (BufferPlacementTransformationBase::isLoop(parentOp) ||
900e9a4a3bSRiver Riddle !isKnownControlFlowInterface(parentOp))
910e9a4a3bSRiver Riddle break;
920e9a4a3bSRiver Riddle }
930e9a4a3bSRiver Riddle } while ((region = region->getParentRegion()));
940e9a4a3bSRiver Riddle return false;
950e9a4a3bSRiver Riddle }
960e9a4a3bSRiver Riddle
970e9a4a3bSRiver Riddle namespace {
980e9a4a3bSRiver Riddle
990e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
1000e9a4a3bSRiver Riddle // BufferAllocationHoisting
1010e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
1020e9a4a3bSRiver Riddle
1030e9a4a3bSRiver Riddle /// A base implementation compatible with the `BufferAllocationHoisting` class.
1040e9a4a3bSRiver Riddle struct BufferAllocationHoistingStateBase {
1050e9a4a3bSRiver Riddle /// A pointer to the current dominance info.
1060e9a4a3bSRiver Riddle DominanceInfo *dominators;
1070e9a4a3bSRiver Riddle
1080e9a4a3bSRiver Riddle /// The current allocation value.
1090e9a4a3bSRiver Riddle Value allocValue;
1100e9a4a3bSRiver Riddle
1110e9a4a3bSRiver Riddle /// The current placement block (if any).
1120e9a4a3bSRiver Riddle Block *placementBlock;
1130e9a4a3bSRiver Riddle
1140e9a4a3bSRiver Riddle /// Initializes the state base.
BufferAllocationHoistingStateBase__anon68c8f7d70211::BufferAllocationHoistingStateBase1150e9a4a3bSRiver Riddle BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
1160e9a4a3bSRiver Riddle Block *placementBlock)
1170e9a4a3bSRiver Riddle : dominators(dominators), allocValue(allocValue),
1180e9a4a3bSRiver Riddle placementBlock(placementBlock) {}
1190e9a4a3bSRiver Riddle };
1200e9a4a3bSRiver Riddle
1210e9a4a3bSRiver Riddle /// Implements the actual hoisting logic for allocation nodes.
1220e9a4a3bSRiver Riddle template <typename StateT>
1230e9a4a3bSRiver Riddle class BufferAllocationHoisting : public BufferPlacementTransformationBase {
1240e9a4a3bSRiver Riddle public:
BufferAllocationHoisting(Operation * op)1250e9a4a3bSRiver Riddle BufferAllocationHoisting(Operation *op)
1260e9a4a3bSRiver Riddle : BufferPlacementTransformationBase(op), dominators(op),
1270e9a4a3bSRiver Riddle postDominators(op), scopeOp(op) {}
1280e9a4a3bSRiver Riddle
1290e9a4a3bSRiver Riddle /// Moves allocations upwards.
hoist()1300e9a4a3bSRiver Riddle void hoist() {
1310e9a4a3bSRiver Riddle SmallVector<Value> allocsAndAllocas;
1320e9a4a3bSRiver Riddle for (BufferPlacementAllocs::AllocEntry &entry : allocs)
1330e9a4a3bSRiver Riddle allocsAndAllocas.push_back(std::get<0>(entry));
134*136d746eSJacques Pienaar scopeOp->walk([&](memref::AllocaOp op) {
135*136d746eSJacques Pienaar allocsAndAllocas.push_back(op.getMemref());
136*136d746eSJacques Pienaar });
1370e9a4a3bSRiver Riddle
1380e9a4a3bSRiver Riddle for (auto allocValue : allocsAndAllocas) {
1390e9a4a3bSRiver Riddle if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
1400e9a4a3bSRiver Riddle continue;
1410e9a4a3bSRiver Riddle Operation *definingOp = allocValue.getDefiningOp();
1420e9a4a3bSRiver Riddle assert(definingOp && "No defining op");
1430e9a4a3bSRiver Riddle auto operands = definingOp->getOperands();
1440e9a4a3bSRiver Riddle auto resultAliases = aliases.resolve(allocValue);
1450e9a4a3bSRiver Riddle // Determine the common dominator block of all aliases.
1460e9a4a3bSRiver Riddle Block *dominatorBlock =
1470e9a4a3bSRiver Riddle findCommonDominator(allocValue, resultAliases, dominators);
1480e9a4a3bSRiver Riddle // Init the initial hoisting state.
1490e9a4a3bSRiver Riddle StateT state(&dominators, allocValue, allocValue.getParentBlock());
1500e9a4a3bSRiver Riddle // Check for additional allocation dependencies to compute an upper bound
1510e9a4a3bSRiver Riddle // for hoisting.
1520e9a4a3bSRiver Riddle Block *dependencyBlock = nullptr;
1530e9a4a3bSRiver Riddle // If this node has dependencies, check all dependent nodes. This ensures
1540e9a4a3bSRiver Riddle // that all dependency values have been computed before allocating the
1550e9a4a3bSRiver Riddle // buffer.
1560e9a4a3bSRiver Riddle for (Value depValue : operands) {
1570e9a4a3bSRiver Riddle Block *depBlock = depValue.getParentBlock();
1580e9a4a3bSRiver Riddle if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))
1590e9a4a3bSRiver Riddle dependencyBlock = depBlock;
1600e9a4a3bSRiver Riddle }
1610e9a4a3bSRiver Riddle
1620e9a4a3bSRiver Riddle // Find the actual placement block and determine the start operation using
1630e9a4a3bSRiver Riddle // an upper placement-block boundary. The idea is that placement block
1640e9a4a3bSRiver Riddle // cannot be moved any further upwards than the given upper bound.
1650e9a4a3bSRiver Riddle Block *placementBlock = findPlacementBlock(
1660e9a4a3bSRiver Riddle state, state.computeUpperBound(dominatorBlock, dependencyBlock));
1670e9a4a3bSRiver Riddle Operation *startOperation = BufferPlacementAllocs::getStartOperation(
1680e9a4a3bSRiver Riddle allocValue, placementBlock, liveness);
1690e9a4a3bSRiver Riddle
1700e9a4a3bSRiver Riddle // Move the alloc in front of the start operation.
1710e9a4a3bSRiver Riddle Operation *allocOperation = allocValue.getDefiningOp();
1720e9a4a3bSRiver Riddle allocOperation->moveBefore(startOperation);
1730e9a4a3bSRiver Riddle }
1740e9a4a3bSRiver Riddle }
1750e9a4a3bSRiver Riddle
1760e9a4a3bSRiver Riddle private:
1770e9a4a3bSRiver Riddle /// Finds a valid placement block by walking upwards in the CFG until we
1780e9a4a3bSRiver Riddle /// either cannot continue our walk due to constraints (given by the StateT
1790e9a4a3bSRiver Riddle /// implementation) or we have reached the upper-most dominator block.
findPlacementBlock(StateT & state,Block * upperBound)1800e9a4a3bSRiver Riddle Block *findPlacementBlock(StateT &state, Block *upperBound) {
1810e9a4a3bSRiver Riddle Block *currentBlock = state.placementBlock;
1820e9a4a3bSRiver Riddle // Walk from the innermost regions/loops to the outermost regions/loops and
1830e9a4a3bSRiver Riddle // find an appropriate placement block that satisfies the constraint of the
1840e9a4a3bSRiver Riddle // current StateT implementation. Walk until we reach the upperBound block
1850e9a4a3bSRiver Riddle // (if any).
1860e9a4a3bSRiver Riddle
1870e9a4a3bSRiver Riddle // If we are not able to find a valid parent operation or an associated
1880e9a4a3bSRiver Riddle // parent block, break the walk loop.
1890e9a4a3bSRiver Riddle Operation *parentOp;
1900e9a4a3bSRiver Riddle Block *parentBlock;
1910e9a4a3bSRiver Riddle while ((parentOp = currentBlock->getParentOp()) &&
1920e9a4a3bSRiver Riddle (parentBlock = parentOp->getBlock()) &&
1930e9a4a3bSRiver Riddle (!upperBound ||
1940e9a4a3bSRiver Riddle dominators.properlyDominates(upperBound, currentBlock))) {
1950e9a4a3bSRiver Riddle // Try to find an immediate dominator and check whether the parent block
1960e9a4a3bSRiver Riddle // is above the immediate dominator (if any).
1970e9a4a3bSRiver Riddle DominanceInfoNode *idom = nullptr;
1980e9a4a3bSRiver Riddle
1990e9a4a3bSRiver Riddle // DominanceInfo doesn't support getNode queries for single-block regions.
2000e9a4a3bSRiver Riddle if (!currentBlock->isEntryBlock())
2010e9a4a3bSRiver Riddle idom = dominators.getNode(currentBlock)->getIDom();
2020e9a4a3bSRiver Riddle
2030e9a4a3bSRiver Riddle if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {
2040e9a4a3bSRiver Riddle // If the current immediate dominator is below the placement block, move
2050e9a4a3bSRiver Riddle // to the immediate dominator block.
2060e9a4a3bSRiver Riddle currentBlock = idom->getBlock();
2070e9a4a3bSRiver Riddle state.recordMoveToDominator(currentBlock);
2080e9a4a3bSRiver Riddle } else {
2090e9a4a3bSRiver Riddle // We have to move to our parent block since an immediate dominator does
2100e9a4a3bSRiver Riddle // either not exist or is above our parent block. If we cannot move to
2110e9a4a3bSRiver Riddle // our parent operation due to constraints given by the StateT
2120e9a4a3bSRiver Riddle // implementation, break the walk loop. Furthermore, we should not move
2130e9a4a3bSRiver Riddle // allocations out of unknown region-based control-flow operations.
2140e9a4a3bSRiver Riddle if (!isKnownControlFlowInterface(parentOp) ||
2150e9a4a3bSRiver Riddle !state.isLegalPlacement(parentOp))
2160e9a4a3bSRiver Riddle break;
2170e9a4a3bSRiver Riddle // Move to our parent block by notifying the current StateT
2180e9a4a3bSRiver Riddle // implementation.
2190e9a4a3bSRiver Riddle currentBlock = parentBlock;
2200e9a4a3bSRiver Riddle state.recordMoveToParent(currentBlock);
2210e9a4a3bSRiver Riddle }
2220e9a4a3bSRiver Riddle }
2230e9a4a3bSRiver Riddle // Return the finally determined placement block.
2240e9a4a3bSRiver Riddle return state.placementBlock;
2250e9a4a3bSRiver Riddle }
2260e9a4a3bSRiver Riddle
2270e9a4a3bSRiver Riddle /// The dominator info to find the appropriate start operation to move the
2280e9a4a3bSRiver Riddle /// allocs.
2290e9a4a3bSRiver Riddle DominanceInfo dominators;
2300e9a4a3bSRiver Riddle
2310e9a4a3bSRiver Riddle /// The post dominator info to move the dependent allocs in the right
2320e9a4a3bSRiver Riddle /// position.
2330e9a4a3bSRiver Riddle PostDominanceInfo postDominators;
2340e9a4a3bSRiver Riddle
2350e9a4a3bSRiver Riddle /// The map storing the final placement blocks of a given alloc value.
2360e9a4a3bSRiver Riddle llvm::DenseMap<Value, Block *> placementBlocks;
2370e9a4a3bSRiver Riddle
2380e9a4a3bSRiver Riddle /// The operation that this transformation is working on. It is used to also
2390e9a4a3bSRiver Riddle /// gather allocas.
2400e9a4a3bSRiver Riddle Operation *scopeOp;
2410e9a4a3bSRiver Riddle };
2420e9a4a3bSRiver Riddle
2430e9a4a3bSRiver Riddle /// A state implementation compatible with the `BufferAllocationHoisting` class
2440e9a4a3bSRiver Riddle /// that hoists allocations into dominator blocks while keeping them inside of
2450e9a4a3bSRiver Riddle /// loops.
2460e9a4a3bSRiver Riddle struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
2470e9a4a3bSRiver Riddle using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
2480e9a4a3bSRiver Riddle
2490e9a4a3bSRiver Riddle /// Computes the upper bound for the placement block search.
computeUpperBound__anon68c8f7d70211::BufferAllocationHoistingState2500e9a4a3bSRiver Riddle Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
2510e9a4a3bSRiver Riddle // If we do not have a dependency block, the upper bound is given by the
2520e9a4a3bSRiver Riddle // dominator block.
2530e9a4a3bSRiver Riddle if (!dependencyBlock)
2540e9a4a3bSRiver Riddle return dominatorBlock;
2550e9a4a3bSRiver Riddle
2560e9a4a3bSRiver Riddle // Find the "lower" block of the dominator and the dependency block to
2570e9a4a3bSRiver Riddle // ensure that we do not move allocations above this block.
2580e9a4a3bSRiver Riddle return dominators->properlyDominates(dominatorBlock, dependencyBlock)
2590e9a4a3bSRiver Riddle ? dependencyBlock
2600e9a4a3bSRiver Riddle : dominatorBlock;
2610e9a4a3bSRiver Riddle }
2620e9a4a3bSRiver Riddle
2630e9a4a3bSRiver Riddle /// Returns true if the given operation does not represent a loop.
isLegalPlacement__anon68c8f7d70211::BufferAllocationHoistingState2640e9a4a3bSRiver Riddle bool isLegalPlacement(Operation *op) {
2650e9a4a3bSRiver Riddle return !BufferPlacementTransformationBase::isLoop(op);
2660e9a4a3bSRiver Riddle }
2670e9a4a3bSRiver Riddle
2680e9a4a3bSRiver Riddle /// Returns true if the given operation should be considered for hoisting.
shouldHoistOpType__anon68c8f7d70211::BufferAllocationHoistingState2690e9a4a3bSRiver Riddle static bool shouldHoistOpType(Operation *op) {
2700e9a4a3bSRiver Riddle return llvm::isa<memref::AllocOp>(op);
2710e9a4a3bSRiver Riddle }
2720e9a4a3bSRiver Riddle
2730e9a4a3bSRiver Riddle /// Sets the current placement block to the given block.
recordMoveToDominator__anon68c8f7d70211::BufferAllocationHoistingState2740e9a4a3bSRiver Riddle void recordMoveToDominator(Block *block) { placementBlock = block; }
2750e9a4a3bSRiver Riddle
2760e9a4a3bSRiver Riddle /// Sets the current placement block to the given block.
recordMoveToParent__anon68c8f7d70211::BufferAllocationHoistingState2770e9a4a3bSRiver Riddle void recordMoveToParent(Block *block) { recordMoveToDominator(block); }
2780e9a4a3bSRiver Riddle };
2790e9a4a3bSRiver Riddle
2800e9a4a3bSRiver Riddle /// A state implementation compatible with the `BufferAllocationHoisting` class
2810e9a4a3bSRiver Riddle /// that hoists allocations out of loops.
2820e9a4a3bSRiver Riddle struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
2830e9a4a3bSRiver Riddle using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
2840e9a4a3bSRiver Riddle
2850e9a4a3bSRiver Riddle /// Remembers the dominator block of all aliases.
2860e9a4a3bSRiver Riddle Block *aliasDominatorBlock = nullptr;
2870e9a4a3bSRiver Riddle
2880e9a4a3bSRiver Riddle /// Computes the upper bound for the placement block search.
computeUpperBound__anon68c8f7d70211::BufferAllocationLoopHoistingState2890e9a4a3bSRiver Riddle Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
2900e9a4a3bSRiver Riddle aliasDominatorBlock = dominatorBlock;
2910e9a4a3bSRiver Riddle // If there is a dependency block, we have to use this block as an upper
2920e9a4a3bSRiver Riddle // bound to satisfy all allocation value dependencies.
2930e9a4a3bSRiver Riddle return dependencyBlock ? dependencyBlock : nullptr;
2940e9a4a3bSRiver Riddle }
2950e9a4a3bSRiver Riddle
2960e9a4a3bSRiver Riddle /// Returns true if the given operation represents a loop and one of the
2970e9a4a3bSRiver Riddle /// aliases caused the `aliasDominatorBlock` to be "above" the block of the
2980e9a4a3bSRiver Riddle /// given loop operation. If this is the case, it indicates that the
2990e9a4a3bSRiver Riddle /// allocation is passed via a back edge.
isLegalPlacement__anon68c8f7d70211::BufferAllocationLoopHoistingState3000e9a4a3bSRiver Riddle bool isLegalPlacement(Operation *op) {
3010e9a4a3bSRiver Riddle return BufferPlacementTransformationBase::isLoop(op) &&
3020e9a4a3bSRiver Riddle !dominators->dominates(aliasDominatorBlock, op->getBlock());
3030e9a4a3bSRiver Riddle }
3040e9a4a3bSRiver Riddle
3050e9a4a3bSRiver Riddle /// Returns true if the given operation should be considered for hoisting.
shouldHoistOpType__anon68c8f7d70211::BufferAllocationLoopHoistingState3060e9a4a3bSRiver Riddle static bool shouldHoistOpType(Operation *op) {
3070e9a4a3bSRiver Riddle return llvm::isa<memref::AllocOp, memref::AllocaOp>(op);
3080e9a4a3bSRiver Riddle }
3090e9a4a3bSRiver Riddle
3100e9a4a3bSRiver Riddle /// Does not change the internal placement block, as we want to move
3110e9a4a3bSRiver Riddle /// operations out of loops only.
recordMoveToDominator__anon68c8f7d70211::BufferAllocationLoopHoistingState3120e9a4a3bSRiver Riddle void recordMoveToDominator(Block *block) {}
3130e9a4a3bSRiver Riddle
3140e9a4a3bSRiver Riddle /// Sets the current placement block to the given block.
recordMoveToParent__anon68c8f7d70211::BufferAllocationLoopHoistingState3150e9a4a3bSRiver Riddle void recordMoveToParent(Block *block) { placementBlock = block; }
3160e9a4a3bSRiver Riddle };
3170e9a4a3bSRiver Riddle
3180e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
3190e9a4a3bSRiver Riddle // BufferPlacementPromotion
3200e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
3210e9a4a3bSRiver Riddle
3220e9a4a3bSRiver Riddle /// Promotes heap-based allocations to stack-based allocations (if possible).
3230e9a4a3bSRiver Riddle class BufferPlacementPromotion : BufferPlacementTransformationBase {
3240e9a4a3bSRiver Riddle public:
BufferPlacementPromotion(Operation * op)3250e9a4a3bSRiver Riddle BufferPlacementPromotion(Operation *op)
3260e9a4a3bSRiver Riddle : BufferPlacementTransformationBase(op) {}
3270e9a4a3bSRiver Riddle
3280e9a4a3bSRiver Riddle /// Promote buffers to stack-based allocations.
promote(function_ref<bool (Value)> isSmallAlloc)3290e9a4a3bSRiver Riddle void promote(function_ref<bool(Value)> isSmallAlloc) {
3300e9a4a3bSRiver Riddle for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
3310e9a4a3bSRiver Riddle Value alloc = std::get<0>(entry);
3320e9a4a3bSRiver Riddle Operation *dealloc = std::get<1>(entry);
3330e9a4a3bSRiver Riddle // Checking several requirements to transform an AllocOp into an AllocaOp.
3340e9a4a3bSRiver Riddle // The transformation is done if the allocation is limited to a given
3350e9a4a3bSRiver Riddle // size. Furthermore, a deallocation must not be defined for this
3360e9a4a3bSRiver Riddle // allocation entry and a parent allocation scope must exist.
3370e9a4a3bSRiver Riddle if (!isSmallAlloc(alloc) || dealloc ||
3380e9a4a3bSRiver Riddle !hasAllocationScope(alloc, aliases))
3390e9a4a3bSRiver Riddle continue;
3400e9a4a3bSRiver Riddle
3410e9a4a3bSRiver Riddle Operation *startOperation = BufferPlacementAllocs::getStartOperation(
3420e9a4a3bSRiver Riddle alloc, alloc.getParentBlock(), liveness);
3430e9a4a3bSRiver Riddle // Build a new alloca that is associated with its parent
3440e9a4a3bSRiver Riddle // `AutomaticAllocationScope` determined during the initialization phase.
3450e9a4a3bSRiver Riddle OpBuilder builder(startOperation);
3460e9a4a3bSRiver Riddle Operation *allocOp = alloc.getDefiningOp();
3470e9a4a3bSRiver Riddle Operation *alloca = builder.create<memref::AllocaOp>(
3480e9a4a3bSRiver Riddle alloc.getLoc(), alloc.getType().cast<MemRefType>(),
3490e9a4a3bSRiver Riddle allocOp->getOperands());
3500e9a4a3bSRiver Riddle
3510e9a4a3bSRiver Riddle // Replace the original alloc by a newly created alloca.
3520e9a4a3bSRiver Riddle allocOp->replaceAllUsesWith(alloca);
3530e9a4a3bSRiver Riddle allocOp->erase();
3540e9a4a3bSRiver Riddle }
3550e9a4a3bSRiver Riddle }
3560e9a4a3bSRiver Riddle };
3570e9a4a3bSRiver Riddle
3580e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
3590e9a4a3bSRiver Riddle // BufferOptimizationPasses
3600e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
3610e9a4a3bSRiver Riddle
3620e9a4a3bSRiver Riddle /// The buffer hoisting pass that hoists allocation nodes into dominating
3630e9a4a3bSRiver Riddle /// blocks.
3640e9a4a3bSRiver Riddle struct BufferHoistingPass : BufferHoistingBase<BufferHoistingPass> {
3650e9a4a3bSRiver Riddle
runOnOperation__anon68c8f7d70211::BufferHoistingPass3660e9a4a3bSRiver Riddle void runOnOperation() override {
3670e9a4a3bSRiver Riddle // Hoist all allocations into dominator blocks.
3680e9a4a3bSRiver Riddle BufferAllocationHoisting<BufferAllocationHoistingState> optimizer(
3690e9a4a3bSRiver Riddle getOperation());
3700e9a4a3bSRiver Riddle optimizer.hoist();
3710e9a4a3bSRiver Riddle }
3720e9a4a3bSRiver Riddle };
3730e9a4a3bSRiver Riddle
3740e9a4a3bSRiver Riddle /// The buffer loop hoisting pass that hoists allocation nodes out of loops.
3750e9a4a3bSRiver Riddle struct BufferLoopHoistingPass : BufferLoopHoistingBase<BufferLoopHoistingPass> {
3760e9a4a3bSRiver Riddle
runOnOperation__anon68c8f7d70211::BufferLoopHoistingPass3770e9a4a3bSRiver Riddle void runOnOperation() override {
3780e9a4a3bSRiver Riddle // Hoist all allocations out of loops.
3790e9a4a3bSRiver Riddle BufferAllocationHoisting<BufferAllocationLoopHoistingState> optimizer(
3800e9a4a3bSRiver Riddle getOperation());
3810e9a4a3bSRiver Riddle optimizer.hoist();
3820e9a4a3bSRiver Riddle }
3830e9a4a3bSRiver Riddle };
3840e9a4a3bSRiver Riddle
3850e9a4a3bSRiver Riddle /// The promote buffer to stack pass that tries to convert alloc nodes into
3860e9a4a3bSRiver Riddle /// alloca nodes.
3870e9a4a3bSRiver Riddle class PromoteBuffersToStackPass
3880e9a4a3bSRiver Riddle : public PromoteBuffersToStackBase<PromoteBuffersToStackPass> {
3890e9a4a3bSRiver Riddle public:
PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,unsigned maxRankOfAllocatedMemRef)3900e9a4a3bSRiver Riddle PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
3910e9a4a3bSRiver Riddle unsigned maxRankOfAllocatedMemRef) {
3920e9a4a3bSRiver Riddle this->maxAllocSizeInBytes = maxAllocSizeInBytes;
3930e9a4a3bSRiver Riddle this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef;
3940e9a4a3bSRiver Riddle }
3950e9a4a3bSRiver Riddle
PromoteBuffersToStackPass(std::function<bool (Value)> isSmallAlloc)3960e9a4a3bSRiver Riddle explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
3970e9a4a3bSRiver Riddle : isSmallAlloc(std::move(isSmallAlloc)) {}
3980e9a4a3bSRiver Riddle
initialize(MLIRContext * context)3990e9a4a3bSRiver Riddle LogicalResult initialize(MLIRContext *context) override {
4000e9a4a3bSRiver Riddle if (isSmallAlloc == nullptr) {
4010e9a4a3bSRiver Riddle isSmallAlloc = [=](Value alloc) {
4020e9a4a3bSRiver Riddle return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes,
4030e9a4a3bSRiver Riddle maxRankOfAllocatedMemRef);
4040e9a4a3bSRiver Riddle };
4050e9a4a3bSRiver Riddle }
4060e9a4a3bSRiver Riddle return success();
4070e9a4a3bSRiver Riddle }
4080e9a4a3bSRiver Riddle
runOnOperation()4090e9a4a3bSRiver Riddle void runOnOperation() override {
4100e9a4a3bSRiver Riddle // Move all allocation nodes and convert candidates into allocas.
4110e9a4a3bSRiver Riddle BufferPlacementPromotion optimizer(getOperation());
4120e9a4a3bSRiver Riddle optimizer.promote(isSmallAlloc);
4130e9a4a3bSRiver Riddle }
4140e9a4a3bSRiver Riddle
4150e9a4a3bSRiver Riddle private:
4160e9a4a3bSRiver Riddle std::function<bool(Value)> isSmallAlloc;
4170e9a4a3bSRiver Riddle };
4180e9a4a3bSRiver Riddle
4190e9a4a3bSRiver Riddle } // namespace
4200e9a4a3bSRiver Riddle
createBufferHoistingPass()4210e9a4a3bSRiver Riddle std::unique_ptr<Pass> mlir::bufferization::createBufferHoistingPass() {
4220e9a4a3bSRiver Riddle return std::make_unique<BufferHoistingPass>();
4230e9a4a3bSRiver Riddle }
4240e9a4a3bSRiver Riddle
createBufferLoopHoistingPass()4250e9a4a3bSRiver Riddle std::unique_ptr<Pass> mlir::bufferization::createBufferLoopHoistingPass() {
4260e9a4a3bSRiver Riddle return std::make_unique<BufferLoopHoistingPass>();
4270e9a4a3bSRiver Riddle }
4280e9a4a3bSRiver Riddle
createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,unsigned maxRankOfAllocatedMemRef)4290e9a4a3bSRiver Riddle std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
430b70366c9SBenjamin Kramer unsigned maxAllocSizeInBytes, unsigned maxRankOfAllocatedMemRef) {
431b70366c9SBenjamin Kramer return std::make_unique<PromoteBuffersToStackPass>(maxAllocSizeInBytes,
432b70366c9SBenjamin Kramer maxRankOfAllocatedMemRef);
4330e9a4a3bSRiver Riddle }
4340e9a4a3bSRiver Riddle
createPromoteBuffersToStackPass(std::function<bool (Value)> isSmallAlloc)4350e9a4a3bSRiver Riddle std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
4360e9a4a3bSRiver Riddle std::function<bool(Value)> isSmallAlloc) {
4370e9a4a3bSRiver Riddle return std::make_unique<PromoteBuffersToStackPass>(std::move(isSmallAlloc));
4380e9a4a3bSRiver Riddle }
439