1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Matchers.h"
18
19 using namespace mlir;
20 using namespace mlir::bufferization;
21
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25
26 FailureOr<Value>
castOrReallocMemRefValue(OpBuilder & b,Value value,MemRefType destType)27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28 MemRefType destType) {
29 auto srcType = value.getType().cast<MemRefType>();
30
31 // Element type, rank and memory space must match.
32 if (srcType.getElementType() != destType.getElementType())
33 return failure();
34 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35 return failure();
36 if (srcType.getRank() != destType.getRank())
37 return failure();
38
39 // In case the affine maps are different, we may need to use a copy if we go
40 // from dynamic to static offset or stride (the canonicalization cannot know
41 // at this point that it is really cast compatible).
42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43 int64_t sourceOffset, targetOffset;
44 SmallVector<int64_t, 4> sourceStrides, targetStrides;
45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46 failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47 return false;
48 auto dynamicToStatic = [](int64_t a, int64_t b) {
49 return a == MemRefType::getDynamicStrideOrOffset() &&
50 b != MemRefType::getDynamicStrideOrOffset();
51 };
52 if (dynamicToStatic(sourceOffset, targetOffset))
53 return false;
54 for (auto it : zip(sourceStrides, targetStrides))
55 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56 return false;
57 return true;
58 };
59
60 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61 // ensure that we only generate casts that always succeed at runtime, we check
62 // a fix extra conditions in `isGuaranteedCastCompatible`.
63 if (memref::CastOp::areCastCompatible(srcType, destType) &&
64 isGuaranteedCastCompatible(srcType, destType)) {
65 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66 return casted;
67 }
68
69 auto loc = value.getLoc();
70 SmallVector<Value, 4> dynamicOperands;
71 for (int i = 0; i < destType.getRank(); ++i) {
72 if (destType.getShape()[i] != ShapedType::kDynamicSize)
73 continue;
74 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75 Value size = b.create<memref::DimOp>(loc, value, index);
76 dynamicOperands.push_back(size);
77 }
78 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79 // BufferizableOpInterface impl of ToMemrefOp.
80 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81 b.create<memref::CopyOp>(loc, value, copy);
82 return copy;
83 }
84
85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86 /// to_memref op are different, a memref.cast is needed.
87 LogicalResult
foldToMemrefToTensorPair(RewriterBase & rewriter,ToMemrefOp toMemref)88 mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
89 ToMemrefOp toMemref) {
90 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
91 if (!memrefToTensor)
92 return failure();
93
94 Type srcType = memrefToTensor.getMemref().getType();
95 Type destType = toMemref.getType();
96
97 // Directly rewrite if the type did not change.
98 if (srcType == destType) {
99 rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100 return success();
101 }
102
103 auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104 auto rankedDestType = destType.dyn_cast<MemRefType>();
105 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106
107 // Ranked memref -> Ranked memref cast.
108 if (rankedSrcType && rankedDestType) {
109 FailureOr<Value> replacement = castOrReallocMemRefValue(
110 rewriter, memrefToTensor.getMemref(), rankedDestType);
111 if (failed(replacement))
112 return failure();
113
114 rewriter.replaceOp(toMemref, *replacement);
115 return success();
116 }
117
118 // Unranked memref -> Ranked memref cast: May require a copy.
119 // TODO: Not implemented at the moment.
120 if (unrankedSrcType && rankedDestType)
121 return failure();
122
123 // Unranked memref -> unranked memref cast
124 // Ranked memref -> unranked memref cast: No copy needed.
125 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126 "expected that types are cast compatible");
127 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128 memrefToTensor.getMemref());
129 return success();
130 }
131
populateDynamicDimSizes(OpBuilder & b,Location loc,Value shapedValue,SmallVector<Value> & dynamicDims)132 void mlir::bufferization::populateDynamicDimSizes(
133 OpBuilder &b, Location loc, Value shapedValue,
134 SmallVector<Value> &dynamicDims) {
135 auto shapedType = shapedValue.getType().cast<ShapedType>();
136 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137 if (shapedType.isDynamicDim(i)) {
138 if (shapedType.isa<MemRefType>()) {
139 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140 } else {
141 assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143 }
144 }
145 }
146 }
147
148 //===----------------------------------------------------------------------===//
149 // AllocTensorOp
150 //===----------------------------------------------------------------------===//
151
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153 const BufferizationOptions &options) {
154 OpBuilder::InsertionGuard g(rewriter);
155 Operation *op = this->getOperation();
156 Location loc = getLoc();
157
158 // Nothing to do for dead AllocTensorOps.
159 if (getOperation()->getUses().empty()) {
160 rewriter.eraseOp(getOperation());
161 return success();
162 }
163
164 // Get "copy" buffer.
165 Value copyBuffer;
166 if (getCopy()) {
167 FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168 if (failed(maybeCopyBuffer))
169 return failure();
170 copyBuffer = *maybeCopyBuffer;
171 }
172
173 // Compute memory space of this allocation.
174 unsigned memorySpace;
175 if (getMemorySpace().has_value()) {
176 memorySpace = *getMemorySpace();
177 } else if (getCopy()) {
178 memorySpace =
179 copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
180 } else if (options.defaultMemorySpace.has_value()) {
181 memorySpace = *options.defaultMemorySpace;
182 } else {
183 return op->emitError("could not infer memory space");
184 }
185
186 // Create memory allocation.
187 auto allocType =
188 MemRefType::get(getType().getShape(), getType().getElementType(),
189 AffineMap(), memorySpace);
190 SmallVector<Value> dynamicDims = getDynamicSizes();
191 if (getCopy()) {
192 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
193 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
194 }
195 FailureOr<Value> alloc =
196 options.createAlloc(rewriter, loc, allocType, dynamicDims);
197 if (failed(alloc))
198 return failure();
199
200 // Create memory copy (if any).
201 if (getCopy()) {
202 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
203 return failure();
204 }
205
206 // Should the buffer be deallocated?
207 bool dealloc =
208 shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
209
210 // Replace op.
211 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
212
213 // Create buffer deallocation (if requested).
214 if (!dealloc)
215 return success();
216
217 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
218 if (failed(options.createDealloc(rewriter, loc, *alloc)))
219 return failure();
220 return success();
221 }
222
isMemoryWrite(OpResult opResult,const AnalysisState & state)223 bool AllocTensorOp::isMemoryWrite(OpResult opResult,
224 const AnalysisState &state) {
225 // AllocTensorOps do not write unless they have a `copy` value.
226 return static_cast<bool>(getCopy());
227 }
228
bufferizesToMemoryRead(OpOperand & opOperand,const AnalysisState & state)229 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
230 const AnalysisState &state) {
231 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
232 "expected copy operand");
233 return true;
234 }
235
bufferizesToMemoryWrite(OpOperand & opOperand,const AnalysisState & state)236 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
237 const AnalysisState &state) {
238 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
239 "expected copy operand");
240 return false;
241 }
242
243 SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand,const AnalysisState & state)244 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
245 const AnalysisState &state) {
246 // This is a new allocation. It does not alias with any other buffer.
247 return {};
248 }
249
verify()250 LogicalResult AllocTensorOp::verify() {
251 if (getCopy() && !getDynamicSizes().empty())
252 return emitError("dynamic sizes not needed when copying a tensor");
253 if (!getCopy() && getType().getNumDynamicDims() !=
254 static_cast<int64_t>(getDynamicSizes().size()))
255 return emitError("expected ")
256 << getType().getNumDynamicDims() << " dynamic sizes";
257 if (getCopy() && getCopy().getType() != getType())
258 return emitError("expected that `copy` and return type match");
259
260 // For sparse tensor allocation, we require that none of its
261 // uses escapes the function boundary directly.
262 if (sparse_tensor::getSparseTensorEncoding(getType())) {
263 for (auto &use : getOperation()->getUses())
264 if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
265 use.getOwner()))
266 return emitError("sparse tensor allocation should not escape function");
267 }
268
269 return success();
270 }
271
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes)272 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
273 RankedTensorType type, ValueRange dynamicSizes) {
274 build(builder, result, type, dynamicSizes, /*copy=*/Value(),
275 /*memory_space=*/IntegerAttr());
276 }
277
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes,Value copy)278 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
279 RankedTensorType type, ValueRange dynamicSizes,
280 Value copy) {
281 build(builder, result, type, dynamicSizes, copy,
282 /*memory_space=*/IntegerAttr());
283 }
284
285 namespace {
286 /// Change the type of the result of a `bufferization.alloc_tensor` by making
287 /// the result type statically sized along dimension that in the original
288 /// operation where defined as dynamic, but the size was defined using a
289 /// `constant` op. For example:
290 ///
291 /// %c5 = arith.constant 5: index
292 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
293 ///
294 /// to
295 ///
296 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
297 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
298 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
299
matchAndRewrite__anon56ded4390311::ReplaceStaticShapeDims300 LogicalResult matchAndRewrite(AllocTensorOp op,
301 PatternRewriter &rewriter) const override {
302 if (op.getCopy())
303 return failure();
304 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305 SmallVector<Value> newDynamicSizes;
306 unsigned int dynValCounter = 0;
307 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308 if (!op.isDynamicDim(i))
309 continue;
310 Value value = op.getDynamicSizes()[dynValCounter++];
311 APInt intVal;
312 if (matchPattern(value, m_ConstantInt(&intVal))) {
313 newShape[i] = intVal.getSExtValue();
314 } else {
315 newDynamicSizes.push_back(value);
316 }
317 }
318 RankedTensorType newType = RankedTensorType::get(
319 newShape, op.getType().getElementType(), op.getType().getEncoding());
320 if (newType == op.getType())
321 return failure();
322 auto newOp = rewriter.create<AllocTensorOp>(
323 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325 return success();
326 }
327 };
328
329 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
330 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
331
matchAndRewrite__anon56ded4390311::FoldDimOfAllocTensorOp332 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333 PatternRewriter &rewriter) const override {
334 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
335 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336 if (!allocTensorOp || !maybeConstantIndex)
337 return failure();
338 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
339 return failure();
340 rewriter.replaceOp(
341 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
342 return success();
343 }
344 };
345 } // namespace
346
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * ctx)347 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
348 MLIRContext *ctx) {
349 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
350 }
351
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)352 LogicalResult AllocTensorOp::reifyResultShapes(
353 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
354 auto shapes = llvm::to_vector<4>(llvm::map_range(
355 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
356 if (isDynamicDim(dim))
357 return getDynamicSize(builder, dim);
358 return builder.create<arith::ConstantIndexOp>(getLoc(),
359 getStaticSize(dim));
360 }));
361 reifiedReturnShapes.emplace_back(std::move(shapes));
362 return success();
363 }
364
parse(OpAsmParser & parser,OperationState & result)365 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
366 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
367 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
368 parser.parseRParen())
369 return failure();
370 ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
371 OpAsmParser::UnresolvedOperand copyOperand;
372 if (copyKeyword.succeeded())
373 if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
374 parser.parseRParen())
375 return failure();
376 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
377 return failure();
378
379 TensorType type;
380 if (parser.parseCustomTypeWithFallback(type))
381 return failure();
382 result.addTypes(type);
383
384 Type indexType = parser.getBuilder().getIndexType();
385 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
386 return failure();
387 if (copyKeyword.succeeded())
388 if (parser.resolveOperand(copyOperand, type, result.operands))
389 return failure();
390 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
391 parser.getBuilder().getI32VectorAttr(
392 {static_cast<int32_t>(dynamicSizesOperands.size()),
393 static_cast<int32_t>(copyKeyword.succeeded())}));
394 return success();
395 }
396
print(OpAsmPrinter & p)397 void AllocTensorOp::print(OpAsmPrinter &p) {
398 p << "(" << getDynamicSizes() << ")";
399 if (getCopy())
400 p << " copy(" << getCopy() << ")";
401 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
402 AllocTensorOp::getOperandSegmentSizeAttr()});
403 p << " : ";
404 auto type = getResult().getType();
405 if (auto validType = type.dyn_cast<::mlir::TensorType>())
406 p.printStrippedAttrOrType(validType);
407 else
408 p << type;
409 }
410
getDynamicSize(OpBuilder & b,unsigned idx)411 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
412 assert(isDynamicDim(idx) && "expected dynamic dim");
413 if (getCopy())
414 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
415 return getOperand(getIndexOfDynamicSize(idx));
416 }
417
418 //===----------------------------------------------------------------------===//
419 // CloneOp
420 //===----------------------------------------------------------------------===//
421
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)422 void CloneOp::getEffects(
423 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
424 &effects) {
425 effects.emplace_back(MemoryEffects::Read::get(), getInput(),
426 SideEffects::DefaultResource::get());
427 effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
428 SideEffects::DefaultResource::get());
429 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
430 SideEffects::DefaultResource::get());
431 }
432
fold(ArrayRef<Attribute> operands)433 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
434 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
435 }
436
437 namespace {
438
439 /// Merge the clone and its source (by converting the clone to a cast) when
440 /// possible.
441 struct SimplifyClones : public OpRewritePattern<CloneOp> {
442 using OpRewritePattern<CloneOp>::OpRewritePattern;
443
matchAndRewrite__anon56ded4390511::SimplifyClones444 LogicalResult matchAndRewrite(CloneOp cloneOp,
445 PatternRewriter &rewriter) const override {
446 if (cloneOp.use_empty()) {
447 rewriter.eraseOp(cloneOp);
448 return success();
449 }
450
451 Value source = cloneOp.getInput();
452
453 // This only finds dealloc operations for the immediate value. It should
454 // also consider aliases. That would also make the safety check below
455 // redundant.
456 llvm::Optional<Operation *> maybeCloneDeallocOp =
457 memref::findDealloc(cloneOp.getOutput());
458 // Skip if either of them has > 1 deallocate operations.
459 if (!maybeCloneDeallocOp.has_value())
460 return failure();
461 llvm::Optional<Operation *> maybeSourceDeallocOp =
462 memref::findDealloc(source);
463 if (!maybeSourceDeallocOp.has_value())
464 return failure();
465 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
466 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
467
468 // If both are deallocated in the same block, their in-block lifetimes
469 // might not fully overlap, so we cannot decide which one to drop.
470 if (cloneDeallocOp && sourceDeallocOp &&
471 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
472 return failure();
473
474 Block *currentBlock = cloneOp->getBlock();
475 Operation *redundantDealloc = nullptr;
476 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
477 redundantDealloc = cloneDeallocOp;
478 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
479 redundantDealloc = sourceDeallocOp;
480 }
481
482 if (!redundantDealloc)
483 return failure();
484
485 // Safety check that there are no other deallocations inbetween
486 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
487 // of source before the uses of the clone. With alias information, we could
488 // restrict this to only fail of the dealloc's operand is an alias
489 // of the source.
490 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
491 pos = pos->getNextNode()) {
492 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
493 if (!effectInterface)
494 continue;
495 if (effectInterface.hasEffect<MemoryEffects::Free>())
496 return failure();
497 }
498
499 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
500 source);
501 rewriter.eraseOp(redundantDealloc);
502 return success();
503 }
504 };
505
506 } // namespace
507
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)508 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
509 MLIRContext *context) {
510 results.add<SimplifyClones>(context);
511 }
512
513 //===----------------------------------------------------------------------===//
514 // DeallocTensorOp
515 //===----------------------------------------------------------------------===//
516
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)517 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
518 const BufferizationOptions &options) {
519 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
520 if (failed(buffer))
521 return failure();
522 if (failed(options.createDealloc(rewriter, getLoc(), *buffer)))
523 return failure();
524 rewriter.eraseOp(getOperation());
525 return success();
526 }
527
528 //===----------------------------------------------------------------------===//
529 // ToTensorOp
530 //===----------------------------------------------------------------------===//
531
fold(ArrayRef<Attribute>)532 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
533 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
534 // Approximate alias analysis by conservatively folding only when no there
535 // is no interleaved operation.
536 if (toMemref->getBlock() == this->getOperation()->getBlock() &&
537 toMemref->getNextNode() == this->getOperation())
538 return toMemref.getTensor();
539 return {};
540 }
541
542 namespace {
543 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
544 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
545
matchAndRewrite__anon56ded4390611::DimOfToTensorFolder546 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
547 PatternRewriter &rewriter) const override {
548 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
549 if (!memrefToTensorOp)
550 return failure();
551
552 rewriter.replaceOpWithNewOp<memref::DimOp>(
553 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
554 return success();
555 }
556 };
557 } // namespace
558
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)559 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
560 MLIRContext *context) {
561 results.add<DimOfToTensorFolder>(context);
562 }
563
564 //===----------------------------------------------------------------------===//
565 // ToMemrefOp
566 //===----------------------------------------------------------------------===//
567
fold(ArrayRef<Attribute>)568 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
569 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
570 if (memrefToTensor.getMemref().getType() == getType())
571 return memrefToTensor.getMemref();
572 return {};
573 }
574
575 namespace {
576
577 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
578 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
579 using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
580
matchAndRewrite__anon56ded4390711::ToMemrefOfCast581 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
582 PatternRewriter &rewriter) const final {
583 auto tensorCastOperand =
584 toMemref.getOperand().getDefiningOp<tensor::CastOp>();
585 if (!tensorCastOperand)
586 return failure();
587 auto srcTensorType =
588 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
589 if (!srcTensorType)
590 return failure();
591 auto memrefType = MemRefType::get(srcTensorType.getShape(),
592 srcTensorType.getElementType());
593 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
594 tensorCastOperand.getOperand());
595 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
596 memref);
597 return success();
598 }
599 };
600
601 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
602 /// cast if necessary.
603 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
604 using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
605
matchAndRewrite__anon56ded4390711::ToMemrefToTensorFolding606 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
607 PatternRewriter &rewriter) const final {
608 return foldToMemrefToTensorPair(rewriter, toMemref);
609 }
610 };
611
612 /// Fold a load on a to_memref operation into an tensor.extract on the
613 /// corresponding tensor.
614 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
615 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
616
matchAndRewrite__anon56ded4390711::LoadOfToMemref617 LogicalResult matchAndRewrite(memref::LoadOp load,
618 PatternRewriter &rewriter) const override {
619 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
620 if (!toMemref)
621 return failure();
622
623 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
624 load.getIndices());
625 return success();
626 }
627 };
628
629 /// Fold dim of a to_memref into the dim of the tensor.
630 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
631 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
632
matchAndRewrite__anon56ded4390711::DimOfCastOp633 LogicalResult matchAndRewrite(memref::DimOp dimOp,
634 PatternRewriter &rewriter) const override {
635 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
636 if (!castOp)
637 return failure();
638 Value newSource = castOp.getOperand();
639 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
640 dimOp.getIndex());
641 return success();
642 }
643 };
644
645 } // namespace
646
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)647 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
648 MLIRContext *context) {
649 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
650 ToMemrefToTensorFolding>(context);
651 }
652
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)653 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
654 const BufferizationOptions &options) {
655 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
656 (void)foldToMemrefToTensorPair(rewriter, *this);
657 // Note: The return value of `bufferize` indicates whether there was an error
658 // or not. (And not whether the pattern matched or not.)
659 return success();
660 }
661
buildDealloc(OpBuilder & builder,Value alloc)662 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
663 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
664 .getOperation();
665 }
666
buildClone(OpBuilder & builder,Value alloc)667 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
668 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
669 }
670
671 //===----------------------------------------------------------------------===//
672 // TableGen'd op method definitions
673 //===----------------------------------------------------------------------===//
674
675 #define GET_OP_CLASSES
676 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
677