1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 //===----------------------------------------------------------------------===//
21 // Helper functions
22 //===----------------------------------------------------------------------===//
23 
24 FailureOr<Value>
25 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
26                                               MemRefType destType) {
27   auto srcType = value.getType().cast<MemRefType>();
28 
29   // Element type, rank and memory space must match.
30   if (srcType.getElementType() != destType.getElementType())
31     return failure();
32   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
33     return failure();
34   if (srcType.getRank() != destType.getRank())
35     return failure();
36 
37   // In case the affine maps are different, we may need to use a copy if we go
38   // from dynamic to static offset or stride (the canonicalization cannot know
39   // at this point that it is really cast compatible).
40   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41     int64_t sourceOffset, targetOffset;
42     SmallVector<int64_t, 4> sourceStrides, targetStrides;
43     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
44         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
45       return false;
46     auto dynamicToStatic = [](int64_t a, int64_t b) {
47       return a == MemRefType::getDynamicStrideOrOffset() &&
48              b != MemRefType::getDynamicStrideOrOffset();
49     };
50     if (dynamicToStatic(sourceOffset, targetOffset))
51       return false;
52     for (auto it : zip(sourceStrides, targetStrides))
53       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
54         return false;
55     return true;
56   };
57 
58   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
59   // ensure that we only generate casts that always succeed at runtime, we check
60   // a fix extra conditions in `isGuaranteedCastCompatible`.
61   if (memref::CastOp::areCastCompatible(srcType, destType) &&
62       isGuaranteedCastCompatible(srcType, destType)) {
63     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
64     return casted;
65   }
66 
67   auto loc = value.getLoc();
68   SmallVector<Value, 4> dynamicOperands;
69   for (int i = 0; i < destType.getRank(); ++i) {
70     if (destType.getShape()[i] != ShapedType::kDynamicSize)
71       continue;
72     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
73     Value size = b.create<memref::DimOp>(loc, value, index);
74     dynamicOperands.push_back(size);
75   }
76   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77   // BufferizableOpInterface impl of ToMemrefOp.
78   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79   b.create<memref::CopyOp>(loc, value, copy);
80   return copy;
81 }
82 
83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
84 /// to_memref op are different, a memref.cast is needed.
85 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
86     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
87   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
88   if (!memrefToTensor)
89     return failure();
90 
91   Type srcType = memrefToTensor.memref().getType();
92   Type destType = toMemref.getType();
93 
94   // Directly rewrite if the type did not change.
95   if (srcType == destType) {
96     // Function can be configured to only handle cases where a cast is needed.
97     if (!allowSameType)
98       return failure();
99     rewriter.replaceOp(toMemref, memrefToTensor.memref());
100     return success();
101   }
102 
103   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104   auto rankedDestType = destType.dyn_cast<MemRefType>();
105   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106 
107   // Ranked memref -> Ranked memref cast.
108   if (rankedSrcType && rankedDestType) {
109     FailureOr<Value> replacement = castOrReallocMemRefValue(
110         rewriter, memrefToTensor.memref(), rankedDestType);
111     if (failed(replacement))
112       return failure();
113 
114     rewriter.replaceOp(toMemref, *replacement);
115     return success();
116   }
117 
118   // Unranked memref -> Ranked memref cast: May require a copy.
119   // TODO: Not implemented at the moment.
120   if (unrankedSrcType && rankedDestType)
121     return failure();
122 
123   // Unranked memref -> unranked memref cast
124   // Ranked memref -> unranked memref cast: No copy needed.
125   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126          "expected that types are cast compatible");
127   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128                                               memrefToTensor.memref());
129   return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // AllocTensorOp
134 //===----------------------------------------------------------------------===//
135 
136 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
137                                        BufferizationState &state) {
138   // Nothing to do for dead AllocTensorOps.
139   if (getOperation()->getUses().empty())
140     return success();
141 
142   FailureOr<Value> alloc = state.createAlloc(rewriter, getLoc(), getResult());
143   if (failed(alloc))
144     return failure();
145   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
146   return success();
147 }
148 
149 LogicalResult AllocTensorOp::verify() {
150   if (getType().getNumDynamicDims() !=
151       static_cast<int64_t>(dynamicSizes().size()))
152     return emitError("expected ")
153            << getType().getNumDynamicDims() << " dynamic sizes";
154   return success();
155 }
156 
157 namespace {
158 /// Change the type of the result of a `bufferization.alloc_tensor` by making
159 /// the result type statically sized along dimension that in the original
160 /// operation where defined as dynamic, but the size was defined using a
161 /// `constant` op. For example:
162 ///
163 ///  %c5 = arith.constant 5: index
164 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
165 ///
166 ///  to
167 ///
168 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
169 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
170   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
171 
172   LogicalResult matchAndRewrite(AllocTensorOp op,
173                                 PatternRewriter &rewriter) const override {
174     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
175     SmallVector<Value> newDynamicSizes;
176     unsigned int dynValCounter = 0;
177     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
178       if (!op.isDynamicDim(i))
179         continue;
180       Value value = op.dynamicSizes()[dynValCounter++];
181       APInt intVal;
182       if (matchPattern(value, m_ConstantInt(&intVal))) {
183         newShape[i] = intVal.getSExtValue();
184       } else {
185         newDynamicSizes.push_back(value);
186       }
187     }
188     RankedTensorType newType = RankedTensorType::get(
189         newShape, op.getType().getElementType(), op.getType().getEncoding());
190     if (newType == op.getType())
191       return failure();
192     auto newOp =
193         rewriter.create<AllocTensorOp>(op.getLoc(), newType, newDynamicSizes);
194     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
195     return success();
196   }
197 };
198 
199 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
200   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
203                                 PatternRewriter &rewriter) const override {
204     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
205     auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>();
206     if (!allocTensorOp || !maybeConstantIndex)
207       return failure();
208     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
209       return failure();
210     rewriter.replaceOp(dimOp,
211                        allocTensorOp.getDynamicSize(*maybeConstantIndex));
212     return success();
213   }
214 };
215 } // namespace
216 
217 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
218                                                 MLIRContext *ctx) {
219   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
220 }
221 
222 LogicalResult AllocTensorOp::reifyResultShapes(
223     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
224   auto shapes = llvm::to_vector<4>(llvm::map_range(
225       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
226         if (isDynamicDim(dim))
227           return getDynamicSize(dim);
228         return builder.create<arith::ConstantIndexOp>(getLoc(),
229                                                       getStaticSize(dim));
230       }));
231   reifiedReturnShapes.emplace_back(std::move(shapes));
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // CloneOp
237 //===----------------------------------------------------------------------===//
238 
239 void CloneOp::getEffects(
240     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
241         &effects) {
242   effects.emplace_back(MemoryEffects::Read::get(), input(),
243                        SideEffects::DefaultResource::get());
244   effects.emplace_back(MemoryEffects::Write::get(), output(),
245                        SideEffects::DefaultResource::get());
246   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
247                        SideEffects::DefaultResource::get());
248 }
249 
250 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
251   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
252 }
253 
254 namespace {
255 
256 /// Merge the clone and its source (by converting the clone to a cast) when
257 /// possible.
258 struct SimplifyClones : public OpRewritePattern<CloneOp> {
259   using OpRewritePattern<CloneOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(CloneOp cloneOp,
262                                 PatternRewriter &rewriter) const override {
263     if (cloneOp.use_empty()) {
264       rewriter.eraseOp(cloneOp);
265       return success();
266     }
267 
268     Value source = cloneOp.input();
269 
270     // This only finds dealloc operations for the immediate value. It should
271     // also consider aliases. That would also make the safety check below
272     // redundant.
273     llvm::Optional<Operation *> maybeCloneDeallocOp =
274         memref::findDealloc(cloneOp.output());
275     // Skip if either of them has > 1 deallocate operations.
276     if (!maybeCloneDeallocOp.hasValue())
277       return failure();
278     llvm::Optional<Operation *> maybeSourceDeallocOp =
279         memref::findDealloc(source);
280     if (!maybeSourceDeallocOp.hasValue())
281       return failure();
282     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
283     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
284 
285     // If both are deallocated in the same block, their in-block lifetimes
286     // might not fully overlap, so we cannot decide which one to drop.
287     if (cloneDeallocOp && sourceDeallocOp &&
288         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
289       return failure();
290 
291     Block *currentBlock = cloneOp->getBlock();
292     Operation *redundantDealloc = nullptr;
293     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
294       redundantDealloc = cloneDeallocOp;
295     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
296       redundantDealloc = sourceDeallocOp;
297     }
298 
299     if (!redundantDealloc)
300       return failure();
301 
302     // Safety check that there are no other deallocations inbetween
303     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
304     // of source before the uses of the clone. With alias information, we could
305     // restrict this to only fail of the dealloc's operand is an alias
306     // of the source.
307     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
308          pos = pos->getNextNode()) {
309       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
310       if (!effectInterface)
311         continue;
312       if (effectInterface.hasEffect<MemoryEffects::Free>())
313         return failure();
314     }
315 
316     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
317                                                 source);
318     rewriter.eraseOp(redundantDealloc);
319     return success();
320   }
321 };
322 
323 } // namespace
324 
325 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
326                                           MLIRContext *context) {
327   results.add<SimplifyClones>(context);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // ToTensorOp
332 //===----------------------------------------------------------------------===//
333 
334 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
335   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
336     // Approximate alias analysis by conservatively folding only when no there
337     // is no interleaved operation.
338     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
339         toMemref->getNextNode() == this->getOperation())
340       return toMemref.tensor();
341   return {};
342 }
343 
344 namespace {
345 
346 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
347   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
348 
349   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
350                                 PatternRewriter &rewriter) const override {
351     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
352     if (!memrefToTensorOp)
353       return failure();
354 
355     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
356                                                dimOp.index());
357     return success();
358   }
359 };
360 
361 } // namespace
362 
363 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
364                                              MLIRContext *context) {
365   results.add<DimOfToTensorFolder>(context);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ToMemrefOp
370 //===----------------------------------------------------------------------===//
371 
372 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
373   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
374     if (memrefToTensor.memref().getType() == getType())
375       return memrefToTensor.memref();
376   return {};
377 }
378 
379 namespace {
380 
381 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
382 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
383   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
384 
385   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
386                                 PatternRewriter &rewriter) const final {
387     auto tensorCastOperand =
388         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
389     if (!tensorCastOperand)
390       return failure();
391     auto srcTensorType =
392         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
393     if (!srcTensorType)
394       return failure();
395     auto memrefType = MemRefType::get(srcTensorType.getShape(),
396                                       srcTensorType.getElementType());
397     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
398                                                tensorCastOperand.getOperand());
399     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
400                                                 memref);
401     return success();
402   }
403 };
404 
405 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
406 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
407 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
408   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
409 
410   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
411                                 PatternRewriter &rewriter) const final {
412     // Only handle cases where a cast is needed. The other case is handled by
413     // the folder.
414     return foldToMemrefToTensorPair(rewriter, toMemref,
415                                     /*allowSameType=*/false);
416   }
417 };
418 
419 /// Fold a load on a to_memref operation into an tensor.extract on the
420 /// corresponding tensor.
421 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
422   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
423 
424   LogicalResult matchAndRewrite(memref::LoadOp load,
425                                 PatternRewriter &rewriter) const override {
426     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
427     if (!toMemref)
428       return failure();
429 
430     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
431                                                    load.indices());
432     return success();
433   }
434 };
435 
436 /// Fold dim of a to_memref into the dim of the tensor.
437 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
438   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
439 
440   LogicalResult matchAndRewrite(memref::DimOp dimOp,
441                                 PatternRewriter &rewriter) const override {
442     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
443     if (!castOp)
444       return failure();
445     Value newSource = castOp.getOperand();
446     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
447     return success();
448   }
449 };
450 
451 } // namespace
452 
453 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
454                                              MLIRContext *context) {
455   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
456       context);
457 }
458 
459 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
460                                     BufferizationState &state) {
461   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
462   (void)foldToMemrefToTensorPair(rewriter, *this);
463   // Note: The return value of `bufferize` indicates whether there was an error
464   // or not. (And not whether the pattern matched or not.)
465   return success();
466 }
467 
468 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
469   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
470       .getOperation();
471 }
472 
473 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
474   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // TableGen'd op method definitions
479 //===----------------------------------------------------------------------===//
480 
481 #define GET_OP_CLASSES
482 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
483