10e9a4a3bSRiver Riddle //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
20e9a4a3bSRiver Riddle //
30e9a4a3bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40e9a4a3bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
50e9a4a3bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60e9a4a3bSRiver Riddle //
70e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
80e9a4a3bSRiver Riddle //
90e9a4a3bSRiver Riddle // This file implements utilities for buffer optimization passes.
100e9a4a3bSRiver Riddle //
110e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
120e9a4a3bSRiver Riddle 
130e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
140e9a4a3bSRiver Riddle #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
150e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
160e9a4a3bSRiver Riddle #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
170e9a4a3bSRiver Riddle #include "mlir/IR/Operation.h"
180e9a4a3bSRiver Riddle #include "mlir/Interfaces/ControlFlowInterfaces.h"
190e9a4a3bSRiver Riddle #include "mlir/Interfaces/LoopLikeInterface.h"
200e9a4a3bSRiver Riddle #include "mlir/Pass/Pass.h"
210e9a4a3bSRiver Riddle #include "llvm/ADT/SetOperations.h"
2206057248SRiver Riddle #include "llvm/ADT/SmallString.h"
230e9a4a3bSRiver Riddle 
240e9a4a3bSRiver Riddle using namespace mlir;
250e9a4a3bSRiver Riddle using namespace mlir::bufferization;
260e9a4a3bSRiver Riddle 
270e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
280e9a4a3bSRiver Riddle // BufferPlacementAllocs
290e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
300e9a4a3bSRiver Riddle 
310e9a4a3bSRiver Riddle /// Get the start operation to place the given alloc value withing the
320e9a4a3bSRiver Riddle // specified placement block.
getStartOperation(Value allocValue,Block * placementBlock,const Liveness & liveness)330e9a4a3bSRiver Riddle Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
340e9a4a3bSRiver Riddle                                                     Block *placementBlock,
350e9a4a3bSRiver Riddle                                                     const Liveness &liveness) {
360e9a4a3bSRiver Riddle   // We have to ensure that we place the alloc before its first use in this
370e9a4a3bSRiver Riddle   // block.
380e9a4a3bSRiver Riddle   const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
390e9a4a3bSRiver Riddle   Operation *startOperation = livenessInfo.getStartOperation(allocValue);
400e9a4a3bSRiver Riddle   // Check whether the start operation lies in the desired placement block.
410e9a4a3bSRiver Riddle   // If not, we will use the terminator as this is the last operation in
420e9a4a3bSRiver Riddle   // this block.
430e9a4a3bSRiver Riddle   if (startOperation->getBlock() != placementBlock) {
440e9a4a3bSRiver Riddle     Operation *opInPlacementBlock =
450e9a4a3bSRiver Riddle         placementBlock->findAncestorOpInBlock(*startOperation);
460e9a4a3bSRiver Riddle     startOperation = opInPlacementBlock ? opInPlacementBlock
470e9a4a3bSRiver Riddle                                         : placementBlock->getTerminator();
480e9a4a3bSRiver Riddle   }
490e9a4a3bSRiver Riddle 
500e9a4a3bSRiver Riddle   return startOperation;
510e9a4a3bSRiver Riddle }
520e9a4a3bSRiver Riddle 
530e9a4a3bSRiver Riddle /// Initializes the internal list by discovering all supported allocation
540e9a4a3bSRiver Riddle /// nodes.
BufferPlacementAllocs(Operation * op)550e9a4a3bSRiver Riddle BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
560e9a4a3bSRiver Riddle 
570e9a4a3bSRiver Riddle /// Searches for and registers all supported allocation entries.
build(Operation * op)580e9a4a3bSRiver Riddle void BufferPlacementAllocs::build(Operation *op) {
590e9a4a3bSRiver Riddle   op->walk([&](MemoryEffectOpInterface opInterface) {
600e9a4a3bSRiver Riddle     // Try to find a single allocation result.
610e9a4a3bSRiver Riddle     SmallVector<MemoryEffects::EffectInstance, 2> effects;
620e9a4a3bSRiver Riddle     opInterface.getEffects(effects);
630e9a4a3bSRiver Riddle 
640e9a4a3bSRiver Riddle     SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
650e9a4a3bSRiver Riddle     llvm::copy_if(
660e9a4a3bSRiver Riddle         effects, std::back_inserter(allocateResultEffects),
670e9a4a3bSRiver Riddle         [=](MemoryEffects::EffectInstance &it) {
680e9a4a3bSRiver Riddle           Value value = it.getValue();
690e9a4a3bSRiver Riddle           return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
700e9a4a3bSRiver Riddle                  value.isa<OpResult>() &&
710e9a4a3bSRiver Riddle                  it.getResource() !=
720e9a4a3bSRiver Riddle                      SideEffects::AutomaticAllocationScopeResource::get();
730e9a4a3bSRiver Riddle         });
740e9a4a3bSRiver Riddle     // If there is one result only, we will be able to move the allocation and
750e9a4a3bSRiver Riddle     // (possibly existing) deallocation ops.
760e9a4a3bSRiver Riddle     if (allocateResultEffects.size() != 1)
770e9a4a3bSRiver Riddle       return;
780e9a4a3bSRiver Riddle     // Get allocation result.
790e9a4a3bSRiver Riddle     Value allocValue = allocateResultEffects[0].getValue();
800e9a4a3bSRiver Riddle     // Find the associated dealloc value and register the allocation entry.
81af9f7d31SUday Bondhugula     llvm::Optional<Operation *> dealloc = memref::findDealloc(allocValue);
820e9a4a3bSRiver Riddle     // If the allocation has > 1 dealloc associated with it, skip handling it.
83037f0995SKazu Hirata     if (!dealloc)
840e9a4a3bSRiver Riddle       return;
850e9a4a3bSRiver Riddle     allocs.push_back(std::make_tuple(allocValue, *dealloc));
860e9a4a3bSRiver Riddle   });
870e9a4a3bSRiver Riddle }
880e9a4a3bSRiver Riddle 
890e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
900e9a4a3bSRiver Riddle // BufferPlacementTransformationBase
910e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
920e9a4a3bSRiver Riddle 
930e9a4a3bSRiver Riddle /// Constructs a new transformation base using the given root operation.
BufferPlacementTransformationBase(Operation * op)940e9a4a3bSRiver Riddle BufferPlacementTransformationBase::BufferPlacementTransformationBase(
950e9a4a3bSRiver Riddle     Operation *op)
960e9a4a3bSRiver Riddle     : aliases(op), allocs(op), liveness(op) {}
970e9a4a3bSRiver Riddle 
980e9a4a3bSRiver Riddle /// Returns true if the given operation represents a loop by testing whether it
990e9a4a3bSRiver Riddle /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
1000e9a4a3bSRiver Riddle /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
1010e9a4a3bSRiver Riddle /// flow edges for cycles.
isLoop(Operation * op)1020e9a4a3bSRiver Riddle bool BufferPlacementTransformationBase::isLoop(Operation *op) {
1030e9a4a3bSRiver Riddle   // If the operation implements the `LoopLikeOpInterface` it can be considered
1040e9a4a3bSRiver Riddle   // a loop.
1050e9a4a3bSRiver Riddle   if (isa<LoopLikeOpInterface>(op))
1060e9a4a3bSRiver Riddle     return true;
1070e9a4a3bSRiver Riddle 
1080e9a4a3bSRiver Riddle   // If the operation does not implement the `RegionBranchOpInterface`, it is
1090e9a4a3bSRiver Riddle   // (currently) not possible to detect a loop.
1100e9a4a3bSRiver Riddle   RegionBranchOpInterface regionInterface;
1110e9a4a3bSRiver Riddle   if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
1120e9a4a3bSRiver Riddle     return false;
1130e9a4a3bSRiver Riddle 
1140e9a4a3bSRiver Riddle   // Recurses into a region using the current region interface to find potential
1150e9a4a3bSRiver Riddle   // cycles.
1160e9a4a3bSRiver Riddle   SmallPtrSet<Region *, 4> visitedRegions;
1170e9a4a3bSRiver Riddle   std::function<bool(Region *)> recurse = [&](Region *current) {
1180e9a4a3bSRiver Riddle     if (!current)
1190e9a4a3bSRiver Riddle       return false;
1200e9a4a3bSRiver Riddle     // If we have found a back edge, the parent operation induces a loop.
1210e9a4a3bSRiver Riddle     if (!visitedRegions.insert(current).second)
1220e9a4a3bSRiver Riddle       return true;
1230e9a4a3bSRiver Riddle     // Recurses into all region successors.
1240e9a4a3bSRiver Riddle     SmallVector<RegionSuccessor, 2> successors;
1250e9a4a3bSRiver Riddle     regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
1260e9a4a3bSRiver Riddle     for (RegionSuccessor &regionEntry : successors)
1270e9a4a3bSRiver Riddle       if (recurse(regionEntry.getSuccessor()))
1280e9a4a3bSRiver Riddle         return true;
1290e9a4a3bSRiver Riddle     return false;
1300e9a4a3bSRiver Riddle   };
1310e9a4a3bSRiver Riddle 
1320e9a4a3bSRiver Riddle   // Start with all entry regions and test whether they induce a loop.
1330e9a4a3bSRiver Riddle   SmallVector<RegionSuccessor, 2> successorRegions;
1340e9a4a3bSRiver Riddle   regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions);
1350e9a4a3bSRiver Riddle   for (RegionSuccessor &regionEntry : successorRegions) {
1360e9a4a3bSRiver Riddle     if (recurse(regionEntry.getSuccessor()))
1370e9a4a3bSRiver Riddle       return true;
1380e9a4a3bSRiver Riddle     visitedRegions.clear();
1390e9a4a3bSRiver Riddle   }
1400e9a4a3bSRiver Riddle 
1410e9a4a3bSRiver Riddle   return false;
1420e9a4a3bSRiver Riddle }
1430e9a4a3bSRiver Riddle 
1440e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
1450e9a4a3bSRiver Riddle // BufferPlacementTransformationBase
1460e9a4a3bSRiver Riddle //===----------------------------------------------------------------------===//
1470e9a4a3bSRiver Riddle 
148ab47418dSMatthias Springer FailureOr<memref::GlobalOp>
getGlobalFor(arith::ConstantOp constantOp,uint64_t alignment)149ab47418dSMatthias Springer bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
1500e9a4a3bSRiver Riddle   auto type = constantOp.getType().cast<RankedTensorType>();
151ab47418dSMatthias Springer   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
152ab47418dSMatthias Springer   if (!moduleOp)
153ab47418dSMatthias Springer     return failure();
1540e9a4a3bSRiver Riddle 
1550e9a4a3bSRiver Riddle   // If we already have a global for this constant value, no need to do
1560e9a4a3bSRiver Riddle   // anything else.
157ab47418dSMatthias Springer   for (Operation &op : moduleOp.getRegion().getOps()) {
158ab47418dSMatthias Springer     auto globalOp = dyn_cast<memref::GlobalOp>(&op);
159ab47418dSMatthias Springer     if (!globalOp)
160ab47418dSMatthias Springer       continue;
161491d2701SKazu Hirata     if (!globalOp.getInitialValue().has_value())
162ab47418dSMatthias Springer       continue;
163*3b0dce5bSKazu Hirata     uint64_t opAlignment = globalOp.getAlignment().value_or(0);
164c27d8152SKazu Hirata     Attribute initialValue = globalOp.getInitialValue().value();
165ab47418dSMatthias Springer     if (opAlignment == alignment && initialValue == constantOp.getValue())
166ab47418dSMatthias Springer       return globalOp;
167ab47418dSMatthias Springer   }
1680e9a4a3bSRiver Riddle 
1690e9a4a3bSRiver Riddle   // Create a builder without an insertion point. We will insert using the
1700e9a4a3bSRiver Riddle   // symbol table to guarantee unique names.
1710e9a4a3bSRiver Riddle   OpBuilder globalBuilder(moduleOp.getContext());
1720e9a4a3bSRiver Riddle   SymbolTable symbolTable(moduleOp);
1730e9a4a3bSRiver Riddle 
1740e9a4a3bSRiver Riddle   // Create a pretty name.
1750e9a4a3bSRiver Riddle   SmallString<64> buf;
1760e9a4a3bSRiver Riddle   llvm::raw_svector_ostream os(buf);
1770e9a4a3bSRiver Riddle   interleave(type.getShape(), os, "x");
1780e9a4a3bSRiver Riddle   os << "x" << type.getElementType();
1790e9a4a3bSRiver Riddle 
1800e9a4a3bSRiver Riddle   // Add an optional alignment to the global memref.
1810e9a4a3bSRiver Riddle   IntegerAttr memrefAlignment =
1820e9a4a3bSRiver Riddle       alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
1830e9a4a3bSRiver Riddle                     : IntegerAttr();
1840e9a4a3bSRiver Riddle 
185ab47418dSMatthias Springer   BufferizeTypeConverter typeConverter;
1860e9a4a3bSRiver Riddle   auto global = globalBuilder.create<memref::GlobalOp>(
1870e9a4a3bSRiver Riddle       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
1880e9a4a3bSRiver Riddle       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
1890e9a4a3bSRiver Riddle       /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
1900e9a4a3bSRiver Riddle       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
1910e9a4a3bSRiver Riddle       /*constant=*/true,
1920e9a4a3bSRiver Riddle       /*alignment=*/memrefAlignment);
1930e9a4a3bSRiver Riddle   symbolTable.insert(global);
1940e9a4a3bSRiver Riddle   // The symbol table inserts at the end of the module, but globals are a bit
1950e9a4a3bSRiver Riddle   // nicer if they are at the beginning.
1960e9a4a3bSRiver Riddle   global->moveBefore(&moduleOp.front());
1970e9a4a3bSRiver Riddle   return global;
1980e9a4a3bSRiver Riddle }
199