1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 
12 using namespace mlir;
13 using namespace mlir::bufferization;
14 
15 //===----------------------------------------------------------------------===//
16 // Helper functions
17 //===----------------------------------------------------------------------===//
18 
19 FailureOr<Value>
20 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
21                                               MemRefType destType) {
22   auto srcType = value.getType().cast<MemRefType>();
23 
24   // Casting to the same type, nothing to do.
25   if (srcType == destType)
26     return value;
27 
28   // Element type, rank and memory space must match.
29   if (srcType.getElementType() != destType.getElementType())
30     return failure();
31   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
32     return failure();
33   if (srcType.getRank() != destType.getRank())
34     return failure();
35 
36   // In case the affine maps are different, we may need to use a copy if we go
37   // from dynamic to static offset or stride (the canonicalization cannot know
38   // at this point that it is really cast compatible).
39   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
40     int64_t sourceOffset, targetOffset;
41     SmallVector<int64_t, 4> sourceStrides, targetStrides;
42     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
43         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
44       return false;
45     auto dynamicToStatic = [](int64_t a, int64_t b) {
46       return a == MemRefType::getDynamicStrideOrOffset() &&
47              b != MemRefType::getDynamicStrideOrOffset();
48     };
49     if (dynamicToStatic(sourceOffset, targetOffset))
50       return false;
51     for (auto it : zip(sourceStrides, targetStrides))
52       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
53         return false;
54     return true;
55   };
56 
57   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
58   // ensure that we only generate casts that always succeed at runtime, we check
59   // a fix extra conditions in `isGuaranteedCastCompatible`.
60   if (memref::CastOp::areCastCompatible(srcType, destType) &&
61       isGuaranteedCastCompatible(srcType, destType)) {
62     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
63     return casted;
64   }
65 
66   auto loc = value.getLoc();
67   SmallVector<Value, 4> dynamicOperands;
68   for (int i = 0; i < destType.getRank(); ++i) {
69     if (destType.getShape()[i] != ShapedType::kDynamicSize)
70       continue;
71     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
72     Value size = b.create<memref::DimOp>(loc, value, index);
73     dynamicOperands.push_back(size);
74   }
75   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
76   // BufferizableOpInterface impl of ToMemrefOp.
77   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
78   b.create<memref::CopyOp>(loc, value, copy);
79   return copy;
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // CloneOp
84 //===----------------------------------------------------------------------===//
85 
86 void CloneOp::getEffects(
87     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
88         &effects) {
89   effects.emplace_back(MemoryEffects::Read::get(), input(),
90                        SideEffects::DefaultResource::get());
91   effects.emplace_back(MemoryEffects::Write::get(), output(),
92                        SideEffects::DefaultResource::get());
93   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
94                        SideEffects::DefaultResource::get());
95 }
96 
97 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
98   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
99 }
100 
101 namespace {
102 
103 /// Merge the clone and its source (by converting the clone to a cast) when
104 /// possible.
105 struct SimplifyClones : public OpRewritePattern<CloneOp> {
106   using OpRewritePattern<CloneOp>::OpRewritePattern;
107 
108   LogicalResult matchAndRewrite(CloneOp cloneOp,
109                                 PatternRewriter &rewriter) const override {
110     if (cloneOp.use_empty()) {
111       rewriter.eraseOp(cloneOp);
112       return success();
113     }
114 
115     Value source = cloneOp.input();
116 
117     // This only finds dealloc operations for the immediate value. It should
118     // also consider aliases. That would also make the safety check below
119     // redundant.
120     llvm::Optional<Operation *> maybeCloneDeallocOp =
121         findDealloc(cloneOp.output());
122     // Skip if either of them has > 1 deallocate operations.
123     if (!maybeCloneDeallocOp.hasValue())
124       return failure();
125     llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
126     if (!maybeSourceDeallocOp.hasValue())
127       return failure();
128     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
129     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
130 
131     // If both are deallocated in the same block, their in-block lifetimes
132     // might not fully overlap, so we cannot decide which one to drop.
133     if (cloneDeallocOp && sourceDeallocOp &&
134         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
135       return failure();
136 
137     Block *currentBlock = cloneOp->getBlock();
138     Operation *redundantDealloc = nullptr;
139     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
140       redundantDealloc = cloneDeallocOp;
141     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
142       redundantDealloc = sourceDeallocOp;
143     }
144 
145     if (!redundantDealloc)
146       return failure();
147 
148     // Safety check that there are no other deallocations inbetween
149     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
150     // of source before the uses of the clone. With alias information, we could
151     // restrict this to only fail of the dealloc's operand is an alias
152     // of the source.
153     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
154          pos = pos->getNextNode()) {
155       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
156       if (!effectInterface)
157         continue;
158       if (effectInterface.hasEffect<MemoryEffects::Free>())
159         return failure();
160     }
161 
162     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
163                                                 source);
164     rewriter.eraseOp(redundantDealloc);
165     return success();
166   }
167 };
168 
169 } // namespace
170 
171 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
172                                           MLIRContext *context) {
173   results.add<SimplifyClones>(context);
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // ToTensorOp
178 //===----------------------------------------------------------------------===//
179 
180 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
181   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
182     // Approximate alias analysis by conservatively folding only when no there
183     // is no interleaved operation.
184     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
185         toMemref->getNextNode() == this->getOperation())
186       return toMemref.tensor();
187   return {};
188 }
189 
190 namespace {
191 
192 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
193   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
194 
195   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
196                                 PatternRewriter &rewriter) const override {
197     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
198     if (!memrefToTensorOp)
199       return failure();
200 
201     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
202                                                dimOp.index());
203     return success();
204   }
205 };
206 
207 } // namespace
208 
209 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
210                                              MLIRContext *context) {
211   results.add<DimOfToTensorFolder>(context);
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // ToMemrefOp
216 //===----------------------------------------------------------------------===//
217 
218 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
219   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
220     if (memrefToTensor.memref().getType() == getType())
221       return memrefToTensor.memref();
222   return {};
223 }
224 
225 namespace {
226 
227 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
228 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
229   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
230 
231   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
232                                 PatternRewriter &rewriter) const final {
233     auto tensorCastOperand =
234         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
235     if (!tensorCastOperand)
236       return failure();
237     auto srcTensorType =
238         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
239     if (!srcTensorType)
240       return failure();
241     auto memrefType = MemRefType::get(srcTensorType.getShape(),
242                                       srcTensorType.getElementType());
243     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
244                                                tensorCastOperand.getOperand());
245     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
246                                                 memref);
247     return success();
248   }
249 };
250 
251 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
252 /// to_memref op are different, a memref.cast is needed.
253 static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
254                                               ToMemrefOp toMemref,
255                                               bool allowSameType = true) {
256   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
257   if (!memrefToTensor)
258     return failure();
259 
260   Type srcType = memrefToTensor.memref().getType();
261   Type destType = toMemref.getType();
262 
263   // Function can be configured to only handle cases where a cast is needed.
264   if (!allowSameType && srcType == destType)
265     return failure();
266 
267   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
268   auto rankedDestType = destType.dyn_cast<MemRefType>();
269   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
270 
271   // Ranked memref -> Ranked memref cast.
272   if (rankedSrcType && rankedDestType) {
273     FailureOr<Value> replacement = castOrReallocMemRefValue(
274         rewriter, memrefToTensor.memref(), rankedDestType);
275     if (failed(replacement))
276       return failure();
277 
278     rewriter.replaceOp(toMemref, *replacement);
279     return success();
280   }
281 
282   // Unranked memref -> Ranked memref cast: May require a copy.
283   // TODO: Not implemented at the moment.
284   if (unrankedSrcType && rankedDestType)
285     return failure();
286 
287   // Unranked memref -> unranked memref cast
288   // Ranked memref -> unranked memref cast: No copy needed.
289   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
290          "expected that types are cast compatible");
291   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
292                                               memrefToTensor.memref());
293   return success();
294 }
295 
296 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
297 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
298 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
299   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
300 
301   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
302                                 PatternRewriter &rewriter) const final {
303     // Only handle cases where a cast is needed. The other case is handled by
304     // the folder.
305     return foldToMemrefToTensorPair(rewriter, toMemref,
306                                     /*allowSameType=*/false);
307   }
308 };
309 
310 /// Fold a load on a to_memref operation into an tensor.extract on the
311 /// corresponding tensor.
312 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
313   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
314 
315   LogicalResult matchAndRewrite(memref::LoadOp load,
316                                 PatternRewriter &rewriter) const override {
317     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
318     if (!toMemref)
319       return failure();
320 
321     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
322                                                    load.indices());
323     return success();
324   }
325 };
326 
327 /// Fold dim of a to_memref into the dim of the tensor.
328 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
329   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
330 
331   LogicalResult matchAndRewrite(memref::DimOp dimOp,
332                                 PatternRewriter &rewriter) const override {
333     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
334     if (!castOp)
335       return failure();
336     Value newSource = castOp.getOperand();
337     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
338     return success();
339   }
340 };
341 
342 } // namespace
343 
344 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
345                                              MLIRContext *context) {
346   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
347       context);
348 }
349 
350 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
351                                     const BufferizationState &state) {
352   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
353   return foldToMemrefToTensorPair(rewriter, *this);
354 }
355 
356 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
357   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
358       .getOperation();
359 }
360 
361 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
362   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // TableGen'd op method definitions
367 //===----------------------------------------------------------------------===//
368 
369 #define GET_OP_CLASSES
370 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
371