1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/LoopLikeInterface.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/SetOperations.h"
22 
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 
26 //===----------------------------------------------------------------------===//
27 // BufferPlacementAllocs
28 //===----------------------------------------------------------------------===//
29 
30 /// Get the start operation to place the given alloc value withing the
31 // specified placement block.
32 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
33                                                     Block *placementBlock,
34                                                     const Liveness &liveness) {
35   // We have to ensure that we place the alloc before its first use in this
36   // block.
37   const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
38   Operation *startOperation = livenessInfo.getStartOperation(allocValue);
39   // Check whether the start operation lies in the desired placement block.
40   // If not, we will use the terminator as this is the last operation in
41   // this block.
42   if (startOperation->getBlock() != placementBlock) {
43     Operation *opInPlacementBlock =
44         placementBlock->findAncestorOpInBlock(*startOperation);
45     startOperation = opInPlacementBlock ? opInPlacementBlock
46                                         : placementBlock->getTerminator();
47   }
48 
49   return startOperation;
50 }
51 
52 /// Initializes the internal list by discovering all supported allocation
53 /// nodes.
54 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
55 
56 /// Searches for and registers all supported allocation entries.
57 void BufferPlacementAllocs::build(Operation *op) {
58   op->walk([&](MemoryEffectOpInterface opInterface) {
59     // Try to find a single allocation result.
60     SmallVector<MemoryEffects::EffectInstance, 2> effects;
61     opInterface.getEffects(effects);
62 
63     SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
64     llvm::copy_if(
65         effects, std::back_inserter(allocateResultEffects),
66         [=](MemoryEffects::EffectInstance &it) {
67           Value value = it.getValue();
68           return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
69                  value.isa<OpResult>() &&
70                  it.getResource() !=
71                      SideEffects::AutomaticAllocationScopeResource::get();
72         });
73     // If there is one result only, we will be able to move the allocation and
74     // (possibly existing) deallocation ops.
75     if (allocateResultEffects.size() != 1)
76       return;
77     // Get allocation result.
78     Value allocValue = allocateResultEffects[0].getValue();
79     // Find the associated dealloc value and register the allocation entry.
80     llvm::Optional<Operation *> dealloc = findDealloc(allocValue);
81     // If the allocation has > 1 dealloc associated with it, skip handling it.
82     if (!dealloc.hasValue())
83       return;
84     allocs.push_back(std::make_tuple(allocValue, *dealloc));
85   });
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // BufferPlacementTransformationBase
90 //===----------------------------------------------------------------------===//
91 
92 /// Constructs a new transformation base using the given root operation.
93 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
94     Operation *op)
95     : aliases(op), allocs(op), liveness(op) {}
96 
97 /// Returns true if the given operation represents a loop by testing whether it
98 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
99 /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
100 /// flow edges for cycles.
101 bool BufferPlacementTransformationBase::isLoop(Operation *op) {
102   // If the operation implements the `LoopLikeOpInterface` it can be considered
103   // a loop.
104   if (isa<LoopLikeOpInterface>(op))
105     return true;
106 
107   // If the operation does not implement the `RegionBranchOpInterface`, it is
108   // (currently) not possible to detect a loop.
109   RegionBranchOpInterface regionInterface;
110   if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
111     return false;
112 
113   // Recurses into a region using the current region interface to find potential
114   // cycles.
115   SmallPtrSet<Region *, 4> visitedRegions;
116   std::function<bool(Region *)> recurse = [&](Region *current) {
117     if (!current)
118       return false;
119     // If we have found a back edge, the parent operation induces a loop.
120     if (!visitedRegions.insert(current).second)
121       return true;
122     // Recurses into all region successors.
123     SmallVector<RegionSuccessor, 2> successors;
124     regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
125     for (RegionSuccessor &regionEntry : successors)
126       if (recurse(regionEntry.getSuccessor()))
127         return true;
128     return false;
129   };
130 
131   // Start with all entry regions and test whether they induce a loop.
132   SmallVector<RegionSuccessor, 2> successorRegions;
133   regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions);
134   for (RegionSuccessor &regionEntry : successorRegions) {
135     if (recurse(regionEntry.getSuccessor()))
136       return true;
137     visitedRegions.clear();
138   }
139 
140   return false;
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // BufferPlacementTransformationBase
145 //===----------------------------------------------------------------------===//
146 
147 FailureOr<memref::GlobalOp>
148 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
149   auto type = constantOp.getType().cast<RankedTensorType>();
150   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
151   if (!moduleOp)
152     return failure();
153 
154   // If we already have a global for this constant value, no need to do
155   // anything else.
156   for (Operation &op : moduleOp.getRegion().getOps()) {
157     auto globalOp = dyn_cast<memref::GlobalOp>(&op);
158     if (!globalOp)
159       continue;
160     if (!globalOp.initial_value().hasValue())
161       continue;
162     uint64_t opAlignment =
163         globalOp.alignment().hasValue() ? globalOp.alignment().getValue() : 0;
164     Attribute initialValue = globalOp.initial_value().getValue();
165     if (opAlignment == alignment && initialValue == constantOp.getValue())
166       return globalOp;
167   }
168 
169   // Create a builder without an insertion point. We will insert using the
170   // symbol table to guarantee unique names.
171   OpBuilder globalBuilder(moduleOp.getContext());
172   SymbolTable symbolTable(moduleOp);
173 
174   // Create a pretty name.
175   SmallString<64> buf;
176   llvm::raw_svector_ostream os(buf);
177   interleave(type.getShape(), os, "x");
178   os << "x" << type.getElementType();
179 
180   // Add an optional alignment to the global memref.
181   IntegerAttr memrefAlignment =
182       alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
183                     : IntegerAttr();
184 
185   BufferizeTypeConverter typeConverter;
186   auto global = globalBuilder.create<memref::GlobalOp>(
187       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
188       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
189       /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
190       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
191       /*constant=*/true,
192       /*alignment=*/memrefAlignment);
193   symbolTable.insert(global);
194   // The symbol table inserts at the end of the module, but globals are a bit
195   // nicer if they are at the beginning.
196   global->moveBefore(&moduleOp.front());
197   return global;
198 }
199