1 
2 //===----------------------------------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
12 
13 using namespace mlir;
14 using namespace mlir::bufferization;
15 
16 //===----------------------------------------------------------------------===//
17 // CloneOp
18 //===----------------------------------------------------------------------===//
19 
20 void CloneOp::getEffects(
21     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
22         &effects) {
23   effects.emplace_back(MemoryEffects::Read::get(), input(),
24                        SideEffects::DefaultResource::get());
25   effects.emplace_back(MemoryEffects::Write::get(), output(),
26                        SideEffects::DefaultResource::get());
27   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
28                        SideEffects::DefaultResource::get());
29 }
30 
31 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
32   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
33 }
34 
35 namespace {
36 
37 /// Merge the clone and its source (by converting the clone to a cast) when
38 /// possible.
39 struct SimplifyClones : public OpRewritePattern<CloneOp> {
40   using OpRewritePattern<CloneOp>::OpRewritePattern;
41 
42   LogicalResult matchAndRewrite(CloneOp cloneOp,
43                                 PatternRewriter &rewriter) const override {
44     if (cloneOp.use_empty()) {
45       rewriter.eraseOp(cloneOp);
46       return success();
47     }
48 
49     Value source = cloneOp.input();
50 
51     // This only finds dealloc operations for the immediate value. It should
52     // also consider aliases. That would also make the safety check below
53     // redundant.
54     llvm::Optional<Operation *> maybeCloneDeallocOp =
55         findDealloc(cloneOp.output());
56     // Skip if either of them has > 1 deallocate operations.
57     if (!maybeCloneDeallocOp.hasValue())
58       return failure();
59     llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
60     if (!maybeSourceDeallocOp.hasValue())
61       return failure();
62     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
63     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
64 
65     // If both are deallocated in the same block, their in-block lifetimes
66     // might not fully overlap, so we cannot decide which one to drop.
67     if (cloneDeallocOp && sourceDeallocOp &&
68         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
69       return failure();
70 
71     Block *currentBlock = cloneOp->getBlock();
72     Operation *redundantDealloc = nullptr;
73     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
74       redundantDealloc = cloneDeallocOp;
75     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
76       redundantDealloc = sourceDeallocOp;
77     }
78 
79     if (!redundantDealloc)
80       return failure();
81 
82     // Safety check that there are no other deallocations inbetween
83     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
84     // of source before the uses of the clone. With alias information, we could
85     // restrict this to only fail of the dealloc's operand is an alias
86     // of the source.
87     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
88          pos = pos->getNextNode()) {
89       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
90       if (!effectInterface)
91         continue;
92       if (effectInterface.hasEffect<MemoryEffects::Free>())
93         return failure();
94     }
95 
96     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
97                                                 source);
98     rewriter.eraseOp(redundantDealloc);
99     return success();
100   }
101 };
102 
103 } // namespace
104 
105 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
106                                           MLIRContext *context) {
107   results.insert<SimplifyClones>(context);
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // ToTensorOp
112 //===----------------------------------------------------------------------===//
113 
114 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
115   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
116     // Approximate alias analysis by conservatively folding only when no there
117     // is no interleaved operation.
118     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
119         toMemref->getNextNode() == this->getOperation())
120       return toMemref.tensor();
121   return {};
122 }
123 
124 namespace {
125 
126 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
127   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
128 
129   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
130                                 PatternRewriter &rewriter) const override {
131     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
132     if (!memrefToTensorOp)
133       return failure();
134 
135     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
136                                                dimOp.index());
137     return success();
138   }
139 };
140 
141 } // namespace
142 
143 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
144                                              MLIRContext *context) {
145   results.add<DimOfToTensorFolder>(context);
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // ToMemrefOp
150 //===----------------------------------------------------------------------===//
151 
152 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
153   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
154     if (memrefToTensor.memref().getType() == getType())
155       return memrefToTensor.memref();
156   return {};
157 }
158 
159 namespace {
160 
161 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
162 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
163   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
164 
165   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
166                                 PatternRewriter &rewriter) const final {
167     auto tensorCastOperand =
168         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
169     if (!tensorCastOperand)
170       return failure();
171     auto srcTensorType =
172         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
173     if (!srcTensorType)
174       return failure();
175     auto memrefType = MemRefType::get(srcTensorType.getShape(),
176                                       srcTensorType.getElementType());
177     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
178                                                tensorCastOperand.getOperand());
179     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
180                                                 memref);
181     return success();
182   }
183 };
184 
185 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
186 /// to_memref op are different, a memref.cast is needed.
187 static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
188                                               ToMemrefOp toMemref,
189                                               bool allowSameType = true) {
190   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
191   if (!memrefToTensor)
192     return failure();
193 
194   // A memref_to_tensor + tensor_to_memref with same types can be folded without
195   // inserting a cast.
196   if (memrefToTensor.memref().getType() == toMemref.getType()) {
197     if (!allowSameType)
198       // Function can be configured to only handle cases where a cast is needed.
199       return failure();
200     rewriter.replaceOp(toMemref, memrefToTensor.memref());
201     return success();
202   }
203 
204   // If types are definitely not cast-compatible, bail.
205   if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(),
206                                          toMemref.getType()))
207     return failure();
208 
209   // We already know that the types are potentially cast-compatible. However
210   // in case the affine maps are different, we may need to use a copy if we go
211   // from dynamic to static offset or stride (the canonicalization cannot know
212   // at this point that it is really cast compatible).
213   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
214     int64_t sourceOffset, targetOffset;
215     SmallVector<int64_t, 4> sourceStrides, targetStrides;
216     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
217         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
218       return false;
219     auto dynamicToStatic = [](int64_t a, int64_t b) {
220       return a == MemRefType::getDynamicStrideOrOffset() &&
221              b != MemRefType::getDynamicStrideOrOffset();
222     };
223     if (dynamicToStatic(sourceOffset, targetOffset))
224       return false;
225     for (auto it : zip(sourceStrides, targetStrides))
226       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
227         return false;
228     return true;
229   };
230 
231   auto memrefToTensorType =
232       memrefToTensor.memref().getType().dyn_cast<MemRefType>();
233   auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>();
234   if (memrefToTensorType && toMemrefType &&
235       !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) {
236     MemRefType resultType = toMemrefType;
237     auto loc = toMemref.getLoc();
238     SmallVector<Value, 4> dynamicOperands;
239     for (int i = 0; i < resultType.getRank(); ++i) {
240       if (resultType.getShape()[i] != ShapedType::kDynamicSize)
241         continue;
242       auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
243       Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
244       dynamicOperands.push_back(size);
245     }
246     // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
247     // BufferizableOpInterface impl of ToMemrefOp.
248     auto copy =
249         rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
250     rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy);
251     rewriter.replaceOp(toMemref, {copy});
252   } else
253     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
254                                                 memrefToTensor.memref());
255   return success();
256 }
257 
258 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
259 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
260 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
261   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
262 
263   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
264                                 PatternRewriter &rewriter) const final {
265     // Only handle cases where a cast is needed. The other case is handled by
266     // the folder.
267     return foldToMemrefToTensorPair(rewriter, toMemref,
268                                     /*allowSameType=*/false);
269   }
270 };
271 
272 /// Fold a load on a to_memref operation into an tensor.extract on the
273 /// corresponding tensor.
274 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
275   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
276 
277   LogicalResult matchAndRewrite(memref::LoadOp load,
278                                 PatternRewriter &rewriter) const override {
279     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
280     if (!toMemref)
281       return failure();
282 
283     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
284                                                    load.indices());
285     return success();
286   }
287 };
288 
289 /// Fold dim of a to_memref into the dim of the tensor.
290 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
291   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
292 
293   LogicalResult matchAndRewrite(memref::DimOp dimOp,
294                                 PatternRewriter &rewriter) const override {
295     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
296     if (!castOp)
297       return failure();
298     Value newSource = castOp.getOperand();
299     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
300     return success();
301   }
302 };
303 
304 } // namespace
305 
306 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
307                                              MLIRContext *context) {
308   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
309       context);
310 }
311 
312 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
313                                     const BufferizationState &state) {
314   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
315   return foldToMemrefToTensorPair(rewriter, *this);
316 }
317 
318 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
319   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
320       .getOperation();
321 }
322 
323 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
324   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // TableGen'd op method definitions
329 //===----------------------------------------------------------------------===//
330 
331 #define GET_OP_CLASSES
332 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
333