1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 
24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
25 
26 } // namespace bufferization
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "bufferizable-op-interface"
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
32 
33 using namespace mlir;
34 using namespace bufferization;
35 
36 /// Attribute name used to mark region arguments that can be bufferized
37 /// in-place during linalg comprehensive bufferization.
38 constexpr const ::llvm::StringLiteral
39     bufferization::BufferizableOpInterface::kInplaceableAttrName;
40 
41 //===----------------------------------------------------------------------===//
42 // OpFilter
43 //===----------------------------------------------------------------------===//
44 
45 bool OpFilter::isOpAllowed(Operation *op) const {
46   // All other ops: Allow/disallow according to filter.
47   bool isAllowed = !hasAllowRule();
48   for (const Entry &entry : entries) {
49     bool filterResult = entry.fn(op);
50     switch (entry.type) {
51     case Entry::ALLOW:
52       isAllowed |= filterResult;
53       break;
54     case Entry::DENY:
55       if (filterResult)
56         // DENY filter matches. This op is no allowed. (Even if other ALLOW
57         // filters may match.)
58         return false;
59     };
60   }
61   return isAllowed;
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // BufferizationOptions
66 //===----------------------------------------------------------------------===//
67 
68 // Default constructor for BufferizationOptions.
69 BufferizationOptions::BufferizationOptions() = default;
70 
71 bool BufferizationOptions::isOpAllowed(Operation *op) const {
72   // Special case: If function boundary bufferization is deactivated, do not
73   // allow ops that belong to the `func` dialect.
74   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
75   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
76     return false;
77 
78   return opFilter.isOpAllowed(op);
79 }
80 
81 BufferizableOpInterface
82 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
83   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
84   if (!bufferizableOp)
85     return nullptr;
86   if (!isOpAllowed(op))
87     return nullptr;
88   return bufferizableOp;
89 }
90 
91 BufferizableOpInterface
92 BufferizationOptions::dynCastBufferizableOp(Value value) const {
93   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
94     if (isOpAllowed(bufferizableOp.getOperation()))
95       return bufferizableOp;
96   return nullptr;
97 }
98 
99 void BufferizationOptions::addDialectStateInitializer(
100     StringRef name, const DialectStateInitFn &fn) {
101   stateInitializers.push_back(
102       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Helper functions for BufferizableOpInterface
107 //===----------------------------------------------------------------------===//
108 
109 static void setInsertionPointAfter(OpBuilder &b, Value value) {
110   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
111     b.setInsertionPointToStart(bbArg.getOwner());
112   } else {
113     b.setInsertionPointAfter(value.getDefiningOp());
114   }
115 }
116 
117 /// Determine which OpOperand* will alias with `result` if the op is bufferized
118 /// in place. Return an empty vector if the op is not bufferizable.
119 SmallVector<OpOperand *>
120 AnalysisState::getAliasingOpOperand(OpResult result) const {
121   if (Operation *op = result.getDefiningOp())
122     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
123       return bufferizableOp.getAliasingOpOperand(result, *this);
124   return {};
125 }
126 
127 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
128 /// in place. Return an empty vector if the op is not bufferizable.
129 SmallVector<OpResult>
130 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
131   if (auto bufferizableOp =
132           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
133     return bufferizableOp.getAliasingOpResult(opOperand, *this);
134   return {};
135 }
136 
137 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
138 /// op is not bufferizable.
139 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
140   if (auto bufferizableOp =
141           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
142     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
143 
144   // Unknown op that returns a tensor. The inplace analysis does not support it.
145   // Conservatively return true.
146   return true;
147 }
148 
149 /// Return true if `opOperand` bufferizes to a memory write. Return
150 /// `true` if the op is not bufferizable.
151 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
152   if (auto bufferizableOp =
153           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
154     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
155 
156   // Unknown op that returns a tensor. The inplace analysis does not support it.
157   // Conservatively return true.
158   return true;
159 }
160 
161 /// Return true if `opOperand` does neither read nor write but bufferizes to an
162 /// alias. Return false if the op is not bufferizable.
163 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
164   if (auto bufferizableOp =
165           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
166     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
167 
168   // Unknown op that returns a tensor. The inplace analysis does not support it.
169   // Conservatively return false.
170   return false;
171 }
172 
173 /// Return true if the given value is read by an op that bufferizes to a memory
174 /// read. Also takes into account ops that create an alias but do not read by
175 /// themselves (e.g., ExtractSliceOp).
176 bool AnalysisState::isValueRead(Value value) const {
177   assert(value.getType().isa<TensorType>() && "expected TensorType");
178   SmallVector<OpOperand *> workingSet;
179   for (OpOperand &use : value.getUses())
180     workingSet.push_back(&use);
181 
182   while (!workingSet.empty()) {
183     OpOperand *uMaybeReading = workingSet.pop_back_val();
184     // Skip over all ops that neither read nor write (but create an alias).
185     if (bufferizesToAliasOnly(*uMaybeReading))
186       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
187         for (OpOperand &use : opResult.getUses())
188           workingSet.push_back(&use);
189     if (bufferizesToMemoryRead(*uMaybeReading))
190       return true;
191   }
192 
193   return false;
194 }
195 
196 // Starting from `value`, follow the use-def chain in reverse, always selecting
197 // the aliasing OpOperands. Find and return Values for which `condition`
198 // evaluates to true. OpOperands of such matching Values are not traversed any
199 // further.
200 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
201     Value value, llvm::function_ref<bool(Value)> condition) const {
202   llvm::SetVector<Value> result, workingSet;
203   workingSet.insert(value);
204 
205   while (!workingSet.empty()) {
206     Value value = workingSet.pop_back_val();
207     if (condition(value) || value.isa<BlockArgument>()) {
208       result.insert(value);
209       continue;
210     }
211 
212     OpResult opResult = value.cast<OpResult>();
213     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
214     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
215       result.insert(value);
216       continue;
217     }
218 
219     for (OpOperand *o : opOperands)
220       workingSet.insert(o->get());
221   }
222 
223   return result;
224 }
225 
226 // Find the Values of the last preceding write of a given Value.
227 llvm::SetVector<Value>
228 AnalysisState::findLastPrecedingWrite(Value value) const {
229   return findValueInReverseUseDefChain(value, [&](Value value) {
230     Operation *op = value.getDefiningOp();
231     if (!op)
232       return true;
233     auto bufferizableOp = options.dynCastBufferizableOp(op);
234     if (!bufferizableOp)
235       return true;
236     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
237   });
238 }
239 
240 AnalysisState::AnalysisState(const BufferizationOptions &options)
241     : options(options) {
242   for (const BufferizationOptions::AnalysisStateInitFn &fn :
243        options.stateInitializers)
244     fn(*this);
245 }
246 
247 // bufferization.to_memref is not allowed to change the rank.
248 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
249 #ifndef NDEBUG
250   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
251   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
252                                    rankedTensorType.getRank()) &&
253          "to_memref would be invalid: mismatching ranks");
254 #endif
255 }
256 
257 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
258                                         const BufferizationOptions &options) {
259   auto tensorType = tensor.getType().dyn_cast<TensorType>();
260   assert(tensorType && "unexpected non-tensor type");
261 
262   // Replace "%t = to_tensor %m" with %m.
263   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
264     return toTensorOp.memref();
265 
266   // Insert to_memref op.
267   OpBuilder::InsertionGuard g(rewriter);
268   setInsertionPointAfter(rewriter, tensor);
269   Type memrefType = getMemRefType(tensorType, options);
270   ensureToMemrefOpIsValid(tensor, memrefType);
271   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
272                                                     tensor);
273 }
274 
275 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
276 /// a new buffer and copy over data from the existing buffer if out-of-place
277 /// bufferization was decided.
278 FailureOr<Value>
279 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
280                               Optional<ForceInPlacability> overrideInPlace,
281                               Optional<Operation *> customCopyInsertionPoint) {
282   const BufferizationOptions &options = analysisState.getOptions();
283   OpBuilder::InsertionGuard guard(rewriter);
284   Operation *op = opOperand.getOwner();
285   Location loc = op->getLoc();
286   SmallVector<OpResult> aliasingOpResults =
287       analysisState.getAliasingOpResult(opOperand);
288   Value operand = opOperand.get();
289   Value operandBuffer = lookupBuffer(rewriter, operand, options);
290 
291   // Can `operandBuffer` be used directly or do we need a copy?
292   bool inplace =
293       overrideInPlace != FORCE_OUT_OF_PLACE &&
294       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
295   if (inplace)
296     return operandBuffer;
297 
298   // Bufferizing out-of-place: Allocate a new buffer.
299   // Move insertion point right after `operandBuffer`. That is where the
300   // allocation should be inserted (in the absence of allocation hoisting).
301   setInsertionPointAfter(rewriter, operandBuffer);
302   // Allocate the result buffer. The buffer should be deallocated if the tensor
303   // is not yielded and deallocs are enabled in general.
304   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
305     return getAnalysisState().isTensorYielded(v);
306   });
307   FailureOr<Value> resultBuffer = createAlloc(
308       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
309   if (failed(resultBuffer))
310     return failure();
311   // Do not copy the buffer if its contents are undefined.
312   if (analysisState.hasUndefinedContents(&opOperand))
313     return resultBuffer;
314   // Do not copy if the copied data is never read.
315   if (!aliasingOpResults.empty() &&
316       !analysisState.bufferizesToMemoryRead(opOperand) &&
317       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
318         return analysisState.isValueRead(opResult);
319       }))
320     return resultBuffer;
321   // Do not copy if this op does not read the data, but writes it.
322   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
323       !analysisState.bufferizesToMemoryRead(opOperand))
324     return resultBuffer;
325 
326   if (customCopyInsertionPoint) {
327     rewriter.setInsertionPoint(*customCopyInsertionPoint);
328   } else {
329     // The copy happens right before the op that is bufferized.
330     rewriter.setInsertionPoint(op);
331   }
332   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
333     return failure();
334 
335   return resultBuffer;
336 }
337 
338 /// Return the buffer type for a given Value (tensor) after bufferization.
339 BaseMemRefType BufferizationState::getBufferType(Value value) const {
340   auto tensorType = value.getType().dyn_cast<TensorType>();
341   assert(tensorType && "unexpected non-tensor type");
342 
343   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
344     return toTensorOp.memref().getType().cast<BaseMemRefType>();
345 
346   return getMemRefType(tensorType, getOptions());
347 }
348 
349 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
350                                                   Operation *op,
351                                                   ValueRange values) {
352   assert(values.size() == op->getNumResults() &&
353          "expected one value per OpResult");
354   OpBuilder::InsertionGuard g(rewriter);
355 
356   // Replace all OpResults with the given values.
357   SmallVector<Value> replacements;
358   for (OpResult opResult : op->getOpResults()) {
359     Value replacement = values[opResult.getResultNumber()];
360     if (opResult.getType().isa<TensorType>()) {
361       // The OpResult is a tensor. Such values are replaced with memrefs during
362       // bufferization.
363       assert((replacement.getType().isa<MemRefType>() ||
364               replacement.getType().isa<UnrankedMemRefType>()) &&
365              "tensor op result should be replaced with a memref value");
366       // The existing uses of the OpResult still expect a tensor. Insert a
367       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
368       // loose all of its users and eventually DCE away.
369       rewriter.setInsertionPointAfter(op);
370       replacement = rewriter.create<bufferization::ToTensorOp>(
371           replacement.getLoc(), replacement);
372     }
373     replacements.push_back(replacement);
374   }
375 
376   rewriter.replaceOp(op, replacements);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // Bufferization-specific scoped alloc/dealloc insertion support.
381 //===----------------------------------------------------------------------===//
382 
383 /// Create a memref allocation with the given type and dynamic extents.
384 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
385                                                    MemRefType type,
386                                                    ValueRange dynShape) const {
387   if (allocationFn)
388     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
389 
390   // Default bufferallocation via AllocOp.
391   Value allocated = b.create<memref::AllocOp>(
392       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
393   return allocated;
394 }
395 
396 /// Creates a memref deallocation. The given memref buffer must have been
397 /// allocated using `createAlloc`.
398 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
399                                                   Value allocatedBuffer) const {
400   if (deallocationFn)
401     return (*deallocationFn)(b, loc, allocatedBuffer);
402 
403   // Default buffer deallocation via DeallocOp.
404   b.create<memref::DeallocOp>(loc, allocatedBuffer);
405   return success();
406 }
407 
408 static MemRefType getContiguousMemRefType(ShapedType shapedType,
409                                           Attribute memorySpace = {}) {
410   MemRefLayoutAttrInterface layout = {};
411   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
412                          layout, memorySpace);
413 }
414 
415 /// Compute the type of the `memref` to use for allocating the buffer for
416 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
417 /// dynamic dimensions in the returned `memref` type.
418 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
419                                             Value shapedValue,
420                                             SmallVectorImpl<Value> &dynShape) {
421   MemRefType allocMemRefType =
422       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
423 
424   // Compute the dynamic part of the shape.
425   bool reifiedShapes = false;
426   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
427           shapedValue.getDefiningOp())) {
428     ReifiedRankedShapedTypeDims resultDims;
429     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
430       reifiedShapes = true;
431       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
432       auto &shape = resultDims[resultValue.getResultNumber()];
433       for (const auto &dim : enumerate(allocMemRefType.getShape()))
434         if (ShapedType::isDynamic(dim.value()))
435           dynShape.push_back(shape[dim.index()]);
436     }
437   }
438 
439   if (!reifiedShapes) {
440     for (const auto &dim : enumerate(allocMemRefType.getShape()))
441       if (ShapedType::isDynamic(dim.value())) {
442         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
443                 shapedValue.getType().isa<MemRefType>()) &&
444                "expected MemRef type");
445         dynShape.push_back(
446             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
447       }
448   }
449 
450   return allocMemRefType;
451 }
452 
453 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
454 /// block in case of a bbArg).
455 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
456                                                  Value shapedValue,
457                                                  Optional<bool> dealloc) {
458   // Take a guard before anything else.
459   OpBuilder::InsertionGuard g(b);
460 
461   // Compute allocation memref type.
462   assert(shapedValue.getType().isa<ShapedType>());
463   SmallVector<Value> dynShape;
464   MemRefType allocMemRefType =
465       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
466 
467   // Create the buffer allocation.
468   FailureOr<Value> buffer =
469       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
470   if (failed(buffer))
471     return failure();
472 
473   // Should be the buffer be deallocated again or should we let it leak?
474   if (dealloc) {
475     if (!dealloc.getValue())
476       return *buffer;
477   } else {
478     assert(shapedValue.getType().isa<TensorType>() &&
479            "must specify `dealloc` if non-tensor value is passed");
480     // Buffer should be not be deallocated if deallocs are generally deactivated
481     // or if the tensor is yielded from a block.
482     if (!getOptions().createDeallocs ||
483         getAnalysisState().isTensorYielded(shapedValue))
484       return *buffer;
485   }
486 
487   // Create buffer deallocation.
488   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
489   if (failed(getOptions().createDealloc(b, loc, *buffer)))
490     return failure();
491 
492   return *buffer;
493 }
494 
495 /// Create a memory copy between two memref buffers.
496 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
497                                                  Value from, Value to) const {
498   if (memCpyFn)
499     return (*memCpyFn)(b, loc, from, to);
500 
501   b.create<memref::CopyOp>(loc, from, to);
502   return success();
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // Bufferization-specific BlockAndValueMapping support with debugging.
507 //===----------------------------------------------------------------------===//
508 
509 bool bufferization::isFunctionArgument(Value value) {
510   auto bbArg = value.dyn_cast<BlockArgument>();
511   if (!bbArg)
512     return false;
513   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
514 }
515 
516 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
517                                             const BufferizationOptions &options,
518                                             MemRefLayoutAttrInterface layout,
519                                             Attribute memorySpace) {
520   // Case 1: Unranked memref type.
521   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
522     assert(!layout && "UnrankedTensorType cannot have a layout map");
523     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
524                                    memorySpace);
525   }
526 
527   // Case 2: Ranked memref type with specified layout.
528   auto rankedTensorType = tensorType.cast<RankedTensorType>();
529   if (layout) {
530     return MemRefType::get(rankedTensorType.getShape(),
531                            rankedTensorType.getElementType(), layout,
532                            memorySpace);
533   }
534 
535   // Case 3: Configured with "fully dynamic layout maps".
536   if (options.unknownTypeConversion ==
537       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
538     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
539 
540   // Case 4: Configured with "static identity layout maps".
541   if (options.unknownTypeConversion ==
542       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
543     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
544 
545   llvm_unreachable("InferLayoutMap is an invalid option");
546 }
547 
548 BaseMemRefType
549 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
550                                                    Attribute memorySpace) {
551   // Case 1: Unranked memref type.
552   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
553     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
554                                    memorySpace);
555   }
556 
557   // Case 2: Ranked memref type.
558   auto rankedTensorType = tensorType.cast<RankedTensorType>();
559   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
560   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
561                                       ShapedType::kDynamicStrideOrOffset);
562   AffineMap stridedLayout = makeStridedLinearLayoutMap(
563       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
564   return MemRefType::get(rankedTensorType.getShape(),
565                          rankedTensorType.getElementType(), stridedLayout,
566                          memorySpace);
567 }
568 
569 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
570 /// the given tensor type is unranked, return an unranked MemRef type.
571 BaseMemRefType
572 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
573                                                      Attribute memorySpace) {
574   // Case 1: Unranked memref type.
575   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
576     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
577                                    memorySpace);
578   }
579 
580   // Case 2: Ranked memref type.
581   auto rankedTensorType = tensorType.cast<RankedTensorType>();
582   MemRefLayoutAttrInterface layout = {};
583   return MemRefType::get(rankedTensorType.getShape(),
584                          rankedTensorType.getElementType(), layout,
585                          memorySpace);
586 }
587