1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Return the owner of the given value.
42 static Operation *getOwnerOfValue(Value value) {
43   if (auto opResult = value.dyn_cast<OpResult>())
44     return opResult.getDefiningOp();
45   return value.cast<BlockArgument>().getOwner()->getParentOp();
46 }
47 
48 bool bufferization::allocationDoesNotEscape(OpResult opResult) {
49 #ifndef NDEBUG
50   auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
51   assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
52          "expected op that bufferizes to an allocation");
53 #endif // NDEBUG
54 
55   Operation *op = opResult.getDefiningOp();
56   // If there is no 'escape' attribute, we cannot say for sure.
57   if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
58     return false;
59   auto attr =
60       op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
61   return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
62 }
63 
64 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
65 /// shaped value is copied. Otherwise, a tensor with undefined contents is
66 /// allocated.
67 FailureOr<Value> bufferization::allocateTensorForShapedValue(
68     OpBuilder &b, Location loc, Value shapedValue, bool escape,
69     const BufferizationOptions &options, bool copy) {
70   Value tensor;
71   if (shapedValue.getType().isa<RankedTensorType>()) {
72     tensor = shapedValue;
73   } else if (shapedValue.getType().isa<MemRefType>()) {
74     tensor = b.create<ToTensorOp>(loc, shapedValue);
75   } else {
76     llvm_unreachable("expected RankedTensorType or MemRefType");
77   }
78   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
79   SmallVector<Value> dynamicSizes;
80   if (!copy) {
81     // Compute the dynamic part of the shape.
82     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
83     bool reifiedShapes = false;
84     if (shapedValue.getType().isa<RankedTensorType>() &&
85         shapedValue.isa<OpResult>()) {
86       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
87               shapedValue.getDefiningOp())) {
88         ReifiedRankedShapedTypeDims resultDims;
89         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
90           reifiedShapes = true;
91           auto &shape =
92               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
93           for (const auto &dim : enumerate(tensorType.getShape()))
94             if (ShapedType::isDynamic(dim.value()))
95               dynamicSizes.push_back(shape[dim.index()]);
96         }
97       }
98     }
99 
100     // If the shape could not be reified, create DimOps.
101     if (!reifiedShapes)
102       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
103   }
104 
105   // Create AllocTensorOp.
106   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
107                                                copy ? tensor : Value());
108   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
109                          b.getBoolArrayAttr({escape}));
110 
111   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
112   if (copy)
113     return allocTensorOp.getResult();
114   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
115   if (failed(copyBufferType))
116     return failure();
117   allocTensorOp.setMemorySpaceAttr(
118       b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
119                        copyBufferType->getMemorySpaceAsInt()));
120   return allocTensorOp.getResult();
121 }
122 
123 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
124     RewriterBase &rewriter, const AnalysisState &state) {
125   OpBuilder::InsertionGuard g(rewriter);
126   Operation *op = getOperation();
127   SmallVector<OpOperand *> outOfPlaceOpOperands;
128   DenseSet<OpOperand *> copiedOpOperands;
129   DenseSet<OpOperand *> escapingOpOperandCopies;
130   SmallVector<OpResult> outOfPlaceOpResults;
131   DenseSet<OpResult> copiedOpResults;
132   DenseSet<OpResult> escapingOpResultCopies;
133 
134   // Find all out-of-place OpOperands.
135   for (OpOperand &opOperand : op->getOpOperands()) {
136     Type operandType = opOperand.get().getType();
137     if (!operandType.isa<TensorType>())
138       continue;
139     if (state.isInPlace(opOperand))
140       continue;
141     if (operandType.isa<UnrankedTensorType>())
142       return op->emitError("copies of unranked tensors are not supported");
143 
144     SmallVector<OpResult> aliasingOpResults =
145         state.getAliasingOpResult(opOperand);
146     // Is the result yielded from a block? Or are deallocations turned off
147     // entirely? In either case, mark the allocation as "escaping", so that it
148     // will not be deallocated.
149     bool escape = !state.getOptions().createDeallocs ||
150                   llvm::any_of(aliasingOpResults, [&](Value v) {
151                     return state.isTensorYielded(v);
152                   });
153 
154     if (aliasingOpResults.size() == 1 &&
155         !state.bufferizesToMemoryWrite(opOperand) &&
156         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
157       // The op itself does not write but may create exactly one alias. Instead
158       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
159       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
160       // where the result is usually a smaller part of the source).
161       outOfPlaceOpResults.push_back(aliasingOpResults.front());
162       if (!state.canOmitTensorCopy(opOperand))
163         copiedOpResults.insert(aliasingOpResults.front());
164       if (escape)
165         escapingOpResultCopies.insert(aliasingOpResults.front());
166     } else {
167       // In all other cases, make a copy of the OpOperand.
168       outOfPlaceOpOperands.push_back(&opOperand);
169       if (!state.canOmitTensorCopy(opOperand))
170         copiedOpOperands.insert(&opOperand);
171       if (escape)
172         escapingOpOperandCopies.insert(&opOperand);
173     }
174   }
175 
176   // Insert copies of OpOperands.
177   rewriter.setInsertionPoint(op);
178   for (OpOperand *opOperand : outOfPlaceOpOperands) {
179     FailureOr<Value> copy = allocateTensorForShapedValue(
180         rewriter, op->getLoc(), opOperand->get(),
181         escapingOpOperandCopies.contains(opOperand), state.getOptions(),
182         copiedOpOperands.contains(opOperand));
183     if (failed(copy))
184       return failure();
185     rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
186   }
187 
188   // Insert copies of OpResults.
189   rewriter.setInsertionPointAfter(op);
190   for (OpResult opResult : outOfPlaceOpResults) {
191     FailureOr<Value> copy = allocateTensorForShapedValue(
192         rewriter, op->getLoc(), opResult,
193         escapingOpResultCopies.contains(opResult), state.getOptions(),
194         copiedOpResults.count(opResult));
195     if (failed(copy))
196       return failure();
197     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
198         opResult.getUses(), [](OpOperand &use) { return &use; }));
199     for (OpOperand *use : uses) {
200       // Do not update the alloc_tensor op that we just created.
201       if (use->getOwner() != copy->getDefiningOp())
202         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
203     }
204   }
205 
206   return success();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // OpFilter
211 //===----------------------------------------------------------------------===//
212 
213 bool OpFilter::isOpAllowed(Operation *op) const {
214   // All other ops: Allow/disallow according to filter.
215   bool isAllowed = !hasAllowRule();
216   for (const Entry &entry : entries) {
217     bool filterResult = entry.fn(op);
218     switch (entry.type) {
219     case Entry::ALLOW:
220       isAllowed |= filterResult;
221       break;
222     case Entry::DENY:
223       if (filterResult)
224         // DENY filter matches. This op is no allowed. (Even if other ALLOW
225         // filters may match.)
226         return false;
227     };
228   }
229   return isAllowed;
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // BufferizationOptions
234 //===----------------------------------------------------------------------===//
235 
236 /// Default unknown type converter: Use a fully dynamic layout map.
237 static BaseMemRefType
238 defaultUnknownTypeConverter(Value value, unsigned memorySpace,
239                             const BufferizationOptions &options) {
240   return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
241                                              memorySpace);
242 }
243 
244 // Default constructor for BufferizationOptions.
245 BufferizationOptions::BufferizationOptions()
246     : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
247 
248 bool BufferizationOptions::isOpAllowed(Operation *op) const {
249   // Special case: If function boundary bufferization is deactivated, do not
250   // allow ops that belong to the `func` dialect.
251   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
252   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
253     return false;
254 
255   return opFilter.isOpAllowed(op);
256 }
257 
258 BufferizableOpInterface
259 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
260   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
261   if (!bufferizableOp)
262     return nullptr;
263   if (!isOpAllowed(op))
264     return nullptr;
265   return bufferizableOp;
266 }
267 
268 BufferizableOpInterface
269 BufferizationOptions::dynCastBufferizableOp(Value value) const {
270   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
271     if (isOpAllowed(bufferizableOp.getOperation()))
272       return bufferizableOp;
273   return nullptr;
274 }
275 
276 void BufferizationOptions::addDialectStateInitializer(
277     StringRef name, const DialectStateInitFn &fn) {
278   stateInitializers.push_back(
279       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Helper functions for BufferizableOpInterface
284 //===----------------------------------------------------------------------===//
285 
286 static void setInsertionPointAfter(OpBuilder &b, Value value) {
287   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
288     b.setInsertionPointToStart(bbArg.getOwner());
289   } else {
290     b.setInsertionPointAfter(value.getDefiningOp());
291   }
292 }
293 
294 /// Determine which OpOperand* will alias with `result` if the op is bufferized
295 /// in place. Return an empty vector if the op is not bufferizable.
296 SmallVector<OpOperand *>
297 AnalysisState::getAliasingOpOperand(OpResult result) const {
298   if (Operation *op = result.getDefiningOp())
299     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
300       return bufferizableOp.getAliasingOpOperand(result, *this);
301   return {};
302 }
303 
304 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
305 /// in place. Return an empty vector if the op is not bufferizable.
306 SmallVector<OpResult>
307 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
308   if (auto bufferizableOp =
309           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
310     return bufferizableOp.getAliasingOpResult(opOperand, *this);
311   return {};
312 }
313 
314 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
315 /// op is not bufferizable.
316 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
317   if (auto bufferizableOp =
318           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
319     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
320 
321   // Unknown op that returns a tensor. The inplace analysis does not support it.
322   // Conservatively return true.
323   return true;
324 }
325 
326 /// Return true if `opOperand` bufferizes to a memory write. Return
327 /// `true` if the op is not bufferizable.
328 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
329   if (auto bufferizableOp =
330           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
331     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
332 
333   // Unknown op that returns a tensor. The inplace analysis does not support it.
334   // Conservatively return true.
335   return true;
336 }
337 
338 /// Return true if `opOperand` does neither read nor write but bufferizes to an
339 /// alias. Return false if the op is not bufferizable.
340 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
341   if (auto bufferizableOp =
342           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
343     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
344 
345   // Unknown op that returns a tensor. The inplace analysis does not support it.
346   // Conservatively return false.
347   return false;
348 }
349 
350 /// Return true if the given value is read by an op that bufferizes to a memory
351 /// read. Also takes into account ops that create an alias but do not read by
352 /// themselves (e.g., ExtractSliceOp).
353 bool AnalysisState::isValueRead(Value value) const {
354   assert(value.getType().isa<TensorType>() && "expected TensorType");
355   SmallVector<OpOperand *> workingSet;
356   for (OpOperand &use : value.getUses())
357     workingSet.push_back(&use);
358 
359   while (!workingSet.empty()) {
360     OpOperand *uMaybeReading = workingSet.pop_back_val();
361     // Skip over all ops that neither read nor write (but create an alias).
362     if (bufferizesToAliasOnly(*uMaybeReading))
363       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
364         for (OpOperand &use : opResult.getUses())
365           workingSet.push_back(&use);
366     if (bufferizesToMemoryRead(*uMaybeReading))
367       return true;
368   }
369 
370   return false;
371 }
372 
373 // Starting from `value`, follow the use-def chain in reverse, always selecting
374 // the aliasing OpOperands. Find and return Values for which `condition`
375 // evaluates to true. OpOperands of such matching Values are not traversed any
376 // further.
377 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
378     Value value, llvm::function_ref<bool(Value)> condition) const {
379   llvm::SetVector<Value> result, workingSet;
380   workingSet.insert(value);
381 
382   while (!workingSet.empty()) {
383     Value value = workingSet.pop_back_val();
384     if (condition(value) || value.isa<BlockArgument>()) {
385       result.insert(value);
386       continue;
387     }
388 
389     OpResult opResult = value.cast<OpResult>();
390     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
391     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
392       result.insert(value);
393       continue;
394     }
395 
396     for (OpOperand *o : opOperands)
397       workingSet.insert(o->get());
398   }
399 
400   return result;
401 }
402 
403 // Find the Values of the last preceding write of a given Value.
404 llvm::SetVector<Value>
405 AnalysisState::findLastPrecedingWrite(Value value) const {
406   return findValueInReverseUseDefChain(value, [&](Value value) {
407     Operation *op = value.getDefiningOp();
408     if (!op)
409       return true;
410     auto bufferizableOp = options.dynCastBufferizableOp(op);
411     if (!bufferizableOp)
412       return true;
413     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
414   });
415 }
416 
417 AnalysisState::AnalysisState(const BufferizationOptions &options)
418     : options(options) {
419   for (const BufferizationOptions::AnalysisStateInitFn &fn :
420        options.stateInitializers)
421     fn(*this);
422 }
423 
424 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
425   // Do not copy if the tensor has undefined contents.
426   if (hasUndefinedContents(&opOperand))
427     return true;
428 
429   // Do not copy if the buffer of the tensor is entirely overwritten (with
430   // values that do not depend on the old tensor).
431   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
432     return true;
433 
434   // Do not copy if the tensor is never read.
435   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
436   if (!bufferizesToMemoryRead(opOperand) &&
437       llvm::none_of(aliasingOpResults,
438                     [&](OpResult opResult) { return isValueRead(opResult); }))
439     return true;
440 
441   // Default: Cannot omit the copy.
442   return false;
443 }
444 
445 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
446   // ToMemrefOps are always in-place.
447   if (isa<ToMemrefOp>(opOperand.getOwner()))
448     return true;
449 
450   // In the absence of analysis information, OpOperands that bufferize to a
451   // memory write are out-of-place, i.e., an alloc and copy is inserted.
452   return !bufferizesToMemoryWrite(opOperand);
453 }
454 
455 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
456   // In the absence of analysis information, we do not know if the values are
457   // equivalent. The conservative answer is "false".
458   return false;
459 }
460 
461 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
462   // In the absence of analysis information, we do not know if the values may be
463   // aliasing. The conservative answer is "true".
464   return true;
465 }
466 
467 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
468   // In the absence of analysis information, the conservative answer is "false".
469   return false;
470 }
471 
472 bool AnalysisState::isTensorYielded(Value tensor) const {
473   // In the absence of analysis information, the conservative answer is "true".
474   if (!tensor.getDefiningOp<AllocTensorOp>())
475     return true;
476 
477   // For AllocTensorOp results, we can do better: They do not alias with any
478   // preceding value, so we can follow SSA use-def chains and do a simple
479   // analysis.
480   SmallVector<OpOperand *> worklist;
481   for (OpOperand &use : tensor.getUses())
482     worklist.push_back(&use);
483 
484   while (!worklist.empty()) {
485     OpOperand *operand = worklist.pop_back_val();
486     Operation *op = operand->getOwner();
487 
488     // If the op is not bufferizable, we can safely assume that the value is not
489     // yielded. (When bufferizing that op, it must handle such cases.)
490     if (!options.dynCastBufferizableOp(op))
491       continue;
492 
493     // We cannot analyze through ToMemrefOps, so we have to conservatively
494     // assume that the value is yielded.
495     if (isa<ToMemrefOp>(op))
496       return true;
497 
498     // Check if the op is returning/yielding.
499     if (isRegionReturnLike(op))
500       return true;
501 
502     // Add all aliasing OpResults to the worklist.
503     // Note: In the absence of detailed analysis information (e.g., there may be
504     // no function call analysis information), this `getAliasingOpResult` is
505     // conservative and may report additional OpResults as potentially aliasing.
506     for (OpResult opResult : getAliasingOpResult(*operand))
507       for (OpOperand &use : opResult.getUses())
508         worklist.push_back(&use);
509   }
510 
511   // No ReturnLike op found: The value is not yielded.
512   return false;
513 }
514 
515 // bufferization.to_memref is not allowed to change the rank.
516 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
517 #ifndef NDEBUG
518   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
519   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
520                                    rankedTensorType.getRank()) &&
521          "to_memref would be invalid: mismatching ranks");
522 #endif
523 }
524 
525 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
526                                           const BufferizationOptions &options) {
527 #ifndef NDEBUG
528   auto tensorType = value.getType().dyn_cast<TensorType>();
529   assert(tensorType && "unexpected non-tensor type");
530 #endif // NDEBUG
531 
532   // Replace "%t = to_tensor %m" with %m.
533   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
534     return toTensorOp.getMemref();
535 
536   // Insert to_memref op.
537   OpBuilder::InsertionGuard g(rewriter);
538   setInsertionPointAfter(rewriter, value);
539   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
540   if (failed(memrefType))
541     return failure();
542   ensureToMemrefOpIsValid(value, *memrefType);
543   return rewriter
544       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
545       .getResult();
546 }
547 
548 /// Return the buffer type for a given Value (tensor) after bufferization.
549 FailureOr<BaseMemRefType>
550 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
551   assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
552   Operation *op = getOwnerOfValue(value);
553 
554   // ToTensorOp: Take buffer type directly from the op.
555   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
556     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
557 
558   // If value is a bbArg of a bufferizable op: query op interface.
559   if (auto bbArg = value.dyn_cast<BlockArgument>())
560     if (auto bufferizableOp =
561             options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
562       return bufferizableOp.getBufferType(bbArg, options);
563 
564   // Check value is a new buffer allocation with a memory space attribute. In
565   // that case we can at least infer the memory space.
566   Optional<unsigned> memorySpace = None;
567   if (auto opResult = value.dyn_cast<OpResult>()) {
568     if (auto bufferizableOp =
569             options.dynCastBufferizableOp(opResult.getDefiningOp())) {
570       if (bufferizableOp.bufferizesToAllocation(opResult)) {
571         FailureOr<unsigned> queriedMemorySpace =
572             bufferizableOp.getMemorySpace(opResult);
573         if (!failed(queriedMemorySpace))
574           memorySpace = *queriedMemorySpace;
575       }
576     }
577   }
578 
579   // If we still do not know the memory space, use the default memory space (if
580   // any).
581   if (!memorySpace.has_value())
582     memorySpace = options.defaultMemorySpace;
583 
584   // If we still do not know the memory space, report a failure.
585   if (!memorySpace.has_value())
586     return op->emitError("could not infer memory space");
587 
588   return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
589 }
590 
591 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
592                                                   Operation *op,
593                                                   ValueRange values) {
594   assert(values.size() == op->getNumResults() &&
595          "expected one value per OpResult");
596   OpBuilder::InsertionGuard g(rewriter);
597 
598   // Replace all OpResults with the given values.
599   SmallVector<Value> replacements;
600   for (OpResult opResult : op->getOpResults()) {
601     Value replacement = values[opResult.getResultNumber()];
602     if (opResult.getType().isa<TensorType>()) {
603       // The OpResult is a tensor. Such values are replaced with memrefs during
604       // bufferization.
605       assert((replacement.getType().isa<MemRefType>() ||
606               replacement.getType().isa<UnrankedMemRefType>()) &&
607              "tensor op result should be replaced with a memref value");
608       // The existing uses of the OpResult still expect a tensor. Insert a
609       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
610       // loose all of its users and eventually DCE away.
611       rewriter.setInsertionPointAfter(op);
612       replacement = rewriter.create<bufferization::ToTensorOp>(
613           replacement.getLoc(), replacement);
614     }
615     replacements.push_back(replacement);
616   }
617 
618   rewriter.replaceOp(op, replacements);
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // Bufferization-specific scoped alloc/dealloc insertion support.
623 //===----------------------------------------------------------------------===//
624 
625 /// Create a memref allocation with the given type and dynamic extents.
626 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
627                                                    MemRefType type,
628                                                    ValueRange dynShape) const {
629   if (allocationFn)
630     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
631 
632   // Default bufferallocation via AllocOp.
633   if (bufferAlignment != 0)
634     return b
635         .create<memref::AllocOp>(loc, type, dynShape,
636                                  b.getI64IntegerAttr(bufferAlignment))
637         .getResult();
638   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
639 }
640 
641 /// Creates a memref deallocation. The given memref buffer must have been
642 /// allocated using `createAlloc`.
643 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
644                                                   Value allocatedBuffer) const {
645   if (deallocationFn)
646     return (*deallocationFn)(b, loc, allocatedBuffer);
647 
648   // Default buffer deallocation via DeallocOp.
649   b.create<memref::DeallocOp>(loc, allocatedBuffer);
650   return success();
651 }
652 
653 /// Create a memory copy between two memref buffers.
654 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
655                                                  Value from, Value to) const {
656   if (memCpyFn)
657     return (*memCpyFn)(b, loc, from, to);
658 
659   b.create<memref::CopyOp>(loc, from, to);
660   return success();
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // Bufferization-specific BlockAndValueMapping support with debugging.
665 //===----------------------------------------------------------------------===//
666 
667 bool bufferization::isFunctionArgument(Value value) {
668   auto bbArg = value.dyn_cast<BlockArgument>();
669   if (!bbArg)
670     return false;
671   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
672 }
673 
674 BaseMemRefType bufferization::getMemRefType(Value value,
675                                             const BufferizationOptions &options,
676                                             MemRefLayoutAttrInterface layout,
677                                             unsigned memorySpace) {
678   auto tensorType = value.getType().cast<TensorType>();
679   auto memorySpaceAttr = IntegerAttr::get(
680       IntegerType::get(tensorType.getContext(), 64), memorySpace);
681 
682   // Case 1: Unranked memref type.
683   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
684     assert(!layout && "UnrankedTensorType cannot have a layout map");
685     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
686                                    memorySpaceAttr);
687   }
688 
689   // Case 2: Ranked memref type with specified layout.
690   auto rankedTensorType = tensorType.cast<RankedTensorType>();
691   if (layout) {
692     return MemRefType::get(rankedTensorType.getShape(),
693                            rankedTensorType.getElementType(), layout,
694                            memorySpaceAttr);
695   }
696 
697   return options.unknownTypeConverterFn(value, memorySpace, options);
698 }
699 
700 BaseMemRefType
701 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
702                                                    unsigned memorySpace) {
703   // Case 1: Unranked memref type.
704   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
705     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
706                                    memorySpace);
707   }
708 
709   // Case 2: Ranked memref type.
710   auto memorySpaceAttr = IntegerAttr::get(
711       IntegerType::get(tensorType.getContext(), 64), memorySpace);
712   auto rankedTensorType = tensorType.cast<RankedTensorType>();
713   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
714   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
715                                       ShapedType::kDynamicStrideOrOffset);
716   AffineMap stridedLayout = makeStridedLinearLayoutMap(
717       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
718   return MemRefType::get(rankedTensorType.getShape(),
719                          rankedTensorType.getElementType(), stridedLayout,
720                          memorySpaceAttr);
721 }
722 
723 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
724 /// the given tensor type is unranked, return an unranked MemRef type.
725 BaseMemRefType
726 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
727                                                      unsigned memorySpace) {
728   // Case 1: Unranked memref type.
729   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
730     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
731                                    memorySpace);
732   }
733 
734   // Case 2: Ranked memref type.
735   auto rankedTensorType = tensorType.cast<RankedTensorType>();
736   auto memorySpaceAttr = IntegerAttr::get(
737       IntegerType::get(tensorType.getContext(), 64), memorySpace);
738   MemRefLayoutAttrInterface layout = {};
739   return MemRefType::get(rankedTensorType.getShape(),
740                          rankedTensorType.getElementType(), layout,
741                          memorySpaceAttr);
742 }
743