1 //===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25
26 namespace mlir {
27 namespace bufferization {
28
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30
31 } // namespace bufferization
32 } // namespace mlir
33
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37
38 using namespace mlir;
39 using namespace bufferization;
40
41 /// Return the owner of the given value.
getOwnerOfValue(Value value)42 static Operation *getOwnerOfValue(Value value) {
43 if (auto opResult = value.dyn_cast<OpResult>())
44 return opResult.getDefiningOp();
45 return value.cast<BlockArgument>().getOwner()->getParentOp();
46 }
47
allocationDoesNotEscape(OpResult opResult)48 bool bufferization::allocationDoesNotEscape(OpResult opResult) {
49 #ifndef NDEBUG
50 auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
51 assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
52 "expected op that bufferizes to an allocation");
53 #endif // NDEBUG
54
55 Operation *op = opResult.getDefiningOp();
56 // If there is no 'escape' attribute, we cannot say for sure.
57 if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
58 return false;
59 auto attr =
60 op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
61 return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
62 }
63
64 /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
65 /// shaped value is copied. Otherwise, a tensor with undefined contents is
66 /// allocated.
allocateTensorForShapedValue(OpBuilder & b,Location loc,Value shapedValue,bool escape,const BufferizationOptions & options,bool copy)67 FailureOr<Value> bufferization::allocateTensorForShapedValue(
68 OpBuilder &b, Location loc, Value shapedValue, bool escape,
69 const BufferizationOptions &options, bool copy) {
70 Value tensor;
71 if (shapedValue.getType().isa<RankedTensorType>()) {
72 tensor = shapedValue;
73 } else if (shapedValue.getType().isa<MemRefType>()) {
74 tensor = b.create<ToTensorOp>(loc, shapedValue);
75 } else {
76 llvm_unreachable("expected RankedTensorType or MemRefType");
77 }
78 RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
79 SmallVector<Value> dynamicSizes;
80 if (!copy) {
81 // Compute the dynamic part of the shape.
82 // First try to query the shape via ReifyRankedShapedTypeOpInterface.
83 bool reifiedShapes = false;
84 if (shapedValue.getType().isa<RankedTensorType>() &&
85 shapedValue.isa<OpResult>()) {
86 if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
87 shapedValue.getDefiningOp())) {
88 ReifiedRankedShapedTypeDims resultDims;
89 if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
90 reifiedShapes = true;
91 auto &shape =
92 resultDims[shapedValue.cast<OpResult>().getResultNumber()];
93 for (const auto &dim : enumerate(tensorType.getShape()))
94 if (ShapedType::isDynamic(dim.value()))
95 dynamicSizes.push_back(shape[dim.index()]);
96 }
97 }
98 }
99
100 // If the shape could not be reified, create DimOps.
101 if (!reifiedShapes)
102 populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
103 }
104
105 // Create AllocTensorOp.
106 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
107 copy ? tensor : Value());
108 allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
109 b.getBoolArrayAttr({escape}));
110
111 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
112 if (copy)
113 return allocTensorOp.getResult();
114 FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
115 if (failed(copyBufferType))
116 return failure();
117 allocTensorOp.setMemorySpaceAttr(
118 b.getIntegerAttr(b.getIntegerType(64, /*isSigned=*/false),
119 copyBufferType->getMemorySpaceAsInt()));
120 return allocTensorOp.getResult();
121 }
122
resolveTensorOpOperandConflicts(RewriterBase & rewriter,const AnalysisState & state)123 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
124 RewriterBase &rewriter, const AnalysisState &state) {
125 OpBuilder::InsertionGuard g(rewriter);
126 Operation *op = getOperation();
127 SmallVector<OpOperand *> outOfPlaceOpOperands;
128 DenseSet<OpOperand *> copiedOpOperands;
129 DenseSet<OpOperand *> escapingOpOperandCopies;
130 SmallVector<OpResult> outOfPlaceOpResults;
131 DenseSet<OpResult> copiedOpResults;
132 DenseSet<OpResult> escapingOpResultCopies;
133
134 // Find all out-of-place OpOperands.
135 for (OpOperand &opOperand : op->getOpOperands()) {
136 Type operandType = opOperand.get().getType();
137 if (!operandType.isa<TensorType>())
138 continue;
139 if (state.isInPlace(opOperand))
140 continue;
141 if (operandType.isa<UnrankedTensorType>())
142 return op->emitError("copies of unranked tensors are not supported");
143
144 SmallVector<OpResult> aliasingOpResults =
145 state.getAliasingOpResult(opOperand);
146 // Is the result yielded from a block? Or are deallocations turned off
147 // entirely? In either case, mark the allocation as "escaping", so that it
148 // will not be deallocated.
149 bool escape = !state.getOptions().createDeallocs ||
150 llvm::any_of(aliasingOpResults, [&](Value v) {
151 return state.isTensorYielded(v);
152 });
153
154 if (aliasingOpResults.size() == 1 &&
155 !state.bufferizesToMemoryWrite(opOperand) &&
156 state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
157 // The op itself does not write but may create exactly one alias. Instead
158 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
159 // be smaller than the OpOperand (e.g., in the case of an extract_slice,
160 // where the result is usually a smaller part of the source).
161 outOfPlaceOpResults.push_back(aliasingOpResults.front());
162 if (!state.canOmitTensorCopy(opOperand))
163 copiedOpResults.insert(aliasingOpResults.front());
164 if (escape)
165 escapingOpResultCopies.insert(aliasingOpResults.front());
166 } else {
167 // In all other cases, make a copy of the OpOperand.
168 outOfPlaceOpOperands.push_back(&opOperand);
169 if (!state.canOmitTensorCopy(opOperand))
170 copiedOpOperands.insert(&opOperand);
171 if (escape)
172 escapingOpOperandCopies.insert(&opOperand);
173 }
174 }
175
176 // Insert copies of OpOperands.
177 rewriter.setInsertionPoint(op);
178 for (OpOperand *opOperand : outOfPlaceOpOperands) {
179 FailureOr<Value> copy = allocateTensorForShapedValue(
180 rewriter, op->getLoc(), opOperand->get(),
181 escapingOpOperandCopies.contains(opOperand), state.getOptions(),
182 copiedOpOperands.contains(opOperand));
183 if (failed(copy))
184 return failure();
185 rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
186 }
187
188 // Insert copies of OpResults.
189 rewriter.setInsertionPointAfter(op);
190 for (OpResult opResult : outOfPlaceOpResults) {
191 FailureOr<Value> copy = allocateTensorForShapedValue(
192 rewriter, op->getLoc(), opResult,
193 escapingOpResultCopies.contains(opResult), state.getOptions(),
194 copiedOpResults.count(opResult));
195 if (failed(copy))
196 return failure();
197 SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
198 opResult.getUses(), [](OpOperand &use) { return &use; }));
199 for (OpOperand *use : uses) {
200 // Do not update the alloc_tensor op that we just created.
201 if (use->getOwner() != copy->getDefiningOp())
202 rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
203 }
204 }
205
206 return success();
207 }
208
shouldDeallocateOpResult(OpResult opResult,const BufferizationOptions & options)209 bool bufferization::shouldDeallocateOpResult(
210 OpResult opResult, const BufferizationOptions &options) {
211 Operation *op = opResult.getOwner();
212 assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) &&
213 "expected that op allocates");
214
215 AnalysisState analysisState(options);
216 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
217 // AllocTensorOp has one result.
218 ArrayAttr escapeAttr =
219 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
220 return !escapeAttr[0].cast<BoolAttr>().getValue();
221 }
222
223 // No "escape" annotation found.
224 if (options.createDeallocs) {
225 // Perform an ad-hoc analysis.
226 return !analysisState.isTensorYielded(opResult);
227 }
228
229 return false;
230 }
231
232 //===----------------------------------------------------------------------===//
233 // OpFilter
234 //===----------------------------------------------------------------------===//
235
isOpAllowed(Operation * op) const236 bool OpFilter::isOpAllowed(Operation *op) const {
237 // All other ops: Allow/disallow according to filter.
238 bool isAllowed = !hasAllowRule();
239 for (const Entry &entry : entries) {
240 bool filterResult = entry.fn(op);
241 switch (entry.type) {
242 case Entry::ALLOW:
243 isAllowed |= filterResult;
244 break;
245 case Entry::DENY:
246 if (filterResult)
247 // DENY filter matches. This op is no allowed. (Even if other ALLOW
248 // filters may match.)
249 return false;
250 };
251 }
252 return isAllowed;
253 }
254
255 //===----------------------------------------------------------------------===//
256 // BufferizationOptions
257 //===----------------------------------------------------------------------===//
258
259 /// Default unknown type converter: Use a fully dynamic layout map.
260 static BaseMemRefType
defaultUnknownTypeConverter(Value value,unsigned memorySpace,const BufferizationOptions & options)261 defaultUnknownTypeConverter(Value value, unsigned memorySpace,
262 const BufferizationOptions &options) {
263 return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
264 memorySpace);
265 }
266
267 // Default constructor for BufferizationOptions.
BufferizationOptions()268 BufferizationOptions::BufferizationOptions()
269 : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
270
isOpAllowed(Operation * op) const271 bool BufferizationOptions::isOpAllowed(Operation *op) const {
272 // Special case: If function boundary bufferization is deactivated, do not
273 // allow ops that belong to the `func` dialect.
274 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
275 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
276 return false;
277
278 return opFilter.isOpAllowed(op);
279 }
280
281 BufferizableOpInterface
dynCastBufferizableOp(Operation * op) const282 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
283 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
284 if (!bufferizableOp)
285 return nullptr;
286 if (!isOpAllowed(op))
287 return nullptr;
288 return bufferizableOp;
289 }
290
291 BufferizableOpInterface
dynCastBufferizableOp(Value value) const292 BufferizationOptions::dynCastBufferizableOp(Value value) const {
293 if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
294 if (isOpAllowed(bufferizableOp.getOperation()))
295 return bufferizableOp;
296 return nullptr;
297 }
298
addDialectStateInitializer(StringRef name,const DialectStateInitFn & fn)299 void BufferizationOptions::addDialectStateInitializer(
300 StringRef name, const DialectStateInitFn &fn) {
301 stateInitializers.push_back(
302 [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
303 }
304
305 //===----------------------------------------------------------------------===//
306 // Helper functions for BufferizableOpInterface
307 //===----------------------------------------------------------------------===//
308
setInsertionPointAfter(OpBuilder & b,Value value)309 static void setInsertionPointAfter(OpBuilder &b, Value value) {
310 if (auto bbArg = value.dyn_cast<BlockArgument>()) {
311 b.setInsertionPointToStart(bbArg.getOwner());
312 } else {
313 b.setInsertionPointAfter(value.getDefiningOp());
314 }
315 }
316
317 /// Determine which OpOperand* will alias with `result` if the op is bufferized
318 /// in place. Return an empty vector if the op is not bufferizable.
319 SmallVector<OpOperand *>
getAliasingOpOperand(OpResult result) const320 AnalysisState::getAliasingOpOperand(OpResult result) const {
321 if (Operation *op = result.getDefiningOp())
322 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
323 return bufferizableOp.getAliasingOpOperand(result, *this);
324 return {};
325 }
326
327 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
328 /// in place. Return an empty vector if the op is not bufferizable.
329 SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand) const330 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
331 if (auto bufferizableOp =
332 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
333 return bufferizableOp.getAliasingOpResult(opOperand, *this);
334 return {};
335 }
336
337 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
338 /// op is not bufferizable.
bufferizesToMemoryRead(OpOperand & opOperand) const339 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
340 if (auto bufferizableOp =
341 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
342 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
343
344 // Unknown op that returns a tensor. The inplace analysis does not support it.
345 // Conservatively return true.
346 return true;
347 }
348
349 /// Return true if `opOperand` bufferizes to a memory write. Return
350 /// `true` if the op is not bufferizable.
bufferizesToMemoryWrite(OpOperand & opOperand) const351 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
352 if (auto bufferizableOp =
353 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
354 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
355
356 // Unknown op that returns a tensor. The inplace analysis does not support it.
357 // Conservatively return true.
358 return true;
359 }
360
361 /// Return true if `opOperand` does neither read nor write but bufferizes to an
362 /// alias. Return false if the op is not bufferizable.
bufferizesToAliasOnly(OpOperand & opOperand) const363 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
364 if (auto bufferizableOp =
365 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
366 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
367
368 // Unknown op that returns a tensor. The inplace analysis does not support it.
369 // Conservatively return false.
370 return false;
371 }
372
373 /// Return true if the given value is read by an op that bufferizes to a memory
374 /// read. Also takes into account ops that create an alias but do not read by
375 /// themselves (e.g., ExtractSliceOp).
isValueRead(Value value) const376 bool AnalysisState::isValueRead(Value value) const {
377 assert(value.getType().isa<TensorType>() && "expected TensorType");
378 SmallVector<OpOperand *> workingSet;
379 for (OpOperand &use : value.getUses())
380 workingSet.push_back(&use);
381
382 while (!workingSet.empty()) {
383 OpOperand *uMaybeReading = workingSet.pop_back_val();
384 // Skip over all ops that neither read nor write (but create an alias).
385 if (bufferizesToAliasOnly(*uMaybeReading))
386 for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
387 for (OpOperand &use : opResult.getUses())
388 workingSet.push_back(&use);
389 if (bufferizesToMemoryRead(*uMaybeReading))
390 return true;
391 }
392
393 return false;
394 }
395
396 // Starting from `value`, follow the use-def chain in reverse, always selecting
397 // the aliasing OpOperands. Find and return Values for which `condition`
398 // evaluates to true. OpOperands of such matching Values are not traversed any
399 // further.
findValueInReverseUseDefChain(Value value,llvm::function_ref<bool (Value)> condition) const400 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
401 Value value, llvm::function_ref<bool(Value)> condition) const {
402 llvm::SetVector<Value> result, workingSet;
403 workingSet.insert(value);
404
405 while (!workingSet.empty()) {
406 Value value = workingSet.pop_back_val();
407 if (condition(value) || value.isa<BlockArgument>()) {
408 result.insert(value);
409 continue;
410 }
411
412 OpResult opResult = value.cast<OpResult>();
413 SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
414 if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
415 result.insert(value);
416 continue;
417 }
418
419 for (OpOperand *o : opOperands)
420 workingSet.insert(o->get());
421 }
422
423 return result;
424 }
425
426 // Find the Values of the last preceding write of a given Value.
427 llvm::SetVector<Value>
findLastPrecedingWrite(Value value) const428 AnalysisState::findLastPrecedingWrite(Value value) const {
429 return findValueInReverseUseDefChain(value, [&](Value value) {
430 Operation *op = value.getDefiningOp();
431 if (!op)
432 return true;
433 auto bufferizableOp = options.dynCastBufferizableOp(op);
434 if (!bufferizableOp)
435 return true;
436 return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
437 });
438 }
439
AnalysisState(const BufferizationOptions & options)440 AnalysisState::AnalysisState(const BufferizationOptions &options)
441 : options(options) {
442 for (const BufferizationOptions::AnalysisStateInitFn &fn :
443 options.stateInitializers)
444 fn(*this);
445 }
446
canOmitTensorCopy(OpOperand & opOperand) const447 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
448 // Do not copy if the tensor has undefined contents.
449 if (hasUndefinedContents(&opOperand))
450 return true;
451
452 // Do not copy if the buffer of the tensor is entirely overwritten (with
453 // values that do not depend on the old tensor).
454 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
455 return true;
456
457 // Do not copy if the tensor is never read.
458 SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
459 if (!bufferizesToMemoryRead(opOperand) &&
460 llvm::none_of(aliasingOpResults,
461 [&](OpResult opResult) { return isValueRead(opResult); }))
462 return true;
463
464 // Default: Cannot omit the copy.
465 return false;
466 }
467
isInPlace(OpOperand & opOperand) const468 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
469 // ToMemrefOps are always in-place.
470 if (isa<ToMemrefOp>(opOperand.getOwner()))
471 return true;
472
473 // In the absence of analysis information, OpOperands that bufferize to a
474 // memory write are out-of-place, i.e., an alloc and copy is inserted.
475 return !bufferizesToMemoryWrite(opOperand);
476 }
477
areEquivalentBufferizedValues(Value v1,Value v2) const478 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
479 // In the absence of analysis information, we do not know if the values are
480 // equivalent. The conservative answer is "false".
481 return false;
482 }
483
areAliasingBufferizedValues(Value v1,Value v2) const484 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
485 // In the absence of analysis information, we do not know if the values may be
486 // aliasing. The conservative answer is "true".
487 return true;
488 }
489
hasUndefinedContents(OpOperand * opOperand) const490 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
491 // In the absence of analysis information, the conservative answer is "false".
492 return false;
493 }
494
isTensorYielded(Value tensor) const495 bool AnalysisState::isTensorYielded(Value tensor) const {
496 // In the absence of analysis information, the conservative answer is "true".
497 if (!tensor.getDefiningOp<AllocTensorOp>())
498 return true;
499
500 // For AllocTensorOp results, we can do better: They do not alias with any
501 // preceding value, so we can follow SSA use-def chains and do a simple
502 // analysis.
503 SmallVector<OpOperand *> worklist;
504 for (OpOperand &use : tensor.getUses())
505 worklist.push_back(&use);
506
507 while (!worklist.empty()) {
508 OpOperand *operand = worklist.pop_back_val();
509 Operation *op = operand->getOwner();
510
511 // If the op is not bufferizable, we can safely assume that the value is not
512 // yielded. (When bufferizing that op, it must handle such cases.)
513 if (!options.dynCastBufferizableOp(op))
514 continue;
515
516 // We cannot analyze through ToMemrefOps, so we have to conservatively
517 // assume that the value is yielded.
518 if (isa<ToMemrefOp>(op))
519 return true;
520
521 // Check if the op is returning/yielding.
522 if (isRegionReturnLike(op))
523 return true;
524
525 // Add all aliasing OpResults to the worklist.
526 // Note: In the absence of detailed analysis information (e.g., there may be
527 // no function call analysis information), this `getAliasingOpResult` is
528 // conservative and may report additional OpResults as potentially aliasing.
529 for (OpResult opResult : getAliasingOpResult(*operand))
530 for (OpOperand &use : opResult.getUses())
531 worklist.push_back(&use);
532 }
533
534 // No ReturnLike op found: The value is not yielded.
535 return false;
536 }
537
538 // bufferization.to_memref is not allowed to change the rank.
ensureToMemrefOpIsValid(Value tensor,Type memrefType)539 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
540 #ifndef NDEBUG
541 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
542 assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
543 rankedTensorType.getRank()) &&
544 "to_memref would be invalid: mismatching ranks");
545 #endif
546 }
547
getBuffer(RewriterBase & rewriter,Value value,const BufferizationOptions & options)548 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
549 const BufferizationOptions &options) {
550 #ifndef NDEBUG
551 auto tensorType = value.getType().dyn_cast<TensorType>();
552 assert(tensorType && "unexpected non-tensor type");
553 #endif // NDEBUG
554
555 // Replace "%t = to_tensor %m" with %m.
556 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
557 return toTensorOp.getMemref();
558
559 // Insert to_memref op.
560 OpBuilder::InsertionGuard g(rewriter);
561 setInsertionPointAfter(rewriter, value);
562 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
563 if (failed(memrefType))
564 return failure();
565 ensureToMemrefOpIsValid(value, *memrefType);
566 return rewriter
567 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
568 .getResult();
569 }
570
571 /// Return the buffer type for a given Value (tensor) after bufferization.
572 FailureOr<BaseMemRefType>
getBufferType(Value value,const BufferizationOptions & options)573 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
574 assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
575 Operation *op = getOwnerOfValue(value);
576
577 // ToTensorOp: Take buffer type directly from the op.
578 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
579 return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
580
581 // If value is a bbArg of a bufferizable op: query op interface.
582 if (auto bbArg = value.dyn_cast<BlockArgument>())
583 if (auto bufferizableOp =
584 options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
585 return bufferizableOp.getBufferType(bbArg, options);
586
587 // Check value is a new buffer allocation with a memory space attribute. In
588 // that case we can at least infer the memory space.
589 Optional<unsigned> memorySpace = None;
590 if (auto opResult = value.dyn_cast<OpResult>()) {
591 if (auto bufferizableOp =
592 options.dynCastBufferizableOp(opResult.getDefiningOp())) {
593 if (bufferizableOp.bufferizesToAllocation(opResult)) {
594 FailureOr<unsigned> queriedMemorySpace =
595 bufferizableOp.getMemorySpace(opResult);
596 if (!failed(queriedMemorySpace))
597 memorySpace = *queriedMemorySpace;
598 }
599 }
600 }
601
602 // If we still do not know the memory space, use the default memory space (if
603 // any).
604 if (!memorySpace.has_value())
605 memorySpace = options.defaultMemorySpace;
606
607 // If we still do not know the memory space, report a failure.
608 if (!memorySpace.has_value())
609 return op->emitError("could not infer memory space");
610
611 return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
612 }
613
replaceOpWithBufferizedValues(RewriterBase & rewriter,Operation * op,ValueRange values)614 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
615 Operation *op,
616 ValueRange values) {
617 assert(values.size() == op->getNumResults() &&
618 "expected one value per OpResult");
619 OpBuilder::InsertionGuard g(rewriter);
620
621 // Replace all OpResults with the given values.
622 SmallVector<Value> replacements;
623 for (OpResult opResult : op->getOpResults()) {
624 Value replacement = values[opResult.getResultNumber()];
625 if (opResult.getType().isa<TensorType>()) {
626 // The OpResult is a tensor. Such values are replaced with memrefs during
627 // bufferization.
628 assert((replacement.getType().isa<MemRefType>() ||
629 replacement.getType().isa<UnrankedMemRefType>()) &&
630 "tensor op result should be replaced with a memref value");
631 // The existing uses of the OpResult still expect a tensor. Insert a
632 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
633 // loose all of its users and eventually DCE away.
634 rewriter.setInsertionPointAfter(op);
635 replacement = rewriter.create<bufferization::ToTensorOp>(
636 replacement.getLoc(), replacement);
637 }
638 replacements.push_back(replacement);
639 }
640
641 rewriter.replaceOp(op, replacements);
642 }
643
644 //===----------------------------------------------------------------------===//
645 // Bufferization-specific scoped alloc/dealloc insertion support.
646 //===----------------------------------------------------------------------===//
647
648 /// Create a memref allocation with the given type and dynamic extents.
createAlloc(OpBuilder & b,Location loc,MemRefType type,ValueRange dynShape) const649 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
650 MemRefType type,
651 ValueRange dynShape) const {
652 if (allocationFn)
653 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
654
655 // Default bufferallocation via AllocOp.
656 if (bufferAlignment != 0)
657 return b
658 .create<memref::AllocOp>(loc, type, dynShape,
659 b.getI64IntegerAttr(bufferAlignment))
660 .getResult();
661 return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
662 }
663
664 /// Creates a memref deallocation. The given memref buffer must have been
665 /// allocated using `createAlloc`.
createDealloc(OpBuilder & b,Location loc,Value allocatedBuffer) const666 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
667 Value allocatedBuffer) const {
668 if (deallocationFn)
669 return (*deallocationFn)(b, loc, allocatedBuffer);
670
671 // Default buffer deallocation via DeallocOp.
672 b.create<memref::DeallocOp>(loc, allocatedBuffer);
673 return success();
674 }
675
676 /// Create a memory copy between two memref buffers.
createMemCpy(OpBuilder & b,Location loc,Value from,Value to) const677 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
678 Value from, Value to) const {
679 if (memCpyFn)
680 return (*memCpyFn)(b, loc, from, to);
681
682 b.create<memref::CopyOp>(loc, from, to);
683 return success();
684 }
685
686 //===----------------------------------------------------------------------===//
687 // Bufferization-specific BlockAndValueMapping support with debugging.
688 //===----------------------------------------------------------------------===//
689
isFunctionArgument(Value value)690 bool bufferization::isFunctionArgument(Value value) {
691 auto bbArg = value.dyn_cast<BlockArgument>();
692 if (!bbArg)
693 return false;
694 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
695 }
696
getMemRefType(Value value,const BufferizationOptions & options,MemRefLayoutAttrInterface layout,unsigned memorySpace)697 BaseMemRefType bufferization::getMemRefType(Value value,
698 const BufferizationOptions &options,
699 MemRefLayoutAttrInterface layout,
700 unsigned memorySpace) {
701 auto tensorType = value.getType().cast<TensorType>();
702 auto memorySpaceAttr = IntegerAttr::get(
703 IntegerType::get(tensorType.getContext(), 64), memorySpace);
704
705 // Case 1: Unranked memref type.
706 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
707 assert(!layout && "UnrankedTensorType cannot have a layout map");
708 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
709 memorySpaceAttr);
710 }
711
712 // Case 2: Ranked memref type with specified layout.
713 auto rankedTensorType = tensorType.cast<RankedTensorType>();
714 if (layout) {
715 return MemRefType::get(rankedTensorType.getShape(),
716 rankedTensorType.getElementType(), layout,
717 memorySpaceAttr);
718 }
719
720 return options.unknownTypeConverterFn(value, memorySpace, options);
721 }
722
723 BaseMemRefType
getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,unsigned memorySpace)724 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
725 unsigned memorySpace) {
726 // Case 1: Unranked memref type.
727 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
728 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
729 memorySpace);
730 }
731
732 // Case 2: Ranked memref type.
733 auto memorySpaceAttr = IntegerAttr::get(
734 IntegerType::get(tensorType.getContext(), 64), memorySpace);
735 auto rankedTensorType = tensorType.cast<RankedTensorType>();
736 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
737 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
738 ShapedType::kDynamicStrideOrOffset);
739 AffineMap stridedLayout = makeStridedLinearLayoutMap(
740 dynamicStrides, dynamicOffset, rankedTensorType.getContext());
741 return MemRefType::get(rankedTensorType.getShape(),
742 rankedTensorType.getElementType(), stridedLayout,
743 memorySpaceAttr);
744 }
745
746 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
747 /// the given tensor type is unranked, return an unranked MemRef type.
748 BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,unsigned memorySpace)749 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
750 unsigned memorySpace) {
751 // Case 1: Unranked memref type.
752 if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
753 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
754 memorySpace);
755 }
756
757 // Case 2: Ranked memref type.
758 auto rankedTensorType = tensorType.cast<RankedTensorType>();
759 auto memorySpaceAttr = IntegerAttr::get(
760 IntegerType::get(tensorType.getContext(), 64), memorySpace);
761 MemRefLayoutAttrInterface layout = {};
762 return MemRefType::get(rankedTensorType.getShape(),
763 rankedTensorType.getElementType(), layout,
764 memorySpaceAttr);
765 }
766