1 //===- BufferUtils.cpp - buffer transformation utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for buffer optimization passes.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/LoopLikeInterface.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23
24 using namespace mlir;
25 using namespace mlir::bufferization;
26
27 //===----------------------------------------------------------------------===//
28 // BufferPlacementAllocs
29 //===----------------------------------------------------------------------===//
30
31 /// Get the start operation to place the given alloc value withing the
32 // specified placement block.
getStartOperation(Value allocValue,Block * placementBlock,const Liveness & liveness)33 Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
34 Block *placementBlock,
35 const Liveness &liveness) {
36 // We have to ensure that we place the alloc before its first use in this
37 // block.
38 const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
39 Operation *startOperation = livenessInfo.getStartOperation(allocValue);
40 // Check whether the start operation lies in the desired placement block.
41 // If not, we will use the terminator as this is the last operation in
42 // this block.
43 if (startOperation->getBlock() != placementBlock) {
44 Operation *opInPlacementBlock =
45 placementBlock->findAncestorOpInBlock(*startOperation);
46 startOperation = opInPlacementBlock ? opInPlacementBlock
47 : placementBlock->getTerminator();
48 }
49
50 return startOperation;
51 }
52
53 /// Initializes the internal list by discovering all supported allocation
54 /// nodes.
BufferPlacementAllocs(Operation * op)55 BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
56
57 /// Searches for and registers all supported allocation entries.
build(Operation * op)58 void BufferPlacementAllocs::build(Operation *op) {
59 op->walk([&](MemoryEffectOpInterface opInterface) {
60 // Try to find a single allocation result.
61 SmallVector<MemoryEffects::EffectInstance, 2> effects;
62 opInterface.getEffects(effects);
63
64 SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
65 llvm::copy_if(
66 effects, std::back_inserter(allocateResultEffects),
67 [=](MemoryEffects::EffectInstance &it) {
68 Value value = it.getValue();
69 return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
70 value.isa<OpResult>() &&
71 it.getResource() !=
72 SideEffects::AutomaticAllocationScopeResource::get();
73 });
74 // If there is one result only, we will be able to move the allocation and
75 // (possibly existing) deallocation ops.
76 if (allocateResultEffects.size() != 1)
77 return;
78 // Get allocation result.
79 Value allocValue = allocateResultEffects[0].getValue();
80 // Find the associated dealloc value and register the allocation entry.
81 llvm::Optional<Operation *> dealloc = memref::findDealloc(allocValue);
82 // If the allocation has > 1 dealloc associated with it, skip handling it.
83 if (!dealloc)
84 return;
85 allocs.push_back(std::make_tuple(allocValue, *dealloc));
86 });
87 }
88
89 //===----------------------------------------------------------------------===//
90 // BufferPlacementTransformationBase
91 //===----------------------------------------------------------------------===//
92
93 /// Constructs a new transformation base using the given root operation.
BufferPlacementTransformationBase(Operation * op)94 BufferPlacementTransformationBase::BufferPlacementTransformationBase(
95 Operation *op)
96 : aliases(op), allocs(op), liveness(op) {}
97
98 /// Returns true if the given operation represents a loop by testing whether it
99 /// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
100 /// the case of a `RegionBranchOpInterface`, it checks all region-based control-
101 /// flow edges for cycles.
isLoop(Operation * op)102 bool BufferPlacementTransformationBase::isLoop(Operation *op) {
103 // If the operation implements the `LoopLikeOpInterface` it can be considered
104 // a loop.
105 if (isa<LoopLikeOpInterface>(op))
106 return true;
107
108 // If the operation does not implement the `RegionBranchOpInterface`, it is
109 // (currently) not possible to detect a loop.
110 RegionBranchOpInterface regionInterface;
111 if (!(regionInterface = dyn_cast<RegionBranchOpInterface>(op)))
112 return false;
113
114 // Recurses into a region using the current region interface to find potential
115 // cycles.
116 SmallPtrSet<Region *, 4> visitedRegions;
117 std::function<bool(Region *)> recurse = [&](Region *current) {
118 if (!current)
119 return false;
120 // If we have found a back edge, the parent operation induces a loop.
121 if (!visitedRegions.insert(current).second)
122 return true;
123 // Recurses into all region successors.
124 SmallVector<RegionSuccessor, 2> successors;
125 regionInterface.getSuccessorRegions(current->getRegionNumber(), successors);
126 for (RegionSuccessor ®ionEntry : successors)
127 if (recurse(regionEntry.getSuccessor()))
128 return true;
129 return false;
130 };
131
132 // Start with all entry regions and test whether they induce a loop.
133 SmallVector<RegionSuccessor, 2> successorRegions;
134 regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions);
135 for (RegionSuccessor ®ionEntry : successorRegions) {
136 if (recurse(regionEntry.getSuccessor()))
137 return true;
138 visitedRegions.clear();
139 }
140
141 return false;
142 }
143
144 //===----------------------------------------------------------------------===//
145 // BufferPlacementTransformationBase
146 //===----------------------------------------------------------------------===//
147
148 FailureOr<memref::GlobalOp>
getGlobalFor(arith::ConstantOp constantOp,uint64_t alignment)149 bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
150 auto type = constantOp.getType().cast<RankedTensorType>();
151 auto moduleOp = constantOp->getParentOfType<ModuleOp>();
152 if (!moduleOp)
153 return failure();
154
155 // If we already have a global for this constant value, no need to do
156 // anything else.
157 for (Operation &op : moduleOp.getRegion().getOps()) {
158 auto globalOp = dyn_cast<memref::GlobalOp>(&op);
159 if (!globalOp)
160 continue;
161 if (!globalOp.getInitialValue().has_value())
162 continue;
163 uint64_t opAlignment = globalOp.getAlignment().value_or(0);
164 Attribute initialValue = globalOp.getInitialValue().value();
165 if (opAlignment == alignment && initialValue == constantOp.getValue())
166 return globalOp;
167 }
168
169 // Create a builder without an insertion point. We will insert using the
170 // symbol table to guarantee unique names.
171 OpBuilder globalBuilder(moduleOp.getContext());
172 SymbolTable symbolTable(moduleOp);
173
174 // Create a pretty name.
175 SmallString<64> buf;
176 llvm::raw_svector_ostream os(buf);
177 interleave(type.getShape(), os, "x");
178 os << "x" << type.getElementType();
179
180 // Add an optional alignment to the global memref.
181 IntegerAttr memrefAlignment =
182 alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
183 : IntegerAttr();
184
185 BufferizeTypeConverter typeConverter;
186 auto global = globalBuilder.create<memref::GlobalOp>(
187 constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
188 /*sym_visibility=*/globalBuilder.getStringAttr("private"),
189 /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
190 /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
191 /*constant=*/true,
192 /*alignment=*/memrefAlignment);
193 symbolTable.insert(global);
194 // The symbol table inserts at the end of the module, but globals are a bit
195 // nicer if they are at the beginning.
196 global->moveBefore(&moduleOp.front());
197 return global;
198 }
199