1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Return the owner of the given value.
47 static Operation *getOwnerOfValue(Value value) {
48   if (auto opResult = value.dyn_cast<OpResult>())
49     return opResult.getDefiningOp();
50   return value.cast<BlockArgument>().getOwner()->getParentOp();
51 }
52 
53 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
54 /// shaped value is copied. Otherwise, a tensor with undefined contents is
55 /// allocated.
56 FailureOr<Value> bufferization::allocateTensorForShapedValue(
57     OpBuilder &b, Location loc, Value shapedValue, bool escape,
58     const BufferizationOptions &options, bool copy) {
59   Value tensor;
60   if (shapedValue.getType().isa<RankedTensorType>()) {
61     tensor = shapedValue;
62   } else if (shapedValue.getType().isa<MemRefType>()) {
63     tensor = b.create<ToTensorOp>(loc, shapedValue);
64   } else {
65     llvm_unreachable("expected RankedTensorType or MemRefType");
66   }
67   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
68   SmallVector<Value> dynamicSizes;
69   if (!copy) {
70     // Compute the dynamic part of the shape.
71     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
72     bool reifiedShapes = false;
73     if (shapedValue.getType().isa<RankedTensorType>() &&
74         shapedValue.isa<OpResult>()) {
75       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
76               shapedValue.getDefiningOp())) {
77         ReifiedRankedShapedTypeDims resultDims;
78         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
79           reifiedShapes = true;
80           auto &shape =
81               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
82           for (const auto &dim : enumerate(tensorType.getShape()))
83             if (ShapedType::isDynamic(dim.value()))
84               dynamicSizes.push_back(shape[dim.index()]);
85         }
86       }
87     }
88 
89     // If the shape could not be reified, create DimOps.
90     if (!reifiedShapes)
91       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
92   }
93 
94   // Create AllocTensorOp.
95   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
96                                                copy ? tensor : Value());
97   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
98                          b.getBoolArrayAttr({escape}));
99 
100   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
101   if (copy)
102     return allocTensorOp.getResult();
103   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
104   if (failed(copyBufferType))
105     return failure();
106   allocTensorOp.setMemorySpaceAttr(
107       b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
108                        copyBufferType->getMemorySpaceAsInt()));
109   return allocTensorOp.getResult();
110 }
111 
112 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
113     RewriterBase &rewriter, const AnalysisState &state) {
114   OpBuilder::InsertionGuard g(rewriter);
115   Operation *op = getOperation();
116   SmallVector<OpOperand *> outOfPlaceOpOperands;
117   DenseSet<OpOperand *> copiedOpOperands;
118   DenseSet<OpOperand *> escapingOpOperandCopies;
119   SmallVector<OpResult> outOfPlaceOpResults;
120   DenseSet<OpResult> copiedOpResults;
121   DenseSet<OpResult> escapingOpResultCopies;
122 
123   // Find all out-of-place OpOperands.
124   for (OpOperand &opOperand : op->getOpOperands()) {
125     Type operandType = opOperand.get().getType();
126     if (!operandType.isa<TensorType>())
127       continue;
128     if (state.isInPlace(opOperand))
129       continue;
130     if (operandType.isa<UnrankedTensorType>())
131       return op->emitError("copies of unranked tensors are not supported");
132 
133     SmallVector<OpResult> aliasingOpResults =
134         state.getAliasingOpResult(opOperand);
135     // Is the result yielded from a block? Or are deallocations turned off
136     // entirely? In either case, mark the allocation as "escaping", so that it
137     // will not be deallocated.
138     bool escape = !state.getOptions().createDeallocs ||
139                   llvm::any_of(aliasingOpResults, [&](Value v) {
140                     return state.isTensorYielded(v);
141                   });
142 
143     if (aliasingOpResults.size() == 1 &&
144         !state.bufferizesToMemoryWrite(opOperand) &&
145         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
146       // The op itself does not write but may create exactly one alias. Instead
147       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
148       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
149       // where the result is usually a smaller part of the source).
150       outOfPlaceOpResults.push_back(aliasingOpResults.front());
151       if (!state.canOmitTensorCopy(opOperand))
152         copiedOpResults.insert(aliasingOpResults.front());
153       if (escape)
154         escapingOpResultCopies.insert(aliasingOpResults.front());
155     } else {
156       // In all other cases, make a copy of the OpOperand.
157       outOfPlaceOpOperands.push_back(&opOperand);
158       if (!state.canOmitTensorCopy(opOperand))
159         copiedOpOperands.insert(&opOperand);
160       if (escape)
161         escapingOpOperandCopies.insert(&opOperand);
162     }
163   }
164 
165   // Insert copies of OpOperands.
166   rewriter.setInsertionPoint(op);
167   for (OpOperand *opOperand : outOfPlaceOpOperands) {
168     FailureOr<Value> copy = allocateTensorForShapedValue(
169         rewriter, op->getLoc(), opOperand->get(),
170         escapingOpOperandCopies.contains(opOperand), state.getOptions(),
171         copiedOpOperands.contains(opOperand));
172     if (failed(copy))
173       return failure();
174     rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
175   }
176 
177   // Insert copies of OpResults.
178   rewriter.setInsertionPointAfter(op);
179   for (OpResult opResult : outOfPlaceOpResults) {
180     FailureOr<Value> copy = allocateTensorForShapedValue(
181         rewriter, op->getLoc(), opResult,
182         escapingOpResultCopies.contains(opResult), state.getOptions(),
183         copiedOpResults.count(opResult));
184     if (failed(copy))
185       return failure();
186     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
187         opResult.getUses(), [](OpOperand &use) { return &use; }));
188     for (OpOperand *use : uses) {
189       // Do not update the alloc_tensor op that we just created.
190       if (use->getOwner() != copy->getDefiningOp())
191         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
192     }
193   }
194 
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // OpFilter
200 //===----------------------------------------------------------------------===//
201 
202 bool OpFilter::isOpAllowed(Operation *op) const {
203   // All other ops: Allow/disallow according to filter.
204   bool isAllowed = !hasAllowRule();
205   for (const Entry &entry : entries) {
206     bool filterResult = entry.fn(op);
207     switch (entry.type) {
208     case Entry::ALLOW:
209       isAllowed |= filterResult;
210       break;
211     case Entry::DENY:
212       if (filterResult)
213         // DENY filter matches. This op is no allowed. (Even if other ALLOW
214         // filters may match.)
215         return false;
216     };
217   }
218   return isAllowed;
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // BufferizationOptions
223 //===----------------------------------------------------------------------===//
224 
225 // Default constructor for BufferizationOptions.
226 BufferizationOptions::BufferizationOptions() = default;
227 
228 bool BufferizationOptions::isOpAllowed(Operation *op) const {
229   // Special case: If function boundary bufferization is deactivated, do not
230   // allow ops that belong to the `func` dialect.
231   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
232   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
233     return false;
234 
235   return opFilter.isOpAllowed(op);
236 }
237 
238 BufferizableOpInterface
239 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
240   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
241   if (!bufferizableOp)
242     return nullptr;
243   if (!isOpAllowed(op))
244     return nullptr;
245   return bufferizableOp;
246 }
247 
248 BufferizableOpInterface
249 BufferizationOptions::dynCastBufferizableOp(Value value) const {
250   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
251     if (isOpAllowed(bufferizableOp.getOperation()))
252       return bufferizableOp;
253   return nullptr;
254 }
255 
256 void BufferizationOptions::addDialectStateInitializer(
257     StringRef name, const DialectStateInitFn &fn) {
258   stateInitializers.push_back(
259       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Helper functions for BufferizableOpInterface
264 //===----------------------------------------------------------------------===//
265 
266 static void setInsertionPointAfter(OpBuilder &b, Value value) {
267   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
268     b.setInsertionPointToStart(bbArg.getOwner());
269   } else {
270     b.setInsertionPointAfter(value.getDefiningOp());
271   }
272 }
273 
274 /// Determine which OpOperand* will alias with `result` if the op is bufferized
275 /// in place. Return an empty vector if the op is not bufferizable.
276 SmallVector<OpOperand *>
277 AnalysisState::getAliasingOpOperand(OpResult result) const {
278   if (Operation *op = result.getDefiningOp())
279     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
280       return bufferizableOp.getAliasingOpOperand(result, *this);
281   return {};
282 }
283 
284 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
285 /// in place. Return an empty vector if the op is not bufferizable.
286 SmallVector<OpResult>
287 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
288   if (auto bufferizableOp =
289           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
290     return bufferizableOp.getAliasingOpResult(opOperand, *this);
291   return {};
292 }
293 
294 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
295 /// op is not bufferizable.
296 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
297   if (auto bufferizableOp =
298           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
299     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
300 
301   // Unknown op that returns a tensor. The inplace analysis does not support it.
302   // Conservatively return true.
303   return true;
304 }
305 
306 /// Return true if `opOperand` bufferizes to a memory write. Return
307 /// `true` if the op is not bufferizable.
308 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
309   if (auto bufferizableOp =
310           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
311     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
312 
313   // Unknown op that returns a tensor. The inplace analysis does not support it.
314   // Conservatively return true.
315   return true;
316 }
317 
318 /// Return true if `opOperand` does neither read nor write but bufferizes to an
319 /// alias. Return false if the op is not bufferizable.
320 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
321   if (auto bufferizableOp =
322           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
323     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
324 
325   // Unknown op that returns a tensor. The inplace analysis does not support it.
326   // Conservatively return false.
327   return false;
328 }
329 
330 /// Return true if the given value is read by an op that bufferizes to a memory
331 /// read. Also takes into account ops that create an alias but do not read by
332 /// themselves (e.g., ExtractSliceOp).
333 bool AnalysisState::isValueRead(Value value) const {
334   assert(value.getType().isa<TensorType>() && "expected TensorType");
335   SmallVector<OpOperand *> workingSet;
336   for (OpOperand &use : value.getUses())
337     workingSet.push_back(&use);
338 
339   while (!workingSet.empty()) {
340     OpOperand *uMaybeReading = workingSet.pop_back_val();
341     // Skip over all ops that neither read nor write (but create an alias).
342     if (bufferizesToAliasOnly(*uMaybeReading))
343       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
344         for (OpOperand &use : opResult.getUses())
345           workingSet.push_back(&use);
346     if (bufferizesToMemoryRead(*uMaybeReading))
347       return true;
348   }
349 
350   return false;
351 }
352 
353 // Starting from `value`, follow the use-def chain in reverse, always selecting
354 // the aliasing OpOperands. Find and return Values for which `condition`
355 // evaluates to true. OpOperands of such matching Values are not traversed any
356 // further.
357 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
358     Value value, llvm::function_ref<bool(Value)> condition) const {
359   llvm::SetVector<Value> result, workingSet;
360   workingSet.insert(value);
361 
362   while (!workingSet.empty()) {
363     Value value = workingSet.pop_back_val();
364     if (condition(value) || value.isa<BlockArgument>()) {
365       result.insert(value);
366       continue;
367     }
368 
369     OpResult opResult = value.cast<OpResult>();
370     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
371     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
372       result.insert(value);
373       continue;
374     }
375 
376     for (OpOperand *o : opOperands)
377       workingSet.insert(o->get());
378   }
379 
380   return result;
381 }
382 
383 // Find the Values of the last preceding write of a given Value.
384 llvm::SetVector<Value>
385 AnalysisState::findLastPrecedingWrite(Value value) const {
386   return findValueInReverseUseDefChain(value, [&](Value value) {
387     Operation *op = value.getDefiningOp();
388     if (!op)
389       return true;
390     auto bufferizableOp = options.dynCastBufferizableOp(op);
391     if (!bufferizableOp)
392       return true;
393     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
394   });
395 }
396 
397 AnalysisState::AnalysisState(const BufferizationOptions &options)
398     : options(options) {
399   for (const BufferizationOptions::AnalysisStateInitFn &fn :
400        options.stateInitializers)
401     fn(*this);
402 }
403 
404 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
405   // Do not copy if the tensor has undefined contents.
406   if (hasUndefinedContents(&opOperand))
407     return true;
408 
409   // Do not copy if the buffer of the tensor is entirely overwritten (with
410   // values that do not depend on the old tensor).
411   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
412     return true;
413 
414   // Do not copy if the tensor is never read.
415   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
416   if (!bufferizesToMemoryRead(opOperand) &&
417       llvm::none_of(aliasingOpResults,
418                     [&](OpResult opResult) { return isValueRead(opResult); }))
419     return true;
420 
421   // Default: Cannot omit the copy.
422   return false;
423 }
424 
425 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
426   // ToMemrefOps are always in-place.
427   if (isa<ToMemrefOp>(opOperand.getOwner()))
428     return true;
429 
430   // In the absence of analysis information, OpOperands that bufferize to a
431   // memory write are out-of-place, i.e., an alloc and copy is inserted.
432   return !bufferizesToMemoryWrite(opOperand);
433 }
434 
435 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
436   // In the absence of analysis information, we do not know if the values are
437   // equivalent. The conservative answer is "false".
438   return false;
439 }
440 
441 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
442   // In the absence of analysis information, we do not know if the values may be
443   // aliasing. The conservative answer is "true".
444   return true;
445 }
446 
447 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
448   // In the absence of analysis information, the conservative answer is "false".
449   return false;
450 }
451 
452 bool AnalysisState::isTensorYielded(Value tensor) const {
453   // In the absence of analysis information, the conservative answer is "true".
454   if (!tensor.getDefiningOp<AllocTensorOp>())
455     return true;
456 
457   // For AllocTensorOp results, we can do better: They do not alias with any
458   // preceding value, so we can follow SSA use-def chains and do a simple
459   // analysis.
460   SmallVector<OpOperand *> worklist;
461   for (OpOperand &use : tensor.getUses())
462     worklist.push_back(&use);
463 
464   while (!worklist.empty()) {
465     OpOperand *operand = worklist.pop_back_val();
466     Operation *op = operand->getOwner();
467 
468     // If the op is not bufferizable, we can safely assume that the value is not
469     // yielded. (When bufferizing that op, it must handle such cases.)
470     if (!options.dynCastBufferizableOp(op))
471       continue;
472 
473     // We cannot analyze through ToMemrefOps, so we have to conservatively
474     // assume that the value is yielded.
475     if (isa<ToMemrefOp>(op))
476       return true;
477 
478     // Check if the op is returning/yielding.
479     if (isRegionReturnLike(op))
480       return true;
481 
482     // Add all aliasing OpResults to the worklist.
483     // Note: In the absence of detailed analysis information (e.g., there may be
484     // no function call analysis information), this `getAliasingOpResult` is
485     // conservative and may report additional OpResults as potentially aliasing.
486     for (OpResult opResult : getAliasingOpResult(*operand))
487       for (OpOperand &use : opResult.getUses())
488         worklist.push_back(&use);
489   }
490 
491   // No ReturnLike op found: The value is not yielded.
492   return false;
493 }
494 
495 // bufferization.to_memref is not allowed to change the rank.
496 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
497 #ifndef NDEBUG
498   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
499   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
500                                    rankedTensorType.getRank()) &&
501          "to_memref would be invalid: mismatching ranks");
502 #endif
503 }
504 
505 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
506                                           const BufferizationOptions &options) {
507 #ifndef NDEBUG
508   auto tensorType = value.getType().dyn_cast<TensorType>();
509   assert(tensorType && "unexpected non-tensor type");
510 #endif // NDEBUG
511 
512   // Replace "%t = to_tensor %m" with %m.
513   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
514     return toTensorOp.getMemref();
515 
516   // Insert to_memref op.
517   OpBuilder::InsertionGuard g(rewriter);
518   setInsertionPointAfter(rewriter, value);
519   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
520   if (failed(memrefType))
521     return failure();
522   ensureToMemrefOpIsValid(value, *memrefType);
523   return rewriter
524       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
525       .getResult();
526 }
527 
528 /// Return the buffer type for a given Value (tensor) after bufferization.
529 FailureOr<BaseMemRefType>
530 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
531   auto tensorType = value.getType().dyn_cast<TensorType>();
532   assert(tensorType && "unexpected non-tensor type");
533   Operation *op = getOwnerOfValue(value);
534 
535   // ToTensorOp: Take buffer type directly from the op.
536   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
537     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
538 
539   // If value is a bbArg of a bufferizable op: query op interface.
540   if (auto bbArg = value.dyn_cast<BlockArgument>())
541     if (auto bufferizableOp =
542             options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
543       return bufferizableOp.getBufferType(bbArg, options);
544 
545   // Check value is a new buffer allocation with a memory space attribute. In
546   // that case we can at least infer the memory space.
547   Optional<unsigned> memorySpace = None;
548   if (auto opResult = value.dyn_cast<OpResult>()) {
549     if (auto bufferizableOp =
550             options.dynCastBufferizableOp(opResult.getDefiningOp())) {
551       if (bufferizableOp.bufferizesToAllocation(opResult)) {
552         FailureOr<unsigned> queriedMemorySpace =
553             bufferizableOp.getMemorySpace(opResult);
554         if (!failed(queriedMemorySpace))
555           memorySpace = *queriedMemorySpace;
556       }
557     }
558   }
559 
560   // If we still do not know the memory space, use the default memory space (if
561   // any).
562   if (!memorySpace.hasValue())
563     memorySpace = options.defaultMemorySpace;
564 
565   // If we still do not know the memory space, report a failure.
566   if (!memorySpace.hasValue())
567     return op->emitError("could not infer memory space");
568 
569   return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
570 }
571 
572 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
573                                                   Operation *op,
574                                                   ValueRange values) {
575   assert(values.size() == op->getNumResults() &&
576          "expected one value per OpResult");
577   OpBuilder::InsertionGuard g(rewriter);
578 
579   // Replace all OpResults with the given values.
580   SmallVector<Value> replacements;
581   for (OpResult opResult : op->getOpResults()) {
582     Value replacement = values[opResult.getResultNumber()];
583     if (opResult.getType().isa<TensorType>()) {
584       // The OpResult is a tensor. Such values are replaced with memrefs during
585       // bufferization.
586       assert((replacement.getType().isa<MemRefType>() ||
587               replacement.getType().isa<UnrankedMemRefType>()) &&
588              "tensor op result should be replaced with a memref value");
589       // The existing uses of the OpResult still expect a tensor. Insert a
590       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
591       // loose all of its users and eventually DCE away.
592       rewriter.setInsertionPointAfter(op);
593       replacement = rewriter.create<bufferization::ToTensorOp>(
594           replacement.getLoc(), replacement);
595     }
596     replacements.push_back(replacement);
597   }
598 
599   rewriter.replaceOp(op, replacements);
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // Bufferization-specific scoped alloc/dealloc insertion support.
604 //===----------------------------------------------------------------------===//
605 
606 /// Create a memref allocation with the given type and dynamic extents.
607 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
608                                                    MemRefType type,
609                                                    ValueRange dynShape) const {
610   if (allocationFn)
611     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
612 
613   // Default bufferallocation via AllocOp.
614   if (bufferAlignment != 0)
615     return b
616         .create<memref::AllocOp>(loc, type, dynShape,
617                                  b.getI64IntegerAttr(bufferAlignment))
618         .getResult();
619   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
620 }
621 
622 /// Creates a memref deallocation. The given memref buffer must have been
623 /// allocated using `createAlloc`.
624 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
625                                                   Value allocatedBuffer) const {
626   if (deallocationFn)
627     return (*deallocationFn)(b, loc, allocatedBuffer);
628 
629   // Default buffer deallocation via DeallocOp.
630   b.create<memref::DeallocOp>(loc, allocatedBuffer);
631   return success();
632 }
633 
634 /// Create a memory copy between two memref buffers.
635 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
636                                                  Value from, Value to) const {
637   if (memCpyFn)
638     return (*memCpyFn)(b, loc, from, to);
639 
640   b.create<memref::CopyOp>(loc, from, to);
641   return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Bufferization-specific BlockAndValueMapping support with debugging.
646 //===----------------------------------------------------------------------===//
647 
648 bool bufferization::isFunctionArgument(Value value) {
649   auto bbArg = value.dyn_cast<BlockArgument>();
650   if (!bbArg)
651     return false;
652   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
653 }
654 
655 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
656                                             const BufferizationOptions &options,
657                                             MemRefLayoutAttrInterface layout,
658                                             unsigned memorySpace) {
659   auto memorySpaceAttr = IntegerAttr::get(
660       IntegerType::get(tensorType.getContext(), 64), memorySpace);
661 
662   // Case 1: Unranked memref type.
663   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
664     assert(!layout && "UnrankedTensorType cannot have a layout map");
665     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
666                                    memorySpaceAttr);
667   }
668 
669   // Case 2: Ranked memref type with specified layout.
670   auto rankedTensorType = tensorType.cast<RankedTensorType>();
671   if (layout) {
672     return MemRefType::get(rankedTensorType.getShape(),
673                            rankedTensorType.getElementType(), layout,
674                            memorySpaceAttr);
675   }
676 
677   // Case 3: Configured with "fully dynamic layout maps".
678   if (options.unknownTypeConversion ==
679       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
680     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
681 
682   // Case 4: Configured with "static identity layout maps".
683   if (options.unknownTypeConversion ==
684       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
685     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
686 
687   llvm_unreachable("InferLayoutMap is an invalid option");
688 }
689 
690 BaseMemRefType
691 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
692                                                    unsigned memorySpace) {
693   // Case 1: Unranked memref type.
694   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
695     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
696                                    memorySpace);
697   }
698 
699   // Case 2: Ranked memref type.
700   auto memorySpaceAttr = IntegerAttr::get(
701       IntegerType::get(tensorType.getContext(), 64), memorySpace);
702   auto rankedTensorType = tensorType.cast<RankedTensorType>();
703   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
704   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
705                                       ShapedType::kDynamicStrideOrOffset);
706   AffineMap stridedLayout = makeStridedLinearLayoutMap(
707       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
708   return MemRefType::get(rankedTensorType.getShape(),
709                          rankedTensorType.getElementType(), stridedLayout,
710                          memorySpaceAttr);
711 }
712 
713 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
714 /// the given tensor type is unranked, return an unranked MemRef type.
715 BaseMemRefType
716 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
717                                                      unsigned memorySpace) {
718   // Case 1: Unranked memref type.
719   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
720     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
721                                    memorySpace);
722   }
723 
724   // Case 2: Ranked memref type.
725   auto rankedTensorType = tensorType.cast<RankedTensorType>();
726   auto memorySpaceAttr = IntegerAttr::get(
727       IntegerType::get(tensorType.getContext(), 64), memorySpace);
728   MemRefLayoutAttrInterface layout = {};
729   return MemRefType::get(rankedTensorType.getShape(),
730                          rankedTensorType.getElementType(), layout,
731                          memorySpaceAttr);
732 }
733