1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
47 /// shaped value is copied. Otherwise, a tensor with undefined contents is
48 /// allocated.
49 Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                                   Value shapedValue,
51                                                   bool escape, bool copy) {
52   Value tensor;
53   if (shapedValue.getType().isa<RankedTensorType>()) {
54     tensor = shapedValue;
55   } else if (shapedValue.getType().isa<MemRefType>()) {
56     tensor = b.create<ToTensorOp>(loc, shapedValue);
57   } else {
58     llvm_unreachable("expected RankedTensorType or MemRefType");
59   }
60   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
61   SmallVector<Value> dynamicSizes;
62   if (!copy) {
63     // Compute the dynamic part of the shape.
64     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
65     bool reifiedShapes = false;
66     if (shapedValue.getType().isa<RankedTensorType>() &&
67         shapedValue.isa<OpResult>()) {
68       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
69               shapedValue.getDefiningOp())) {
70         ReifiedRankedShapedTypeDims resultDims;
71         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
72           reifiedShapes = true;
73           auto &shape =
74               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
75           for (const auto &dim : enumerate(tensorType.getShape()))
76             if (ShapedType::isDynamic(dim.value()))
77               dynamicSizes.push_back(shape[dim.index()]);
78         }
79       }
80     }
81 
82     // If the shape could not be reified, create DimOps.
83     if (!reifiedShapes)
84       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
85   }
86 
87   return b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
88                                  copy ? tensor : Value(), escape);
89 }
90 
91 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
92     RewriterBase &rewriter, const AnalysisState &state) {
93   OpBuilder::InsertionGuard g(rewriter);
94   Operation *op = getOperation();
95   SmallVector<OpOperand *> outOfPlaceOpOperands;
96   DenseSet<OpOperand *> copiedOpOperands;
97   DenseSet<OpOperand *> escapingOpOperandCopies;
98   SmallVector<OpResult> outOfPlaceOpResults;
99   DenseSet<OpResult> copiedOpResults;
100   DenseSet<OpResult> escapingOpResultCopies;
101 
102   // Find all out-of-place OpOperands.
103   for (OpOperand &opOperand : op->getOpOperands()) {
104     Type operandType = opOperand.get().getType();
105     if (!operandType.isa<TensorType>())
106       continue;
107     if (state.isInPlace(opOperand))
108       continue;
109     if (operandType.isa<UnrankedTensorType>())
110       return op->emitError("copies of unranked tensors are not supported");
111 
112     SmallVector<OpResult> aliasingOpResults =
113         state.getAliasingOpResult(opOperand);
114     // Is the result yielded from a block? Or are deallocations turned off
115     // entirely? In either case, mark the allocation as "escaping", so that it
116     // will not be deallocated.
117     bool escape = !state.getOptions().createDeallocs ||
118                   llvm::any_of(aliasingOpResults, [&](Value v) {
119                     return state.isTensorYielded(v);
120                   });
121 
122     if (aliasingOpResults.size() == 1 &&
123         !state.bufferizesToMemoryWrite(opOperand) &&
124         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
125       // The op itself does not write but may create exactly one alias. Instead
126       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
127       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
128       // where the result is usually a smaller part of the source).
129       outOfPlaceOpResults.push_back(aliasingOpResults.front());
130       if (!state.canOmitTensorCopy(opOperand))
131         copiedOpResults.insert(aliasingOpResults.front());
132       if (escape)
133         escapingOpResultCopies.insert(aliasingOpResults.front());
134     } else {
135       // In all other cases, make a copy of the OpOperand.
136       outOfPlaceOpOperands.push_back(&opOperand);
137       if (!state.canOmitTensorCopy(opOperand))
138         copiedOpOperands.insert(&opOperand);
139       if (escape)
140         escapingOpOperandCopies.insert(&opOperand);
141     }
142   }
143 
144   // Insert copies of OpOperands.
145   rewriter.setInsertionPoint(op);
146   for (OpOperand *opOperand : outOfPlaceOpOperands) {
147     Value copy = allocateTensorForShapedValue(
148         rewriter, op->getLoc(), opOperand->get(),
149         escapingOpOperandCopies.contains(opOperand),
150         copiedOpOperands.contains(opOperand));
151     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
152   }
153 
154   // Insert copies of OpResults.
155   rewriter.setInsertionPointAfter(op);
156   for (OpResult opResult : outOfPlaceOpResults) {
157     Value copy =
158         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
159                                      escapingOpResultCopies.contains(opResult),
160                                      copiedOpResults.count(opResult));
161     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
162         opResult.getUses(), [](OpOperand &use) { return &use; }));
163     for (OpOperand *use : uses) {
164       // Do not update the alloc_tensor op that we just created.
165       if (use->getOwner() != copy.getDefiningOp())
166         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
167     }
168   }
169 
170   return success();
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // OpFilter
175 //===----------------------------------------------------------------------===//
176 
177 bool OpFilter::isOpAllowed(Operation *op) const {
178   // All other ops: Allow/disallow according to filter.
179   bool isAllowed = !hasAllowRule();
180   for (const Entry &entry : entries) {
181     bool filterResult = entry.fn(op);
182     switch (entry.type) {
183     case Entry::ALLOW:
184       isAllowed |= filterResult;
185       break;
186     case Entry::DENY:
187       if (filterResult)
188         // DENY filter matches. This op is no allowed. (Even if other ALLOW
189         // filters may match.)
190         return false;
191     };
192   }
193   return isAllowed;
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // BufferizationOptions
198 //===----------------------------------------------------------------------===//
199 
200 // Default constructor for BufferizationOptions.
201 BufferizationOptions::BufferizationOptions() = default;
202 
203 bool BufferizationOptions::isOpAllowed(Operation *op) const {
204   // Special case: If function boundary bufferization is deactivated, do not
205   // allow ops that belong to the `func` dialect.
206   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
207   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
208     return false;
209 
210   return opFilter.isOpAllowed(op);
211 }
212 
213 BufferizableOpInterface
214 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
215   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
216   if (!bufferizableOp)
217     return nullptr;
218   if (!isOpAllowed(op))
219     return nullptr;
220   return bufferizableOp;
221 }
222 
223 BufferizableOpInterface
224 BufferizationOptions::dynCastBufferizableOp(Value value) const {
225   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
226     if (isOpAllowed(bufferizableOp.getOperation()))
227       return bufferizableOp;
228   return nullptr;
229 }
230 
231 void BufferizationOptions::addDialectStateInitializer(
232     StringRef name, const DialectStateInitFn &fn) {
233   stateInitializers.push_back(
234       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Helper functions for BufferizableOpInterface
239 //===----------------------------------------------------------------------===//
240 
241 static void setInsertionPointAfter(OpBuilder &b, Value value) {
242   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
243     b.setInsertionPointToStart(bbArg.getOwner());
244   } else {
245     b.setInsertionPointAfter(value.getDefiningOp());
246   }
247 }
248 
249 /// Determine which OpOperand* will alias with `result` if the op is bufferized
250 /// in place. Return an empty vector if the op is not bufferizable.
251 SmallVector<OpOperand *>
252 AnalysisState::getAliasingOpOperand(OpResult result) const {
253   if (Operation *op = result.getDefiningOp())
254     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
255       return bufferizableOp.getAliasingOpOperand(result, *this);
256   return {};
257 }
258 
259 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
260 /// in place. Return an empty vector if the op is not bufferizable.
261 SmallVector<OpResult>
262 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
263   if (auto bufferizableOp =
264           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
265     return bufferizableOp.getAliasingOpResult(opOperand, *this);
266   return {};
267 }
268 
269 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
270 /// op is not bufferizable.
271 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
272   if (auto bufferizableOp =
273           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
274     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
275 
276   // Unknown op that returns a tensor. The inplace analysis does not support it.
277   // Conservatively return true.
278   return true;
279 }
280 
281 /// Return true if `opOperand` bufferizes to a memory write. Return
282 /// `true` if the op is not bufferizable.
283 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
284   if (auto bufferizableOp =
285           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
286     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
287 
288   // Unknown op that returns a tensor. The inplace analysis does not support it.
289   // Conservatively return true.
290   return true;
291 }
292 
293 /// Return true if `opOperand` does neither read nor write but bufferizes to an
294 /// alias. Return false if the op is not bufferizable.
295 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
296   if (auto bufferizableOp =
297           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
298     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
299 
300   // Unknown op that returns a tensor. The inplace analysis does not support it.
301   // Conservatively return false.
302   return false;
303 }
304 
305 /// Return true if the given value is read by an op that bufferizes to a memory
306 /// read. Also takes into account ops that create an alias but do not read by
307 /// themselves (e.g., ExtractSliceOp).
308 bool AnalysisState::isValueRead(Value value) const {
309   assert(value.getType().isa<TensorType>() && "expected TensorType");
310   SmallVector<OpOperand *> workingSet;
311   for (OpOperand &use : value.getUses())
312     workingSet.push_back(&use);
313 
314   while (!workingSet.empty()) {
315     OpOperand *uMaybeReading = workingSet.pop_back_val();
316     // Skip over all ops that neither read nor write (but create an alias).
317     if (bufferizesToAliasOnly(*uMaybeReading))
318       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
319         for (OpOperand &use : opResult.getUses())
320           workingSet.push_back(&use);
321     if (bufferizesToMemoryRead(*uMaybeReading))
322       return true;
323   }
324 
325   return false;
326 }
327 
328 // Starting from `value`, follow the use-def chain in reverse, always selecting
329 // the aliasing OpOperands. Find and return Values for which `condition`
330 // evaluates to true. OpOperands of such matching Values are not traversed any
331 // further.
332 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
333     Value value, llvm::function_ref<bool(Value)> condition) const {
334   llvm::SetVector<Value> result, workingSet;
335   workingSet.insert(value);
336 
337   while (!workingSet.empty()) {
338     Value value = workingSet.pop_back_val();
339     if (condition(value) || value.isa<BlockArgument>()) {
340       result.insert(value);
341       continue;
342     }
343 
344     OpResult opResult = value.cast<OpResult>();
345     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
346     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
347       result.insert(value);
348       continue;
349     }
350 
351     for (OpOperand *o : opOperands)
352       workingSet.insert(o->get());
353   }
354 
355   return result;
356 }
357 
358 // Find the Values of the last preceding write of a given Value.
359 llvm::SetVector<Value>
360 AnalysisState::findLastPrecedingWrite(Value value) const {
361   return findValueInReverseUseDefChain(value, [&](Value value) {
362     Operation *op = value.getDefiningOp();
363     if (!op)
364       return true;
365     auto bufferizableOp = options.dynCastBufferizableOp(op);
366     if (!bufferizableOp)
367       return true;
368     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
369   });
370 }
371 
372 AnalysisState::AnalysisState(const BufferizationOptions &options)
373     : options(options) {
374   for (const BufferizationOptions::AnalysisStateInitFn &fn :
375        options.stateInitializers)
376     fn(*this);
377 }
378 
379 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
380   // Do not copy if the tensor has undefined contents.
381   if (hasUndefinedContents(&opOperand))
382     return true;
383 
384   // Do not copy if the buffer of the tensor is entirely overwritten (with
385   // values that do not depend on the old tensor).
386   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
387     return true;
388 
389   // Do not copy if the tensor is never read.
390   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
391   if (!bufferizesToMemoryRead(opOperand) &&
392       llvm::none_of(aliasingOpResults,
393                     [&](OpResult opResult) { return isValueRead(opResult); }))
394     return true;
395 
396   // Default: Cannot omit the copy.
397   return false;
398 }
399 
400 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
401   // ToMemrefOps are always in-place.
402   if (isa<ToMemrefOp>(opOperand.getOwner()))
403     return true;
404 
405   // In the absence of analysis information, OpOperands that bufferize to a
406   // memory write are out-of-place, i.e., an alloc and copy is inserted.
407   return !bufferizesToMemoryWrite(opOperand);
408 }
409 
410 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
411   // In the absence of analysis information, we do not know if the values are
412   // equivalent. The conservative answer is "false".
413   return false;
414 }
415 
416 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
417   // In the absence of analysis information, we do not know if the values may be
418   // aliasing. The conservative answer is "true".
419   return true;
420 }
421 
422 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
423   // In the absence of analysis information, the conservative answer is "false".
424   return false;
425 }
426 
427 bool AnalysisState::isTensorYielded(Value tensor) const {
428   // In the absence of analysis information, the conservative answer is "true".
429   if (!tensor.getDefiningOp<AllocTensorOp>())
430     return true;
431 
432   // For AllocTensorOp results, we can do better: They do not alias with any
433   // preceding value, so we can follow SSA use-def chains and do a simple
434   // analysis.
435   SmallVector<OpOperand *> worklist;
436   for (OpOperand &use : tensor.getUses())
437     worklist.push_back(&use);
438 
439   while (!worklist.empty()) {
440     OpOperand *operand = worklist.pop_back_val();
441     Operation *op = operand->getOwner();
442 
443     // If the op is not bufferizable, we can safely assume that the value is not
444     // yielded. (When bufferizing that op, it must handle such cases.)
445     if (!options.dynCastBufferizableOp(op))
446       continue;
447 
448     // We cannot analyze through ToMemrefOps, so we have to conservatively
449     // assume that the value is yielded.
450     if (isa<ToMemrefOp>(op))
451       return true;
452 
453     // Check if the op is returning/yielding.
454     if (isRegionReturnLike(op))
455       return true;
456 
457     // Add all aliasing OpResults to the worklist.
458     // Note: In the absence of detailed analysis information (e.g., there may be
459     // no function call analysis information), this `getAliasingOpResult` is
460     // conservative and may report additional OpResults as potentially aliasing.
461     for (OpResult opResult : getAliasingOpResult(*operand))
462       for (OpOperand &use : opResult.getUses())
463         worklist.push_back(&use);
464   }
465 
466   // No ReturnLike op found: The value is not yielded.
467   return false;
468 }
469 
470 // bufferization.to_memref is not allowed to change the rank.
471 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
472 #ifndef NDEBUG
473   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
474   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
475                                    rankedTensorType.getRank()) &&
476          "to_memref would be invalid: mismatching ranks");
477 #endif
478 }
479 
480 Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
481                                const BufferizationOptions &options) {
482   auto tensorType = value.getType().dyn_cast<TensorType>();
483   assert(tensorType && "unexpected non-tensor type");
484 
485   // Replace "%t = to_tensor %m" with %m.
486   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
487     return toTensorOp.getMemref();
488 
489   // Insert to_memref op.
490   OpBuilder::InsertionGuard g(rewriter);
491   setInsertionPointAfter(rewriter, value);
492   Type memrefType = getMemRefType(tensorType, options);
493   ensureToMemrefOpIsValid(value, memrefType);
494   return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
495                                                     value);
496 }
497 
498 /// Return the buffer type for a given Value (tensor) after bufferization.
499 BaseMemRefType
500 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
501   auto tensorType = value.getType().dyn_cast<TensorType>();
502   assert(tensorType && "unexpected non-tensor type");
503 
504   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
505     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
506 
507   return getMemRefType(tensorType, options);
508 }
509 
510 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
511                                                   Operation *op,
512                                                   ValueRange values) {
513   assert(values.size() == op->getNumResults() &&
514          "expected one value per OpResult");
515   OpBuilder::InsertionGuard g(rewriter);
516 
517   // Replace all OpResults with the given values.
518   SmallVector<Value> replacements;
519   for (OpResult opResult : op->getOpResults()) {
520     Value replacement = values[opResult.getResultNumber()];
521     if (opResult.getType().isa<TensorType>()) {
522       // The OpResult is a tensor. Such values are replaced with memrefs during
523       // bufferization.
524       assert((replacement.getType().isa<MemRefType>() ||
525               replacement.getType().isa<UnrankedMemRefType>()) &&
526              "tensor op result should be replaced with a memref value");
527       // The existing uses of the OpResult still expect a tensor. Insert a
528       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
529       // loose all of its users and eventually DCE away.
530       rewriter.setInsertionPointAfter(op);
531       replacement = rewriter.create<bufferization::ToTensorOp>(
532           replacement.getLoc(), replacement);
533     }
534     replacements.push_back(replacement);
535   }
536 
537   rewriter.replaceOp(op, replacements);
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // Bufferization-specific scoped alloc/dealloc insertion support.
542 //===----------------------------------------------------------------------===//
543 
544 /// Create a memref allocation with the given type and dynamic extents.
545 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
546                                                    MemRefType type,
547                                                    ValueRange dynShape) const {
548   if (allocationFn)
549     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
550 
551   // Default bufferallocation via AllocOp.
552   if (bufferAlignment != 0)
553     return b
554         .create<memref::AllocOp>(loc, type, dynShape,
555                                  b.getI64IntegerAttr(bufferAlignment))
556         .getResult();
557   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
558 }
559 
560 /// Creates a memref deallocation. The given memref buffer must have been
561 /// allocated using `createAlloc`.
562 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
563                                                   Value allocatedBuffer) const {
564   if (deallocationFn)
565     return (*deallocationFn)(b, loc, allocatedBuffer);
566 
567   // Default buffer deallocation via DeallocOp.
568   b.create<memref::DeallocOp>(loc, allocatedBuffer);
569   return success();
570 }
571 
572 /// Create a memory copy between two memref buffers.
573 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
574                                                  Value from, Value to) const {
575   if (memCpyFn)
576     return (*memCpyFn)(b, loc, from, to);
577 
578   b.create<memref::CopyOp>(loc, from, to);
579   return success();
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // Bufferization-specific BlockAndValueMapping support with debugging.
584 //===----------------------------------------------------------------------===//
585 
586 bool bufferization::isFunctionArgument(Value value) {
587   auto bbArg = value.dyn_cast<BlockArgument>();
588   if (!bbArg)
589     return false;
590   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
591 }
592 
593 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
594                                             const BufferizationOptions &options,
595                                             MemRefLayoutAttrInterface layout,
596                                             Attribute memorySpace) {
597   // Case 1: Unranked memref type.
598   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
599     assert(!layout && "UnrankedTensorType cannot have a layout map");
600     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
601                                    memorySpace);
602   }
603 
604   // Case 2: Ranked memref type with specified layout.
605   auto rankedTensorType = tensorType.cast<RankedTensorType>();
606   if (layout) {
607     return MemRefType::get(rankedTensorType.getShape(),
608                            rankedTensorType.getElementType(), layout,
609                            memorySpace);
610   }
611 
612   // Case 3: Configured with "fully dynamic layout maps".
613   if (options.unknownTypeConversion ==
614       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
615     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
616 
617   // Case 4: Configured with "static identity layout maps".
618   if (options.unknownTypeConversion ==
619       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
620     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
621 
622   llvm_unreachable("InferLayoutMap is an invalid option");
623 }
624 
625 BaseMemRefType
626 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
627                                                    Attribute memorySpace) {
628   // Case 1: Unranked memref type.
629   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
630     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
631                                    memorySpace);
632   }
633 
634   // Case 2: Ranked memref type.
635   auto rankedTensorType = tensorType.cast<RankedTensorType>();
636   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
637   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
638                                       ShapedType::kDynamicStrideOrOffset);
639   AffineMap stridedLayout = makeStridedLinearLayoutMap(
640       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
641   return MemRefType::get(rankedTensorType.getShape(),
642                          rankedTensorType.getElementType(), stridedLayout,
643                          memorySpace);
644 }
645 
646 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
647 /// the given tensor type is unranked, return an unranked MemRef type.
648 BaseMemRefType
649 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
650                                                      Attribute memorySpace) {
651   // Case 1: Unranked memref type.
652   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
653     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
654                                    memorySpace);
655   }
656 
657   // Case 2: Ranked memref type.
658   auto rankedTensorType = tensorType.cast<RankedTensorType>();
659   MemRefLayoutAttrInterface layout = {};
660   return MemRefType::get(rankedTensorType.getShape(),
661                          rankedTensorType.getElementType(), layout,
662                          memorySpace);
663 }
664