1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 
24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
25 
26 } // namespace bufferization
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "bufferizable-op-interface"
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
32 
33 using namespace mlir;
34 using namespace bufferization;
35 
36 /// Attribute name used to mark region arguments that can be bufferized
37 /// in-place during linalg comprehensive bufferization.
38 constexpr const ::llvm::StringLiteral
39     bufferization::BufferizableOpInterface::kInplaceableAttrName;
40 
41 /// Attribute name used to mark allocs that are created by the bufferization.
42 static const char *kBufferAllocationAttr = "bufferization.allocation";
43 
44 /// Attribute name used to mark allocs that should not be deallocated.
45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
46 
47 //===----------------------------------------------------------------------===//
48 // OpFilter
49 //===----------------------------------------------------------------------===//
50 
51 bool OpFilter::isOpAllowed(Operation *op) const {
52   // All other ops: Allow/disallow according to filter.
53   bool isAllowed = !hasAllowRule();
54   for (const Entry &entry : entries) {
55     bool filterResult = entry.fn(op);
56     switch (entry.type) {
57     case Entry::ALLOW:
58       isAllowed |= filterResult;
59       break;
60     case Entry::DENY:
61       if (filterResult)
62         // DENY filter matches. This op is no allowed. (Even if other ALLOW
63         // filters may match.)
64         return false;
65     };
66   }
67   return isAllowed;
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // BufferizationOptions
72 //===----------------------------------------------------------------------===//
73 
74 // Default constructor for BufferizationOptions.
75 BufferizationOptions::BufferizationOptions() = default;
76 
77 bool BufferizationOptions::isOpAllowed(Operation *op) const {
78   // Special case: If function boundary bufferization is deactivated, do not
79   // allow ops that belong to the `func` dialect.
80   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
81   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
82     return false;
83 
84   return opFilter.isOpAllowed(op);
85 }
86 
87 BufferizableOpInterface
88 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
89   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
90   if (!bufferizableOp)
91     return nullptr;
92   if (!isOpAllowed(op))
93     return nullptr;
94   return bufferizableOp;
95 }
96 
97 BufferizableOpInterface
98 BufferizationOptions::dynCastBufferizableOp(Value value) const {
99   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
100     if (isOpAllowed(bufferizableOp.getOperation()))
101       return bufferizableOp;
102   return nullptr;
103 }
104 
105 void BufferizationOptions::addDialectStateInitializer(
106     StringRef name, const DialectStateInitFn &fn) {
107   stateInitializers.push_back(
108       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Helper functions for BufferizableOpInterface
113 //===----------------------------------------------------------------------===//
114 
115 static void setInsertionPointAfter(OpBuilder &b, Value value) {
116   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
117     b.setInsertionPointToStart(bbArg.getOwner());
118   } else {
119     b.setInsertionPointAfter(value.getDefiningOp());
120   }
121 }
122 
123 /// Determine which OpOperand* will alias with `result` if the op is bufferized
124 /// in place. Return an empty vector if the op is not bufferizable.
125 SmallVector<OpOperand *>
126 AnalysisState::getAliasingOpOperand(OpResult result) const {
127   if (Operation *op = result.getDefiningOp())
128     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
129       return bufferizableOp.getAliasingOpOperand(result, *this);
130   return {};
131 }
132 
133 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
134 /// in place. Return an empty vector if the op is not bufferizable.
135 SmallVector<OpResult>
136 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
137   if (auto bufferizableOp =
138           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
139     return bufferizableOp.getAliasingOpResult(opOperand, *this);
140   return {};
141 }
142 
143 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
144 /// op is not bufferizable.
145 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
146   if (auto bufferizableOp =
147           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
148     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
149 
150   // Unknown op that returns a tensor. The inplace analysis does not support it.
151   // Conservatively return true.
152   return true;
153 }
154 
155 /// Return true if `opOperand` bufferizes to a memory write. Return
156 /// `true` if the op is not bufferizable.
157 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
158   if (auto bufferizableOp =
159           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
160     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
161 
162   // Unknown op that returns a tensor. The inplace analysis does not support it.
163   // Conservatively return true.
164   return true;
165 }
166 
167 /// Return true if `opOperand` does neither read nor write but bufferizes to an
168 /// alias. Return false if the op is not bufferizable.
169 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
170   if (auto bufferizableOp =
171           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
172     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
173 
174   // Unknown op that returns a tensor. The inplace analysis does not support it.
175   // Conservatively return false.
176   return false;
177 }
178 
179 /// Return true if the given value is read by an op that bufferizes to a memory
180 /// read. Also takes into account ops that create an alias but do not read by
181 /// themselves (e.g., ExtractSliceOp).
182 bool AnalysisState::isValueRead(Value value) const {
183   assert(value.getType().isa<TensorType>() && "expected TensorType");
184   SmallVector<OpOperand *> workingSet;
185   for (OpOperand &use : value.getUses())
186     workingSet.push_back(&use);
187 
188   while (!workingSet.empty()) {
189     OpOperand *uMaybeReading = workingSet.pop_back_val();
190     // Skip over all ops that neither read nor write (but create an alias).
191     if (bufferizesToAliasOnly(*uMaybeReading))
192       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
193         for (OpOperand &use : opResult.getUses())
194           workingSet.push_back(&use);
195     if (bufferizesToMemoryRead(*uMaybeReading))
196       return true;
197   }
198 
199   return false;
200 }
201 
202 // Starting from `value`, follow the use-def chain in reverse, always selecting
203 // the aliasing OpOperands. Find and return Values for which `condition`
204 // evaluates to true. OpOperands of such matching Values are not traversed any
205 // further.
206 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
207     Value value, llvm::function_ref<bool(Value)> condition) const {
208   llvm::SetVector<Value> result, workingSet;
209   workingSet.insert(value);
210 
211   while (!workingSet.empty()) {
212     Value value = workingSet.pop_back_val();
213     if (condition(value) || value.isa<BlockArgument>()) {
214       result.insert(value);
215       continue;
216     }
217 
218     OpResult opResult = value.cast<OpResult>();
219     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
220     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
221       result.insert(value);
222       continue;
223     }
224 
225     for (OpOperand *o : opOperands)
226       workingSet.insert(o->get());
227   }
228 
229   return result;
230 }
231 
232 // Find the Values of the last preceding write of a given Value.
233 llvm::SetVector<Value>
234 AnalysisState::findLastPrecedingWrite(Value value) const {
235   return findValueInReverseUseDefChain(value, [&](Value value) {
236     Operation *op = value.getDefiningOp();
237     if (!op)
238       return true;
239     auto bufferizableOp = options.dynCastBufferizableOp(op);
240     if (!bufferizableOp)
241       return true;
242     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
243   });
244 }
245 
246 AnalysisState::AnalysisState(const BufferizationOptions &options)
247     : options(options) {
248   for (const BufferizationOptions::AnalysisStateInitFn &fn :
249        options.stateInitializers)
250     fn(*this);
251 }
252 
253 // bufferization.to_memref is not allowed to change the rank.
254 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
255 #ifndef NDEBUG
256   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
257   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
258                                    rankedTensorType.getRank()) &&
259          "to_memref would be invalid: mismatching ranks");
260 #endif
261 }
262 
263 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
264                                         const BufferizationOptions &options) {
265   auto tensorType = tensor.getType().dyn_cast<TensorType>();
266   assert(tensorType && "unexpected non-tensor type");
267 
268   // Replace "%t = to_tensor %m" with %m.
269   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
270     return toTensorOp.memref();
271 
272   // Insert to_memref op.
273   OpBuilder::InsertionGuard g(rewriter);
274   setInsertionPointAfter(rewriter, tensor);
275   Type memrefType = getMemRefType(tensorType, options);
276   ensureToMemrefOpIsValid(tensor, memrefType);
277   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
278                                                     tensor);
279 }
280 
281 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
282 /// a new buffer and copy over data from the existing buffer if out-of-place
283 /// bufferization was decided.
284 FailureOr<Value>
285 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
286                               Optional<ForceInPlacability> overrideInPlace,
287                               Optional<Operation *> customCopyInsertionPoint) {
288   const BufferizationOptions &options = analysisState.getOptions();
289   OpBuilder::InsertionGuard guard(rewriter);
290   Operation *op = opOperand.getOwner();
291   Location loc = op->getLoc();
292   SmallVector<OpResult> aliasingOpResults =
293       analysisState.getAliasingOpResult(opOperand);
294   Value operand = opOperand.get();
295   Value operandBuffer = lookupBuffer(rewriter, operand, options);
296 
297   // Can `operandBuffer` be used directly or do we need a copy?
298   bool inplace =
299       overrideInPlace != FORCE_OUT_OF_PLACE &&
300       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
301   if (inplace)
302     return operandBuffer;
303 
304   // Bufferizing out-of-place: Allocate a new buffer.
305   // Move insertion point right after `operandBuffer`. That is where the
306   // allocation should be inserted (in the absence of allocation hoisting).
307   setInsertionPointAfter(rewriter, operandBuffer);
308   // Allocate the result buffer. The buffer should be deallocated if the tensor
309   // is not yielded and deallocs are enabled in general.
310   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
311     return getAnalysisState().isTensorYielded(v);
312   });
313   FailureOr<Value> resultBuffer = createAlloc(
314       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
315   if (failed(resultBuffer))
316     return failure();
317   // Do not copy the buffer if its contents are undefined.
318   if (analysisState.hasUndefinedContents(&opOperand))
319     return resultBuffer;
320   // Do not copy if the copied data is never read.
321   if (!aliasingOpResults.empty() &&
322       !analysisState.bufferizesToMemoryRead(opOperand) &&
323       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
324         return analysisState.isValueRead(opResult);
325       }))
326     return resultBuffer;
327   // Do not copy if this op does not read the data, but writes it.
328   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
329       !analysisState.bufferizesToMemoryRead(opOperand))
330     return resultBuffer;
331 
332   if (customCopyInsertionPoint) {
333     rewriter.setInsertionPoint(*customCopyInsertionPoint);
334   } else {
335     // The copy happens right before the op that is bufferized.
336     rewriter.setInsertionPoint(op);
337   }
338   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
339     return failure();
340 
341   return resultBuffer;
342 }
343 
344 /// Return the buffer type for a given Value (tensor) after bufferization.
345 BaseMemRefType BufferizationState::getBufferType(Value value) const {
346   auto tensorType = value.getType().dyn_cast<TensorType>();
347   assert(tensorType && "unexpected non-tensor type");
348 
349   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
350     return toTensorOp.memref().getType().cast<BaseMemRefType>();
351 
352   return getMemRefType(tensorType, getOptions());
353 }
354 
355 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
356                                                   Operation *op,
357                                                   ValueRange values) {
358   assert(values.size() == op->getNumResults() &&
359          "expected one value per OpResult");
360   OpBuilder::InsertionGuard g(rewriter);
361 
362   // Replace all OpResults with the given values.
363   SmallVector<Value> replacements;
364   for (OpResult opResult : op->getOpResults()) {
365     Value replacement = values[opResult.getResultNumber()];
366     if (opResult.getType().isa<TensorType>()) {
367       // The OpResult is a tensor. Such values are replaced with memrefs during
368       // bufferization.
369       assert((replacement.getType().isa<MemRefType>() ||
370               replacement.getType().isa<UnrankedMemRefType>()) &&
371              "tensor op result should be replaced with a memref value");
372       // The existing uses of the OpResult still expect a tensor. Insert a
373       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
374       // loose all of its users and eventually DCE away.
375       rewriter.setInsertionPointAfter(op);
376       replacement = rewriter.create<bufferization::ToTensorOp>(
377           replacement.getLoc(), replacement);
378     }
379     replacements.push_back(replacement);
380   }
381 
382   rewriter.replaceOp(op, replacements);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Bufferization-specific scoped alloc/dealloc insertion support.
387 //===----------------------------------------------------------------------===//
388 
389 /// Create a memref allocation with the given type and dynamic extents.
390 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
391                                                    MemRefType type,
392                                                    ValueRange dynShape) const {
393   if (allocationFn)
394     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
395 
396   // Default bufferallocation via AllocOp.
397   Value allocated = b.create<memref::AllocOp>(
398       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
399   return allocated;
400 }
401 
402 /// Creates a memref deallocation. The given memref buffer must have been
403 /// allocated using `createAlloc`.
404 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
405                                                   Value allocatedBuffer) const {
406   if (deallocationFn)
407     return (*deallocationFn)(b, loc, allocatedBuffer);
408 
409   // Default buffer deallocation via DeallocOp.
410   b.create<memref::DeallocOp>(loc, allocatedBuffer);
411   return success();
412 }
413 
414 static MemRefType getContiguousMemRefType(ShapedType shapedType,
415                                           Attribute memorySpace = {}) {
416   MemRefLayoutAttrInterface layout = {};
417   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
418                          layout, memorySpace);
419 }
420 
421 /// Compute the type of the `memref` to use for allocating the buffer for
422 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
423 /// dynamic dimensions in the returned `memref` type.
424 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
425                                             Value shapedValue,
426                                             SmallVectorImpl<Value> &dynShape) {
427   MemRefType allocMemRefType =
428       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
429 
430   // Compute the dynamic part of the shape.
431   bool reifiedShapes = false;
432   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
433           shapedValue.getDefiningOp())) {
434     ReifiedRankedShapedTypeDims resultDims;
435     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
436       reifiedShapes = true;
437       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
438       auto &shape = resultDims[resultValue.getResultNumber()];
439       for (const auto &dim : enumerate(allocMemRefType.getShape()))
440         if (ShapedType::isDynamic(dim.value()))
441           dynShape.push_back(shape[dim.index()]);
442     }
443   }
444 
445   if (!reifiedShapes) {
446     for (const auto &dim : enumerate(allocMemRefType.getShape()))
447       if (ShapedType::isDynamic(dim.value())) {
448         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
449                 shapedValue.getType().isa<MemRefType>()) &&
450                "expected MemRef type");
451         dynShape.push_back(
452             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
453       }
454   }
455 
456   return allocMemRefType;
457 }
458 
459 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
460                                     ValueRange dynShape, bool skipDealloc) {
461   auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
462   allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
463   if (skipDealloc)
464     allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
465   return allocaOp.getResult();
466 }
467 
468 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
469 /// block in case of a bbArg).
470 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
471                                                  Value shapedValue,
472                                                  Optional<bool> dealloc) {
473   // Take a guard before anything else.
474   OpBuilder::InsertionGuard g(b);
475 
476   // Compute allocation memref type.
477   assert(shapedValue.getType().isa<ShapedType>());
478   SmallVector<Value> dynShape;
479   MemRefType allocMemRefType =
480       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
481 
482   // Should be the buffer be deallocated again or should we let it leak?
483   bool skipDealloc;
484   if (dealloc) {
485     skipDealloc = !dealloc.getValue();
486   } else {
487     assert(shapedValue.getType().isa<TensorType>() &&
488            "must specify `dealloc` if non-tensor value is passed");
489     // Buffer should be not be deallocated if deallocs are generally deactivated
490     // or if the tensor is yielded from a block.
491     skipDealloc = !getOptions().createDeallocs ||
492                   getAnalysisState().isTensorYielded(shapedValue);
493   }
494 
495   // Create the buffer allocation.
496   return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
497 }
498 
499 /// Create a memory copy between two memref buffers.
500 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
501                                                  Value from, Value to) const {
502   if (memCpyFn)
503     return (*memCpyFn)(b, loc, from, to);
504 
505   b.create<memref::CopyOp>(loc, from, to);
506   return success();
507 }
508 
509 LogicalResult
510 bufferization::createAllocDeallocOps(Operation *op,
511                                      const BufferizationOptions &options,
512                                      bool onlyLeakingAllocs, bool *changed) {
513   IRRewriter rewriter(op->getContext());
514   if (changed)
515     *changed = false;
516 
517   // Bufferization creates memref.alloca ops. After bufferization, these must be
518   // rewritten to alloc/dealloc ops as specified in the bufferization options.
519   WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
520     // Ignore memref.alloca ops that were not created by the bufferization.
521     if (!allocaOp->hasAttr(kBufferAllocationAttr))
522       return WalkResult::skip();
523     // If `onlyLeakingAllocs`, process only ops that are marked as
524     // "skip dealloc".
525     bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
526     if (onlyLeakingAllocs && !skipDealloc)
527       return WalkResult::skip();
528 
529     // Create alloc.
530     Block *block = allocaOp->getBlock();
531     rewriter.setInsertionPoint(allocaOp);
532     FailureOr<Value> alloc =
533         options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
534                             allocaOp.dynamicSizes());
535     if (failed(alloc))
536       return WalkResult::interrupt();
537     rewriter.replaceOp(allocaOp, *alloc);
538     if (changed)
539       *changed = true;
540 
541     // Stop here if the buffer should not be deallocated.
542     if (skipDealloc)
543       return WalkResult::advance();
544 
545     // Create dealloc.
546     rewriter.setInsertionPoint(block->getTerminator());
547     if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc)))
548       return WalkResult::interrupt();
549 
550     return WalkResult::advance();
551   });
552 
553   return success(!status.wasInterrupted());
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // Bufferization-specific BlockAndValueMapping support with debugging.
558 //===----------------------------------------------------------------------===//
559 
560 bool bufferization::isFunctionArgument(Value value) {
561   auto bbArg = value.dyn_cast<BlockArgument>();
562   if (!bbArg)
563     return false;
564   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
565 }
566 
567 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
568                                             const BufferizationOptions &options,
569                                             MemRefLayoutAttrInterface layout,
570                                             Attribute memorySpace) {
571   // Case 1: Unranked memref type.
572   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
573     assert(!layout && "UnrankedTensorType cannot have a layout map");
574     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
575                                    memorySpace);
576   }
577 
578   // Case 2: Ranked memref type with specified layout.
579   auto rankedTensorType = tensorType.cast<RankedTensorType>();
580   if (layout) {
581     return MemRefType::get(rankedTensorType.getShape(),
582                            rankedTensorType.getElementType(), layout,
583                            memorySpace);
584   }
585 
586   // Case 3: Configured with "fully dynamic layout maps".
587   if (options.unknownTypeConversion ==
588       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
589     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
590 
591   // Case 4: Configured with "static identity layout maps".
592   if (options.unknownTypeConversion ==
593       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
594     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
595 
596   llvm_unreachable("InferLayoutMap is an invalid option");
597 }
598 
599 BaseMemRefType
600 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
601                                                    Attribute memorySpace) {
602   // Case 1: Unranked memref type.
603   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
604     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
605                                    memorySpace);
606   }
607 
608   // Case 2: Ranked memref type.
609   auto rankedTensorType = tensorType.cast<RankedTensorType>();
610   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
611   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
612                                       ShapedType::kDynamicStrideOrOffset);
613   AffineMap stridedLayout = makeStridedLinearLayoutMap(
614       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
615   return MemRefType::get(rankedTensorType.getShape(),
616                          rankedTensorType.getElementType(), stridedLayout,
617                          memorySpace);
618 }
619 
620 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
621 /// the given tensor type is unranked, return an unranked MemRef type.
622 BaseMemRefType
623 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
624                                                      Attribute memorySpace) {
625   // Case 1: Unranked memref type.
626   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
627     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
628                                    memorySpace);
629   }
630 
631   // Case 2: Ranked memref type.
632   auto rankedTensorType = tensorType.cast<RankedTensorType>();
633   MemRefLayoutAttrInterface layout = {};
634   return MemRefType::get(rankedTensorType.getShape(),
635                          rankedTensorType.getElementType(), layout,
636                          memorySpace);
637 }
638