1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 
12 using namespace mlir;
13 using namespace mlir::bufferization;
14 
15 //===----------------------------------------------------------------------===//
16 // Helper functions
17 //===----------------------------------------------------------------------===//
18 
19 FailureOr<Value>
20 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
21                                               MemRefType destType) {
22   auto srcType = value.getType().cast<MemRefType>();
23 
24   // Element type, rank and memory space must match.
25   if (srcType.getElementType() != destType.getElementType())
26     return failure();
27   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
28     return failure();
29   if (srcType.getRank() != destType.getRank())
30     return failure();
31 
32   // In case the affine maps are different, we may need to use a copy if we go
33   // from dynamic to static offset or stride (the canonicalization cannot know
34   // at this point that it is really cast compatible).
35   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
36     int64_t sourceOffset, targetOffset;
37     SmallVector<int64_t, 4> sourceStrides, targetStrides;
38     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
39         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
40       return false;
41     auto dynamicToStatic = [](int64_t a, int64_t b) {
42       return a == MemRefType::getDynamicStrideOrOffset() &&
43              b != MemRefType::getDynamicStrideOrOffset();
44     };
45     if (dynamicToStatic(sourceOffset, targetOffset))
46       return false;
47     for (auto it : zip(sourceStrides, targetStrides))
48       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
49         return false;
50     return true;
51   };
52 
53   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
54   // ensure that we only generate casts that always succeed at runtime, we check
55   // a fix extra conditions in `isGuaranteedCastCompatible`.
56   if (memref::CastOp::areCastCompatible(srcType, destType) &&
57       isGuaranteedCastCompatible(srcType, destType)) {
58     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
59     return casted;
60   }
61 
62   auto loc = value.getLoc();
63   SmallVector<Value, 4> dynamicOperands;
64   for (int i = 0; i < destType.getRank(); ++i) {
65     if (destType.getShape()[i] != ShapedType::kDynamicSize)
66       continue;
67     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
68     Value size = b.create<memref::DimOp>(loc, value, index);
69     dynamicOperands.push_back(size);
70   }
71   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
72   // BufferizableOpInterface impl of ToMemrefOp.
73   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
74   b.create<memref::CopyOp>(loc, value, copy);
75   return copy;
76 }
77 
78 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
79 /// to_memref op are different, a memref.cast is needed.
80 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
81     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
82   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
83   if (!memrefToTensor)
84     return failure();
85 
86   Type srcType = memrefToTensor.memref().getType();
87   Type destType = toMemref.getType();
88 
89   // Directly rewrite if the type did not change.
90   if (srcType == destType) {
91     // Function can be configured to only handle cases where a cast is needed.
92     if (!allowSameType)
93       return failure();
94     rewriter.replaceOp(toMemref, memrefToTensor.memref());
95     return success();
96   }
97 
98   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
99   auto rankedDestType = destType.dyn_cast<MemRefType>();
100   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
101 
102   // Ranked memref -> Ranked memref cast.
103   if (rankedSrcType && rankedDestType) {
104     FailureOr<Value> replacement = castOrReallocMemRefValue(
105         rewriter, memrefToTensor.memref(), rankedDestType);
106     if (failed(replacement))
107       return failure();
108 
109     rewriter.replaceOp(toMemref, *replacement);
110     return success();
111   }
112 
113   // Unranked memref -> Ranked memref cast: May require a copy.
114   // TODO: Not implemented at the moment.
115   if (unrankedSrcType && rankedDestType)
116     return failure();
117 
118   // Unranked memref -> unranked memref cast
119   // Ranked memref -> unranked memref cast: No copy needed.
120   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
121          "expected that types are cast compatible");
122   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
123                                               memrefToTensor.memref());
124   return success();
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // CloneOp
129 //===----------------------------------------------------------------------===//
130 
131 void CloneOp::getEffects(
132     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
133         &effects) {
134   effects.emplace_back(MemoryEffects::Read::get(), input(),
135                        SideEffects::DefaultResource::get());
136   effects.emplace_back(MemoryEffects::Write::get(), output(),
137                        SideEffects::DefaultResource::get());
138   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
139                        SideEffects::DefaultResource::get());
140 }
141 
142 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
143   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
144 }
145 
146 namespace {
147 
148 /// Merge the clone and its source (by converting the clone to a cast) when
149 /// possible.
150 struct SimplifyClones : public OpRewritePattern<CloneOp> {
151   using OpRewritePattern<CloneOp>::OpRewritePattern;
152 
153   LogicalResult matchAndRewrite(CloneOp cloneOp,
154                                 PatternRewriter &rewriter) const override {
155     if (cloneOp.use_empty()) {
156       rewriter.eraseOp(cloneOp);
157       return success();
158     }
159 
160     Value source = cloneOp.input();
161 
162     // This only finds dealloc operations for the immediate value. It should
163     // also consider aliases. That would also make the safety check below
164     // redundant.
165     llvm::Optional<Operation *> maybeCloneDeallocOp =
166         memref::findDealloc(cloneOp.output());
167     // Skip if either of them has > 1 deallocate operations.
168     if (!maybeCloneDeallocOp.hasValue())
169       return failure();
170     llvm::Optional<Operation *> maybeSourceDeallocOp =
171         memref::findDealloc(source);
172     if (!maybeSourceDeallocOp.hasValue())
173       return failure();
174     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
175     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
176 
177     // If both are deallocated in the same block, their in-block lifetimes
178     // might not fully overlap, so we cannot decide which one to drop.
179     if (cloneDeallocOp && sourceDeallocOp &&
180         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
181       return failure();
182 
183     Block *currentBlock = cloneOp->getBlock();
184     Operation *redundantDealloc = nullptr;
185     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
186       redundantDealloc = cloneDeallocOp;
187     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
188       redundantDealloc = sourceDeallocOp;
189     }
190 
191     if (!redundantDealloc)
192       return failure();
193 
194     // Safety check that there are no other deallocations inbetween
195     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
196     // of source before the uses of the clone. With alias information, we could
197     // restrict this to only fail of the dealloc's operand is an alias
198     // of the source.
199     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
200          pos = pos->getNextNode()) {
201       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
202       if (!effectInterface)
203         continue;
204       if (effectInterface.hasEffect<MemoryEffects::Free>())
205         return failure();
206     }
207 
208     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
209                                                 source);
210     rewriter.eraseOp(redundantDealloc);
211     return success();
212   }
213 };
214 
215 } // namespace
216 
217 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
218                                           MLIRContext *context) {
219   results.add<SimplifyClones>(context);
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // ToTensorOp
224 //===----------------------------------------------------------------------===//
225 
226 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
227   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
228     // Approximate alias analysis by conservatively folding only when no there
229     // is no interleaved operation.
230     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
231         toMemref->getNextNode() == this->getOperation())
232       return toMemref.tensor();
233   return {};
234 }
235 
236 namespace {
237 
238 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
239   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
240 
241   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
242                                 PatternRewriter &rewriter) const override {
243     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
244     if (!memrefToTensorOp)
245       return failure();
246 
247     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
248                                                dimOp.index());
249     return success();
250   }
251 };
252 
253 } // namespace
254 
255 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
256                                              MLIRContext *context) {
257   results.add<DimOfToTensorFolder>(context);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ToMemrefOp
262 //===----------------------------------------------------------------------===//
263 
264 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
265   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
266     if (memrefToTensor.memref().getType() == getType())
267       return memrefToTensor.memref();
268   return {};
269 }
270 
271 namespace {
272 
273 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
274 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
275   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
276 
277   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
278                                 PatternRewriter &rewriter) const final {
279     auto tensorCastOperand =
280         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
281     if (!tensorCastOperand)
282       return failure();
283     auto srcTensorType =
284         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
285     if (!srcTensorType)
286       return failure();
287     auto memrefType = MemRefType::get(srcTensorType.getShape(),
288                                       srcTensorType.getElementType());
289     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
290                                                tensorCastOperand.getOperand());
291     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
292                                                 memref);
293     return success();
294   }
295 };
296 
297 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
298 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
299 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
300   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
301 
302   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
303                                 PatternRewriter &rewriter) const final {
304     // Only handle cases where a cast is needed. The other case is handled by
305     // the folder.
306     return foldToMemrefToTensorPair(rewriter, toMemref,
307                                     /*allowSameType=*/false);
308   }
309 };
310 
311 /// Fold a load on a to_memref operation into an tensor.extract on the
312 /// corresponding tensor.
313 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
314   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
315 
316   LogicalResult matchAndRewrite(memref::LoadOp load,
317                                 PatternRewriter &rewriter) const override {
318     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
319     if (!toMemref)
320       return failure();
321 
322     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
323                                                    load.indices());
324     return success();
325   }
326 };
327 
328 /// Fold dim of a to_memref into the dim of the tensor.
329 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
330   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
331 
332   LogicalResult matchAndRewrite(memref::DimOp dimOp,
333                                 PatternRewriter &rewriter) const override {
334     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
335     if (!castOp)
336       return failure();
337     Value newSource = castOp.getOperand();
338     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
339     return success();
340   }
341 };
342 
343 } // namespace
344 
345 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
346                                              MLIRContext *context) {
347   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
348       context);
349 }
350 
351 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
352                                     BufferizationState &state) {
353   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
354   return foldToMemrefToTensorPair(rewriter, *this);
355 }
356 
357 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
358   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
359       .getOperation();
360 }
361 
362 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
363   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // TableGen'd op method definitions
368 //===----------------------------------------------------------------------===//
369 
370 #define GET_OP_CLASSES
371 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
372