1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. Only ranked tensors are
47 /// supported at the moment. If `copy` is set, the shaped value is copied.
48 /// Otherwise, a tensor with undefined contents is allocated.
49 static Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                           Value shapedValue, bool escape,
51                                           bool copy = true) {
52   auto tensorType = shapedValue.getType().dyn_cast<RankedTensorType>();
53   assert(tensorType && "only RankedTensorType supported at the moment");
54   Value alloc;
55   if (!copy) {
56     // No copy needed: Just allocate.
57     SmallVector<Value> dynamicSizes;
58     for (int64_t i = 0; i < tensorType.getRank(); ++i)
59       if (tensorType.isDynamicDim(i))
60         dynamicSizes.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
61     alloc = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
62                                     /*copy=*/Value(), escape);
63   } else {
64     // Allocate and copy.
65     alloc = b.create<AllocTensorOp>(loc, tensorType,
66                                     /*dynamicSizes=*/ValueRange(), shapedValue,
67                                     escape);
68   }
69   return alloc;
70 }
71 
72 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
73     RewriterBase &rewriter, const AnalysisState &state) {
74   OpBuilder::InsertionGuard g(rewriter);
75   Operation *op = getOperation();
76   SmallVector<OpOperand *> outOfPlaceOpOperands;
77   DenseSet<OpOperand *> copiedOpOperands;
78   DenseSet<OpOperand *> escapingOpOperandCopies;
79   SmallVector<OpResult> outOfPlaceOpResults;
80   DenseSet<OpResult> copiedOpResults;
81   DenseSet<OpResult> escapingOpResultCopies;
82 
83   // Find all out-of-place OpOperands.
84   for (OpOperand &opOperand : op->getOpOperands()) {
85     Type operandType = opOperand.get().getType();
86     if (!operandType.isa<TensorType>())
87       continue;
88     if (state.isInPlace(opOperand))
89       continue;
90     if (operandType.isa<UnrankedTensorType>())
91       return op->emitError("copies of unranked tensors are not supported");
92 
93     SmallVector<OpResult> aliasingOpResults =
94         state.getAliasingOpResult(opOperand);
95     // Is the result yielded from a block? Or are deallocations turned off
96     // entirely? In either case, mark the allocation as "escaping", so that it
97     // will not be deallocated.
98     bool escape = !state.getOptions().createDeallocs ||
99                   llvm::any_of(aliasingOpResults, [&](Value v) {
100                     return state.isTensorYielded(v);
101                   });
102 
103     if (aliasingOpResults.size() == 1 &&
104         !state.bufferizesToMemoryWrite(opOperand) &&
105         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
106       // The op itself does not write but may create exactly one alias. Instead
107       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
108       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
109       // where the result is usually a smaller part of the source).
110       outOfPlaceOpResults.push_back(aliasingOpResults.front());
111       if (!state.canOmitTensorCopy(opOperand))
112         copiedOpResults.insert(aliasingOpResults.front());
113       if (escape)
114         escapingOpResultCopies.insert(aliasingOpResults.front());
115     } else {
116       // In all other cases, make a copy of the OpOperand.
117       outOfPlaceOpOperands.push_back(&opOperand);
118       if (!state.canOmitTensorCopy(opOperand))
119         copiedOpOperands.insert(&opOperand);
120       if (escape)
121         escapingOpOperandCopies.insert(&opOperand);
122     }
123   }
124 
125   // Insert copies of OpOperands.
126   rewriter.setInsertionPoint(op);
127   for (OpOperand *opOperand : outOfPlaceOpOperands) {
128     Value copy = allocateTensorForShapedValue(
129         rewriter, op->getLoc(), opOperand->get(),
130         escapingOpOperandCopies.contains(opOperand),
131         copiedOpOperands.contains(opOperand));
132     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
133   }
134 
135   // Insert copies of OpResults.
136   rewriter.setInsertionPointAfter(op);
137   for (OpResult opResult : outOfPlaceOpResults) {
138     Value copy =
139         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
140                                      escapingOpResultCopies.contains(opResult),
141                                      copiedOpResults.count(opResult));
142     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
143         opResult.getUses(), [](OpOperand &use) { return &use; }));
144     for (OpOperand *use : uses) {
145       // Do not update the alloc_tensor op that we just created.
146       if (use->getOwner() != copy.getDefiningOp())
147         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
148     }
149   }
150 
151   return success();
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // OpFilter
156 //===----------------------------------------------------------------------===//
157 
158 bool OpFilter::isOpAllowed(Operation *op) const {
159   // All other ops: Allow/disallow according to filter.
160   bool isAllowed = !hasAllowRule();
161   for (const Entry &entry : entries) {
162     bool filterResult = entry.fn(op);
163     switch (entry.type) {
164     case Entry::ALLOW:
165       isAllowed |= filterResult;
166       break;
167     case Entry::DENY:
168       if (filterResult)
169         // DENY filter matches. This op is no allowed. (Even if other ALLOW
170         // filters may match.)
171         return false;
172     };
173   }
174   return isAllowed;
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // BufferizationOptions
179 //===----------------------------------------------------------------------===//
180 
181 // Default constructor for BufferizationOptions.
182 BufferizationOptions::BufferizationOptions() = default;
183 
184 bool BufferizationOptions::isOpAllowed(Operation *op) const {
185   // Special case: If function boundary bufferization is deactivated, do not
186   // allow ops that belong to the `func` dialect.
187   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
188   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
189     return false;
190 
191   return opFilter.isOpAllowed(op);
192 }
193 
194 BufferizableOpInterface
195 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
196   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
197   if (!bufferizableOp)
198     return nullptr;
199   if (!isOpAllowed(op))
200     return nullptr;
201   return bufferizableOp;
202 }
203 
204 BufferizableOpInterface
205 BufferizationOptions::dynCastBufferizableOp(Value value) const {
206   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
207     if (isOpAllowed(bufferizableOp.getOperation()))
208       return bufferizableOp;
209   return nullptr;
210 }
211 
212 void BufferizationOptions::addDialectStateInitializer(
213     StringRef name, const DialectStateInitFn &fn) {
214   stateInitializers.push_back(
215       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // Helper functions for BufferizableOpInterface
220 //===----------------------------------------------------------------------===//
221 
222 static void setInsertionPointAfter(OpBuilder &b, Value value) {
223   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
224     b.setInsertionPointToStart(bbArg.getOwner());
225   } else {
226     b.setInsertionPointAfter(value.getDefiningOp());
227   }
228 }
229 
230 /// Determine which OpOperand* will alias with `result` if the op is bufferized
231 /// in place. Return an empty vector if the op is not bufferizable.
232 SmallVector<OpOperand *>
233 AnalysisState::getAliasingOpOperand(OpResult result) const {
234   if (Operation *op = result.getDefiningOp())
235     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
236       return bufferizableOp.getAliasingOpOperand(result, *this);
237   return {};
238 }
239 
240 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
241 /// in place. Return an empty vector if the op is not bufferizable.
242 SmallVector<OpResult>
243 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
244   if (auto bufferizableOp =
245           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
246     return bufferizableOp.getAliasingOpResult(opOperand, *this);
247   return {};
248 }
249 
250 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
251 /// op is not bufferizable.
252 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
253   if (auto bufferizableOp =
254           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
255     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
256 
257   // Unknown op that returns a tensor. The inplace analysis does not support it.
258   // Conservatively return true.
259   return true;
260 }
261 
262 /// Return true if `opOperand` bufferizes to a memory write. Return
263 /// `true` if the op is not bufferizable.
264 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
265   if (auto bufferizableOp =
266           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
267     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
268 
269   // Unknown op that returns a tensor. The inplace analysis does not support it.
270   // Conservatively return true.
271   return true;
272 }
273 
274 /// Return true if `opOperand` does neither read nor write but bufferizes to an
275 /// alias. Return false if the op is not bufferizable.
276 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
277   if (auto bufferizableOp =
278           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
279     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
280 
281   // Unknown op that returns a tensor. The inplace analysis does not support it.
282   // Conservatively return false.
283   return false;
284 }
285 
286 /// Return true if the given value is read by an op that bufferizes to a memory
287 /// read. Also takes into account ops that create an alias but do not read by
288 /// themselves (e.g., ExtractSliceOp).
289 bool AnalysisState::isValueRead(Value value) const {
290   assert(value.getType().isa<TensorType>() && "expected TensorType");
291   SmallVector<OpOperand *> workingSet;
292   for (OpOperand &use : value.getUses())
293     workingSet.push_back(&use);
294 
295   while (!workingSet.empty()) {
296     OpOperand *uMaybeReading = workingSet.pop_back_val();
297     // Skip over all ops that neither read nor write (but create an alias).
298     if (bufferizesToAliasOnly(*uMaybeReading))
299       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
300         for (OpOperand &use : opResult.getUses())
301           workingSet.push_back(&use);
302     if (bufferizesToMemoryRead(*uMaybeReading))
303       return true;
304   }
305 
306   return false;
307 }
308 
309 // Starting from `value`, follow the use-def chain in reverse, always selecting
310 // the aliasing OpOperands. Find and return Values for which `condition`
311 // evaluates to true. OpOperands of such matching Values are not traversed any
312 // further.
313 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
314     Value value, llvm::function_ref<bool(Value)> condition) const {
315   llvm::SetVector<Value> result, workingSet;
316   workingSet.insert(value);
317 
318   while (!workingSet.empty()) {
319     Value value = workingSet.pop_back_val();
320     if (condition(value) || value.isa<BlockArgument>()) {
321       result.insert(value);
322       continue;
323     }
324 
325     OpResult opResult = value.cast<OpResult>();
326     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
327     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
328       result.insert(value);
329       continue;
330     }
331 
332     for (OpOperand *o : opOperands)
333       workingSet.insert(o->get());
334   }
335 
336   return result;
337 }
338 
339 // Find the Values of the last preceding write of a given Value.
340 llvm::SetVector<Value>
341 AnalysisState::findLastPrecedingWrite(Value value) const {
342   return findValueInReverseUseDefChain(value, [&](Value value) {
343     Operation *op = value.getDefiningOp();
344     if (!op)
345       return true;
346     auto bufferizableOp = options.dynCastBufferizableOp(op);
347     if (!bufferizableOp)
348       return true;
349     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
350   });
351 }
352 
353 AnalysisState::AnalysisState(const BufferizationOptions &options)
354     : options(options) {
355   for (const BufferizationOptions::AnalysisStateInitFn &fn :
356        options.stateInitializers)
357     fn(*this);
358 }
359 
360 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
361   // Do not copy if the tensor has undefined contents.
362   if (hasUndefinedContents(&opOperand))
363     return true;
364 
365   // Do not copy if the buffer of the tensor is entirely overwritten (with
366   // values that do not depend on the old tensor).
367   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
368     return true;
369 
370   // Do not copy if the tensor is never read.
371   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
372   if (!bufferizesToMemoryRead(opOperand) &&
373       llvm::none_of(aliasingOpResults,
374                     [&](OpResult opResult) { return isValueRead(opResult); }))
375     return true;
376 
377   // Default: Cannot omit the copy.
378   return false;
379 }
380 
381 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
382   // In the absence of analysis information, OpOperands that bufferize to a
383   // memory write are out-of-place, i.e., an alloc and copy is inserted.
384   return !bufferizesToMemoryWrite(opOperand);
385 }
386 
387 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
388   // In the absence of analysis information, we do not know if the values are
389   // equivalent. The conservative answer is "false".
390   return false;
391 }
392 
393 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
394   // In the absence of analysis information, we do not know if the values may be
395   // aliasing. The conservative answer is "true".
396   return true;
397 }
398 
399 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
400   // In the absence of analysis information, the conservative answer is "false".
401   return false;
402 }
403 
404 bool AnalysisState::isTensorYielded(Value tensor) const {
405   // In the absence of analysis information, the conservative answer is "true".
406   if (!tensor.getDefiningOp<AllocTensorOp>())
407     return true;
408 
409   // For AllocTensorOp results, we can do better: They do not alias with any
410   // preceding value, so we can follow SSA use-def chains and do a simple
411   // analysis.
412   SmallVector<OpOperand *> worklist;
413   for (OpOperand &use : tensor.getUses())
414     worklist.push_back(&use);
415 
416   while (!worklist.empty()) {
417     OpOperand *operand = worklist.pop_back_val();
418     Operation *op = operand->getOwner();
419 
420     // If the op is not bufferizable, we can safely assume that the value is not
421     // yielded. (When bufferizing that op, it must handle such cases.)
422     if (!options.dynCastBufferizableOp(op))
423       continue;
424 
425     // We cannot analyze through ToMemrefOps, so we have to conservatively
426     // assume that the value is yielded.
427     if (isa<ToMemrefOp>(op))
428       return true;
429 
430     // Check if the op is returning/yielding.
431     if (isRegionReturnLike(op))
432       return true;
433 
434     // Add all aliasing OpResults to the worklist.
435     // Note: In the absence of detailed analysis information (e.g., there may be
436     // no function call analysis information), this `getAliasingOpResult` is
437     // conservative and may report additional OpResults as potentially aliasing.
438     for (OpResult opResult : getAliasingOpResult(*operand))
439       for (OpOperand &use : opResult.getUses())
440         worklist.push_back(&use);
441   }
442 
443   // No ReturnLike op found: The value is not yielded.
444   return false;
445 }
446 
447 // bufferization.to_memref is not allowed to change the rank.
448 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
449 #ifndef NDEBUG
450   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
451   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
452                                    rankedTensorType.getRank()) &&
453          "to_memref would be invalid: mismatching ranks");
454 #endif
455 }
456 
457 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
458                                         const BufferizationOptions &options) {
459   auto tensorType = tensor.getType().dyn_cast<TensorType>();
460   assert(tensorType && "unexpected non-tensor type");
461 
462   // Replace "%t = to_tensor %m" with %m.
463   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
464     return toTensorOp.memref();
465 
466   // Insert to_memref op.
467   OpBuilder::InsertionGuard g(rewriter);
468   setInsertionPointAfter(rewriter, tensor);
469   Type memrefType = getMemRefType(tensorType, options);
470   ensureToMemrefOpIsValid(tensor, memrefType);
471   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
472                                                     tensor);
473 }
474 
475 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
476 /// a new buffer and copy over data from the existing buffer if out-of-place
477 /// bufferization was decided.
478 FailureOr<Value>
479 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
480                               Optional<ForceInPlacability> overrideInPlace,
481                               Optional<Operation *> customCopyInsertionPoint) {
482   const BufferizationOptions &options = analysisState.getOptions();
483   OpBuilder::InsertionGuard guard(rewriter);
484   Operation *op = opOperand.getOwner();
485   Location loc = op->getLoc();
486   SmallVector<OpResult> aliasingOpResults =
487       analysisState.getAliasingOpResult(opOperand);
488   Value operand = opOperand.get();
489   Value operandBuffer = lookupBuffer(rewriter, operand, options);
490 
491   // Can `operandBuffer` be used directly or do we need a copy?
492   bool inplace =
493       overrideInPlace != FORCE_OUT_OF_PLACE &&
494       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
495   if (inplace)
496     return operandBuffer;
497 
498   // Bufferizing out-of-place: Allocate a new buffer.
499   // Move insertion point right after `operandBuffer`. That is where the
500   // allocation should be inserted (in the absence of allocation hoisting).
501   setInsertionPointAfter(rewriter, operandBuffer);
502   // Allocate the result buffer. The buffer should be deallocated if the tensor
503   // is not yielded and deallocs are enabled in general.
504   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
505     return getAnalysisState().isTensorYielded(v);
506   });
507   FailureOr<Value> resultBuffer = createAlloc(
508       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
509   if (failed(resultBuffer))
510     return failure();
511   // Do not copy the buffer if its contents are undefined.
512   if (analysisState.hasUndefinedContents(&opOperand))
513     return resultBuffer;
514   // Do not copy if the copied data is never read.
515   if (!aliasingOpResults.empty() &&
516       !analysisState.bufferizesToMemoryRead(opOperand) &&
517       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
518         return analysisState.isValueRead(opResult);
519       }))
520     return resultBuffer;
521   // Do not copy if this op does not read the data, but writes it.
522   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
523       !analysisState.bufferizesToMemoryRead(opOperand))
524     return resultBuffer;
525 
526   if (customCopyInsertionPoint) {
527     rewriter.setInsertionPoint(*customCopyInsertionPoint);
528   } else {
529     // The copy happens right before the op that is bufferized.
530     rewriter.setInsertionPoint(op);
531   }
532   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
533     return failure();
534 
535   return resultBuffer;
536 }
537 
538 /// Return the buffer type for a given Value (tensor) after bufferization.
539 BaseMemRefType BufferizationState::getBufferType(Value value) const {
540   auto tensorType = value.getType().dyn_cast<TensorType>();
541   assert(tensorType && "unexpected non-tensor type");
542 
543   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
544     return toTensorOp.memref().getType().cast<BaseMemRefType>();
545 
546   return getMemRefType(tensorType, getOptions());
547 }
548 
549 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
550                                                   Operation *op,
551                                                   ValueRange values) {
552   assert(values.size() == op->getNumResults() &&
553          "expected one value per OpResult");
554   OpBuilder::InsertionGuard g(rewriter);
555 
556   // Replace all OpResults with the given values.
557   SmallVector<Value> replacements;
558   for (OpResult opResult : op->getOpResults()) {
559     Value replacement = values[opResult.getResultNumber()];
560     if (opResult.getType().isa<TensorType>()) {
561       // The OpResult is a tensor. Such values are replaced with memrefs during
562       // bufferization.
563       assert((replacement.getType().isa<MemRefType>() ||
564               replacement.getType().isa<UnrankedMemRefType>()) &&
565              "tensor op result should be replaced with a memref value");
566       // The existing uses of the OpResult still expect a tensor. Insert a
567       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
568       // loose all of its users and eventually DCE away.
569       rewriter.setInsertionPointAfter(op);
570       replacement = rewriter.create<bufferization::ToTensorOp>(
571           replacement.getLoc(), replacement);
572     }
573     replacements.push_back(replacement);
574   }
575 
576   rewriter.replaceOp(op, replacements);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // Bufferization-specific scoped alloc/dealloc insertion support.
581 //===----------------------------------------------------------------------===//
582 
583 /// Create a memref allocation with the given type and dynamic extents.
584 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
585                                                    MemRefType type,
586                                                    ValueRange dynShape) const {
587   if (allocationFn)
588     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
589 
590   // Default bufferallocation via AllocOp.
591   Value allocated = b.create<memref::AllocOp>(
592       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
593   return allocated;
594 }
595 
596 /// Creates a memref deallocation. The given memref buffer must have been
597 /// allocated using `createAlloc`.
598 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
599                                                   Value allocatedBuffer) const {
600   if (deallocationFn)
601     return (*deallocationFn)(b, loc, allocatedBuffer);
602 
603   // Default buffer deallocation via DeallocOp.
604   b.create<memref::DeallocOp>(loc, allocatedBuffer);
605   return success();
606 }
607 
608 static MemRefType getContiguousMemRefType(ShapedType shapedType,
609                                           Attribute memorySpace = {}) {
610   MemRefLayoutAttrInterface layout = {};
611   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
612                          layout, memorySpace);
613 }
614 
615 /// Compute the type of the `memref` to use for allocating the buffer for
616 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
617 /// dynamic dimensions in the returned `memref` type.
618 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
619                                             Value shapedValue,
620                                             SmallVectorImpl<Value> &dynShape) {
621   MemRefType allocMemRefType =
622       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
623 
624   // Compute the dynamic part of the shape.
625   bool reifiedShapes = false;
626   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
627           shapedValue.getDefiningOp())) {
628     ReifiedRankedShapedTypeDims resultDims;
629     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
630       reifiedShapes = true;
631       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
632       auto &shape = resultDims[resultValue.getResultNumber()];
633       for (const auto &dim : enumerate(allocMemRefType.getShape()))
634         if (ShapedType::isDynamic(dim.value()))
635           dynShape.push_back(shape[dim.index()]);
636     }
637   }
638 
639   if (!reifiedShapes) {
640     for (const auto &dim : enumerate(allocMemRefType.getShape()))
641       if (ShapedType::isDynamic(dim.value())) {
642         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
643                 shapedValue.getType().isa<MemRefType>()) &&
644                "expected MemRef type");
645         dynShape.push_back(
646             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
647       }
648   }
649 
650   return allocMemRefType;
651 }
652 
653 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
654 /// block in case of a bbArg).
655 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
656                                                  Value shapedValue,
657                                                  Optional<bool> dealloc) {
658   // Take a guard before anything else.
659   OpBuilder::InsertionGuard g(b);
660 
661   // Compute allocation memref type.
662   assert(shapedValue.getType().isa<ShapedType>());
663   SmallVector<Value> dynShape;
664   MemRefType allocMemRefType =
665       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
666 
667   // Create the buffer allocation.
668   FailureOr<Value> buffer =
669       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
670   if (failed(buffer))
671     return failure();
672 
673   // Should be the buffer be deallocated again or should we let it leak?
674   if (dealloc) {
675     if (!dealloc.getValue())
676       return *buffer;
677   } else {
678     assert(shapedValue.getType().isa<TensorType>() &&
679            "must specify `dealloc` if non-tensor value is passed");
680     // Buffer should be not be deallocated if deallocs are generally deactivated
681     // or if the tensor is yielded from a block.
682     if (!getOptions().createDeallocs ||
683         getAnalysisState().isTensorYielded(shapedValue))
684       return *buffer;
685   }
686 
687   // Create buffer deallocation.
688   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
689   if (failed(getOptions().createDealloc(b, loc, *buffer)))
690     return failure();
691 
692   return *buffer;
693 }
694 
695 /// Create a memory copy between two memref buffers.
696 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
697                                                  Value from, Value to) const {
698   if (memCpyFn)
699     return (*memCpyFn)(b, loc, from, to);
700 
701   b.create<memref::CopyOp>(loc, from, to);
702   return success();
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // Bufferization-specific BlockAndValueMapping support with debugging.
707 //===----------------------------------------------------------------------===//
708 
709 bool bufferization::isFunctionArgument(Value value) {
710   auto bbArg = value.dyn_cast<BlockArgument>();
711   if (!bbArg)
712     return false;
713   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
714 }
715 
716 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
717                                             const BufferizationOptions &options,
718                                             MemRefLayoutAttrInterface layout,
719                                             Attribute memorySpace) {
720   // Case 1: Unranked memref type.
721   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
722     assert(!layout && "UnrankedTensorType cannot have a layout map");
723     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
724                                    memorySpace);
725   }
726 
727   // Case 2: Ranked memref type with specified layout.
728   auto rankedTensorType = tensorType.cast<RankedTensorType>();
729   if (layout) {
730     return MemRefType::get(rankedTensorType.getShape(),
731                            rankedTensorType.getElementType(), layout,
732                            memorySpace);
733   }
734 
735   // Case 3: Configured with "fully dynamic layout maps".
736   if (options.unknownTypeConversion ==
737       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
738     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
739 
740   // Case 4: Configured with "static identity layout maps".
741   if (options.unknownTypeConversion ==
742       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
743     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
744 
745   llvm_unreachable("InferLayoutMap is an invalid option");
746 }
747 
748 BaseMemRefType
749 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
750                                                    Attribute memorySpace) {
751   // Case 1: Unranked memref type.
752   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
753     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
754                                    memorySpace);
755   }
756 
757   // Case 2: Ranked memref type.
758   auto rankedTensorType = tensorType.cast<RankedTensorType>();
759   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
760   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
761                                       ShapedType::kDynamicStrideOrOffset);
762   AffineMap stridedLayout = makeStridedLinearLayoutMap(
763       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
764   return MemRefType::get(rankedTensorType.getShape(),
765                          rankedTensorType.getElementType(), stridedLayout,
766                          memorySpace);
767 }
768 
769 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
770 /// the given tensor type is unranked, return an unranked MemRef type.
771 BaseMemRefType
772 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
773                                                      Attribute memorySpace) {
774   // Case 1: Unranked memref type.
775   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
776     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
777                                    memorySpace);
778   }
779 
780   // Case 2: Ranked memref type.
781   auto rankedTensorType = tensorType.cast<RankedTensorType>();
782   MemRefLayoutAttrInterface layout = {};
783   return MemRefType::get(rankedTensorType.getShape(),
784                          rankedTensorType.getElementType(), layout,
785                          memorySpace);
786 }
787