1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
47 /// shaped value is copied. Otherwise, a tensor with undefined contents is
48 /// allocated.
49 Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                                   Value shapedValue,
51                                                   bool escape, bool copy) {
52   Value tensor;
53   if (shapedValue.getType().isa<RankedTensorType>()) {
54     tensor = shapedValue;
55   } else if (shapedValue.getType().isa<MemRefType>()) {
56     tensor = b.create<ToTensorOp>(loc, shapedValue);
57   } else {
58     llvm_unreachable("expected RankedTensorType or MemRefType");
59   }
60   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
61   SmallVector<Value> dynamicSizes;
62   if (!copy) {
63     // Compute the dynamic part of the shape.
64     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
65     bool reifiedShapes = false;
66     if (shapedValue.getType().isa<RankedTensorType>() &&
67         shapedValue.isa<OpResult>()) {
68       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
69               shapedValue.getDefiningOp())) {
70         ReifiedRankedShapedTypeDims resultDims;
71         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
72           reifiedShapes = true;
73           auto &shape =
74               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
75           for (const auto &dim : enumerate(tensorType.getShape()))
76             if (ShapedType::isDynamic(dim.value()))
77               dynamicSizes.push_back(shape[dim.index()]);
78         }
79       }
80     }
81 
82     // If the shape could not be reified, create DimOps.
83     if (!reifiedShapes)
84       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
85   }
86 
87   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
88                                                copy ? tensor : Value());
89   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
90                          b.getBoolArrayAttr({escape}));
91   return allocTensorOp;
92 }
93 
94 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
95     RewriterBase &rewriter, const AnalysisState &state) {
96   OpBuilder::InsertionGuard g(rewriter);
97   Operation *op = getOperation();
98   SmallVector<OpOperand *> outOfPlaceOpOperands;
99   DenseSet<OpOperand *> copiedOpOperands;
100   DenseSet<OpOperand *> escapingOpOperandCopies;
101   SmallVector<OpResult> outOfPlaceOpResults;
102   DenseSet<OpResult> copiedOpResults;
103   DenseSet<OpResult> escapingOpResultCopies;
104 
105   // Find all out-of-place OpOperands.
106   for (OpOperand &opOperand : op->getOpOperands()) {
107     Type operandType = opOperand.get().getType();
108     if (!operandType.isa<TensorType>())
109       continue;
110     if (state.isInPlace(opOperand))
111       continue;
112     if (operandType.isa<UnrankedTensorType>())
113       return op->emitError("copies of unranked tensors are not supported");
114 
115     SmallVector<OpResult> aliasingOpResults =
116         state.getAliasingOpResult(opOperand);
117     // Is the result yielded from a block? Or are deallocations turned off
118     // entirely? In either case, mark the allocation as "escaping", so that it
119     // will not be deallocated.
120     bool escape = !state.getOptions().createDeallocs ||
121                   llvm::any_of(aliasingOpResults, [&](Value v) {
122                     return state.isTensorYielded(v);
123                   });
124 
125     if (aliasingOpResults.size() == 1 &&
126         !state.bufferizesToMemoryWrite(opOperand) &&
127         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
128       // The op itself does not write but may create exactly one alias. Instead
129       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
130       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
131       // where the result is usually a smaller part of the source).
132       outOfPlaceOpResults.push_back(aliasingOpResults.front());
133       if (!state.canOmitTensorCopy(opOperand))
134         copiedOpResults.insert(aliasingOpResults.front());
135       if (escape)
136         escapingOpResultCopies.insert(aliasingOpResults.front());
137     } else {
138       // In all other cases, make a copy of the OpOperand.
139       outOfPlaceOpOperands.push_back(&opOperand);
140       if (!state.canOmitTensorCopy(opOperand))
141         copiedOpOperands.insert(&opOperand);
142       if (escape)
143         escapingOpOperandCopies.insert(&opOperand);
144     }
145   }
146 
147   // Insert copies of OpOperands.
148   rewriter.setInsertionPoint(op);
149   for (OpOperand *opOperand : outOfPlaceOpOperands) {
150     Value copy = allocateTensorForShapedValue(
151         rewriter, op->getLoc(), opOperand->get(),
152         escapingOpOperandCopies.contains(opOperand),
153         copiedOpOperands.contains(opOperand));
154     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
155   }
156 
157   // Insert copies of OpResults.
158   rewriter.setInsertionPointAfter(op);
159   for (OpResult opResult : outOfPlaceOpResults) {
160     Value copy =
161         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
162                                      escapingOpResultCopies.contains(opResult),
163                                      copiedOpResults.count(opResult));
164     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
165         opResult.getUses(), [](OpOperand &use) { return &use; }));
166     for (OpOperand *use : uses) {
167       // Do not update the alloc_tensor op that we just created.
168       if (use->getOwner() != copy.getDefiningOp())
169         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
170     }
171   }
172 
173   return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // OpFilter
178 //===----------------------------------------------------------------------===//
179 
180 bool OpFilter::isOpAllowed(Operation *op) const {
181   // All other ops: Allow/disallow according to filter.
182   bool isAllowed = !hasAllowRule();
183   for (const Entry &entry : entries) {
184     bool filterResult = entry.fn(op);
185     switch (entry.type) {
186     case Entry::ALLOW:
187       isAllowed |= filterResult;
188       break;
189     case Entry::DENY:
190       if (filterResult)
191         // DENY filter matches. This op is no allowed. (Even if other ALLOW
192         // filters may match.)
193         return false;
194     };
195   }
196   return isAllowed;
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // BufferizationOptions
201 //===----------------------------------------------------------------------===//
202 
203 // Default constructor for BufferizationOptions.
204 BufferizationOptions::BufferizationOptions() = default;
205 
206 bool BufferizationOptions::isOpAllowed(Operation *op) const {
207   // Special case: If function boundary bufferization is deactivated, do not
208   // allow ops that belong to the `func` dialect.
209   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
210   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
211     return false;
212 
213   return opFilter.isOpAllowed(op);
214 }
215 
216 BufferizableOpInterface
217 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
218   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
219   if (!bufferizableOp)
220     return nullptr;
221   if (!isOpAllowed(op))
222     return nullptr;
223   return bufferizableOp;
224 }
225 
226 BufferizableOpInterface
227 BufferizationOptions::dynCastBufferizableOp(Value value) const {
228   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
229     if (isOpAllowed(bufferizableOp.getOperation()))
230       return bufferizableOp;
231   return nullptr;
232 }
233 
234 void BufferizationOptions::addDialectStateInitializer(
235     StringRef name, const DialectStateInitFn &fn) {
236   stateInitializers.push_back(
237       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Helper functions for BufferizableOpInterface
242 //===----------------------------------------------------------------------===//
243 
244 static void setInsertionPointAfter(OpBuilder &b, Value value) {
245   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
246     b.setInsertionPointToStart(bbArg.getOwner());
247   } else {
248     b.setInsertionPointAfter(value.getDefiningOp());
249   }
250 }
251 
252 /// Determine which OpOperand* will alias with `result` if the op is bufferized
253 /// in place. Return an empty vector if the op is not bufferizable.
254 SmallVector<OpOperand *>
255 AnalysisState::getAliasingOpOperand(OpResult result) const {
256   if (Operation *op = result.getDefiningOp())
257     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
258       return bufferizableOp.getAliasingOpOperand(result, *this);
259   return {};
260 }
261 
262 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
263 /// in place. Return an empty vector if the op is not bufferizable.
264 SmallVector<OpResult>
265 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
266   if (auto bufferizableOp =
267           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
268     return bufferizableOp.getAliasingOpResult(opOperand, *this);
269   return {};
270 }
271 
272 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
273 /// op is not bufferizable.
274 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
275   if (auto bufferizableOp =
276           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
277     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
278 
279   // Unknown op that returns a tensor. The inplace analysis does not support it.
280   // Conservatively return true.
281   return true;
282 }
283 
284 /// Return true if `opOperand` bufferizes to a memory write. Return
285 /// `true` if the op is not bufferizable.
286 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
287   if (auto bufferizableOp =
288           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
289     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
290 
291   // Unknown op that returns a tensor. The inplace analysis does not support it.
292   // Conservatively return true.
293   return true;
294 }
295 
296 /// Return true if `opOperand` does neither read nor write but bufferizes to an
297 /// alias. Return false if the op is not bufferizable.
298 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
299   if (auto bufferizableOp =
300           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
301     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
302 
303   // Unknown op that returns a tensor. The inplace analysis does not support it.
304   // Conservatively return false.
305   return false;
306 }
307 
308 /// Return true if the given value is read by an op that bufferizes to a memory
309 /// read. Also takes into account ops that create an alias but do not read by
310 /// themselves (e.g., ExtractSliceOp).
311 bool AnalysisState::isValueRead(Value value) const {
312   assert(value.getType().isa<TensorType>() && "expected TensorType");
313   SmallVector<OpOperand *> workingSet;
314   for (OpOperand &use : value.getUses())
315     workingSet.push_back(&use);
316 
317   while (!workingSet.empty()) {
318     OpOperand *uMaybeReading = workingSet.pop_back_val();
319     // Skip over all ops that neither read nor write (but create an alias).
320     if (bufferizesToAliasOnly(*uMaybeReading))
321       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
322         for (OpOperand &use : opResult.getUses())
323           workingSet.push_back(&use);
324     if (bufferizesToMemoryRead(*uMaybeReading))
325       return true;
326   }
327 
328   return false;
329 }
330 
331 // Starting from `value`, follow the use-def chain in reverse, always selecting
332 // the aliasing OpOperands. Find and return Values for which `condition`
333 // evaluates to true. OpOperands of such matching Values are not traversed any
334 // further.
335 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
336     Value value, llvm::function_ref<bool(Value)> condition) const {
337   llvm::SetVector<Value> result, workingSet;
338   workingSet.insert(value);
339 
340   while (!workingSet.empty()) {
341     Value value = workingSet.pop_back_val();
342     if (condition(value) || value.isa<BlockArgument>()) {
343       result.insert(value);
344       continue;
345     }
346 
347     OpResult opResult = value.cast<OpResult>();
348     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
349     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
350       result.insert(value);
351       continue;
352     }
353 
354     for (OpOperand *o : opOperands)
355       workingSet.insert(o->get());
356   }
357 
358   return result;
359 }
360 
361 // Find the Values of the last preceding write of a given Value.
362 llvm::SetVector<Value>
363 AnalysisState::findLastPrecedingWrite(Value value) const {
364   return findValueInReverseUseDefChain(value, [&](Value value) {
365     Operation *op = value.getDefiningOp();
366     if (!op)
367       return true;
368     auto bufferizableOp = options.dynCastBufferizableOp(op);
369     if (!bufferizableOp)
370       return true;
371     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
372   });
373 }
374 
375 AnalysisState::AnalysisState(const BufferizationOptions &options)
376     : options(options) {
377   for (const BufferizationOptions::AnalysisStateInitFn &fn :
378        options.stateInitializers)
379     fn(*this);
380 }
381 
382 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
383   // Do not copy if the tensor has undefined contents.
384   if (hasUndefinedContents(&opOperand))
385     return true;
386 
387   // Do not copy if the buffer of the tensor is entirely overwritten (with
388   // values that do not depend on the old tensor).
389   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
390     return true;
391 
392   // Do not copy if the tensor is never read.
393   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
394   if (!bufferizesToMemoryRead(opOperand) &&
395       llvm::none_of(aliasingOpResults,
396                     [&](OpResult opResult) { return isValueRead(opResult); }))
397     return true;
398 
399   // Default: Cannot omit the copy.
400   return false;
401 }
402 
403 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
404   // ToMemrefOps are always in-place.
405   if (isa<ToMemrefOp>(opOperand.getOwner()))
406     return true;
407 
408   // In the absence of analysis information, OpOperands that bufferize to a
409   // memory write are out-of-place, i.e., an alloc and copy is inserted.
410   return !bufferizesToMemoryWrite(opOperand);
411 }
412 
413 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
414   // In the absence of analysis information, we do not know if the values are
415   // equivalent. The conservative answer is "false".
416   return false;
417 }
418 
419 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
420   // In the absence of analysis information, we do not know if the values may be
421   // aliasing. The conservative answer is "true".
422   return true;
423 }
424 
425 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
426   // In the absence of analysis information, the conservative answer is "false".
427   return false;
428 }
429 
430 bool AnalysisState::isTensorYielded(Value tensor) const {
431   // In the absence of analysis information, the conservative answer is "true".
432   if (!tensor.getDefiningOp<AllocTensorOp>())
433     return true;
434 
435   // For AllocTensorOp results, we can do better: They do not alias with any
436   // preceding value, so we can follow SSA use-def chains and do a simple
437   // analysis.
438   SmallVector<OpOperand *> worklist;
439   for (OpOperand &use : tensor.getUses())
440     worklist.push_back(&use);
441 
442   while (!worklist.empty()) {
443     OpOperand *operand = worklist.pop_back_val();
444     Operation *op = operand->getOwner();
445 
446     // If the op is not bufferizable, we can safely assume that the value is not
447     // yielded. (When bufferizing that op, it must handle such cases.)
448     if (!options.dynCastBufferizableOp(op))
449       continue;
450 
451     // We cannot analyze through ToMemrefOps, so we have to conservatively
452     // assume that the value is yielded.
453     if (isa<ToMemrefOp>(op))
454       return true;
455 
456     // Check if the op is returning/yielding.
457     if (isRegionReturnLike(op))
458       return true;
459 
460     // Add all aliasing OpResults to the worklist.
461     // Note: In the absence of detailed analysis information (e.g., there may be
462     // no function call analysis information), this `getAliasingOpResult` is
463     // conservative and may report additional OpResults as potentially aliasing.
464     for (OpResult opResult : getAliasingOpResult(*operand))
465       for (OpOperand &use : opResult.getUses())
466         worklist.push_back(&use);
467   }
468 
469   // No ReturnLike op found: The value is not yielded.
470   return false;
471 }
472 
473 // bufferization.to_memref is not allowed to change the rank.
474 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
475 #ifndef NDEBUG
476   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
477   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
478                                    rankedTensorType.getRank()) &&
479          "to_memref would be invalid: mismatching ranks");
480 #endif
481 }
482 
483 Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
484                                const BufferizationOptions &options) {
485   auto tensorType = value.getType().dyn_cast<TensorType>();
486   assert(tensorType && "unexpected non-tensor type");
487 
488   // Replace "%t = to_tensor %m" with %m.
489   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
490     return toTensorOp.getMemref();
491 
492   // Insert to_memref op.
493   OpBuilder::InsertionGuard g(rewriter);
494   setInsertionPointAfter(rewriter, value);
495   Type memrefType = getMemRefType(tensorType, options);
496   ensureToMemrefOpIsValid(value, memrefType);
497   return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
498                                                     value);
499 }
500 
501 /// Return the buffer type for a given Value (tensor) after bufferization.
502 BaseMemRefType
503 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
504   auto tensorType = value.getType().dyn_cast<TensorType>();
505   assert(tensorType && "unexpected non-tensor type");
506 
507   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
508     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
509 
510   return getMemRefType(tensorType, options);
511 }
512 
513 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
514                                                   Operation *op,
515                                                   ValueRange values) {
516   assert(values.size() == op->getNumResults() &&
517          "expected one value per OpResult");
518   OpBuilder::InsertionGuard g(rewriter);
519 
520   // Replace all OpResults with the given values.
521   SmallVector<Value> replacements;
522   for (OpResult opResult : op->getOpResults()) {
523     Value replacement = values[opResult.getResultNumber()];
524     if (opResult.getType().isa<TensorType>()) {
525       // The OpResult is a tensor. Such values are replaced with memrefs during
526       // bufferization.
527       assert((replacement.getType().isa<MemRefType>() ||
528               replacement.getType().isa<UnrankedMemRefType>()) &&
529              "tensor op result should be replaced with a memref value");
530       // The existing uses of the OpResult still expect a tensor. Insert a
531       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
532       // loose all of its users and eventually DCE away.
533       rewriter.setInsertionPointAfter(op);
534       replacement = rewriter.create<bufferization::ToTensorOp>(
535           replacement.getLoc(), replacement);
536     }
537     replacements.push_back(replacement);
538   }
539 
540   rewriter.replaceOp(op, replacements);
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Bufferization-specific scoped alloc/dealloc insertion support.
545 //===----------------------------------------------------------------------===//
546 
547 /// Create a memref allocation with the given type and dynamic extents.
548 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
549                                                    MemRefType type,
550                                                    ValueRange dynShape) const {
551   if (allocationFn)
552     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
553 
554   // Default bufferallocation via AllocOp.
555   if (bufferAlignment != 0)
556     return b
557         .create<memref::AllocOp>(loc, type, dynShape,
558                                  b.getI64IntegerAttr(bufferAlignment))
559         .getResult();
560   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
561 }
562 
563 /// Creates a memref deallocation. The given memref buffer must have been
564 /// allocated using `createAlloc`.
565 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
566                                                   Value allocatedBuffer) const {
567   if (deallocationFn)
568     return (*deallocationFn)(b, loc, allocatedBuffer);
569 
570   // Default buffer deallocation via DeallocOp.
571   b.create<memref::DeallocOp>(loc, allocatedBuffer);
572   return success();
573 }
574 
575 /// Create a memory copy between two memref buffers.
576 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
577                                                  Value from, Value to) const {
578   if (memCpyFn)
579     return (*memCpyFn)(b, loc, from, to);
580 
581   b.create<memref::CopyOp>(loc, from, to);
582   return success();
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // Bufferization-specific BlockAndValueMapping support with debugging.
587 //===----------------------------------------------------------------------===//
588 
589 bool bufferization::isFunctionArgument(Value value) {
590   auto bbArg = value.dyn_cast<BlockArgument>();
591   if (!bbArg)
592     return false;
593   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
594 }
595 
596 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
597                                             const BufferizationOptions &options,
598                                             MemRefLayoutAttrInterface layout,
599                                             Attribute memorySpace) {
600   // Case 1: Unranked memref type.
601   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
602     assert(!layout && "UnrankedTensorType cannot have a layout map");
603     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
604                                    memorySpace);
605   }
606 
607   // Case 2: Ranked memref type with specified layout.
608   auto rankedTensorType = tensorType.cast<RankedTensorType>();
609   if (layout) {
610     return MemRefType::get(rankedTensorType.getShape(),
611                            rankedTensorType.getElementType(), layout,
612                            memorySpace);
613   }
614 
615   // Case 3: Configured with "fully dynamic layout maps".
616   if (options.unknownTypeConversion ==
617       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
618     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
619 
620   // Case 4: Configured with "static identity layout maps".
621   if (options.unknownTypeConversion ==
622       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
623     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
624 
625   llvm_unreachable("InferLayoutMap is an invalid option");
626 }
627 
628 BaseMemRefType
629 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
630                                                    Attribute memorySpace) {
631   // Case 1: Unranked memref type.
632   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
633     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
634                                    memorySpace);
635   }
636 
637   // Case 2: Ranked memref type.
638   auto rankedTensorType = tensorType.cast<RankedTensorType>();
639   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
640   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
641                                       ShapedType::kDynamicStrideOrOffset);
642   AffineMap stridedLayout = makeStridedLinearLayoutMap(
643       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
644   return MemRefType::get(rankedTensorType.getShape(),
645                          rankedTensorType.getElementType(), stridedLayout,
646                          memorySpace);
647 }
648 
649 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
650 /// the given tensor type is unranked, return an unranked MemRef type.
651 BaseMemRefType
652 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
653                                                      Attribute memorySpace) {
654   // Case 1: Unranked memref type.
655   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
656     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
657                                    memorySpace);
658   }
659 
660   // Case 2: Ranked memref type.
661   auto rankedTensorType = tensorType.cast<RankedTensorType>();
662   MemRefLayoutAttrInterface layout = {};
663   return MemRefType::get(rankedTensorType.getShape(),
664                          rankedTensorType.getElementType(), layout,
665                          memorySpace);
666 }
667