1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 
15 using namespace mlir;
16 using namespace mlir::bufferization;
17 
18 //===----------------------------------------------------------------------===//
19 // Helper functions
20 //===----------------------------------------------------------------------===//
21 
22 FailureOr<Value>
23 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
24                                               MemRefType destType) {
25   auto srcType = value.getType().cast<MemRefType>();
26 
27   // Element type, rank and memory space must match.
28   if (srcType.getElementType() != destType.getElementType())
29     return failure();
30   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
31     return failure();
32   if (srcType.getRank() != destType.getRank())
33     return failure();
34 
35   // In case the affine maps are different, we may need to use a copy if we go
36   // from dynamic to static offset or stride (the canonicalization cannot know
37   // at this point that it is really cast compatible).
38   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
39     int64_t sourceOffset, targetOffset;
40     SmallVector<int64_t, 4> sourceStrides, targetStrides;
41     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
42         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
43       return false;
44     auto dynamicToStatic = [](int64_t a, int64_t b) {
45       return a == MemRefType::getDynamicStrideOrOffset() &&
46              b != MemRefType::getDynamicStrideOrOffset();
47     };
48     if (dynamicToStatic(sourceOffset, targetOffset))
49       return false;
50     for (auto it : zip(sourceStrides, targetStrides))
51       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
52         return false;
53     return true;
54   };
55 
56   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
57   // ensure that we only generate casts that always succeed at runtime, we check
58   // a fix extra conditions in `isGuaranteedCastCompatible`.
59   if (memref::CastOp::areCastCompatible(srcType, destType) &&
60       isGuaranteedCastCompatible(srcType, destType)) {
61     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
62     return casted;
63   }
64 
65   auto loc = value.getLoc();
66   SmallVector<Value, 4> dynamicOperands;
67   for (int i = 0; i < destType.getRank(); ++i) {
68     if (destType.getShape()[i] != ShapedType::kDynamicSize)
69       continue;
70     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
71     Value size = b.create<memref::DimOp>(loc, value, index);
72     dynamicOperands.push_back(size);
73   }
74   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
75   // BufferizableOpInterface impl of ToMemrefOp.
76   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
77   b.create<memref::CopyOp>(loc, value, copy);
78   return copy;
79 }
80 
81 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
82 /// to_memref op are different, a memref.cast is needed.
83 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
84     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
85   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
86   if (!memrefToTensor)
87     return failure();
88 
89   Type srcType = memrefToTensor.memref().getType();
90   Type destType = toMemref.getType();
91 
92   // Directly rewrite if the type did not change.
93   if (srcType == destType) {
94     // Function can be configured to only handle cases where a cast is needed.
95     if (!allowSameType)
96       return failure();
97     rewriter.replaceOp(toMemref, memrefToTensor.memref());
98     return success();
99   }
100 
101   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
102   auto rankedDestType = destType.dyn_cast<MemRefType>();
103   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
104 
105   // Ranked memref -> Ranked memref cast.
106   if (rankedSrcType && rankedDestType) {
107     FailureOr<Value> replacement = castOrReallocMemRefValue(
108         rewriter, memrefToTensor.memref(), rankedDestType);
109     if (failed(replacement))
110       return failure();
111 
112     rewriter.replaceOp(toMemref, *replacement);
113     return success();
114   }
115 
116   // Unranked memref -> Ranked memref cast: May require a copy.
117   // TODO: Not implemented at the moment.
118   if (unrankedSrcType && rankedDestType)
119     return failure();
120 
121   // Unranked memref -> unranked memref cast
122   // Ranked memref -> unranked memref cast: No copy needed.
123   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124          "expected that types are cast compatible");
125   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
126                                               memrefToTensor.memref());
127   return success();
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // CloneOp
132 //===----------------------------------------------------------------------===//
133 
134 void CloneOp::getEffects(
135     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
136         &effects) {
137   effects.emplace_back(MemoryEffects::Read::get(), input(),
138                        SideEffects::DefaultResource::get());
139   effects.emplace_back(MemoryEffects::Write::get(), output(),
140                        SideEffects::DefaultResource::get());
141   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
142                        SideEffects::DefaultResource::get());
143 }
144 
145 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
146   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
147 }
148 
149 namespace {
150 
151 /// Merge the clone and its source (by converting the clone to a cast) when
152 /// possible.
153 struct SimplifyClones : public OpRewritePattern<CloneOp> {
154   using OpRewritePattern<CloneOp>::OpRewritePattern;
155 
156   LogicalResult matchAndRewrite(CloneOp cloneOp,
157                                 PatternRewriter &rewriter) const override {
158     if (cloneOp.use_empty()) {
159       rewriter.eraseOp(cloneOp);
160       return success();
161     }
162 
163     Value source = cloneOp.input();
164 
165     // This only finds dealloc operations for the immediate value. It should
166     // also consider aliases. That would also make the safety check below
167     // redundant.
168     llvm::Optional<Operation *> maybeCloneDeallocOp =
169         memref::findDealloc(cloneOp.output());
170     // Skip if either of them has > 1 deallocate operations.
171     if (!maybeCloneDeallocOp.hasValue())
172       return failure();
173     llvm::Optional<Operation *> maybeSourceDeallocOp =
174         memref::findDealloc(source);
175     if (!maybeSourceDeallocOp.hasValue())
176       return failure();
177     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
178     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
179 
180     // If both are deallocated in the same block, their in-block lifetimes
181     // might not fully overlap, so we cannot decide which one to drop.
182     if (cloneDeallocOp && sourceDeallocOp &&
183         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
184       return failure();
185 
186     Block *currentBlock = cloneOp->getBlock();
187     Operation *redundantDealloc = nullptr;
188     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
189       redundantDealloc = cloneDeallocOp;
190     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
191       redundantDealloc = sourceDeallocOp;
192     }
193 
194     if (!redundantDealloc)
195       return failure();
196 
197     // Safety check that there are no other deallocations inbetween
198     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
199     // of source before the uses of the clone. With alias information, we could
200     // restrict this to only fail of the dealloc's operand is an alias
201     // of the source.
202     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
203          pos = pos->getNextNode()) {
204       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
205       if (!effectInterface)
206         continue;
207       if (effectInterface.hasEffect<MemoryEffects::Free>())
208         return failure();
209     }
210 
211     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
212                                                 source);
213     rewriter.eraseOp(redundantDealloc);
214     return success();
215   }
216 };
217 
218 } // namespace
219 
220 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
221                                           MLIRContext *context) {
222   results.add<SimplifyClones>(context);
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // ToTensorOp
227 //===----------------------------------------------------------------------===//
228 
229 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
230   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
231     // Approximate alias analysis by conservatively folding only when no there
232     // is no interleaved operation.
233     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
234         toMemref->getNextNode() == this->getOperation())
235       return toMemref.tensor();
236   return {};
237 }
238 
239 namespace {
240 
241 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
242   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
243 
244   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
245                                 PatternRewriter &rewriter) const override {
246     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
247     if (!memrefToTensorOp)
248       return failure();
249 
250     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
251                                                dimOp.index());
252     return success();
253   }
254 };
255 
256 } // namespace
257 
258 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
259                                              MLIRContext *context) {
260   results.add<DimOfToTensorFolder>(context);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // ToMemrefOp
265 //===----------------------------------------------------------------------===//
266 
267 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
268   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
269     if (memrefToTensor.memref().getType() == getType())
270       return memrefToTensor.memref();
271   return {};
272 }
273 
274 namespace {
275 
276 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
277 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
278   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
279 
280   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
281                                 PatternRewriter &rewriter) const final {
282     auto tensorCastOperand =
283         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
284     if (!tensorCastOperand)
285       return failure();
286     auto srcTensorType =
287         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
288     if (!srcTensorType)
289       return failure();
290     auto memrefType = MemRefType::get(srcTensorType.getShape(),
291                                       srcTensorType.getElementType());
292     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
293                                                tensorCastOperand.getOperand());
294     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
295                                                 memref);
296     return success();
297   }
298 };
299 
300 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
301 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
302 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
303   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
304 
305   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
306                                 PatternRewriter &rewriter) const final {
307     // Only handle cases where a cast is needed. The other case is handled by
308     // the folder.
309     return foldToMemrefToTensorPair(rewriter, toMemref,
310                                     /*allowSameType=*/false);
311   }
312 };
313 
314 /// Fold a load on a to_memref operation into an tensor.extract on the
315 /// corresponding tensor.
316 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
317   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
318 
319   LogicalResult matchAndRewrite(memref::LoadOp load,
320                                 PatternRewriter &rewriter) const override {
321     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
322     if (!toMemref)
323       return failure();
324 
325     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
326                                                    load.indices());
327     return success();
328   }
329 };
330 
331 /// Fold dim of a to_memref into the dim of the tensor.
332 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
333   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
334 
335   LogicalResult matchAndRewrite(memref::DimOp dimOp,
336                                 PatternRewriter &rewriter) const override {
337     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
338     if (!castOp)
339       return failure();
340     Value newSource = castOp.getOperand();
341     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
342     return success();
343   }
344 };
345 
346 } // namespace
347 
348 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
349                                              MLIRContext *context) {
350   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
351       context);
352 }
353 
354 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
355                                     BufferizationState &state) {
356   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
357   return foldToMemrefToTensorPair(rewriter, *this);
358 }
359 
360 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
361   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
362       .getOperation();
363 }
364 
365 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
366   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // TableGen'd op method definitions
371 //===----------------------------------------------------------------------===//
372 
373 #define GET_OP_CLASSES
374 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
375