1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Return the owner of the given value.
getOwnerOfValue(Value value)42 static Operation *getOwnerOfValue(Value value) {
43   if (auto opResult = value.dyn_cast<OpResult>())
44     return opResult.getDefiningOp();
45   return value.cast<BlockArgument>().getOwner()->getParentOp();
46 }
47 
allocationDoesNotEscape(OpResult opResult)48 bool bufferization::allocationDoesNotEscape(OpResult opResult) {
49 #ifndef NDEBUG
50   auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
51   assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
52          "expected op that bufferizes to an allocation");
53 #endif // NDEBUG
54 
55   Operation *op = opResult.getDefiningOp();
56   // If there is no 'escape' attribute, we cannot say for sure.
57   if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
58     return false;
59   auto attr =
60       op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
61   return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
62 }
63 
64 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
65 /// shaped value is copied. Otherwise, a tensor with undefined contents is
66 /// allocated.
allocateTensorForShapedValue(OpBuilder & b,Location loc,Value shapedValue,bool escape,const BufferizationOptions & options,bool copy)67 FailureOr<Value> bufferization::allocateTensorForShapedValue(
68     OpBuilder &b, Location loc, Value shapedValue, bool escape,
69     const BufferizationOptions &options, bool copy) {
70   Value tensor;
71   if (shapedValue.getType().isa<RankedTensorType>()) {
72     tensor = shapedValue;
73   } else if (shapedValue.getType().isa<MemRefType>()) {
74     tensor = b.create<ToTensorOp>(loc, shapedValue);
75   } else {
76     llvm_unreachable("expected RankedTensorType or MemRefType");
77   }
78   RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
79   SmallVector<Value> dynamicSizes;
80   if (!copy) {
81     // Compute the dynamic part of the shape.
82     // First try to query the shape via ReifyRankedShapedTypeOpInterface.
83     bool reifiedShapes = false;
84     if (shapedValue.getType().isa<RankedTensorType>() &&
85         shapedValue.isa<OpResult>()) {
86       if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
87               shapedValue.getDefiningOp())) {
88         ReifiedRankedShapedTypeDims resultDims;
89         if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
90           reifiedShapes = true;
91           auto &shape =
92               resultDims[shapedValue.cast<OpResult>().getResultNumber()];
93           for (const auto &dim : enumerate(tensorType.getShape()))
94             if (ShapedType::isDynamic(dim.value()))
95               dynamicSizes.push_back(shape[dim.index()]);
96         }
97       }
98     }
99 
100     // If the shape could not be reified, create DimOps.
101     if (!reifiedShapes)
102       populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
103   }
104 
105   // Create AllocTensorOp.
106   auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
107                                                copy ? tensor : Value());
108   allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
109                          b.getBoolArrayAttr({escape}));
110 
111   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
112   if (copy)
113     return allocTensorOp.getResult();
114   FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
115   if (failed(copyBufferType))
116     return failure();
117   allocTensorOp.setMemorySpaceAttr(
118       b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
119                        copyBufferType->getMemorySpaceAsInt()));
120   return allocTensorOp.getResult();
121 }
122 
resolveTensorOpOperandConflicts(RewriterBase & rewriter,const AnalysisState & state)123 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
124     RewriterBase &rewriter, const AnalysisState &state) {
125   OpBuilder::InsertionGuard g(rewriter);
126   Operation *op = getOperation();
127   SmallVector<OpOperand *> outOfPlaceOpOperands;
128   DenseSet<OpOperand *> copiedOpOperands;
129   DenseSet<OpOperand *> escapingOpOperandCopies;
130   SmallVector<OpResult> outOfPlaceOpResults;
131   DenseSet<OpResult> copiedOpResults;
132   DenseSet<OpResult> escapingOpResultCopies;
133 
134   // Find all out-of-place OpOperands.
135   for (OpOperand &opOperand : op->getOpOperands()) {
136     Type operandType = opOperand.get().getType();
137     if (!operandType.isa<TensorType>())
138       continue;
139     if (state.isInPlace(opOperand))
140       continue;
141     if (operandType.isa<UnrankedTensorType>())
142       return op->emitError("copies of unranked tensors are not supported");
143 
144     SmallVector<OpResult> aliasingOpResults =
145         state.getAliasingOpResult(opOperand);
146     // Is the result yielded from a block? Or are deallocations turned off
147     // entirely? In either case, mark the allocation as "escaping", so that it
148     // will not be deallocated.
149     bool escape = !state.getOptions().createDeallocs ||
150                   llvm::any_of(aliasingOpResults, [&](Value v) {
151                     return state.isTensorYielded(v);
152                   });
153 
154     if (aliasingOpResults.size() == 1 &&
155         !state.bufferizesToMemoryWrite(opOperand) &&
156         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
157       // The op itself does not write but may create exactly one alias. Instead
158       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
159       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
160       // where the result is usually a smaller part of the source).
161       outOfPlaceOpResults.push_back(aliasingOpResults.front());
162       if (!state.canOmitTensorCopy(opOperand))
163         copiedOpResults.insert(aliasingOpResults.front());
164       if (escape)
165         escapingOpResultCopies.insert(aliasingOpResults.front());
166     } else {
167       // In all other cases, make a copy of the OpOperand.
168       outOfPlaceOpOperands.push_back(&opOperand);
169       if (!state.canOmitTensorCopy(opOperand))
170         copiedOpOperands.insert(&opOperand);
171       if (escape)
172         escapingOpOperandCopies.insert(&opOperand);
173     }
174   }
175 
176   // Insert copies of OpOperands.
177   rewriter.setInsertionPoint(op);
178   for (OpOperand *opOperand : outOfPlaceOpOperands) {
179     FailureOr<Value> copy = allocateTensorForShapedValue(
180         rewriter, op->getLoc(), opOperand->get(),
181         escapingOpOperandCopies.contains(opOperand), state.getOptions(),
182         copiedOpOperands.contains(opOperand));
183     if (failed(copy))
184       return failure();
185     rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
186   }
187 
188   // Insert copies of OpResults.
189   rewriter.setInsertionPointAfter(op);
190   for (OpResult opResult : outOfPlaceOpResults) {
191     FailureOr<Value> copy = allocateTensorForShapedValue(
192         rewriter, op->getLoc(), opResult,
193         escapingOpResultCopies.contains(opResult), state.getOptions(),
194         copiedOpResults.count(opResult));
195     if (failed(copy))
196       return failure();
197     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
198         opResult.getUses(), [](OpOperand &use) { return &use; }));
199     for (OpOperand *use : uses) {
200       // Do not update the alloc_tensor op that we just created.
201       if (use->getOwner() != copy->getDefiningOp())
202         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
203     }
204   }
205 
206   return success();
207 }
208 
shouldDeallocateOpResult(OpResult opResult,const BufferizationOptions & options)209 bool bufferization::shouldDeallocateOpResult(
210     OpResult opResult, const BufferizationOptions &options) {
211   Operation *op = opResult.getOwner();
212   assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) &&
213          "expected that op allocates");
214 
215   AnalysisState analysisState(options);
216   if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
217     // AllocTensorOp has one result.
218     ArrayAttr escapeAttr =
219         op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
220     return !escapeAttr[0].cast<BoolAttr>().getValue();
221   }
222 
223   // No "escape" annotation found.
224   if (options.createDeallocs) {
225     // Perform an ad-hoc analysis.
226     return !analysisState.isTensorYielded(opResult);
227   }
228 
229   return false;
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // OpFilter
234 //===----------------------------------------------------------------------===//
235 
isOpAllowed(Operation * op) const236 bool OpFilter::isOpAllowed(Operation *op) const {
237   // All other ops: Allow/disallow according to filter.
238   bool isAllowed = !hasAllowRule();
239   for (const Entry &entry : entries) {
240     bool filterResult = entry.fn(op);
241     switch (entry.type) {
242     case Entry::ALLOW:
243       isAllowed |= filterResult;
244       break;
245     case Entry::DENY:
246       if (filterResult)
247         // DENY filter matches. This op is no allowed. (Even if other ALLOW
248         // filters may match.)
249         return false;
250     };
251   }
252   return isAllowed;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // BufferizationOptions
257 //===----------------------------------------------------------------------===//
258 
259 /// Default unknown type converter: Use a fully dynamic layout map.
260 static BaseMemRefType
defaultUnknownTypeConverter(Value value,unsigned memorySpace,const BufferizationOptions & options)261 defaultUnknownTypeConverter(Value value, unsigned memorySpace,
262                             const BufferizationOptions &options) {
263   return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
264                                              memorySpace);
265 }
266 
267 // Default constructor for BufferizationOptions.
BufferizationOptions()268 BufferizationOptions::BufferizationOptions()
269     : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
270 
isOpAllowed(Operation * op) const271 bool BufferizationOptions::isOpAllowed(Operation *op) const {
272   // Special case: If function boundary bufferization is deactivated, do not
273   // allow ops that belong to the `func` dialect.
274   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
275   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
276     return false;
277 
278   return opFilter.isOpAllowed(op);
279 }
280 
281 BufferizableOpInterface
dynCastBufferizableOp(Operation * op) const282 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
283   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
284   if (!bufferizableOp)
285     return nullptr;
286   if (!isOpAllowed(op))
287     return nullptr;
288   return bufferizableOp;
289 }
290 
291 BufferizableOpInterface
dynCastBufferizableOp(Value value) const292 BufferizationOptions::dynCastBufferizableOp(Value value) const {
293   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
294     if (isOpAllowed(bufferizableOp.getOperation()))
295       return bufferizableOp;
296   return nullptr;
297 }
298 
addDialectStateInitializer(StringRef name,const DialectStateInitFn & fn)299 void BufferizationOptions::addDialectStateInitializer(
300     StringRef name, const DialectStateInitFn &fn) {
301   stateInitializers.push_back(
302       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Helper functions for BufferizableOpInterface
307 //===----------------------------------------------------------------------===//
308 
setInsertionPointAfter(OpBuilder & b,Value value)309 static void setInsertionPointAfter(OpBuilder &b, Value value) {
310   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
311     b.setInsertionPointToStart(bbArg.getOwner());
312   } else {
313     b.setInsertionPointAfter(value.getDefiningOp());
314   }
315 }
316 
317 /// Determine which OpOperand* will alias with `result` if the op is bufferized
318 /// in place. Return an empty vector if the op is not bufferizable.
319 SmallVector<OpOperand *>
getAliasingOpOperand(OpResult result) const320 AnalysisState::getAliasingOpOperand(OpResult result) const {
321   if (Operation *op = result.getDefiningOp())
322     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
323       return bufferizableOp.getAliasingOpOperand(result, *this);
324   return {};
325 }
326 
327 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
328 /// in place. Return an empty vector if the op is not bufferizable.
329 SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand) const330 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
331   if (auto bufferizableOp =
332           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
333     return bufferizableOp.getAliasingOpResult(opOperand, *this);
334   return {};
335 }
336 
337 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
338 /// op is not bufferizable.
bufferizesToMemoryRead(OpOperand & opOperand) const339 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
340   if (auto bufferizableOp =
341           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
342     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
343 
344   // Unknown op that returns a tensor. The inplace analysis does not support it.
345   // Conservatively return true.
346   return true;
347 }
348 
349 /// Return true if `opOperand` bufferizes to a memory write. Return
350 /// `true` if the op is not bufferizable.
bufferizesToMemoryWrite(OpOperand & opOperand) const351 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
352   if (auto bufferizableOp =
353           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
354     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
355 
356   // Unknown op that returns a tensor. The inplace analysis does not support it.
357   // Conservatively return true.
358   return true;
359 }
360 
361 /// Return true if `opOperand` does neither read nor write but bufferizes to an
362 /// alias. Return false if the op is not bufferizable.
bufferizesToAliasOnly(OpOperand & opOperand) const363 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
364   if (auto bufferizableOp =
365           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
366     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
367 
368   // Unknown op that returns a tensor. The inplace analysis does not support it.
369   // Conservatively return false.
370   return false;
371 }
372 
373 /// Return true if the given value is read by an op that bufferizes to a memory
374 /// read. Also takes into account ops that create an alias but do not read by
375 /// themselves (e.g., ExtractSliceOp).
isValueRead(Value value) const376 bool AnalysisState::isValueRead(Value value) const {
377   assert(value.getType().isa<TensorType>() && "expected TensorType");
378   SmallVector<OpOperand *> workingSet;
379   for (OpOperand &use : value.getUses())
380     workingSet.push_back(&use);
381 
382   while (!workingSet.empty()) {
383     OpOperand *uMaybeReading = workingSet.pop_back_val();
384     // Skip over all ops that neither read nor write (but create an alias).
385     if (bufferizesToAliasOnly(*uMaybeReading))
386       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
387         for (OpOperand &use : opResult.getUses())
388           workingSet.push_back(&use);
389     if (bufferizesToMemoryRead(*uMaybeReading))
390       return true;
391   }
392 
393   return false;
394 }
395 
396 // Starting from `value`, follow the use-def chain in reverse, always selecting
397 // the aliasing OpOperands. Find and return Values for which `condition`
398 // evaluates to true. OpOperands of such matching Values are not traversed any
399 // further.
findValueInReverseUseDefChain(Value value,llvm::function_ref<bool (Value)> condition) const400 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
401     Value value, llvm::function_ref<bool(Value)> condition) const {
402   llvm::SetVector<Value> result, workingSet;
403   workingSet.insert(value);
404 
405   while (!workingSet.empty()) {
406     Value value = workingSet.pop_back_val();
407     if (condition(value) || value.isa<BlockArgument>()) {
408       result.insert(value);
409       continue;
410     }
411 
412     OpResult opResult = value.cast<OpResult>();
413     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
414     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
415       result.insert(value);
416       continue;
417     }
418 
419     for (OpOperand *o : opOperands)
420       workingSet.insert(o->get());
421   }
422 
423   return result;
424 }
425 
426 // Find the Values of the last preceding write of a given Value.
427 llvm::SetVector<Value>
findLastPrecedingWrite(Value value) const428 AnalysisState::findLastPrecedingWrite(Value value) const {
429   return findValueInReverseUseDefChain(value, [&](Value value) {
430     Operation *op = value.getDefiningOp();
431     if (!op)
432       return true;
433     auto bufferizableOp = options.dynCastBufferizableOp(op);
434     if (!bufferizableOp)
435       return true;
436     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
437   });
438 }
439 
AnalysisState(const BufferizationOptions & options)440 AnalysisState::AnalysisState(const BufferizationOptions &options)
441     : options(options) {
442   for (const BufferizationOptions::AnalysisStateInitFn &fn :
443        options.stateInitializers)
444     fn(*this);
445 }
446 
canOmitTensorCopy(OpOperand & opOperand) const447 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
448   // Do not copy if the tensor has undefined contents.
449   if (hasUndefinedContents(&opOperand))
450     return true;
451 
452   // Do not copy if the buffer of the tensor is entirely overwritten (with
453   // values that do not depend on the old tensor).
454   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
455     return true;
456 
457   // Do not copy if the tensor is never read.
458   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
459   if (!bufferizesToMemoryRead(opOperand) &&
460       llvm::none_of(aliasingOpResults,
461                     [&](OpResult opResult) { return isValueRead(opResult); }))
462     return true;
463 
464   // Default: Cannot omit the copy.
465   return false;
466 }
467 
isInPlace(OpOperand & opOperand) const468 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
469   // ToMemrefOps are always in-place.
470   if (isa<ToMemrefOp>(opOperand.getOwner()))
471     return true;
472 
473   // In the absence of analysis information, OpOperands that bufferize to a
474   // memory write are out-of-place, i.e., an alloc and copy is inserted.
475   return !bufferizesToMemoryWrite(opOperand);
476 }
477 
areEquivalentBufferizedValues(Value v1,Value v2) const478 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
479   // In the absence of analysis information, we do not know if the values are
480   // equivalent. The conservative answer is "false".
481   return false;
482 }
483 
areAliasingBufferizedValues(Value v1,Value v2) const484 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
485   // In the absence of analysis information, we do not know if the values may be
486   // aliasing. The conservative answer is "true".
487   return true;
488 }
489 
hasUndefinedContents(OpOperand * opOperand) const490 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
491   // In the absence of analysis information, the conservative answer is "false".
492   return false;
493 }
494 
isTensorYielded(Value tensor) const495 bool AnalysisState::isTensorYielded(Value tensor) const {
496   // In the absence of analysis information, the conservative answer is "true".
497   if (!tensor.getDefiningOp<AllocTensorOp>())
498     return true;
499 
500   // For AllocTensorOp results, we can do better: They do not alias with any
501   // preceding value, so we can follow SSA use-def chains and do a simple
502   // analysis.
503   SmallVector<OpOperand *> worklist;
504   for (OpOperand &use : tensor.getUses())
505     worklist.push_back(&use);
506 
507   while (!worklist.empty()) {
508     OpOperand *operand = worklist.pop_back_val();
509     Operation *op = operand->getOwner();
510 
511     // If the op is not bufferizable, we can safely assume that the value is not
512     // yielded. (When bufferizing that op, it must handle such cases.)
513     if (!options.dynCastBufferizableOp(op))
514       continue;
515 
516     // We cannot analyze through ToMemrefOps, so we have to conservatively
517     // assume that the value is yielded.
518     if (isa<ToMemrefOp>(op))
519       return true;
520 
521     // Check if the op is returning/yielding.
522     if (isRegionReturnLike(op))
523       return true;
524 
525     // Add all aliasing OpResults to the worklist.
526     // Note: In the absence of detailed analysis information (e.g., there may be
527     // no function call analysis information), this `getAliasingOpResult` is
528     // conservative and may report additional OpResults as potentially aliasing.
529     for (OpResult opResult : getAliasingOpResult(*operand))
530       for (OpOperand &use : opResult.getUses())
531         worklist.push_back(&use);
532   }
533 
534   // No ReturnLike op found: The value is not yielded.
535   return false;
536 }
537 
538 // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)539 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
540 #ifndef NDEBUG
541   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
542   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
543                                    rankedTensorType.getRank()) &&
544          "to_memref would be invalid: mismatching ranks");
545 #endif
546 }
547 
getBuffer(RewriterBase & rewriter,Value value,const BufferizationOptions & options)548 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
549                                           const BufferizationOptions &options) {
550 #ifndef NDEBUG
551   auto tensorType = value.getType().dyn_cast<TensorType>();
552   assert(tensorType && "unexpected non-tensor type");
553 #endif // NDEBUG
554 
555   // Replace "%t = to_tensor %m" with %m.
556   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
557     return toTensorOp.getMemref();
558 
559   // Insert to_memref op.
560   OpBuilder::InsertionGuard g(rewriter);
561   setInsertionPointAfter(rewriter, value);
562   FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
563   if (failed(memrefType))
564     return failure();
565   ensureToMemrefOpIsValid(value, *memrefType);
566   return rewriter
567       .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
568       .getResult();
569 }
570 
571 /// Return the buffer type for a given Value (tensor) after bufferization.
572 FailureOr<BaseMemRefType>
getBufferType(Value value,const BufferizationOptions & options)573 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
574   assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
575   Operation *op = getOwnerOfValue(value);
576 
577   // ToTensorOp: Take buffer type directly from the op.
578   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
579     return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
580 
581   // If value is a bbArg of a bufferizable op: query op interface.
582   if (auto bbArg = value.dyn_cast<BlockArgument>())
583     if (auto bufferizableOp =
584             options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
585       return bufferizableOp.getBufferType(bbArg, options);
586 
587   // Check value is a new buffer allocation with a memory space attribute. In
588   // that case we can at least infer the memory space.
589   Optional<unsigned> memorySpace = None;
590   if (auto opResult = value.dyn_cast<OpResult>()) {
591     if (auto bufferizableOp =
592             options.dynCastBufferizableOp(opResult.getDefiningOp())) {
593       if (bufferizableOp.bufferizesToAllocation(opResult)) {
594         FailureOr<unsigned> queriedMemorySpace =
595             bufferizableOp.getMemorySpace(opResult);
596         if (!failed(queriedMemorySpace))
597           memorySpace = *queriedMemorySpace;
598       }
599     }
600   }
601 
602   // If we still do not know the memory space, use the default memory space (if
603   // any).
604   if (!memorySpace.has_value())
605     memorySpace = options.defaultMemorySpace;
606 
607   // If we still do not know the memory space, report a failure.
608   if (!memorySpace.has_value())
609     return op->emitError("could not infer memory space");
610 
611   return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
612 }
613 
replaceOpWithBufferizedValues(RewriterBase & rewriter,Operation * op,ValueRange values)614 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
615                                                   Operation *op,
616                                                   ValueRange values) {
617   assert(values.size() == op->getNumResults() &&
618          "expected one value per OpResult");
619   OpBuilder::InsertionGuard g(rewriter);
620 
621   // Replace all OpResults with the given values.
622   SmallVector<Value> replacements;
623   for (OpResult opResult : op->getOpResults()) {
624     Value replacement = values[opResult.getResultNumber()];
625     if (opResult.getType().isa<TensorType>()) {
626       // The OpResult is a tensor. Such values are replaced with memrefs during
627       // bufferization.
628       assert((replacement.getType().isa<MemRefType>() ||
629               replacement.getType().isa<UnrankedMemRefType>()) &&
630              "tensor op result should be replaced with a memref value");
631       // The existing uses of the OpResult still expect a tensor. Insert a
632       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
633       // loose all of its users and eventually DCE away.
634       rewriter.setInsertionPointAfter(op);
635       replacement = rewriter.create<bufferization::ToTensorOp>(
636           replacement.getLoc(), replacement);
637     }
638     replacements.push_back(replacement);
639   }
640 
641   rewriter.replaceOp(op, replacements);
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Bufferization-specific scoped alloc/dealloc insertion support.
646 //===----------------------------------------------------------------------===//
647 
648 /// Create a memref allocation with the given type and dynamic extents.
createAlloc(OpBuilder & b,Location loc,MemRefType type,ValueRange dynShape) const649 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
650                                                    MemRefType type,
651                                                    ValueRange dynShape) const {
652   if (allocationFn)
653     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
654 
655   // Default bufferallocation via AllocOp.
656   if (bufferAlignment != 0)
657     return b
658         .create<memref::AllocOp>(loc, type, dynShape,
659                                  b.getI64IntegerAttr(bufferAlignment))
660         .getResult();
661   return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
662 }
663 
664 /// Creates a memref deallocation. The given memref buffer must have been
665 /// allocated using `createAlloc`.
createDealloc(OpBuilder & b,Location loc,Value allocatedBuffer) const666 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
667                                                   Value allocatedBuffer) const {
668   if (deallocationFn)
669     return (*deallocationFn)(b, loc, allocatedBuffer);
670 
671   // Default buffer deallocation via DeallocOp.
672   b.create<memref::DeallocOp>(loc, allocatedBuffer);
673   return success();
674 }
675 
676 /// Create a memory copy between two memref buffers.
createMemCpy(OpBuilder & b,Location loc,Value from,Value to) const677 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
678                                                  Value from, Value to) const {
679   if (memCpyFn)
680     return (*memCpyFn)(b, loc, from, to);
681 
682   b.create<memref::CopyOp>(loc, from, to);
683   return success();
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Bufferization-specific BlockAndValueMapping support with debugging.
688 //===----------------------------------------------------------------------===//
689 
isFunctionArgument(Value value)690 bool bufferization::isFunctionArgument(Value value) {
691   auto bbArg = value.dyn_cast<BlockArgument>();
692   if (!bbArg)
693     return false;
694   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
695 }
696 
getMemRefType(Value value,const BufferizationOptions & options,MemRefLayoutAttrInterface layout,unsigned memorySpace)697 BaseMemRefType bufferization::getMemRefType(Value value,
698                                             const BufferizationOptions &options,
699                                             MemRefLayoutAttrInterface layout,
700                                             unsigned memorySpace) {
701   auto tensorType = value.getType().cast<TensorType>();
702   auto memorySpaceAttr = IntegerAttr::get(
703       IntegerType::get(tensorType.getContext(), 64), memorySpace);
704 
705   // Case 1: Unranked memref type.
706   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
707     assert(!layout && "UnrankedTensorType cannot have a layout map");
708     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
709                                    memorySpaceAttr);
710   }
711 
712   // Case 2: Ranked memref type with specified layout.
713   auto rankedTensorType = tensorType.cast<RankedTensorType>();
714   if (layout) {
715     return MemRefType::get(rankedTensorType.getShape(),
716                            rankedTensorType.getElementType(), layout,
717                            memorySpaceAttr);
718   }
719 
720   return options.unknownTypeConverterFn(value, memorySpace, options);
721 }
722 
723 BaseMemRefType
getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,unsigned memorySpace)724 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
725                                                    unsigned memorySpace) {
726   // Case 1: Unranked memref type.
727   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
728     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
729                                    memorySpace);
730   }
731 
732   // Case 2: Ranked memref type.
733   auto memorySpaceAttr = IntegerAttr::get(
734       IntegerType::get(tensorType.getContext(), 64), memorySpace);
735   auto rankedTensorType = tensorType.cast<RankedTensorType>();
736   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
737   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
738                                       ShapedType::kDynamicStrideOrOffset);
739   AffineMap stridedLayout = makeStridedLinearLayoutMap(
740       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
741   return MemRefType::get(rankedTensorType.getShape(),
742                          rankedTensorType.getElementType(), stridedLayout,
743                          memorySpaceAttr);
744 }
745 
746 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
747 /// the given tensor type is unranked, return an unranked MemRef type.
748 BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,unsigned memorySpace)749 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
750                                                      unsigned memorySpace) {
751   // Case 1: Unranked memref type.
752   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
753     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
754                                    memorySpace);
755   }
756 
757   // Case 2: Ranked memref type.
758   auto rankedTensorType = tensorType.cast<RankedTensorType>();
759   auto memorySpaceAttr = IntegerAttr::get(
760       IntegerType::get(tensorType.getContext(), 64), memorySpace);
761   MemRefLayoutAttrInterface layout = {};
762   return MemRefType::get(rankedTensorType.getShape(),
763                          rankedTensorType.getElementType(), layout,
764                          memorySpaceAttr);
765 }
766