1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 //===----------------------------------------------------------------------===//
22 // BufferizableOpInterface
23 //===----------------------------------------------------------------------===//
24 
25 namespace mlir {
26 namespace bufferization {
27 
28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
29 
30 } // namespace bufferization
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "bufferizable-op-interface"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
36 
37 using namespace mlir;
38 using namespace bufferization;
39 
40 /// Attribute name used to mark region arguments that can be bufferized
41 /// in-place during linalg comprehensive bufferization.
42 constexpr const ::llvm::StringLiteral
43     bufferization::BufferizableOpInterface::kInplaceableAttrName;
44 
45 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
46     RewriterBase &rewriter, const AnalysisState &state) {
47   OpBuilder::InsertionGuard g(rewriter);
48   Operation *op = getOperation();
49   SmallVector<OpOperand *> outOfPlaceOpOperands;
50   SmallVector<OpResult> outOfPlaceOpResults;
51 
52   // Find all out-of-place OpOperands.
53   for (OpOperand &opOperand : op->getOpOperands()) {
54     Type operandType = opOperand.get().getType();
55     if (!operandType.isa<TensorType>())
56       continue;
57     if (state.isInPlace(opOperand))
58       continue;
59     if (operandType.isa<UnrankedTensorType>())
60       return op->emitError("copies of unranked tensors are not supported");
61 
62     SmallVector<OpResult> aliasingOpResults =
63         state.getAliasingOpResult(opOperand);
64     if (aliasingOpResults.size() == 1 &&
65         !state.bufferizesToMemoryWrite(opOperand) &&
66         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
67       // The op itself does not write but may create exactly one alias. Instead
68       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
69       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
70       // where the result is usually a smaller part of the source).
71       outOfPlaceOpResults.push_back(aliasingOpResults.front());
72     } else {
73       // In all other cases, make a copy of the OpOperand.
74       outOfPlaceOpOperands.push_back(&opOperand);
75     }
76   }
77 
78   // Insert copies of OpOperands.
79   rewriter.setInsertionPoint(op);
80   for (OpOperand *opOperand : outOfPlaceOpOperands) {
81     auto tensorType = opOperand->get().getType().cast<RankedTensorType>();
82     SmallVector<OpResult> aliasingOpResults =
83         state.getAliasingOpResult(*opOperand);
84     bool escape = llvm::any_of(
85         aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
86     Value copy = rewriter.create<AllocTensorOp>(
87         op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape);
88     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
89   }
90 
91   // Insert copies of OpResults.
92   rewriter.setInsertionPointAfter(op);
93   for (OpResult opResult : outOfPlaceOpResults) {
94     auto tensorType = opResult.getType().cast<RankedTensorType>();
95     bool escape = state.isTensorYielded(opResult);
96     Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType,
97                                                 ValueRange(), opResult, escape);
98     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
99         opResult.getUses(), [](OpOperand &use) { return &use; }));
100     for (OpOperand *use : uses) {
101       // Do not update the alloc_tensor op that we just created.
102       if (use->getOwner() != copy.getDefiningOp())
103         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
104     }
105   }
106 
107   return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // OpFilter
112 //===----------------------------------------------------------------------===//
113 
114 bool OpFilter::isOpAllowed(Operation *op) const {
115   // All other ops: Allow/disallow according to filter.
116   bool isAllowed = !hasAllowRule();
117   for (const Entry &entry : entries) {
118     bool filterResult = entry.fn(op);
119     switch (entry.type) {
120     case Entry::ALLOW:
121       isAllowed |= filterResult;
122       break;
123     case Entry::DENY:
124       if (filterResult)
125         // DENY filter matches. This op is no allowed. (Even if other ALLOW
126         // filters may match.)
127         return false;
128     };
129   }
130   return isAllowed;
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // BufferizationOptions
135 //===----------------------------------------------------------------------===//
136 
137 // Default constructor for BufferizationOptions.
138 BufferizationOptions::BufferizationOptions() = default;
139 
140 bool BufferizationOptions::isOpAllowed(Operation *op) const {
141   // Special case: If function boundary bufferization is deactivated, do not
142   // allow ops that belong to the `func` dialect.
143   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
144   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
145     return false;
146 
147   return opFilter.isOpAllowed(op);
148 }
149 
150 BufferizableOpInterface
151 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
152   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
153   if (!bufferizableOp)
154     return nullptr;
155   if (!isOpAllowed(op))
156     return nullptr;
157   return bufferizableOp;
158 }
159 
160 BufferizableOpInterface
161 BufferizationOptions::dynCastBufferizableOp(Value value) const {
162   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
163     if (isOpAllowed(bufferizableOp.getOperation()))
164       return bufferizableOp;
165   return nullptr;
166 }
167 
168 void BufferizationOptions::addDialectStateInitializer(
169     StringRef name, const DialectStateInitFn &fn) {
170   stateInitializers.push_back(
171       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Helper functions for BufferizableOpInterface
176 //===----------------------------------------------------------------------===//
177 
178 static void setInsertionPointAfter(OpBuilder &b, Value value) {
179   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
180     b.setInsertionPointToStart(bbArg.getOwner());
181   } else {
182     b.setInsertionPointAfter(value.getDefiningOp());
183   }
184 }
185 
186 /// Determine which OpOperand* will alias with `result` if the op is bufferized
187 /// in place. Return an empty vector if the op is not bufferizable.
188 SmallVector<OpOperand *>
189 AnalysisState::getAliasingOpOperand(OpResult result) const {
190   if (Operation *op = result.getDefiningOp())
191     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
192       return bufferizableOp.getAliasingOpOperand(result, *this);
193   return {};
194 }
195 
196 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
197 /// in place. Return an empty vector if the op is not bufferizable.
198 SmallVector<OpResult>
199 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
200   if (auto bufferizableOp =
201           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
202     return bufferizableOp.getAliasingOpResult(opOperand, *this);
203   return {};
204 }
205 
206 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
207 /// op is not bufferizable.
208 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
209   if (auto bufferizableOp =
210           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
211     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
212 
213   // Unknown op that returns a tensor. The inplace analysis does not support it.
214   // Conservatively return true.
215   return true;
216 }
217 
218 /// Return true if `opOperand` bufferizes to a memory write. Return
219 /// `true` if the op is not bufferizable.
220 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
221   if (auto bufferizableOp =
222           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
223     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
224 
225   // Unknown op that returns a tensor. The inplace analysis does not support it.
226   // Conservatively return true.
227   return true;
228 }
229 
230 /// Return true if `opOperand` does neither read nor write but bufferizes to an
231 /// alias. Return false if the op is not bufferizable.
232 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
233   if (auto bufferizableOp =
234           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
235     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
236 
237   // Unknown op that returns a tensor. The inplace analysis does not support it.
238   // Conservatively return false.
239   return false;
240 }
241 
242 /// Return true if the given value is read by an op that bufferizes to a memory
243 /// read. Also takes into account ops that create an alias but do not read by
244 /// themselves (e.g., ExtractSliceOp).
245 bool AnalysisState::isValueRead(Value value) const {
246   assert(value.getType().isa<TensorType>() && "expected TensorType");
247   SmallVector<OpOperand *> workingSet;
248   for (OpOperand &use : value.getUses())
249     workingSet.push_back(&use);
250 
251   while (!workingSet.empty()) {
252     OpOperand *uMaybeReading = workingSet.pop_back_val();
253     // Skip over all ops that neither read nor write (but create an alias).
254     if (bufferizesToAliasOnly(*uMaybeReading))
255       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
256         for (OpOperand &use : opResult.getUses())
257           workingSet.push_back(&use);
258     if (bufferizesToMemoryRead(*uMaybeReading))
259       return true;
260   }
261 
262   return false;
263 }
264 
265 // Starting from `value`, follow the use-def chain in reverse, always selecting
266 // the aliasing OpOperands. Find and return Values for which `condition`
267 // evaluates to true. OpOperands of such matching Values are not traversed any
268 // further.
269 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
270     Value value, llvm::function_ref<bool(Value)> condition) const {
271   llvm::SetVector<Value> result, workingSet;
272   workingSet.insert(value);
273 
274   while (!workingSet.empty()) {
275     Value value = workingSet.pop_back_val();
276     if (condition(value) || value.isa<BlockArgument>()) {
277       result.insert(value);
278       continue;
279     }
280 
281     OpResult opResult = value.cast<OpResult>();
282     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
283     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
284       result.insert(value);
285       continue;
286     }
287 
288     for (OpOperand *o : opOperands)
289       workingSet.insert(o->get());
290   }
291 
292   return result;
293 }
294 
295 // Find the Values of the last preceding write of a given Value.
296 llvm::SetVector<Value>
297 AnalysisState::findLastPrecedingWrite(Value value) const {
298   return findValueInReverseUseDefChain(value, [&](Value value) {
299     Operation *op = value.getDefiningOp();
300     if (!op)
301       return true;
302     auto bufferizableOp = options.dynCastBufferizableOp(op);
303     if (!bufferizableOp)
304       return true;
305     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
306   });
307 }
308 
309 AnalysisState::AnalysisState(const BufferizationOptions &options)
310     : options(options) {
311   for (const BufferizationOptions::AnalysisStateInitFn &fn :
312        options.stateInitializers)
313     fn(*this);
314 }
315 
316 // bufferization.to_memref is not allowed to change the rank.
317 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
318 #ifndef NDEBUG
319   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
320   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
321                                    rankedTensorType.getRank()) &&
322          "to_memref would be invalid: mismatching ranks");
323 #endif
324 }
325 
326 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
327                                         const BufferizationOptions &options) {
328   auto tensorType = tensor.getType().dyn_cast<TensorType>();
329   assert(tensorType && "unexpected non-tensor type");
330 
331   // Replace "%t = to_tensor %m" with %m.
332   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
333     return toTensorOp.memref();
334 
335   // Insert to_memref op.
336   OpBuilder::InsertionGuard g(rewriter);
337   setInsertionPointAfter(rewriter, tensor);
338   Type memrefType = getMemRefType(tensorType, options);
339   ensureToMemrefOpIsValid(tensor, memrefType);
340   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
341                                                     tensor);
342 }
343 
344 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
345 /// a new buffer and copy over data from the existing buffer if out-of-place
346 /// bufferization was decided.
347 FailureOr<Value>
348 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
349                               Optional<ForceInPlacability> overrideInPlace,
350                               Optional<Operation *> customCopyInsertionPoint) {
351   const BufferizationOptions &options = analysisState.getOptions();
352   OpBuilder::InsertionGuard guard(rewriter);
353   Operation *op = opOperand.getOwner();
354   Location loc = op->getLoc();
355   SmallVector<OpResult> aliasingOpResults =
356       analysisState.getAliasingOpResult(opOperand);
357   Value operand = opOperand.get();
358   Value operandBuffer = lookupBuffer(rewriter, operand, options);
359 
360   // Can `operandBuffer` be used directly or do we need a copy?
361   bool inplace =
362       overrideInPlace != FORCE_OUT_OF_PLACE &&
363       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
364   if (inplace)
365     return operandBuffer;
366 
367   // Bufferizing out-of-place: Allocate a new buffer.
368   // Move insertion point right after `operandBuffer`. That is where the
369   // allocation should be inserted (in the absence of allocation hoisting).
370   setInsertionPointAfter(rewriter, operandBuffer);
371   // Allocate the result buffer. The buffer should be deallocated if the tensor
372   // is not yielded and deallocs are enabled in general.
373   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
374     return getAnalysisState().isTensorYielded(v);
375   });
376   FailureOr<Value> resultBuffer = createAlloc(
377       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
378   if (failed(resultBuffer))
379     return failure();
380   // Do not copy the buffer if its contents are undefined.
381   if (analysisState.hasUndefinedContents(&opOperand))
382     return resultBuffer;
383   // Do not copy if the copied data is never read.
384   if (!aliasingOpResults.empty() &&
385       !analysisState.bufferizesToMemoryRead(opOperand) &&
386       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
387         return analysisState.isValueRead(opResult);
388       }))
389     return resultBuffer;
390   // Do not copy if this op does not read the data, but writes it.
391   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
392       !analysisState.bufferizesToMemoryRead(opOperand))
393     return resultBuffer;
394 
395   if (customCopyInsertionPoint) {
396     rewriter.setInsertionPoint(*customCopyInsertionPoint);
397   } else {
398     // The copy happens right before the op that is bufferized.
399     rewriter.setInsertionPoint(op);
400   }
401   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
402     return failure();
403 
404   return resultBuffer;
405 }
406 
407 /// Return the buffer type for a given Value (tensor) after bufferization.
408 BaseMemRefType BufferizationState::getBufferType(Value value) const {
409   auto tensorType = value.getType().dyn_cast<TensorType>();
410   assert(tensorType && "unexpected non-tensor type");
411 
412   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
413     return toTensorOp.memref().getType().cast<BaseMemRefType>();
414 
415   return getMemRefType(tensorType, getOptions());
416 }
417 
418 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
419                                                   Operation *op,
420                                                   ValueRange values) {
421   assert(values.size() == op->getNumResults() &&
422          "expected one value per OpResult");
423   OpBuilder::InsertionGuard g(rewriter);
424 
425   // Replace all OpResults with the given values.
426   SmallVector<Value> replacements;
427   for (OpResult opResult : op->getOpResults()) {
428     Value replacement = values[opResult.getResultNumber()];
429     if (opResult.getType().isa<TensorType>()) {
430       // The OpResult is a tensor. Such values are replaced with memrefs during
431       // bufferization.
432       assert((replacement.getType().isa<MemRefType>() ||
433               replacement.getType().isa<UnrankedMemRefType>()) &&
434              "tensor op result should be replaced with a memref value");
435       // The existing uses of the OpResult still expect a tensor. Insert a
436       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
437       // loose all of its users and eventually DCE away.
438       rewriter.setInsertionPointAfter(op);
439       replacement = rewriter.create<bufferization::ToTensorOp>(
440           replacement.getLoc(), replacement);
441     }
442     replacements.push_back(replacement);
443   }
444 
445   rewriter.replaceOp(op, replacements);
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // Bufferization-specific scoped alloc/dealloc insertion support.
450 //===----------------------------------------------------------------------===//
451 
452 /// Create a memref allocation with the given type and dynamic extents.
453 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
454                                                    MemRefType type,
455                                                    ValueRange dynShape) const {
456   if (allocationFn)
457     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
458 
459   // Default bufferallocation via AllocOp.
460   Value allocated = b.create<memref::AllocOp>(
461       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
462   return allocated;
463 }
464 
465 /// Creates a memref deallocation. The given memref buffer must have been
466 /// allocated using `createAlloc`.
467 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
468                                                   Value allocatedBuffer) const {
469   if (deallocationFn)
470     return (*deallocationFn)(b, loc, allocatedBuffer);
471 
472   // Default buffer deallocation via DeallocOp.
473   b.create<memref::DeallocOp>(loc, allocatedBuffer);
474   return success();
475 }
476 
477 static MemRefType getContiguousMemRefType(ShapedType shapedType,
478                                           Attribute memorySpace = {}) {
479   MemRefLayoutAttrInterface layout = {};
480   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
481                          layout, memorySpace);
482 }
483 
484 /// Compute the type of the `memref` to use for allocating the buffer for
485 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
486 /// dynamic dimensions in the returned `memref` type.
487 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
488                                             Value shapedValue,
489                                             SmallVectorImpl<Value> &dynShape) {
490   MemRefType allocMemRefType =
491       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
492 
493   // Compute the dynamic part of the shape.
494   bool reifiedShapes = false;
495   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
496           shapedValue.getDefiningOp())) {
497     ReifiedRankedShapedTypeDims resultDims;
498     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
499       reifiedShapes = true;
500       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
501       auto &shape = resultDims[resultValue.getResultNumber()];
502       for (const auto &dim : enumerate(allocMemRefType.getShape()))
503         if (ShapedType::isDynamic(dim.value()))
504           dynShape.push_back(shape[dim.index()]);
505     }
506   }
507 
508   if (!reifiedShapes) {
509     for (const auto &dim : enumerate(allocMemRefType.getShape()))
510       if (ShapedType::isDynamic(dim.value())) {
511         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
512                 shapedValue.getType().isa<MemRefType>()) &&
513                "expected MemRef type");
514         dynShape.push_back(
515             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
516       }
517   }
518 
519   return allocMemRefType;
520 }
521 
522 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
523 /// block in case of a bbArg).
524 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
525                                                  Value shapedValue,
526                                                  Optional<bool> dealloc) {
527   // Take a guard before anything else.
528   OpBuilder::InsertionGuard g(b);
529 
530   // Compute allocation memref type.
531   assert(shapedValue.getType().isa<ShapedType>());
532   SmallVector<Value> dynShape;
533   MemRefType allocMemRefType =
534       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
535 
536   // Create the buffer allocation.
537   FailureOr<Value> buffer =
538       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
539   if (failed(buffer))
540     return failure();
541 
542   // Should be the buffer be deallocated again or should we let it leak?
543   if (dealloc) {
544     if (!dealloc.getValue())
545       return *buffer;
546   } else {
547     assert(shapedValue.getType().isa<TensorType>() &&
548            "must specify `dealloc` if non-tensor value is passed");
549     // Buffer should be not be deallocated if deallocs are generally deactivated
550     // or if the tensor is yielded from a block.
551     if (!getOptions().createDeallocs ||
552         getAnalysisState().isTensorYielded(shapedValue))
553       return *buffer;
554   }
555 
556   // Create buffer deallocation.
557   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
558   if (failed(getOptions().createDealloc(b, loc, *buffer)))
559     return failure();
560 
561   return *buffer;
562 }
563 
564 /// Create a memory copy between two memref buffers.
565 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
566                                                  Value from, Value to) const {
567   if (memCpyFn)
568     return (*memCpyFn)(b, loc, from, to);
569 
570   b.create<memref::CopyOp>(loc, from, to);
571   return success();
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // Bufferization-specific BlockAndValueMapping support with debugging.
576 //===----------------------------------------------------------------------===//
577 
578 bool bufferization::isFunctionArgument(Value value) {
579   auto bbArg = value.dyn_cast<BlockArgument>();
580   if (!bbArg)
581     return false;
582   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
583 }
584 
585 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
586                                             const BufferizationOptions &options,
587                                             MemRefLayoutAttrInterface layout,
588                                             Attribute memorySpace) {
589   // Case 1: Unranked memref type.
590   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
591     assert(!layout && "UnrankedTensorType cannot have a layout map");
592     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
593                                    memorySpace);
594   }
595 
596   // Case 2: Ranked memref type with specified layout.
597   auto rankedTensorType = tensorType.cast<RankedTensorType>();
598   if (layout) {
599     return MemRefType::get(rankedTensorType.getShape(),
600                            rankedTensorType.getElementType(), layout,
601                            memorySpace);
602   }
603 
604   // Case 3: Configured with "fully dynamic layout maps".
605   if (options.unknownTypeConversion ==
606       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
607     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
608 
609   // Case 4: Configured with "static identity layout maps".
610   if (options.unknownTypeConversion ==
611       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
612     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
613 
614   llvm_unreachable("InferLayoutMap is an invalid option");
615 }
616 
617 BaseMemRefType
618 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
619                                                    Attribute memorySpace) {
620   // Case 1: Unranked memref type.
621   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
622     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
623                                    memorySpace);
624   }
625 
626   // Case 2: Ranked memref type.
627   auto rankedTensorType = tensorType.cast<RankedTensorType>();
628   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
629   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
630                                       ShapedType::kDynamicStrideOrOffset);
631   AffineMap stridedLayout = makeStridedLinearLayoutMap(
632       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
633   return MemRefType::get(rankedTensorType.getShape(),
634                          rankedTensorType.getElementType(), stridedLayout,
635                          memorySpace);
636 }
637 
638 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
639 /// the given tensor type is unranked, return an unranked MemRef type.
640 BaseMemRefType
641 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
642                                                      Attribute memorySpace) {
643   // Case 1: Unranked memref type.
644   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
645     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
646                                    memorySpace);
647   }
648 
649   // Case 2: Ranked memref type.
650   auto rankedTensorType = tensorType.cast<RankedTensorType>();
651   MemRefLayoutAttrInterface layout = {};
652   return MemRefType::get(rankedTensorType.getShape(),
653                          rankedTensorType.getElementType(), layout,
654                          memorySpace);
655 }
656