1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 
19 //===----------------------------------------------------------------------===//
20 // Helper functions
21 //===----------------------------------------------------------------------===//
22 
23 FailureOr<Value>
24 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
25                                               MemRefType destType) {
26   auto srcType = value.getType().cast<MemRefType>();
27 
28   // Element type, rank and memory space must match.
29   if (srcType.getElementType() != destType.getElementType())
30     return failure();
31   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
32     return failure();
33   if (srcType.getRank() != destType.getRank())
34     return failure();
35 
36   // In case the affine maps are different, we may need to use a copy if we go
37   // from dynamic to static offset or stride (the canonicalization cannot know
38   // at this point that it is really cast compatible).
39   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
40     int64_t sourceOffset, targetOffset;
41     SmallVector<int64_t, 4> sourceStrides, targetStrides;
42     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
43         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
44       return false;
45     auto dynamicToStatic = [](int64_t a, int64_t b) {
46       return a == MemRefType::getDynamicStrideOrOffset() &&
47              b != MemRefType::getDynamicStrideOrOffset();
48     };
49     if (dynamicToStatic(sourceOffset, targetOffset))
50       return false;
51     for (auto it : zip(sourceStrides, targetStrides))
52       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
53         return false;
54     return true;
55   };
56 
57   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
58   // ensure that we only generate casts that always succeed at runtime, we check
59   // a fix extra conditions in `isGuaranteedCastCompatible`.
60   if (memref::CastOp::areCastCompatible(srcType, destType) &&
61       isGuaranteedCastCompatible(srcType, destType)) {
62     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
63     return casted;
64   }
65 
66   auto loc = value.getLoc();
67   SmallVector<Value, 4> dynamicOperands;
68   for (int i = 0; i < destType.getRank(); ++i) {
69     if (destType.getShape()[i] != ShapedType::kDynamicSize)
70       continue;
71     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
72     Value size = b.create<memref::DimOp>(loc, value, index);
73     dynamicOperands.push_back(size);
74   }
75   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
76   // BufferizableOpInterface impl of ToMemrefOp.
77   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
78   b.create<memref::CopyOp>(loc, value, copy);
79   return copy;
80 }
81 
82 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
83 /// to_memref op are different, a memref.cast is needed.
84 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
85     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
86   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
87   if (!memrefToTensor)
88     return failure();
89 
90   Type srcType = memrefToTensor.memref().getType();
91   Type destType = toMemref.getType();
92 
93   // Directly rewrite if the type did not change.
94   if (srcType == destType) {
95     // Function can be configured to only handle cases where a cast is needed.
96     if (!allowSameType)
97       return failure();
98     rewriter.replaceOp(toMemref, memrefToTensor.memref());
99     return success();
100   }
101 
102   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
103   auto rankedDestType = destType.dyn_cast<MemRefType>();
104   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
105 
106   // Ranked memref -> Ranked memref cast.
107   if (rankedSrcType && rankedDestType) {
108     FailureOr<Value> replacement = castOrReallocMemRefValue(
109         rewriter, memrefToTensor.memref(), rankedDestType);
110     if (failed(replacement))
111       return failure();
112 
113     rewriter.replaceOp(toMemref, *replacement);
114     return success();
115   }
116 
117   // Unranked memref -> Ranked memref cast: May require a copy.
118   // TODO: Not implemented at the moment.
119   if (unrankedSrcType && rankedDestType)
120     return failure();
121 
122   // Unranked memref -> unranked memref cast
123   // Ranked memref -> unranked memref cast: No copy needed.
124   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
125          "expected that types are cast compatible");
126   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
127                                               memrefToTensor.memref());
128   return success();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // AllocTensorOp
133 //===----------------------------------------------------------------------===//
134 
135 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
136                                        BufferizationState &state) {
137   // Nothing to do for dead AllocTensorOps.
138   if (getOperation()->getUses().empty())
139     return success();
140 
141   FailureOr<Value> alloc = state.createAlloc(rewriter, getLoc(), getResult());
142   if (failed(alloc))
143     return failure();
144   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
145   return success();
146 }
147 
148 void AllocTensorOp::build(OpBuilder &b, OperationState &result,
149                           ArrayRef<OpFoldResult> sizes, Type elementType,
150                           ArrayRef<NamedAttribute> attrs) {
151   SmallVector<Value, 4> dynamicSizes;
152   SmallVector<int64_t, 4> staticSizes;
153   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
154                              ShapedType::kDynamicSize);
155   auto resultType = RankedTensorType ::get(staticSizes, elementType);
156   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
157   result.addAttributes(attrs);
158 }
159 
160 LogicalResult AllocTensorOp::verify() {
161   RankedTensorType resultType = getType();
162   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
163       static_sizes().cast<ArrayAttr>(),
164       [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
165 
166   if (failed(verifyListOfOperandsOrIntegers(
167           *this, "sizes", resultType.getRank(), static_sizes(), sizes(),
168           ShapedType::isDynamic)))
169     return failure();
170 
171   if (static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
172     return emitError("expected ") << resultType.getRank() << " sizes values";
173 
174   Type expectedType = AllocTensorOp::inferResultType(
175       staticSizes, resultType.getElementType(), resultType.getEncoding());
176   if (resultType != expectedType) {
177     return emitError("specified type ")
178            << resultType << " does not match the inferred type "
179            << expectedType;
180   }
181   return success();
182 }
183 
184 Type AllocTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
185                                     Type elementType, Attribute encoding) {
186   return RankedTensorType::get(staticSizes, elementType, encoding);
187 }
188 
189 SmallVector<OpFoldResult> AllocTensorOp::getMixedSizes() {
190   SmallVector<OpFoldResult> mixedSizes;
191   mixedSizes.reserve(getType().getRank());
192   unsigned dynamicValIndex = 0;
193   for (Attribute attr : static_sizes()) {
194     auto intAttr = attr.cast<IntegerAttr>();
195     if (!ShapedType::isDynamic(intAttr.getInt())) {
196       mixedSizes.push_back(intAttr);
197       continue;
198     }
199     mixedSizes.push_back(sizes()[dynamicValIndex++]);
200   }
201   return mixedSizes;
202 }
203 
204 namespace {
205 /// Change the type of the result of a `bufferization.alloc_tensor` by making
206 /// the result type statically sized along dimension that in the original
207 /// operation where defined as dynamic, but the size was defined using a
208 /// `constant` op. For example:
209 ///
210 ///  %c5 = arith.constant 5: index
211 ///  %0 = bufferization.alloc_tensor [%arg0, %c5] : tensor<?x?xf32>
212 ///
213 ///  to
214 ///
215 ///  %0 = bufferization.alloc_tensor [%arg0, 5] : tensor<?x5xf32>
216 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
217   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
218 
219   LogicalResult matchAndRewrite(AllocTensorOp op,
220                                 PatternRewriter &rewriter) const override {
221     SmallVector<Value, 4> dynamicSizes;
222     SmallVector<int64_t, 4> staticSizes;
223     for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
224       // If the size is already static, nothing to do.
225       if (!op.isDynamicSize(i)) {
226         staticSizes.push_back(op.getStaticSize(i));
227         continue;
228       }
229 
230       // If the size is dynamic but defined using a `constant` op, get the
231       // constant value to find the static size to use.
232       unsigned operandNum = op.getIndexOfDynamicSize(i);
233       Value sizeOperand = op.getOperand(operandNum);
234       if (auto constantIndexOp =
235               sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
236         staticSizes.push_back(constantIndexOp.value());
237         continue;
238       }
239 
240       // Fallback case. Keep the size dynamic.
241       dynamicSizes.push_back(sizeOperand);
242       staticSizes.push_back(ShapedType::kDynamicSize);
243     }
244     RankedTensorType newType =
245         RankedTensorType::get(staticSizes, op.getType().getElementType());
246     if (newType == op.getType())
247       return failure();
248     auto newOp =
249         rewriter.create<AllocTensorOp>(op.getLoc(), newType, dynamicSizes,
250                                        rewriter.getI64ArrayAttr(staticSizes));
251     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
252     return success();
253   }
254 };
255 
256 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
257   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
258 
259   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
260                                 PatternRewriter &rewriter) const override {
261     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
262     auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>();
263     if (!allocTensorOp || !maybeConstantIndex)
264       return failure();
265     if (!allocTensorOp.isDynamicSize(*maybeConstantIndex))
266       return failure();
267     rewriter.replaceOp(dimOp,
268                        allocTensorOp.getDynamicSize(*maybeConstantIndex));
269     return success();
270   }
271 };
272 } // namespace
273 
274 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
275                                                 MLIRContext *ctx) {
276   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
277 }
278 
279 LogicalResult AllocTensorOp::reifyResultShapes(
280     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
281   auto shapes = llvm::to_vector<4>(llvm::map_range(
282       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
283         if (isDynamicSize(dim))
284           return getDynamicSize(dim);
285         return builder.create<arith::ConstantIndexOp>(getLoc(),
286                                                       getStaticSize(dim));
287       }));
288   reifiedReturnShapes.emplace_back(std::move(shapes));
289   return success();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // CloneOp
294 //===----------------------------------------------------------------------===//
295 
296 void CloneOp::getEffects(
297     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
298         &effects) {
299   effects.emplace_back(MemoryEffects::Read::get(), input(),
300                        SideEffects::DefaultResource::get());
301   effects.emplace_back(MemoryEffects::Write::get(), output(),
302                        SideEffects::DefaultResource::get());
303   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
304                        SideEffects::DefaultResource::get());
305 }
306 
307 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
308   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
309 }
310 
311 namespace {
312 
313 /// Merge the clone and its source (by converting the clone to a cast) when
314 /// possible.
315 struct SimplifyClones : public OpRewritePattern<CloneOp> {
316   using OpRewritePattern<CloneOp>::OpRewritePattern;
317 
318   LogicalResult matchAndRewrite(CloneOp cloneOp,
319                                 PatternRewriter &rewriter) const override {
320     if (cloneOp.use_empty()) {
321       rewriter.eraseOp(cloneOp);
322       return success();
323     }
324 
325     Value source = cloneOp.input();
326 
327     // This only finds dealloc operations for the immediate value. It should
328     // also consider aliases. That would also make the safety check below
329     // redundant.
330     llvm::Optional<Operation *> maybeCloneDeallocOp =
331         memref::findDealloc(cloneOp.output());
332     // Skip if either of them has > 1 deallocate operations.
333     if (!maybeCloneDeallocOp.hasValue())
334       return failure();
335     llvm::Optional<Operation *> maybeSourceDeallocOp =
336         memref::findDealloc(source);
337     if (!maybeSourceDeallocOp.hasValue())
338       return failure();
339     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
340     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
341 
342     // If both are deallocated in the same block, their in-block lifetimes
343     // might not fully overlap, so we cannot decide which one to drop.
344     if (cloneDeallocOp && sourceDeallocOp &&
345         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
346       return failure();
347 
348     Block *currentBlock = cloneOp->getBlock();
349     Operation *redundantDealloc = nullptr;
350     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
351       redundantDealloc = cloneDeallocOp;
352     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
353       redundantDealloc = sourceDeallocOp;
354     }
355 
356     if (!redundantDealloc)
357       return failure();
358 
359     // Safety check that there are no other deallocations inbetween
360     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
361     // of source before the uses of the clone. With alias information, we could
362     // restrict this to only fail of the dealloc's operand is an alias
363     // of the source.
364     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
365          pos = pos->getNextNode()) {
366       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
367       if (!effectInterface)
368         continue;
369       if (effectInterface.hasEffect<MemoryEffects::Free>())
370         return failure();
371     }
372 
373     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
374                                                 source);
375     rewriter.eraseOp(redundantDealloc);
376     return success();
377   }
378 };
379 
380 } // namespace
381 
382 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
383                                           MLIRContext *context) {
384   results.add<SimplifyClones>(context);
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // ToTensorOp
389 //===----------------------------------------------------------------------===//
390 
391 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
392   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
393     // Approximate alias analysis by conservatively folding only when no there
394     // is no interleaved operation.
395     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
396         toMemref->getNextNode() == this->getOperation())
397       return toMemref.tensor();
398   return {};
399 }
400 
401 namespace {
402 
403 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
404   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
405 
406   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
407                                 PatternRewriter &rewriter) const override {
408     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
409     if (!memrefToTensorOp)
410       return failure();
411 
412     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
413                                                dimOp.index());
414     return success();
415   }
416 };
417 
418 } // namespace
419 
420 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
421                                              MLIRContext *context) {
422   results.add<DimOfToTensorFolder>(context);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // ToMemrefOp
427 //===----------------------------------------------------------------------===//
428 
429 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
430   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
431     if (memrefToTensor.memref().getType() == getType())
432       return memrefToTensor.memref();
433   return {};
434 }
435 
436 namespace {
437 
438 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
439 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
440   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
441 
442   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
443                                 PatternRewriter &rewriter) const final {
444     auto tensorCastOperand =
445         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
446     if (!tensorCastOperand)
447       return failure();
448     auto srcTensorType =
449         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
450     if (!srcTensorType)
451       return failure();
452     auto memrefType = MemRefType::get(srcTensorType.getShape(),
453                                       srcTensorType.getElementType());
454     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
455                                                tensorCastOperand.getOperand());
456     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
457                                                 memref);
458     return success();
459   }
460 };
461 
462 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
463 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
464 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
465   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
466 
467   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
468                                 PatternRewriter &rewriter) const final {
469     // Only handle cases where a cast is needed. The other case is handled by
470     // the folder.
471     return foldToMemrefToTensorPair(rewriter, toMemref,
472                                     /*allowSameType=*/false);
473   }
474 };
475 
476 /// Fold a load on a to_memref operation into an tensor.extract on the
477 /// corresponding tensor.
478 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
479   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
480 
481   LogicalResult matchAndRewrite(memref::LoadOp load,
482                                 PatternRewriter &rewriter) const override {
483     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
484     if (!toMemref)
485       return failure();
486 
487     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
488                                                    load.indices());
489     return success();
490   }
491 };
492 
493 /// Fold dim of a to_memref into the dim of the tensor.
494 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
495   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
496 
497   LogicalResult matchAndRewrite(memref::DimOp dimOp,
498                                 PatternRewriter &rewriter) const override {
499     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
500     if (!castOp)
501       return failure();
502     Value newSource = castOp.getOperand();
503     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
504     return success();
505   }
506 };
507 
508 } // namespace
509 
510 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
511                                              MLIRContext *context) {
512   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
513       context);
514 }
515 
516 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
517                                     BufferizationState &state) {
518   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
519   (void)foldToMemrefToTensorPair(rewriter, *this);
520   // Note: The return value of `bufferize` indicates whether there was an error
521   // or not. (And not whether the pattern matched or not.)
522   return success();
523 }
524 
525 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
526   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
527       .getOperation();
528 }
529 
530 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
531   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // TableGen'd op method definitions
536 //===----------------------------------------------------------------------===//
537 
538 #define GET_OP_CLASSES
539 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
540