1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Return the owner of the given value.
47 static Operation *getOwnerOfValue(Value value) {
48   if (auto opResult = value.dyn_cast<OpResult>())
49     return opResult.getDefiningOp();
50   return value.cast<BlockArgument>().getOwner()->getParentOp();
51 }
52 
53 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
54 /// shaped value is copied. Otherwise, a tensor with undefined contents is
55 /// allocated.
56 FailureOr<Value> bufferization::allocateTensorForShapedValue(
57     OpBuilder &b, Location loc, Value shapedValue, bool escape,
58     const BufferizationOptions &options, bool copy) {
59   Value tensor;
60   if (shapedValue.getType().isa<RankedTensorType>()) {
61     tensor = shapedValue;
62   } else if (shapedValue.getType().isa<MemRefType>()) {
63     tensor = b.create<ToTensorOp>(loc, shapedValue);
64   } else {
65     llvm_unreachable("expected RankedTensorType or MemRefType");
66   }
67   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
68   SmallVector<Value> dynamicSizes;
69   if (!copy) {
70     // Compute the dynamic part of the shape.
71     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
72     bool reifiedShapes = false;
73     if (shapedValue.getType().isa<RankedTensorType>() &&
74         shapedValue.isa<OpResult>()) {
75       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
76               shapedValue.getDefiningOp())) {
77         ReifiedRankedShapedTypeDims resultDims;
78         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
79           reifiedShapes = true;
80           auto &shape =
81               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
82           for (const auto &dim : enumerate(tensorType.getShape()))
83             if (ShapedType::isDynamic(dim.value()))
84               dynamicSizes.push_back(shape[dim.index()]);
85         }
86       }
87     }
88 
89     // If the shape could not be reified, create DimOps.
90     if (!reifiedShapes)
91       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
92   }
93 
94   // Create AllocTensorOp.
95   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
96                                                copy ? tensor : Value());
97   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
98                          b.getBoolArrayAttr({escape}));
99 
100   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
101   if (copy)
102     return allocTensorOp.getResult();
103   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
104   if (failed(copyBufferType))
105     return failure();
106   allocTensorOp.setMemorySpaceAttr(
107       b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
108                        copyBufferType->getMemorySpaceAsInt()));
109   return allocTensorOp.getResult();
110 }
111 
112 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
113     RewriterBase &rewriter, const AnalysisState &state) {
114   OpBuilder::InsertionGuard g(rewriter);
115   Operation *op = getOperation();
116   SmallVector<OpOperand *> outOfPlaceOpOperands;
117   DenseSet<OpOperand *> copiedOpOperands;
118   DenseSet<OpOperand *> escapingOpOperandCopies;
119   SmallVector<OpResult> outOfPlaceOpResults;
120   DenseSet<OpResult> copiedOpResults;
121   DenseSet<OpResult> escapingOpResultCopies;
122 
123   // Find all out-of-place OpOperands.
124   for (OpOperand &opOperand : op->getOpOperands()) {
125     Type operandType = opOperand.get().getType();
126     if (!operandType.isa<TensorType>())
127       continue;
128     if (state.isInPlace(opOperand))
129       continue;
130     if (operandType.isa<UnrankedTensorType>())
131       return op->emitError("copies of unranked tensors are not supported");
132 
133     SmallVector<OpResult> aliasingOpResults =
134         state.getAliasingOpResult(opOperand);
135     // Is the result yielded from a block? Or are deallocations turned off
136     // entirely? In either case, mark the allocation as "escaping", so that it
137     // will not be deallocated.
138     bool escape = !state.getOptions().createDeallocs ||
139                   llvm::any_of(aliasingOpResults, [&](Value v) {
140                     return state.isTensorYielded(v);
141                   });
142 
143     if (aliasingOpResults.size() == 1 &&
144         !state.bufferizesToMemoryWrite(opOperand) &&
145         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
146       // The op itself does not write but may create exactly one alias. Instead
147       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
148       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
149       // where the result is usually a smaller part of the source).
150       outOfPlaceOpResults.push_back(aliasingOpResults.front());
151       if (!state.canOmitTensorCopy(opOperand))
152         copiedOpResults.insert(aliasingOpResults.front());
153       if (escape)
154         escapingOpResultCopies.insert(aliasingOpResults.front());
155     } else {
156       // In all other cases, make a copy of the OpOperand.
157       outOfPlaceOpOperands.push_back(&opOperand);
158       if (!state.canOmitTensorCopy(opOperand))
159         copiedOpOperands.insert(&opOperand);
160       if (escape)
161         escapingOpOperandCopies.insert(&opOperand);
162     }
163   }
164 
165   // Insert copies of OpOperands.
166   rewriter.setInsertionPoint(op);
167   for (OpOperand *opOperand : outOfPlaceOpOperands) {
168     FailureOr<Value> copy = allocateTensorForShapedValue(
169         rewriter, op->getLoc(), opOperand->get(),
170         escapingOpOperandCopies.contains(opOperand), state.getOptions(),
171         copiedOpOperands.contains(opOperand));
172     if (failed(copy))
173       return failure();
174     rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
175   }
176 
177   // Insert copies of OpResults.
178   rewriter.setInsertionPointAfter(op);
179   for (OpResult opResult : outOfPlaceOpResults) {
180     FailureOr<Value> copy = allocateTensorForShapedValue(
181         rewriter, op->getLoc(), opResult,
182         escapingOpResultCopies.contains(opResult), state.getOptions(),
183         copiedOpResults.count(opResult));
184     if (failed(copy))
185       return failure();
186     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
187         opResult.getUses(), [](OpOperand &use) { return &use; }));
188     for (OpOperand *use : uses) {
189       // Do not update the alloc_tensor op that we just created.
190       if (use->getOwner() != copy->getDefiningOp())
191         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
192     }
193   }
194 
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // OpFilter
200 //===----------------------------------------------------------------------===//
201 
202 bool OpFilter::isOpAllowed(Operation *op) const {
203   // All other ops: Allow/disallow according to filter.
204   bool isAllowed = !hasAllowRule();
205   for (const Entry &entry : entries) {
206     bool filterResult = entry.fn(op);
207     switch (entry.type) {
208     case Entry::ALLOW:
209       isAllowed |= filterResult;
210       break;
211     case Entry::DENY:
212       if (filterResult)
213         // DENY filter matches. This op is no allowed. (Even if other ALLOW
214         // filters may match.)
215         return false;
216     };
217   }
218   return isAllowed;
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // BufferizationOptions
223 //===----------------------------------------------------------------------===//
224 
225 /// Default unknown type converter: Use a fully dynamic layout map.
226 static BaseMemRefType
227 defaultUnknownTypeConverter(Value value, unsigned memorySpace,
228                             const BufferizationOptions &options) {
229   return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
230                                              memorySpace);
231 }
232 
233 // Default constructor for BufferizationOptions.
234 BufferizationOptions::BufferizationOptions()
235     : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
236 
237 bool BufferizationOptions::isOpAllowed(Operation *op) const {
238   // Special case: If function boundary bufferization is deactivated, do not
239   // allow ops that belong to the `func` dialect.
240   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
241   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
242     return false;
243 
244   return opFilter.isOpAllowed(op);
245 }
246 
247 BufferizableOpInterface
248 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
249   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
250   if (!bufferizableOp)
251     return nullptr;
252   if (!isOpAllowed(op))
253     return nullptr;
254   return bufferizableOp;
255 }
256 
257 BufferizableOpInterface
258 BufferizationOptions::dynCastBufferizableOp(Value value) const {
259   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
260     if (isOpAllowed(bufferizableOp.getOperation()))
261       return bufferizableOp;
262   return nullptr;
263 }
264 
265 void BufferizationOptions::addDialectStateInitializer(
266     StringRef name, const DialectStateInitFn &fn) {
267   stateInitializers.push_back(
268       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Helper functions for BufferizableOpInterface
273 //===----------------------------------------------------------------------===//
274 
275 static void setInsertionPointAfter(OpBuilder &b, Value value) {
276   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
277     b.setInsertionPointToStart(bbArg.getOwner());
278   } else {
279     b.setInsertionPointAfter(value.getDefiningOp());
280   }
281 }
282 
283 /// Determine which OpOperand* will alias with `result` if the op is bufferized
284 /// in place. Return an empty vector if the op is not bufferizable.
285 SmallVector<OpOperand *>
286 AnalysisState::getAliasingOpOperand(OpResult result) const {
287   if (Operation *op = result.getDefiningOp())
288     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
289       return bufferizableOp.getAliasingOpOperand(result, *this);
290   return {};
291 }
292 
293 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
294 /// in place. Return an empty vector if the op is not bufferizable.
295 SmallVector<OpResult>
296 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
297   if (auto bufferizableOp =
298           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
299     return bufferizableOp.getAliasingOpResult(opOperand, *this);
300   return {};
301 }
302 
303 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
304 /// op is not bufferizable.
305 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
306   if (auto bufferizableOp =
307           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
308     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
309 
310   // Unknown op that returns a tensor. The inplace analysis does not support it.
311   // Conservatively return true.
312   return true;
313 }
314 
315 /// Return true if `opOperand` bufferizes to a memory write. Return
316 /// `true` if the op is not bufferizable.
317 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
318   if (auto bufferizableOp =
319           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
320     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
321 
322   // Unknown op that returns a tensor. The inplace analysis does not support it.
323   // Conservatively return true.
324   return true;
325 }
326 
327 /// Return true if `opOperand` does neither read nor write but bufferizes to an
328 /// alias. Return false if the op is not bufferizable.
329 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
330   if (auto bufferizableOp =
331           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
332     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
333 
334   // Unknown op that returns a tensor. The inplace analysis does not support it.
335   // Conservatively return false.
336   return false;
337 }
338 
339 /// Return true if the given value is read by an op that bufferizes to a memory
340 /// read. Also takes into account ops that create an alias but do not read by
341 /// themselves (e.g., ExtractSliceOp).
342 bool AnalysisState::isValueRead(Value value) const {
343   assert(value.getType().isa<TensorType>() && "expected TensorType");
344   SmallVector<OpOperand *> workingSet;
345   for (OpOperand &use : value.getUses())
346     workingSet.push_back(&use);
347 
348   while (!workingSet.empty()) {
349     OpOperand *uMaybeReading = workingSet.pop_back_val();
350     // Skip over all ops that neither read nor write (but create an alias).
351     if (bufferizesToAliasOnly(*uMaybeReading))
352       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
353         for (OpOperand &use : opResult.getUses())
354           workingSet.push_back(&use);
355     if (bufferizesToMemoryRead(*uMaybeReading))
356       return true;
357   }
358 
359   return false;
360 }
361 
362 // Starting from `value`, follow the use-def chain in reverse, always selecting
363 // the aliasing OpOperands. Find and return Values for which `condition`
364 // evaluates to true. OpOperands of such matching Values are not traversed any
365 // further.
366 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
367     Value value, llvm::function_ref<bool(Value)> condition) const {
368   llvm::SetVector<Value> result, workingSet;
369   workingSet.insert(value);
370 
371   while (!workingSet.empty()) {
372     Value value = workingSet.pop_back_val();
373     if (condition(value) || value.isa<BlockArgument>()) {
374       result.insert(value);
375       continue;
376     }
377 
378     OpResult opResult = value.cast<OpResult>();
379     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
380     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
381       result.insert(value);
382       continue;
383     }
384 
385     for (OpOperand *o : opOperands)
386       workingSet.insert(o->get());
387   }
388 
389   return result;
390 }
391 
392 // Find the Values of the last preceding write of a given Value.
393 llvm::SetVector<Value>
394 AnalysisState::findLastPrecedingWrite(Value value) const {
395   return findValueInReverseUseDefChain(value, [&](Value value) {
396     Operation *op = value.getDefiningOp();
397     if (!op)
398       return true;
399     auto bufferizableOp = options.dynCastBufferizableOp(op);
400     if (!bufferizableOp)
401       return true;
402     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
403   });
404 }
405 
406 AnalysisState::AnalysisState(const BufferizationOptions &options)
407     : options(options) {
408   for (const BufferizationOptions::AnalysisStateInitFn &fn :
409        options.stateInitializers)
410     fn(*this);
411 }
412 
413 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
414   // Do not copy if the tensor has undefined contents.
415   if (hasUndefinedContents(&opOperand))
416     return true;
417 
418   // Do not copy if the buffer of the tensor is entirely overwritten (with
419   // values that do not depend on the old tensor).
420   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
421     return true;
422 
423   // Do not copy if the tensor is never read.
424   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
425   if (!bufferizesToMemoryRead(opOperand) &&
426       llvm::none_of(aliasingOpResults,
427                     [&](OpResult opResult) { return isValueRead(opResult); }))
428     return true;
429 
430   // Default: Cannot omit the copy.
431   return false;
432 }
433 
434 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
435   // ToMemrefOps are always in-place.
436   if (isa<ToMemrefOp>(opOperand.getOwner()))
437     return true;
438 
439   // In the absence of analysis information, OpOperands that bufferize to a
440   // memory write are out-of-place, i.e., an alloc and copy is inserted.
441   return !bufferizesToMemoryWrite(opOperand);
442 }
443 
444 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
445   // In the absence of analysis information, we do not know if the values are
446   // equivalent. The conservative answer is "false".
447   return false;
448 }
449 
450 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
451   // In the absence of analysis information, we do not know if the values may be
452   // aliasing. The conservative answer is "true".
453   return true;
454 }
455 
456 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
457   // In the absence of analysis information, the conservative answer is "false".
458   return false;
459 }
460 
461 bool AnalysisState::isTensorYielded(Value tensor) const {
462   // In the absence of analysis information, the conservative answer is "true".
463   if (!tensor.getDefiningOp<AllocTensorOp>())
464     return true;
465 
466   // For AllocTensorOp results, we can do better: They do not alias with any
467   // preceding value, so we can follow SSA use-def chains and do a simple
468   // analysis.
469   SmallVector<OpOperand *> worklist;
470   for (OpOperand &use : tensor.getUses())
471     worklist.push_back(&use);
472 
473   while (!worklist.empty()) {
474     OpOperand *operand = worklist.pop_back_val();
475     Operation *op = operand->getOwner();
476 
477     // If the op is not bufferizable, we can safely assume that the value is not
478     // yielded. (When bufferizing that op, it must handle such cases.)
479     if (!options.dynCastBufferizableOp(op))
480       continue;
481 
482     // We cannot analyze through ToMemrefOps, so we have to conservatively
483     // assume that the value is yielded.
484     if (isa<ToMemrefOp>(op))
485       return true;
486 
487     // Check if the op is returning/yielding.
488     if (isRegionReturnLike(op))
489       return true;
490 
491     // Add all aliasing OpResults to the worklist.
492     // Note: In the absence of detailed analysis information (e.g., there may be
493     // no function call analysis information), this `getAliasingOpResult` is
494     // conservative and may report additional OpResults as potentially aliasing.
495     for (OpResult opResult : getAliasingOpResult(*operand))
496       for (OpOperand &use : opResult.getUses())
497         worklist.push_back(&use);
498   }
499 
500   // No ReturnLike op found: The value is not yielded.
501   return false;
502 }
503 
504 // bufferization.to_memref is not allowed to change the rank.
505 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
506 #ifndef NDEBUG
507   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
508   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
509                                    rankedTensorType.getRank()) &&
510          "to_memref would be invalid: mismatching ranks");
511 #endif
512 }
513 
514 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
515                                           const BufferizationOptions &options) {
516 #ifndef NDEBUG
517   auto tensorType = value.getType().dyn_cast<TensorType>();
518   assert(tensorType && "unexpected non-tensor type");
519 #endif // NDEBUG
520 
521   // Replace "%t = to_tensor %m" with %m.
522   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
523     return toTensorOp.getMemref();
524 
525   // Insert to_memref op.
526   OpBuilder::InsertionGuard g(rewriter);
527   setInsertionPointAfter(rewriter, value);
528   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
529   if (failed(memrefType))
530     return failure();
531   ensureToMemrefOpIsValid(value, *memrefType);
532   return rewriter
533       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
534       .getResult();
535 }
536 
537 /// Return the buffer type for a given Value (tensor) after bufferization.
538 FailureOr<BaseMemRefType>
539 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
540   assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
541   Operation *op = getOwnerOfValue(value);
542 
543   // ToTensorOp: Take buffer type directly from the op.
544   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
545     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
546 
547   // If value is a bbArg of a bufferizable op: query op interface.
548   if (auto bbArg = value.dyn_cast<BlockArgument>())
549     if (auto bufferizableOp =
550             options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
551       return bufferizableOp.getBufferType(bbArg, options);
552 
553   // Check value is a new buffer allocation with a memory space attribute. In
554   // that case we can at least infer the memory space.
555   Optional<unsigned> memorySpace = None;
556   if (auto opResult = value.dyn_cast<OpResult>()) {
557     if (auto bufferizableOp =
558             options.dynCastBufferizableOp(opResult.getDefiningOp())) {
559       if (bufferizableOp.bufferizesToAllocation(opResult)) {
560         FailureOr<unsigned> queriedMemorySpace =
561             bufferizableOp.getMemorySpace(opResult);
562         if (!failed(queriedMemorySpace))
563           memorySpace = *queriedMemorySpace;
564       }
565     }
566   }
567 
568   // If we still do not know the memory space, use the default memory space (if
569   // any).
570   if (!memorySpace.hasValue())
571     memorySpace = options.defaultMemorySpace;
572 
573   // If we still do not know the memory space, report a failure.
574   if (!memorySpace.hasValue())
575     return op->emitError("could not infer memory space");
576 
577   return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
578 }
579 
580 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
581                                                   Operation *op,
582                                                   ValueRange values) {
583   assert(values.size() == op->getNumResults() &&
584          "expected one value per OpResult");
585   OpBuilder::InsertionGuard g(rewriter);
586 
587   // Replace all OpResults with the given values.
588   SmallVector<Value> replacements;
589   for (OpResult opResult : op->getOpResults()) {
590     Value replacement = values[opResult.getResultNumber()];
591     if (opResult.getType().isa<TensorType>()) {
592       // The OpResult is a tensor. Such values are replaced with memrefs during
593       // bufferization.
594       assert((replacement.getType().isa<MemRefType>() ||
595               replacement.getType().isa<UnrankedMemRefType>()) &&
596              "tensor op result should be replaced with a memref value");
597       // The existing uses of the OpResult still expect a tensor. Insert a
598       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
599       // loose all of its users and eventually DCE away.
600       rewriter.setInsertionPointAfter(op);
601       replacement = rewriter.create<bufferization::ToTensorOp>(
602           replacement.getLoc(), replacement);
603     }
604     replacements.push_back(replacement);
605   }
606 
607   rewriter.replaceOp(op, replacements);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // Bufferization-specific scoped alloc/dealloc insertion support.
612 //===----------------------------------------------------------------------===//
613 
614 /// Create a memref allocation with the given type and dynamic extents.
615 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
616                                                    MemRefType type,
617                                                    ValueRange dynShape) const {
618   if (allocationFn)
619     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
620 
621   // Default bufferallocation via AllocOp.
622   if (bufferAlignment != 0)
623     return b
624         .create<memref::AllocOp>(loc, type, dynShape,
625                                  b.getI64IntegerAttr(bufferAlignment))
626         .getResult();
627   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
628 }
629 
630 /// Creates a memref deallocation. The given memref buffer must have been
631 /// allocated using `createAlloc`.
632 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
633                                                   Value allocatedBuffer) const {
634   if (deallocationFn)
635     return (*deallocationFn)(b, loc, allocatedBuffer);
636 
637   // Default buffer deallocation via DeallocOp.
638   b.create<memref::DeallocOp>(loc, allocatedBuffer);
639   return success();
640 }
641 
642 /// Create a memory copy between two memref buffers.
643 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
644                                                  Value from, Value to) const {
645   if (memCpyFn)
646     return (*memCpyFn)(b, loc, from, to);
647 
648   b.create<memref::CopyOp>(loc, from, to);
649   return success();
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // Bufferization-specific BlockAndValueMapping support with debugging.
654 //===----------------------------------------------------------------------===//
655 
656 bool bufferization::isFunctionArgument(Value value) {
657   auto bbArg = value.dyn_cast<BlockArgument>();
658   if (!bbArg)
659     return false;
660   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
661 }
662 
663 BaseMemRefType bufferization::getMemRefType(Value value,
664                                             const BufferizationOptions &options,
665                                             MemRefLayoutAttrInterface layout,
666                                             unsigned memorySpace) {
667   auto tensorType = value.getType().cast<TensorType>();
668   auto memorySpaceAttr = IntegerAttr::get(
669       IntegerType::get(tensorType.getContext(), 64), memorySpace);
670 
671   // Case 1: Unranked memref type.
672   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
673     assert(!layout && "UnrankedTensorType cannot have a layout map");
674     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
675                                    memorySpaceAttr);
676   }
677 
678   // Case 2: Ranked memref type with specified layout.
679   auto rankedTensorType = tensorType.cast<RankedTensorType>();
680   if (layout) {
681     return MemRefType::get(rankedTensorType.getShape(),
682                            rankedTensorType.getElementType(), layout,
683                            memorySpaceAttr);
684   }
685 
686   return options.unknownTypeConverterFn(value, memorySpace, options);
687 }
688 
689 BaseMemRefType
690 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
691                                                    unsigned memorySpace) {
692   // Case 1: Unranked memref type.
693   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
694     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
695                                    memorySpace);
696   }
697 
698   // Case 2: Ranked memref type.
699   auto memorySpaceAttr = IntegerAttr::get(
700       IntegerType::get(tensorType.getContext(), 64), memorySpace);
701   auto rankedTensorType = tensorType.cast<RankedTensorType>();
702   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
703   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
704                                       ShapedType::kDynamicStrideOrOffset);
705   AffineMap stridedLayout = makeStridedLinearLayoutMap(
706       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
707   return MemRefType::get(rankedTensorType.getShape(),
708                          rankedTensorType.getElementType(), stridedLayout,
709                          memorySpaceAttr);
710 }
711 
712 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
713 /// the given tensor type is unranked, return an unranked MemRef type.
714 BaseMemRefType
715 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
716                                                      unsigned memorySpace) {
717   // Case 1: Unranked memref type.
718   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
719     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
720                                    memorySpace);
721   }
722 
723   // Case 2: Ranked memref type.
724   auto rankedTensorType = tensorType.cast<RankedTensorType>();
725   auto memorySpaceAttr = IntegerAttr::get(
726       IntegerType::get(tensorType.getContext(), 64), memorySpace);
727   MemRefLayoutAttrInterface layout = {};
728   return MemRefType::get(rankedTensorType.getShape(),
729                          rankedTensorType.getElementType(), layout,
730                          memorySpaceAttr);
731 }
732