1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
47 /// shaped value is copied. Otherwise, a tensor with undefined contents is
48 /// allocated.
49 Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                                   Value shapedValue,
51                                                   bool escape, bool copy) {
52   Value tensor;
53   if (shapedValue.getType().isa<RankedTensorType>()) {
54     tensor = shapedValue;
55   } else if (shapedValue.getType().isa<MemRefType>()) {
56     tensor = b.create<ToTensorOp>(loc, shapedValue);
57   } else {
58     llvm_unreachable("expected RankedTensorType or MemRefType");
59   }
60   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
61   SmallVector<Value> dynamicSizes;
62   if (!copy) {
63     // Compute the dynamic part of the shape.
64     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
65     bool reifiedShapes = false;
66     if (shapedValue.getType().isa<RankedTensorType>() &&
67         shapedValue.isa<OpResult>()) {
68       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
69               shapedValue.getDefiningOp())) {
70         ReifiedRankedShapedTypeDims resultDims;
71         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
72           reifiedShapes = true;
73           auto &shape =
74               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
75           for (const auto &dim : enumerate(tensorType.getShape()))
76             if (ShapedType::isDynamic(dim.value()))
77               dynamicSizes.push_back(shape[dim.index()]);
78         }
79       }
80     }
81 
82     // If the shape could not be reified, create DimOps.
83     if (!reifiedShapes)
84       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
85   }
86 
87   return b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
88                                  copy ? tensor : Value(), escape);
89 }
90 
91 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
92     RewriterBase &rewriter, const AnalysisState &state) {
93   OpBuilder::InsertionGuard g(rewriter);
94   Operation *op = getOperation();
95   SmallVector<OpOperand *> outOfPlaceOpOperands;
96   DenseSet<OpOperand *> copiedOpOperands;
97   DenseSet<OpOperand *> escapingOpOperandCopies;
98   SmallVector<OpResult> outOfPlaceOpResults;
99   DenseSet<OpResult> copiedOpResults;
100   DenseSet<OpResult> escapingOpResultCopies;
101 
102   // Find all out-of-place OpOperands.
103   for (OpOperand &opOperand : op->getOpOperands()) {
104     Type operandType = opOperand.get().getType();
105     if (!operandType.isa<TensorType>())
106       continue;
107     if (state.isInPlace(opOperand))
108       continue;
109     if (operandType.isa<UnrankedTensorType>())
110       return op->emitError("copies of unranked tensors are not supported");
111 
112     SmallVector<OpResult> aliasingOpResults =
113         state.getAliasingOpResult(opOperand);
114     // Is the result yielded from a block? Or are deallocations turned off
115     // entirely? In either case, mark the allocation as "escaping", so that it
116     // will not be deallocated.
117     bool escape = !state.getOptions().createDeallocs ||
118                   llvm::any_of(aliasingOpResults, [&](Value v) {
119                     return state.isTensorYielded(v);
120                   });
121 
122     if (aliasingOpResults.size() == 1 &&
123         !state.bufferizesToMemoryWrite(opOperand) &&
124         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
125       // The op itself does not write but may create exactly one alias. Instead
126       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
127       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
128       // where the result is usually a smaller part of the source).
129       outOfPlaceOpResults.push_back(aliasingOpResults.front());
130       if (!state.canOmitTensorCopy(opOperand))
131         copiedOpResults.insert(aliasingOpResults.front());
132       if (escape)
133         escapingOpResultCopies.insert(aliasingOpResults.front());
134     } else {
135       // In all other cases, make a copy of the OpOperand.
136       outOfPlaceOpOperands.push_back(&opOperand);
137       if (!state.canOmitTensorCopy(opOperand))
138         copiedOpOperands.insert(&opOperand);
139       if (escape)
140         escapingOpOperandCopies.insert(&opOperand);
141     }
142   }
143 
144   // Insert copies of OpOperands.
145   rewriter.setInsertionPoint(op);
146   for (OpOperand *opOperand : outOfPlaceOpOperands) {
147     Value copy = allocateTensorForShapedValue(
148         rewriter, op->getLoc(), opOperand->get(),
149         escapingOpOperandCopies.contains(opOperand),
150         copiedOpOperands.contains(opOperand));
151     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
152   }
153 
154   // Insert copies of OpResults.
155   rewriter.setInsertionPointAfter(op);
156   for (OpResult opResult : outOfPlaceOpResults) {
157     Value copy =
158         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
159                                      escapingOpResultCopies.contains(opResult),
160                                      copiedOpResults.count(opResult));
161     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
162         opResult.getUses(), [](OpOperand &use) { return &use; }));
163     for (OpOperand *use : uses) {
164       // Do not update the alloc_tensor op that we just created.
165       if (use->getOwner() != copy.getDefiningOp())
166         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
167     }
168   }
169 
170   return success();
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // OpFilter
175 //===----------------------------------------------------------------------===//
176 
177 bool OpFilter::isOpAllowed(Operation *op) const {
178   // All other ops: Allow/disallow according to filter.
179   bool isAllowed = !hasAllowRule();
180   for (const Entry &entry : entries) {
181     bool filterResult = entry.fn(op);
182     switch (entry.type) {
183     case Entry::ALLOW:
184       isAllowed |= filterResult;
185       break;
186     case Entry::DENY:
187       if (filterResult)
188         // DENY filter matches. This op is no allowed. (Even if other ALLOW
189         // filters may match.)
190         return false;
191     };
192   }
193   return isAllowed;
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // BufferizationOptions
198 //===----------------------------------------------------------------------===//
199 
200 // Default constructor for BufferizationOptions.
201 BufferizationOptions::BufferizationOptions() = default;
202 
203 bool BufferizationOptions::isOpAllowed(Operation *op) const {
204   // Special case: If function boundary bufferization is deactivated, do not
205   // allow ops that belong to the `func` dialect.
206   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
207   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
208     return false;
209 
210   return opFilter.isOpAllowed(op);
211 }
212 
213 BufferizableOpInterface
214 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
215   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
216   if (!bufferizableOp)
217     return nullptr;
218   if (!isOpAllowed(op))
219     return nullptr;
220   return bufferizableOp;
221 }
222 
223 BufferizableOpInterface
224 BufferizationOptions::dynCastBufferizableOp(Value value) const {
225   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
226     if (isOpAllowed(bufferizableOp.getOperation()))
227       return bufferizableOp;
228   return nullptr;
229 }
230 
231 void BufferizationOptions::addDialectStateInitializer(
232     StringRef name, const DialectStateInitFn &fn) {
233   stateInitializers.push_back(
234       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Helper functions for BufferizableOpInterface
239 //===----------------------------------------------------------------------===//
240 
241 static void setInsertionPointAfter(OpBuilder &b, Value value) {
242   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
243     b.setInsertionPointToStart(bbArg.getOwner());
244   } else {
245     b.setInsertionPointAfter(value.getDefiningOp());
246   }
247 }
248 
249 /// Determine which OpOperand* will alias with `result` if the op is bufferized
250 /// in place. Return an empty vector if the op is not bufferizable.
251 SmallVector<OpOperand *>
252 AnalysisState::getAliasingOpOperand(OpResult result) const {
253   if (Operation *op = result.getDefiningOp())
254     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
255       return bufferizableOp.getAliasingOpOperand(result, *this);
256   return {};
257 }
258 
259 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
260 /// in place. Return an empty vector if the op is not bufferizable.
261 SmallVector<OpResult>
262 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
263   if (auto bufferizableOp =
264           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
265     return bufferizableOp.getAliasingOpResult(opOperand, *this);
266   return {};
267 }
268 
269 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
270 /// op is not bufferizable.
271 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
272   if (auto bufferizableOp =
273           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
274     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
275 
276   // Unknown op that returns a tensor. The inplace analysis does not support it.
277   // Conservatively return true.
278   return true;
279 }
280 
281 /// Return true if `opOperand` bufferizes to a memory write. Return
282 /// `true` if the op is not bufferizable.
283 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
284   if (auto bufferizableOp =
285           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
286     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
287 
288   // Unknown op that returns a tensor. The inplace analysis does not support it.
289   // Conservatively return true.
290   return true;
291 }
292 
293 /// Return true if `opOperand` does neither read nor write but bufferizes to an
294 /// alias. Return false if the op is not bufferizable.
295 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
296   if (auto bufferizableOp =
297           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
298     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
299 
300   // Unknown op that returns a tensor. The inplace analysis does not support it.
301   // Conservatively return false.
302   return false;
303 }
304 
305 /// Return true if the given value is read by an op that bufferizes to a memory
306 /// read. Also takes into account ops that create an alias but do not read by
307 /// themselves (e.g., ExtractSliceOp).
308 bool AnalysisState::isValueRead(Value value) const {
309   assert(value.getType().isa<TensorType>() && "expected TensorType");
310   SmallVector<OpOperand *> workingSet;
311   for (OpOperand &use : value.getUses())
312     workingSet.push_back(&use);
313 
314   while (!workingSet.empty()) {
315     OpOperand *uMaybeReading = workingSet.pop_back_val();
316     // Skip over all ops that neither read nor write (but create an alias).
317     if (bufferizesToAliasOnly(*uMaybeReading))
318       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
319         for (OpOperand &use : opResult.getUses())
320           workingSet.push_back(&use);
321     if (bufferizesToMemoryRead(*uMaybeReading))
322       return true;
323   }
324 
325   return false;
326 }
327 
328 // Starting from `value`, follow the use-def chain in reverse, always selecting
329 // the aliasing OpOperands. Find and return Values for which `condition`
330 // evaluates to true. OpOperands of such matching Values are not traversed any
331 // further.
332 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
333     Value value, llvm::function_ref<bool(Value)> condition) const {
334   llvm::SetVector<Value> result, workingSet;
335   workingSet.insert(value);
336 
337   while (!workingSet.empty()) {
338     Value value = workingSet.pop_back_val();
339     if (condition(value) || value.isa<BlockArgument>()) {
340       result.insert(value);
341       continue;
342     }
343 
344     OpResult opResult = value.cast<OpResult>();
345     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
346     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
347       result.insert(value);
348       continue;
349     }
350 
351     for (OpOperand *o : opOperands)
352       workingSet.insert(o->get());
353   }
354 
355   return result;
356 }
357 
358 // Find the Values of the last preceding write of a given Value.
359 llvm::SetVector<Value>
360 AnalysisState::findLastPrecedingWrite(Value value) const {
361   return findValueInReverseUseDefChain(value, [&](Value value) {
362     Operation *op = value.getDefiningOp();
363     if (!op)
364       return true;
365     auto bufferizableOp = options.dynCastBufferizableOp(op);
366     if (!bufferizableOp)
367       return true;
368     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
369   });
370 }
371 
372 AnalysisState::AnalysisState(const BufferizationOptions &options)
373     : options(options) {
374   for (const BufferizationOptions::AnalysisStateInitFn &fn :
375        options.stateInitializers)
376     fn(*this);
377 }
378 
379 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
380   // Do not copy if the tensor has undefined contents.
381   if (hasUndefinedContents(&opOperand))
382     return true;
383 
384   // Do not copy if the buffer of the tensor is entirely overwritten (with
385   // values that do not depend on the old tensor).
386   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
387     return true;
388 
389   // Do not copy if the tensor is never read.
390   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
391   if (!bufferizesToMemoryRead(opOperand) &&
392       llvm::none_of(aliasingOpResults,
393                     [&](OpResult opResult) { return isValueRead(opResult); }))
394     return true;
395 
396   // Default: Cannot omit the copy.
397   return false;
398 }
399 
400 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
401   // ToMemrefOps are always in-place.
402   if (isa<ToMemrefOp>(opOperand.getOwner()))
403     return true;
404 
405   // In the absence of analysis information, OpOperands that bufferize to a
406   // memory write are out-of-place, i.e., an alloc and copy is inserted.
407   return !bufferizesToMemoryWrite(opOperand);
408 }
409 
410 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
411   // In the absence of analysis information, we do not know if the values are
412   // equivalent. The conservative answer is "false".
413   return false;
414 }
415 
416 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
417   // In the absence of analysis information, we do not know if the values may be
418   // aliasing. The conservative answer is "true".
419   return true;
420 }
421 
422 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
423   // In the absence of analysis information, the conservative answer is "false".
424   return false;
425 }
426 
427 bool AnalysisState::isTensorYielded(Value tensor) const {
428   // In the absence of analysis information, the conservative answer is "true".
429   if (!tensor.getDefiningOp<AllocTensorOp>())
430     return true;
431 
432   // For AllocTensorOp results, we can do better: They do not alias with any
433   // preceding value, so we can follow SSA use-def chains and do a simple
434   // analysis.
435   SmallVector<OpOperand *> worklist;
436   for (OpOperand &use : tensor.getUses())
437     worklist.push_back(&use);
438 
439   while (!worklist.empty()) {
440     OpOperand *operand = worklist.pop_back_val();
441     Operation *op = operand->getOwner();
442 
443     // If the op is not bufferizable, we can safely assume that the value is not
444     // yielded. (When bufferizing that op, it must handle such cases.)
445     if (!options.dynCastBufferizableOp(op))
446       continue;
447 
448     // We cannot analyze through ToMemrefOps, so we have to conservatively
449     // assume that the value is yielded.
450     if (isa<ToMemrefOp>(op))
451       return true;
452 
453     // Check if the op is returning/yielding.
454     if (isRegionReturnLike(op))
455       return true;
456 
457     // Add all aliasing OpResults to the worklist.
458     // Note: In the absence of detailed analysis information (e.g., there may be
459     // no function call analysis information), this `getAliasingOpResult` is
460     // conservative and may report additional OpResults as potentially aliasing.
461     for (OpResult opResult : getAliasingOpResult(*operand))
462       for (OpOperand &use : opResult.getUses())
463         worklist.push_back(&use);
464   }
465 
466   // No ReturnLike op found: The value is not yielded.
467   return false;
468 }
469 
470 // bufferization.to_memref is not allowed to change the rank.
471 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
472 #ifndef NDEBUG
473   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
474   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
475                                    rankedTensorType.getRank()) &&
476          "to_memref would be invalid: mismatching ranks");
477 #endif
478 }
479 
480 Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
481   auto tensorType = value.getType().dyn_cast<TensorType>();
482   assert(tensorType && "unexpected non-tensor type");
483 
484   // Replace "%t = to_tensor %m" with %m.
485   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
486     return toTensorOp.memref();
487 
488   // Insert to_memref op.
489   OpBuilder::InsertionGuard g(rewriter);
490   setInsertionPointAfter(rewriter, value);
491   Type memrefType = getMemRefType(tensorType, getOptions());
492   ensureToMemrefOpIsValid(value, memrefType);
493   return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
494                                                     value);
495 }
496 
497 /// Return the buffer type for a given Value (tensor) after bufferization.
498 BaseMemRefType BufferizationState::getBufferType(Value value) const {
499   auto tensorType = value.getType().dyn_cast<TensorType>();
500   assert(tensorType && "unexpected non-tensor type");
501 
502   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
503     return toTensorOp.memref().getType().cast<BaseMemRefType>();
504 
505   return getMemRefType(tensorType, getOptions());
506 }
507 
508 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
509                                                   Operation *op,
510                                                   ValueRange values) {
511   assert(values.size() == op->getNumResults() &&
512          "expected one value per OpResult");
513   OpBuilder::InsertionGuard g(rewriter);
514 
515   // Replace all OpResults with the given values.
516   SmallVector<Value> replacements;
517   for (OpResult opResult : op->getOpResults()) {
518     Value replacement = values[opResult.getResultNumber()];
519     if (opResult.getType().isa<TensorType>()) {
520       // The OpResult is a tensor. Such values are replaced with memrefs during
521       // bufferization.
522       assert((replacement.getType().isa<MemRefType>() ||
523               replacement.getType().isa<UnrankedMemRefType>()) &&
524              "tensor op result should be replaced with a memref value");
525       // The existing uses of the OpResult still expect a tensor. Insert a
526       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
527       // loose all of its users and eventually DCE away.
528       rewriter.setInsertionPointAfter(op);
529       replacement = rewriter.create<bufferization::ToTensorOp>(
530           replacement.getLoc(), replacement);
531     }
532     replacements.push_back(replacement);
533   }
534 
535   rewriter.replaceOp(op, replacements);
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // Bufferization-specific scoped alloc/dealloc insertion support.
540 //===----------------------------------------------------------------------===//
541 
542 /// Create a memref allocation with the given type and dynamic extents.
543 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
544                                                    MemRefType type,
545                                                    ValueRange dynShape) const {
546   if (allocationFn)
547     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
548 
549   // Default bufferallocation via AllocOp.
550   if (bufferAlignment != 0)
551     return b
552         .create<memref::AllocOp>(loc, type, dynShape,
553                                  b.getI64IntegerAttr(bufferAlignment))
554         .getResult();
555   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
556 }
557 
558 /// Creates a memref deallocation. The given memref buffer must have been
559 /// allocated using `createAlloc`.
560 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
561                                                   Value allocatedBuffer) const {
562   if (deallocationFn)
563     return (*deallocationFn)(b, loc, allocatedBuffer);
564 
565   // Default buffer deallocation via DeallocOp.
566   b.create<memref::DeallocOp>(loc, allocatedBuffer);
567   return success();
568 }
569 
570 /// Create a memory copy between two memref buffers.
571 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
572                                                  Value from, Value to) const {
573   if (memCpyFn)
574     return (*memCpyFn)(b, loc, from, to);
575 
576   b.create<memref::CopyOp>(loc, from, to);
577   return success();
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // Bufferization-specific BlockAndValueMapping support with debugging.
582 //===----------------------------------------------------------------------===//
583 
584 bool bufferization::isFunctionArgument(Value value) {
585   auto bbArg = value.dyn_cast<BlockArgument>();
586   if (!bbArg)
587     return false;
588   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
589 }
590 
591 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
592                                             const BufferizationOptions &options,
593                                             MemRefLayoutAttrInterface layout,
594                                             Attribute memorySpace) {
595   // Case 1: Unranked memref type.
596   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
597     assert(!layout && "UnrankedTensorType cannot have a layout map");
598     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
599                                    memorySpace);
600   }
601 
602   // Case 2: Ranked memref type with specified layout.
603   auto rankedTensorType = tensorType.cast<RankedTensorType>();
604   if (layout) {
605     return MemRefType::get(rankedTensorType.getShape(),
606                            rankedTensorType.getElementType(), layout,
607                            memorySpace);
608   }
609 
610   // Case 3: Configured with "fully dynamic layout maps".
611   if (options.unknownTypeConversion ==
612       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
613     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
614 
615   // Case 4: Configured with "static identity layout maps".
616   if (options.unknownTypeConversion ==
617       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
618     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
619 
620   llvm_unreachable("InferLayoutMap is an invalid option");
621 }
622 
623 BaseMemRefType
624 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
625                                                    Attribute memorySpace) {
626   // Case 1: Unranked memref type.
627   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
628     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
629                                    memorySpace);
630   }
631 
632   // Case 2: Ranked memref type.
633   auto rankedTensorType = tensorType.cast<RankedTensorType>();
634   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
635   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
636                                       ShapedType::kDynamicStrideOrOffset);
637   AffineMap stridedLayout = makeStridedLinearLayoutMap(
638       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
639   return MemRefType::get(rankedTensorType.getShape(),
640                          rankedTensorType.getElementType(), stridedLayout,
641                          memorySpace);
642 }
643 
644 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
645 /// the given tensor type is unranked, return an unranked MemRef type.
646 BaseMemRefType
647 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
648                                                      Attribute memorySpace) {
649   // Case 1: Unranked memref type.
650   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
651     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
652                                    memorySpace);
653   }
654 
655   // Case 2: Ranked memref type.
656   auto rankedTensorType = tensorType.cast<RankedTensorType>();
657   MemRefLayoutAttrInterface layout = {};
658   return MemRefType::get(rankedTensorType.getShape(),
659                          rankedTensorType.getElementType(), layout,
660                          memorySpace);
661 }
662