1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 //===----------------------------------------------------------------------===//
22 // BufferizableOpInterface
23 //===----------------------------------------------------------------------===//
24 
25 namespace mlir {
26 namespace bufferization {
27 
28 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
29 
30 } // namespace bufferization
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "bufferizable-op-interface"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
36 
37 using namespace mlir;
38 using namespace bufferization;
39 
40 /// Attribute name used to mark region arguments that can be bufferized
41 /// in-place during linalg comprehensive bufferization.
42 constexpr const ::llvm::StringLiteral
43     bufferization::BufferizableOpInterface::kInplaceableAttrName;
44 
45 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
46     RewriterBase &rewriter, const AnalysisState &state) {
47   Operation *op = getOperation();
48   for (OpOperand &opOperand : op->getOpOperands()) {
49     Type operandType = opOperand.get().getType();
50     if (!operandType.isa<TensorType>())
51       continue;
52     if (state.isInPlace(opOperand))
53       continue;
54     if (operandType.isa<UnrankedTensorType>())
55       return op->emitError("copies of unranked tensors are not supported");
56     auto tensorType = operandType.dyn_cast<RankedTensorType>();
57     if (!tensorType)
58       continue;
59     SmallVector<OpResult> aliasingOpResults =
60         state.getAliasingOpResult(opOperand);
61     bool escape = llvm::any_of(
62         aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
63     Value copy = rewriter.create<AllocTensorOp>(
64         op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
65     rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
66   }
67   return success();
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // OpFilter
72 //===----------------------------------------------------------------------===//
73 
74 bool OpFilter::isOpAllowed(Operation *op) const {
75   // All other ops: Allow/disallow according to filter.
76   bool isAllowed = !hasAllowRule();
77   for (const Entry &entry : entries) {
78     bool filterResult = entry.fn(op);
79     switch (entry.type) {
80     case Entry::ALLOW:
81       isAllowed |= filterResult;
82       break;
83     case Entry::DENY:
84       if (filterResult)
85         // DENY filter matches. This op is no allowed. (Even if other ALLOW
86         // filters may match.)
87         return false;
88     };
89   }
90   return isAllowed;
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // BufferizationOptions
95 //===----------------------------------------------------------------------===//
96 
97 // Default constructor for BufferizationOptions.
98 BufferizationOptions::BufferizationOptions() = default;
99 
100 bool BufferizationOptions::isOpAllowed(Operation *op) const {
101   // Special case: If function boundary bufferization is deactivated, do not
102   // allow ops that belong to the `func` dialect.
103   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
104   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
105     return false;
106 
107   return opFilter.isOpAllowed(op);
108 }
109 
110 BufferizableOpInterface
111 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
112   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
113   if (!bufferizableOp)
114     return nullptr;
115   if (!isOpAllowed(op))
116     return nullptr;
117   return bufferizableOp;
118 }
119 
120 BufferizableOpInterface
121 BufferizationOptions::dynCastBufferizableOp(Value value) const {
122   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
123     if (isOpAllowed(bufferizableOp.getOperation()))
124       return bufferizableOp;
125   return nullptr;
126 }
127 
128 void BufferizationOptions::addDialectStateInitializer(
129     StringRef name, const DialectStateInitFn &fn) {
130   stateInitializers.push_back(
131       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Helper functions for BufferizableOpInterface
136 //===----------------------------------------------------------------------===//
137 
138 static void setInsertionPointAfter(OpBuilder &b, Value value) {
139   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
140     b.setInsertionPointToStart(bbArg.getOwner());
141   } else {
142     b.setInsertionPointAfter(value.getDefiningOp());
143   }
144 }
145 
146 /// Determine which OpOperand* will alias with `result` if the op is bufferized
147 /// in place. Return an empty vector if the op is not bufferizable.
148 SmallVector<OpOperand *>
149 AnalysisState::getAliasingOpOperand(OpResult result) const {
150   if (Operation *op = result.getDefiningOp())
151     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
152       return bufferizableOp.getAliasingOpOperand(result, *this);
153   return {};
154 }
155 
156 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
157 /// in place. Return an empty vector if the op is not bufferizable.
158 SmallVector<OpResult>
159 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
160   if (auto bufferizableOp =
161           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
162     return bufferizableOp.getAliasingOpResult(opOperand, *this);
163   return {};
164 }
165 
166 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
167 /// op is not bufferizable.
168 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
169   if (auto bufferizableOp =
170           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
171     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
172 
173   // Unknown op that returns a tensor. The inplace analysis does not support it.
174   // Conservatively return true.
175   return true;
176 }
177 
178 /// Return true if `opOperand` bufferizes to a memory write. Return
179 /// `true` if the op is not bufferizable.
180 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
181   if (auto bufferizableOp =
182           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
183     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
184 
185   // Unknown op that returns a tensor. The inplace analysis does not support it.
186   // Conservatively return true.
187   return true;
188 }
189 
190 /// Return true if `opOperand` does neither read nor write but bufferizes to an
191 /// alias. Return false if the op is not bufferizable.
192 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
193   if (auto bufferizableOp =
194           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
195     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
196 
197   // Unknown op that returns a tensor. The inplace analysis does not support it.
198   // Conservatively return false.
199   return false;
200 }
201 
202 /// Return true if the given value is read by an op that bufferizes to a memory
203 /// read. Also takes into account ops that create an alias but do not read by
204 /// themselves (e.g., ExtractSliceOp).
205 bool AnalysisState::isValueRead(Value value) const {
206   assert(value.getType().isa<TensorType>() && "expected TensorType");
207   SmallVector<OpOperand *> workingSet;
208   for (OpOperand &use : value.getUses())
209     workingSet.push_back(&use);
210 
211   while (!workingSet.empty()) {
212     OpOperand *uMaybeReading = workingSet.pop_back_val();
213     // Skip over all ops that neither read nor write (but create an alias).
214     if (bufferizesToAliasOnly(*uMaybeReading))
215       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
216         for (OpOperand &use : opResult.getUses())
217           workingSet.push_back(&use);
218     if (bufferizesToMemoryRead(*uMaybeReading))
219       return true;
220   }
221 
222   return false;
223 }
224 
225 // Starting from `value`, follow the use-def chain in reverse, always selecting
226 // the aliasing OpOperands. Find and return Values for which `condition`
227 // evaluates to true. OpOperands of such matching Values are not traversed any
228 // further.
229 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
230     Value value, llvm::function_ref<bool(Value)> condition) const {
231   llvm::SetVector<Value> result, workingSet;
232   workingSet.insert(value);
233 
234   while (!workingSet.empty()) {
235     Value value = workingSet.pop_back_val();
236     if (condition(value) || value.isa<BlockArgument>()) {
237       result.insert(value);
238       continue;
239     }
240 
241     OpResult opResult = value.cast<OpResult>();
242     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
243     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
244       result.insert(value);
245       continue;
246     }
247 
248     for (OpOperand *o : opOperands)
249       workingSet.insert(o->get());
250   }
251 
252   return result;
253 }
254 
255 // Find the Values of the last preceding write of a given Value.
256 llvm::SetVector<Value>
257 AnalysisState::findLastPrecedingWrite(Value value) const {
258   return findValueInReverseUseDefChain(value, [&](Value value) {
259     Operation *op = value.getDefiningOp();
260     if (!op)
261       return true;
262     auto bufferizableOp = options.dynCastBufferizableOp(op);
263     if (!bufferizableOp)
264       return true;
265     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
266   });
267 }
268 
269 AnalysisState::AnalysisState(const BufferizationOptions &options)
270     : options(options) {
271   for (const BufferizationOptions::AnalysisStateInitFn &fn :
272        options.stateInitializers)
273     fn(*this);
274 }
275 
276 // bufferization.to_memref is not allowed to change the rank.
277 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
278 #ifndef NDEBUG
279   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
280   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
281                                    rankedTensorType.getRank()) &&
282          "to_memref would be invalid: mismatching ranks");
283 #endif
284 }
285 
286 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
287                                         const BufferizationOptions &options) {
288   auto tensorType = tensor.getType().dyn_cast<TensorType>();
289   assert(tensorType && "unexpected non-tensor type");
290 
291   // Replace "%t = to_tensor %m" with %m.
292   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
293     return toTensorOp.memref();
294 
295   // Insert to_memref op.
296   OpBuilder::InsertionGuard g(rewriter);
297   setInsertionPointAfter(rewriter, tensor);
298   Type memrefType = getMemRefType(tensorType, options);
299   ensureToMemrefOpIsValid(tensor, memrefType);
300   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
301                                                     tensor);
302 }
303 
304 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
305 /// a new buffer and copy over data from the existing buffer if out-of-place
306 /// bufferization was decided.
307 FailureOr<Value>
308 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
309                               Optional<ForceInPlacability> overrideInPlace,
310                               Optional<Operation *> customCopyInsertionPoint) {
311   const BufferizationOptions &options = analysisState.getOptions();
312   OpBuilder::InsertionGuard guard(rewriter);
313   Operation *op = opOperand.getOwner();
314   Location loc = op->getLoc();
315   SmallVector<OpResult> aliasingOpResults =
316       analysisState.getAliasingOpResult(opOperand);
317   Value operand = opOperand.get();
318   Value operandBuffer = lookupBuffer(rewriter, operand, options);
319 
320   // Can `operandBuffer` be used directly or do we need a copy?
321   bool inplace =
322       overrideInPlace != FORCE_OUT_OF_PLACE &&
323       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
324   if (inplace)
325     return operandBuffer;
326 
327   // Bufferizing out-of-place: Allocate a new buffer.
328   // Move insertion point right after `operandBuffer`. That is where the
329   // allocation should be inserted (in the absence of allocation hoisting).
330   setInsertionPointAfter(rewriter, operandBuffer);
331   // Allocate the result buffer. The buffer should be deallocated if the tensor
332   // is not yielded and deallocs are enabled in general.
333   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
334     return getAnalysisState().isTensorYielded(v);
335   });
336   FailureOr<Value> resultBuffer = createAlloc(
337       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
338   if (failed(resultBuffer))
339     return failure();
340   // Do not copy the buffer if its contents are undefined.
341   if (analysisState.hasUndefinedContents(&opOperand))
342     return resultBuffer;
343   // Do not copy if the copied data is never read.
344   if (!aliasingOpResults.empty() &&
345       !analysisState.bufferizesToMemoryRead(opOperand) &&
346       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
347         return analysisState.isValueRead(opResult);
348       }))
349     return resultBuffer;
350   // Do not copy if this op does not read the data, but writes it.
351   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
352       !analysisState.bufferizesToMemoryRead(opOperand))
353     return resultBuffer;
354 
355   if (customCopyInsertionPoint) {
356     rewriter.setInsertionPoint(*customCopyInsertionPoint);
357   } else {
358     // The copy happens right before the op that is bufferized.
359     rewriter.setInsertionPoint(op);
360   }
361   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
362     return failure();
363 
364   return resultBuffer;
365 }
366 
367 /// Return the buffer type for a given Value (tensor) after bufferization.
368 BaseMemRefType BufferizationState::getBufferType(Value value) const {
369   auto tensorType = value.getType().dyn_cast<TensorType>();
370   assert(tensorType && "unexpected non-tensor type");
371 
372   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
373     return toTensorOp.memref().getType().cast<BaseMemRefType>();
374 
375   return getMemRefType(tensorType, getOptions());
376 }
377 
378 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
379                                                   Operation *op,
380                                                   ValueRange values) {
381   assert(values.size() == op->getNumResults() &&
382          "expected one value per OpResult");
383   OpBuilder::InsertionGuard g(rewriter);
384 
385   // Replace all OpResults with the given values.
386   SmallVector<Value> replacements;
387   for (OpResult opResult : op->getOpResults()) {
388     Value replacement = values[opResult.getResultNumber()];
389     if (opResult.getType().isa<TensorType>()) {
390       // The OpResult is a tensor. Such values are replaced with memrefs during
391       // bufferization.
392       assert((replacement.getType().isa<MemRefType>() ||
393               replacement.getType().isa<UnrankedMemRefType>()) &&
394              "tensor op result should be replaced with a memref value");
395       // The existing uses of the OpResult still expect a tensor. Insert a
396       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
397       // loose all of its users and eventually DCE away.
398       rewriter.setInsertionPointAfter(op);
399       replacement = rewriter.create<bufferization::ToTensorOp>(
400           replacement.getLoc(), replacement);
401     }
402     replacements.push_back(replacement);
403   }
404 
405   rewriter.replaceOp(op, replacements);
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // Bufferization-specific scoped alloc/dealloc insertion support.
410 //===----------------------------------------------------------------------===//
411 
412 /// Create a memref allocation with the given type and dynamic extents.
413 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
414                                                    MemRefType type,
415                                                    ValueRange dynShape) const {
416   if (allocationFn)
417     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
418 
419   // Default bufferallocation via AllocOp.
420   Value allocated = b.create<memref::AllocOp>(
421       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
422   return allocated;
423 }
424 
425 /// Creates a memref deallocation. The given memref buffer must have been
426 /// allocated using `createAlloc`.
427 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
428                                                   Value allocatedBuffer) const {
429   if (deallocationFn)
430     return (*deallocationFn)(b, loc, allocatedBuffer);
431 
432   // Default buffer deallocation via DeallocOp.
433   b.create<memref::DeallocOp>(loc, allocatedBuffer);
434   return success();
435 }
436 
437 static MemRefType getContiguousMemRefType(ShapedType shapedType,
438                                           Attribute memorySpace = {}) {
439   MemRefLayoutAttrInterface layout = {};
440   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
441                          layout, memorySpace);
442 }
443 
444 /// Compute the type of the `memref` to use for allocating the buffer for
445 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
446 /// dynamic dimensions in the returned `memref` type.
447 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
448                                             Value shapedValue,
449                                             SmallVectorImpl<Value> &dynShape) {
450   MemRefType allocMemRefType =
451       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
452 
453   // Compute the dynamic part of the shape.
454   bool reifiedShapes = false;
455   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
456           shapedValue.getDefiningOp())) {
457     ReifiedRankedShapedTypeDims resultDims;
458     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
459       reifiedShapes = true;
460       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
461       auto &shape = resultDims[resultValue.getResultNumber()];
462       for (const auto &dim : enumerate(allocMemRefType.getShape()))
463         if (ShapedType::isDynamic(dim.value()))
464           dynShape.push_back(shape[dim.index()]);
465     }
466   }
467 
468   if (!reifiedShapes) {
469     for (const auto &dim : enumerate(allocMemRefType.getShape()))
470       if (ShapedType::isDynamic(dim.value())) {
471         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
472                 shapedValue.getType().isa<MemRefType>()) &&
473                "expected MemRef type");
474         dynShape.push_back(
475             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
476       }
477   }
478 
479   return allocMemRefType;
480 }
481 
482 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
483 /// block in case of a bbArg).
484 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
485                                                  Value shapedValue,
486                                                  Optional<bool> dealloc) {
487   // Take a guard before anything else.
488   OpBuilder::InsertionGuard g(b);
489 
490   // Compute allocation memref type.
491   assert(shapedValue.getType().isa<ShapedType>());
492   SmallVector<Value> dynShape;
493   MemRefType allocMemRefType =
494       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
495 
496   // Create the buffer allocation.
497   FailureOr<Value> buffer =
498       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
499   if (failed(buffer))
500     return failure();
501 
502   // Should be the buffer be deallocated again or should we let it leak?
503   if (dealloc) {
504     if (!dealloc.getValue())
505       return *buffer;
506   } else {
507     assert(shapedValue.getType().isa<TensorType>() &&
508            "must specify `dealloc` if non-tensor value is passed");
509     // Buffer should be not be deallocated if deallocs are generally deactivated
510     // or if the tensor is yielded from a block.
511     if (!getOptions().createDeallocs ||
512         getAnalysisState().isTensorYielded(shapedValue))
513       return *buffer;
514   }
515 
516   // Create buffer deallocation.
517   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
518   if (failed(getOptions().createDealloc(b, loc, *buffer)))
519     return failure();
520 
521   return *buffer;
522 }
523 
524 /// Create a memory copy between two memref buffers.
525 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
526                                                  Value from, Value to) const {
527   if (memCpyFn)
528     return (*memCpyFn)(b, loc, from, to);
529 
530   b.create<memref::CopyOp>(loc, from, to);
531   return success();
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // Bufferization-specific BlockAndValueMapping support with debugging.
536 //===----------------------------------------------------------------------===//
537 
538 bool bufferization::isFunctionArgument(Value value) {
539   auto bbArg = value.dyn_cast<BlockArgument>();
540   if (!bbArg)
541     return false;
542   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
543 }
544 
545 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
546                                             const BufferizationOptions &options,
547                                             MemRefLayoutAttrInterface layout,
548                                             Attribute memorySpace) {
549   // Case 1: Unranked memref type.
550   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
551     assert(!layout && "UnrankedTensorType cannot have a layout map");
552     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
553                                    memorySpace);
554   }
555 
556   // Case 2: Ranked memref type with specified layout.
557   auto rankedTensorType = tensorType.cast<RankedTensorType>();
558   if (layout) {
559     return MemRefType::get(rankedTensorType.getShape(),
560                            rankedTensorType.getElementType(), layout,
561                            memorySpace);
562   }
563 
564   // Case 3: Configured with "fully dynamic layout maps".
565   if (options.unknownTypeConversion ==
566       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
567     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
568 
569   // Case 4: Configured with "static identity layout maps".
570   if (options.unknownTypeConversion ==
571       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
572     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
573 
574   llvm_unreachable("InferLayoutMap is an invalid option");
575 }
576 
577 BaseMemRefType
578 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
579                                                    Attribute memorySpace) {
580   // Case 1: Unranked memref type.
581   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
582     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
583                                    memorySpace);
584   }
585 
586   // Case 2: Ranked memref type.
587   auto rankedTensorType = tensorType.cast<RankedTensorType>();
588   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
589   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
590                                       ShapedType::kDynamicStrideOrOffset);
591   AffineMap stridedLayout = makeStridedLinearLayoutMap(
592       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
593   return MemRefType::get(rankedTensorType.getShape(),
594                          rankedTensorType.getElementType(), stridedLayout,
595                          memorySpace);
596 }
597 
598 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
599 /// the given tensor type is unranked, return an unranked MemRef type.
600 BaseMemRefType
601 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
602                                                      Attribute memorySpace) {
603   // Case 1: Unranked memref type.
604   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
605     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
606                                    memorySpace);
607   }
608 
609   // Case 2: Ranked memref type.
610   auto rankedTensorType = tensorType.cast<RankedTensorType>();
611   MemRefLayoutAttrInterface layout = {};
612   return MemRefType::get(rankedTensorType.getShape(),
613                          rankedTensorType.getElementType(), layout,
614                          memorySpace);
615 }
616