1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/Support/Debug.h"
19 
20 namespace mlir {
21 namespace bufferization {
22 
23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
24 
25 } // namespace bufferization
26 } // namespace mlir
27 
28 #define DEBUG_TYPE "bufferizable-op-interface"
29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
31 
32 using namespace mlir;
33 using namespace bufferization;
34 
35 //===----------------------------------------------------------------------===//
36 // BufferizationOptions
37 //===----------------------------------------------------------------------===//
38 
39 // Default constructor for BufferizationOptions.
40 BufferizationOptions::BufferizationOptions() {}
41 
42 BufferizableOpInterface
43 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
44   if (isOpAllowed(op))
45     return dyn_cast<BufferizableOpInterface>(op);
46   return nullptr;
47 }
48 
49 BufferizableOpInterface
50 BufferizationOptions::dynCastBufferizableOp(Value value) const {
51   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
52     if (isOpAllowed(bufferizableOp.getOperation()))
53       return bufferizableOp;
54   return nullptr;
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Helper functions for BufferizableOpInterface
59 //===----------------------------------------------------------------------===//
60 
61 static void setInsertionPointAfter(OpBuilder &b, Value value) {
62   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
63     b.setInsertionPointToStart(bbArg.getOwner());
64   } else {
65     b.setInsertionPointAfter(value.getDefiningOp());
66   }
67 }
68 
69 /// Determine which OpOperand* will alias with `result` if the op is bufferized
70 /// in place. Return an empty vector if the op is not bufferizable.
71 SmallVector<OpOperand *>
72 BufferizationState::getAliasingOpOperand(OpResult result) const {
73   if (Operation *op = result.getDefiningOp())
74     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
75       return bufferizableOp.getAliasingOpOperand(result, *this);
76   return {};
77 }
78 
79 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
80 /// in place. Return an empty OpResult if the op is not bufferizable.
81 OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
82   if (auto bufferizableOp =
83           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
84     return bufferizableOp.getAliasingOpResult(opOperand, *this);
85   return OpResult();
86 }
87 
88 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
89 /// op is not bufferizable.
90 bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const {
91   if (auto bufferizableOp =
92           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
93     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
94 
95   // Unknown op that returns a tensor. The inplace analysis does not support it.
96   // Conservatively return true.
97   return true;
98 }
99 
100 /// Return true if `opOperand` bufferizes to a memory write. Return
101 /// `true` if the op is not bufferizable.
102 bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
103   if (auto bufferizableOp =
104           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
105     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
106 
107   // Unknown op that returns a tensor. The inplace analysis does not support it.
108   // Conservatively return true.
109   return true;
110 }
111 
112 /// Return true if `opOperand` does neither read nor write but bufferizes to an
113 /// alias. Return false if the op is not bufferizable.
114 bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const {
115   if (auto bufferizableOp =
116           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
117     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
118 
119   // Unknown op that returns a tensor. The inplace analysis does not support it.
120   // Conservatively return false.
121   return false;
122 }
123 
124 /// Return true if the given value is read by an op that bufferizes to a memory
125 /// read. Also takes into account ops that create an alias but do not read by
126 /// themselves (e.g., ExtractSliceOp).
127 bool BufferizationState::isValueRead(Value value) const {
128   assert(value.getType().isa<TensorType>() && "expected TensorType");
129   SmallVector<OpOperand *> workingSet;
130   for (OpOperand &use : value.getUses())
131     workingSet.push_back(&use);
132 
133   while (!workingSet.empty()) {
134     OpOperand *uMaybeReading = workingSet.pop_back_val();
135     // Skip over all ops that neither read nor write (but create an alias).
136     if (bufferizesToAliasOnly(*uMaybeReading))
137       for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses())
138         workingSet.push_back(&use);
139     if (bufferizesToMemoryRead(*uMaybeReading))
140       return true;
141   }
142 
143   return false;
144 }
145 
146 // Starting from `value`, follow the use-def chain in reverse, always selecting
147 // the aliasing OpOperands. Find and return Values for which `condition`
148 // evaluates to true. OpOperands of such matching Values are not traversed any
149 // further.
150 llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain(
151     Value value, llvm::function_ref<bool(Value)> condition) const {
152   llvm::SetVector<Value> result, workingSet;
153   workingSet.insert(value);
154 
155   while (!workingSet.empty()) {
156     Value value = workingSet.pop_back_val();
157     if (condition(value) || value.isa<BlockArgument>()) {
158       result.insert(value);
159       continue;
160     }
161 
162     OpResult opResult = value.cast<OpResult>();
163     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
164     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
165       result.insert(value);
166       continue;
167     }
168 
169     for (OpOperand *o : opOperands)
170       workingSet.insert(o->get());
171   }
172 
173   return result;
174 }
175 
176 // Find the Values of the last preceding write of a given Value.
177 llvm::SetVector<Value>
178 BufferizationState::findLastPrecedingWrite(Value value) const {
179   return findValueInReverseUseDefChain(value, [&](Value value) {
180     Operation *op = value.getDefiningOp();
181     if (!op)
182       return true;
183     auto bufferizableOp = options.dynCastBufferizableOp(op);
184     if (!bufferizableOp)
185       return true;
186     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
187   });
188 }
189 
190 BufferizationState::BufferizationState(const BufferizationOptions &options)
191     : options(options) {}
192 
193 // bufferization.to_memref is not allowed to change the rank.
194 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
195 #ifndef NDEBUG
196   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
197   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
198                                    rankedTensorType.getRank()) &&
199          "to_memref would be invalid: mismatching ranks");
200 #endif
201 }
202 
203 static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
204   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
205 
206   // Replace "%t = to_tensor %m" with %m.
207   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
208     return toTensorOp.memref();
209 
210   // Insert to_memref op.
211   OpBuilder::InsertionGuard g(rewriter);
212   setInsertionPointAfter(rewriter, tensor);
213   Type memrefType;
214   if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
215     memrefType = getDynamicMemRefType(rankedTensorType);
216   } else {
217     memrefType = getUnrankedMemRefType(
218         tensor.getType().cast<TensorType>().getElementType());
219   }
220   ensureToMemrefOpIsValid(tensor, memrefType);
221   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
222                                                     tensor);
223 }
224 
225 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
226 /// a new buffer and copy over data from the existing buffer if out-of-place
227 /// bufferization is necessary.
228 FailureOr<Value> BufferizationState::getBuffer(
229     RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
230     Optional<Operation *> customCopyInsertionPoint) const {
231   OpBuilder::InsertionGuard guard(rewriter);
232   Operation *op = opOperand.getOwner();
233   Location loc = op->getLoc();
234   Value operand = opOperand.get();
235   Value operandBuffer = lookupBuffer(rewriter, operand);
236 
237   if (forceInPlace || isInPlace(opOperand))
238     return operandBuffer;
239 
240   // Bufferizing out-of-place: Allocate a new buffer.
241   // Move insertion point right after `operandBuffer`. That is where the
242   // allocation should be inserted (in the absence of allocation hoisting).
243   setInsertionPointAfter(rewriter, operandBuffer);
244   // Allocate the result buffer.
245   FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
246                                               options.createDeallocs, options);
247   if (failed(resultBuffer))
248     return failure();
249   // Do not copy if the last preceding writes of `operand` are ops that do
250   // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
251   // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
252   // use-def chain, it returns that value, regardless of whether it is a
253   // memory write or not.
254   SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
255   if (llvm::none_of(lastWrites, [&](Value lastWrite) {
256         if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
257           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
258                                               *this);
259         return true;
260       }))
261     return resultBuffer;
262   // Do not copy if the copied data is never read.
263   OpResult aliasingOpResult = getAliasingOpResult(opOperand);
264   if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
265       !isValueRead(aliasingOpResult))
266     return resultBuffer;
267   // Do not copy if this op does not read the data, but writes it.
268   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
269     return resultBuffer;
270 
271   if (customCopyInsertionPoint) {
272     rewriter.setInsertionPoint(*customCopyInsertionPoint);
273   } else {
274     // The copy happens right before the op that is bufferized.
275     rewriter.setInsertionPoint(op);
276   }
277   if (failed(
278           createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
279     return failure();
280 
281   return resultBuffer;
282 }
283 
284 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
285                                                   Operation *op,
286                                                   ValueRange values) {
287   OpBuilder::InsertionGuard g(rewriter);
288 
289   // Replace all OpResults with the given values.
290   for (OpResult opResult : op->getOpResults()) {
291     // Skip OpResult if it has no uses.
292     if (opResult.getUses().empty())
293       continue;
294 
295     Value replacement = values[opResult.getResultNumber()];
296     if (opResult.getType().isa<TensorType>()) {
297       // The OpResult is a tensor. Such values are replaced with memrefs during
298       // bufferization.
299       assert((replacement.getType().isa<MemRefType>() ||
300               replacement.getType().isa<UnrankedMemRefType>()) &&
301              "tensor op result should be replaced with a memref value");
302       // The existing uses of the OpResult still expect a tensor. Insert a
303       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
304       // loose all of its users and eventually DCE away.
305       setInsertionPointAfter(rewriter, replacement);
306       replacement = rewriter.create<bufferization::ToTensorOp>(
307           replacement.getLoc(), replacement);
308     }
309     opResult.replaceAllUsesWith(replacement);
310   }
311 
312   rewriter.eraseOp(op);
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Bufferization-specific scoped alloc/dealloc insertion support.
317 //===----------------------------------------------------------------------===//
318 
319 /// Move the insertion point of the given builder to the beginning of a
320 /// surrounding block as much as possible, while not crossing any allocation
321 /// hoisting barriers.
322 static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
323   Operation *op = b.getInsertionBlock()->getParentOp();
324   while (op) {
325     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
326       if (bufferizableOp.isAllocationHoistingBarrier())
327         break;
328     op = op->getParentOp();
329   }
330 
331   if (!op) {
332     // No allocation hoisting barrier found. Hoist to FuncOp.
333     op = b.getInsertionBlock()->getParentOp();
334     if (!isa<FuncOp>(op))
335       op = op->getParentOfType<FuncOp>();
336     assert(op && "could not find enclosing FuncOp");
337   }
338 
339   // TODO: Handle cases where allocation hoisting barrier has more than one
340   // region or block.
341   assert(op->getNumRegions() == 1 &&
342          "allocation hoisting barriers with >1 regions not supported");
343   assert(op->getRegion(0).getBlocks().size() == 1 &&
344          "allocation hoisting barriers with >1 blocks not supported");
345   b.setInsertionPointToStart(&(op->getRegion(0).front()));
346 }
347 
348 /// Compute the type of the `memref` to use for allocating the buffer for
349 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
350 /// dynamic dimensions in the returned `memref` type. The function may also set
351 /// the insertion point to an earlier location, where the allocation should
352 /// happen ("allocation hoisting").
353 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
354                                             Value shapedValue,
355                                             SmallVectorImpl<Value> &dynShape) {
356   MemRefType allocMemRefType =
357       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
358 
359   // Compute the dynamic part of the shape.
360   bool reifiedShapes = false;
361   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
362           shapedValue.getDefiningOp())) {
363     ReifiedRankedShapedTypeDims resultDims;
364     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
365       reifiedShapes = true;
366       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
367       auto &shape = resultDims[resultValue.getResultNumber()];
368       for (const auto &dim : enumerate(allocMemRefType.getShape()))
369         if (ShapedType::isDynamic(dim.value()))
370           dynShape.push_back(shape[dim.index()]);
371     }
372   }
373 
374   if (!reifiedShapes) {
375     for (const auto &dim : enumerate(allocMemRefType.getShape()))
376       if (ShapedType::isDynamic(dim.value())) {
377         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
378                 shapedValue.getType().isa<MemRefType>()) &&
379                "expected MemRef type");
380         dynShape.push_back(
381             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
382       }
383   }
384 
385   // If the buffer is statically shaped, try to hoist it to the first enclosing
386   // parallel region.
387   // TODO: also hoist in the dynamic case. For now this relies on subsequent
388   // calls to LICM and buffer hoisting which will most likely not succeed.
389   // TODO: when packing, allocate a static bounding box which will enable more
390   // hoisting.
391   if (dynShape.empty())
392     moveInsertionPointToAllocationHoistingBarrier(b);
393 
394   return allocMemRefType;
395 }
396 
397 /// Create an AllocOp/DeallocOp pair, where the AllocOp is after
398 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
399 /// bbArg) and the DeallocOp is at the end of the block.
400 FailureOr<Value>
401 bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
402                            bool deallocMemref,
403                            const BufferizationOptions &options) {
404   // Take a guard before anything else.
405   OpBuilder::InsertionGuard g(b);
406 
407   // 1. Create memory allocation.
408   assert(shapedValue.getType().isa<ShapedType>());
409   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
410   SmallVector<Value> dynShape;
411   // Note: getAllocationTypeAndShape also sets the insertion point.
412   MemRefType allocMemRefType =
413       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
414   FailureOr<Value> allocated =
415       createAlloc(b, loc, allocMemRefType, dynShape, options);
416   if (failed(allocated))
417     return failure();
418   Value casted = allocated.getValue();
419   if (memRefType && memRefType != allocMemRefType) {
420     assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
421                                              memRefType) &&
422            "createAlloc: cast incompatible");
423     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
424   }
425 
426   if (deallocMemref) {
427     // 2. Create memory deallocation.
428     b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
429     if (failed(createDealloc(b, loc, allocated.getValue(), options)))
430       return failure();
431   }
432 
433   return casted;
434 }
435 
436 /// Create a memref allocation.
437 FailureOr<Value>
438 bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
439                            ArrayRef<Value> dynShape,
440                            const BufferizationOptions &options) {
441   if (options.allocationFn)
442     return (*options.allocationFn)(b, loc, type, dynShape);
443 
444   // Default bufferallocation via AllocOp.
445   Value allocated = b.create<memref::AllocOp>(
446       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
447   return allocated;
448 }
449 
450 /// Create a memref deallocation.
451 LogicalResult
452 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
453                              const BufferizationOptions &options) {
454   if (options.deallocationFn)
455     return (*options.deallocationFn)(b, loc, allocatedBuffer);
456 
457   // Default buffer deallocation via DeallocOp.
458   b.create<memref::DeallocOp>(loc, allocatedBuffer);
459   return success();
460 }
461 
462 /// Create a memory copy between two memref buffers.
463 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
464                                           Value from, Value to,
465                                           const BufferizationOptions &options) {
466   if (options.memCpyFn)
467     return (*options.memCpyFn)(b, loc, from, to);
468 
469   b.create<memref::CopyOp>(loc, from, to);
470   return success();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // Bufferization-specific BlockAndValueMapping support with debugging.
475 //===----------------------------------------------------------------------===//
476 
477 bool bufferization::isFunctionArgument(Value value) {
478   auto bbArg = value.dyn_cast<BlockArgument>();
479   if (!bbArg)
480     return false;
481   return isa<FuncOp>(bbArg.getOwner()->getParentOp());
482 }
483 
484 MemRefType
485 bufferization::getContiguousMemRefType(ShapedType shapedType,
486                                        MemRefLayoutAttrInterface layout,
487                                        Attribute memorySpace) {
488   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
489                          layout, memorySpace);
490 }
491 
492 UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType,
493                                                         Attribute memorySpace) {
494   return UnrankedMemRefType::get(elementType, memorySpace);
495 }
496 
497 MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType,
498                                                unsigned addressSpace) {
499   // TODO: address space decisions to connect with the actual alloc.
500   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
501   SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
502                                       ShapedType::kDynamicStrideOrOffset);
503   AffineMap stridedLayout = makeStridedLinearLayoutMap(
504       dynamicStrides, dynamicOffset, tensorType.getContext());
505   return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
506                          stridedLayout, addressSpace);
507 }
508