//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" using namespace mlir; using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// void CloneOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), output(), SideEffects::DefaultResource::get()); } OpFoldResult CloneOp::fold(ArrayRef operands) { return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); } namespace { /// Merge the clone and its source (by converting the clone to a cast) when /// possible. struct SimplifyClones : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CloneOp cloneOp, PatternRewriter &rewriter) const override { if (cloneOp.use_empty()) { rewriter.eraseOp(cloneOp); return success(); } Value source = cloneOp.input(); // This only finds dealloc operations for the immediate value. It should // also consider aliases. That would also make the safety check below // redundant. llvm::Optional maybeCloneDeallocOp = findDealloc(cloneOp.output()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.hasValue()) return failure(); llvm::Optional maybeSourceDeallocOp = findDealloc(source); if (!maybeSourceDeallocOp.hasValue()) return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; Operation *sourceDeallocOp = *maybeSourceDeallocOp; // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. if (cloneDeallocOp && sourceDeallocOp && cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) return failure(); Block *currentBlock = cloneOp->getBlock(); Operation *redundantDealloc = nullptr; if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { redundantDealloc = cloneDeallocOp; } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { redundantDealloc = sourceDeallocOp; } if (!redundantDealloc) return failure(); // Safety check that there are no other deallocations inbetween // cloneOp and redundantDealloc, as otherwise we might deallocate an alias // of source before the uses of the clone. With alias information, we could // restrict this to only fail of the dealloc's operand is an alias // of the source. for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; pos = pos->getNextNode()) { auto effectInterface = dyn_cast(pos); if (!effectInterface) continue; if (effectInterface.hasEffect()) return failure(); } rewriter.replaceOpWithNewOp(cloneOp, cloneOp.getType(), source); rewriter.eraseOp(redundantDealloc); return success(); } }; } // namespace void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // ToTensorOp //===----------------------------------------------------------------------===// OpFoldResult ToTensorOp::fold(ArrayRef) { if (auto toMemref = memref().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. if (toMemref->getBlock() == this->getOperation()->getBlock() && toMemref->getNextNode() == this->getOperation()) return toMemref.tensor(); return {}; } namespace { struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { auto memrefToTensorOp = dimOp.source().getDefiningOp(); if (!memrefToTensorOp) return failure(); rewriter.replaceOpWithNewOp(dimOp, memrefToTensorOp.memref(), dimOp.index()); return success(); } }; } // namespace void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // ToMemrefOp //===----------------------------------------------------------------------===// OpFoldResult ToMemrefOp::fold(ArrayRef) { if (auto memrefToTensor = tensor().getDefiningOp()) if (memrefToTensor.memref().getType() == getType()) return memrefToTensor.memref(); return {}; } namespace { /// Replace tensor.cast + to_memref by to_memref + memref.cast. struct ToMemrefOfCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { auto tensorCastOperand = toMemref.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto srcTensorType = tensorCastOperand.getOperand().getType().dyn_cast(); if (!srcTensorType) return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), srcTensorType.getElementType()); Value memref = rewriter.create(toMemref.getLoc(), memrefType, tensorCastOperand.getOperand()); rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), memref); return success(); } }; /// Canonicalize bufferization.to_tensor + bufferization.to_memref to /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. struct TensorLoadToMemref : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { auto memrefToTensor = toMemref.tensor().getDefiningOp(); // Bail unless we have a memref_to_tensor + tensor_to_memref with different // types. `ToMemrefOp::fold` handles the same type case. if (!memrefToTensor || memrefToTensor.memref().getType() == toMemref.getType()) return failure(); // If types are definitely not cast-compatible, bail. if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), toMemref.getType())) return failure(); // We already know that the types are potentially cast-compatible. However // in case the affine maps are different, we may need to use a copy if we go // from dynamic to static offset or stride (the canonicalization cannot know // at this point that it is really cast compatible). auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { int64_t sourceOffset, targetOffset; SmallVector sourceStrides, targetStrides; if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || failed(getStridesAndOffset(target, targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { return a == MemRefType::getDynamicStrideOrOffset() && b != MemRefType::getDynamicStrideOrOffset(); }; if (dynamicToStatic(sourceOffset, targetOffset)) return false; for (auto it : zip(sourceStrides, targetStrides)) if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) return false; return true; }; auto memrefToTensorType = memrefToTensor.memref().getType().dyn_cast(); auto toMemrefType = toMemref.getType().dyn_cast(); if (memrefToTensorType && toMemrefType && !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { MemRefType resultType = toMemrefType; auto loc = toMemref.getLoc(); SmallVector dynamicOperands; for (int i = 0; i < resultType.getRank(); ++i) { if (resultType.getShape()[i] != ShapedType::kDynamicSize) continue; auto index = rewriter.createOrFold(loc, i); Value size = rewriter.create(loc, memrefToTensor, index); dynamicOperands.push_back(size); } auto copy = rewriter.create(loc, resultType, dynamicOperands); rewriter.create(loc, memrefToTensor.memref(), copy); rewriter.replaceOp(toMemref, {copy}); } else rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), memrefToTensor.memref()); return success(); } }; /// Fold a load on a to_memref operation into an tensor.extract on the /// corresponding tensor. struct LoadOfToMemref : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::LoadOp load, PatternRewriter &rewriter) const override { auto toMemref = load.memref().getDefiningOp(); if (!toMemref) return failure(); rewriter.replaceOpWithNewOp(load, toMemref.tensor(), load.indices()); return success(); } }; /// Fold dim of a to_memref into the dim of the tensor. struct DimOfCastOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::DimOp dimOp, PatternRewriter &rewriter) const override { auto castOp = dimOp.source().getDefiningOp(); if (!castOp) return failure(); Value newSource = castOp.getOperand(); rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.index()); return success(); } }; } // namespace void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add( context); } Optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc) .getOperation(); } Optional CloneOp::buildClone(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc).getResult(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"