//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// FailureOr mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType destType) { auto srcType = value.getType().cast(); // Element type, rank and memory space must match. if (srcType.getElementType() != destType.getElementType()) return failure(); if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) return failure(); if (srcType.getRank() != destType.getRank()) return failure(); // In case the affine maps are different, we may need to use a copy if we go // from dynamic to static offset or stride (the canonicalization cannot know // at this point that it is really cast compatible). auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { int64_t sourceOffset, targetOffset; SmallVector sourceStrides, targetStrides; if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || failed(getStridesAndOffset(target, targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { return a == MemRefType::getDynamicStrideOrOffset() && b != MemRefType::getDynamicStrideOrOffset(); }; if (dynamicToStatic(sourceOffset, targetOffset)) return false; for (auto it : zip(sourceStrides, targetStrides)) if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) return false; return true; }; // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To // ensure that we only generate casts that always succeed at runtime, we check // a fix extra conditions in `isGuaranteedCastCompatible`. if (memref::CastOp::areCastCompatible(srcType, destType) && isGuaranteedCastCompatible(srcType, destType)) { Value casted = b.create(value.getLoc(), destType, value); return casted; } auto loc = value.getLoc(); SmallVector dynamicOperands; for (int i = 0; i < destType.getRank(); ++i) { if (destType.getShape()[i] != ShapedType::kDynamicSize) continue; auto index = b.createOrFold(loc, i); Value size = b.create(loc, value, index); dynamicOperands.push_back(size); } // TODO: Use alloc/memcpy callback from BufferizationOptions if called via // BufferizableOpInterface impl of ToMemrefOp. Value copy = b.create(loc, destType, dynamicOperands); b.create(loc, value, copy); return copy; } /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. LogicalResult mlir::bufferization::foldToMemrefToTensorPair( RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { auto memrefToTensor = toMemref.tensor().getDefiningOp(); if (!memrefToTensor) return failure(); Type srcType = memrefToTensor.memref().getType(); Type destType = toMemref.getType(); // Directly rewrite if the type did not change. if (srcType == destType) { // Function can be configured to only handle cases where a cast is needed. if (!allowSameType) return failure(); rewriter.replaceOp(toMemref, memrefToTensor.memref()); return success(); } auto rankedSrcType = srcType.dyn_cast(); auto rankedDestType = destType.dyn_cast(); auto unrankedSrcType = srcType.dyn_cast(); // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { FailureOr replacement = castOrReallocMemRefValue( rewriter, memrefToTensor.memref(), rankedDestType); if (failed(replacement)) return failure(); rewriter.replaceOp(toMemref, *replacement); return success(); } // Unranked memref -> Ranked memref cast: May require a copy. // TODO: Not implemented at the moment. if (unrankedSrcType && rankedDestType) return failure(); // Unranked memref -> unranked memref cast // Ranked memref -> unranked memref cast: No copy needed. assert(memref::CastOp::areCastCompatible(srcType, destType) && "expected that types are cast compatible"); rewriter.replaceOpWithNewOp(toMemref, destType, memrefToTensor.memref()); return success(); } //===----------------------------------------------------------------------===// // AllocTensorOp //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, BufferizationState &state) { // Nothing to do for dead AllocTensorOps. if (getOperation()->getUses().empty()) return success(); FailureOr alloc = state.createAlloc(rewriter, getLoc(), getResult()); if (failed(alloc)) return failure(); replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); return success(); } void AllocTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, ArrayRef attrs) { SmallVector dynamicSizes; SmallVector staticSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); auto resultType = RankedTensorType ::get(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); } LogicalResult AllocTensorOp::verify() { RankedTensorType resultType = getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( static_sizes().cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); if (failed(verifyListOfOperandsOrIntegers( *this, "sizes", resultType.getRank(), static_sizes(), sizes(), ShapedType::isDynamic))) return failure(); if (static_sizes().size() != static_cast(resultType.getRank())) return emitError("expected ") << resultType.getRank() << " sizes values"; Type expectedType = AllocTensorOp::inferResultType( staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } return success(); } Type AllocTensorOp::inferResultType(ArrayRef staticSizes, Type elementType, Attribute encoding) { return RankedTensorType::get(staticSizes, elementType, encoding); } SmallVector AllocTensorOp::getMixedSizes() { SmallVector mixedSizes; mixedSizes.reserve(getType().getRank()); unsigned dynamicValIndex = 0; for (Attribute attr : static_sizes()) { auto intAttr = attr.cast(); if (!ShapedType::isDynamic(intAttr.getInt())) { mixedSizes.push_back(intAttr); continue; } mixedSizes.push_back(sizes()[dynamicValIndex++]); } return mixedSizes; } namespace { /// Change the type of the result of a `bufferization.alloc_tensor` by making /// the result type statically sized along dimension that in the original /// operation where defined as dynamic, but the size was defined using a /// `constant` op. For example: /// /// %c5 = arith.constant 5: index /// %0 = bufferization.alloc_tensor [%arg0, %c5] : tensor /// /// to /// /// %0 = bufferization.alloc_tensor [%arg0, 5] : tensor struct ReplaceStaticShapeDims : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocTensorOp op, PatternRewriter &rewriter) const override { SmallVector dynamicSizes; SmallVector staticSizes; for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { // If the size is already static, nothing to do. if (!op.isDynamicSize(i)) { staticSizes.push_back(op.getStaticSize(i)); continue; } // If the size is dynamic but defined using a `constant` op, get the // constant value to find the static size to use. unsigned operandNum = op.getIndexOfDynamicSize(i); Value sizeOperand = op.getOperand(operandNum); if (auto constantIndexOp = sizeOperand.getDefiningOp()) { staticSizes.push_back(constantIndexOp.value()); continue; } // Fallback case. Keep the size dynamic. dynamicSizes.push_back(sizeOperand); staticSizes.push_back(ShapedType::kDynamicSize); } RankedTensorType newType = RankedTensorType::get(staticSizes, op.getType().getElementType()); if (newType == op.getType()) return failure(); auto newOp = rewriter.create(op.getLoc(), newType, dynamicSizes, rewriter.getI64ArrayAttr(staticSizes)); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } }; struct FoldDimOfAllocTensorOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { Optional maybeConstantIndex = dimOp.getConstantIndex(); auto allocTensorOp = dimOp.source().getDefiningOp(); if (!allocTensorOp || !maybeConstantIndex) return failure(); if (!allocTensorOp.isDynamicSize(*maybeConstantIndex)) return failure(); rewriter.replaceOp(dimOp, allocTensorOp.getDynamicSize(*maybeConstantIndex)); return success(); } }; } // namespace void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *ctx) { results.add(ctx); } LogicalResult AllocTensorOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto shapes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { if (isDynamicSize(dim)) return getDynamicSize(dim); return builder.create(getLoc(), getStaticSize(dim)); })); reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// void CloneOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), input(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), output(), SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), output(), SideEffects::DefaultResource::get()); } OpFoldResult CloneOp::fold(ArrayRef operands) { return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); } namespace { /// Merge the clone and its source (by converting the clone to a cast) when /// possible. struct SimplifyClones : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CloneOp cloneOp, PatternRewriter &rewriter) const override { if (cloneOp.use_empty()) { rewriter.eraseOp(cloneOp); return success(); } Value source = cloneOp.input(); // This only finds dealloc operations for the immediate value. It should // also consider aliases. That would also make the safety check below // redundant. llvm::Optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.output()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.hasValue()) return failure(); llvm::Optional maybeSourceDeallocOp = memref::findDealloc(source); if (!maybeSourceDeallocOp.hasValue()) return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; Operation *sourceDeallocOp = *maybeSourceDeallocOp; // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. if (cloneDeallocOp && sourceDeallocOp && cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) return failure(); Block *currentBlock = cloneOp->getBlock(); Operation *redundantDealloc = nullptr; if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { redundantDealloc = cloneDeallocOp; } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { redundantDealloc = sourceDeallocOp; } if (!redundantDealloc) return failure(); // Safety check that there are no other deallocations inbetween // cloneOp and redundantDealloc, as otherwise we might deallocate an alias // of source before the uses of the clone. With alias information, we could // restrict this to only fail of the dealloc's operand is an alias // of the source. for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; pos = pos->getNextNode()) { auto effectInterface = dyn_cast(pos); if (!effectInterface) continue; if (effectInterface.hasEffect()) return failure(); } rewriter.replaceOpWithNewOp(cloneOp, cloneOp.getType(), source); rewriter.eraseOp(redundantDealloc); return success(); } }; } // namespace void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // ToTensorOp //===----------------------------------------------------------------------===// OpFoldResult ToTensorOp::fold(ArrayRef) { if (auto toMemref = memref().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. if (toMemref->getBlock() == this->getOperation()->getBlock() && toMemref->getNextNode() == this->getOperation()) return toMemref.tensor(); return {}; } namespace { struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { auto memrefToTensorOp = dimOp.source().getDefiningOp(); if (!memrefToTensorOp) return failure(); rewriter.replaceOpWithNewOp(dimOp, memrefToTensorOp.memref(), dimOp.index()); return success(); } }; } // namespace void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // ToMemrefOp //===----------------------------------------------------------------------===// OpFoldResult ToMemrefOp::fold(ArrayRef) { if (auto memrefToTensor = tensor().getDefiningOp()) if (memrefToTensor.memref().getType() == getType()) return memrefToTensor.memref(); return {}; } namespace { /// Replace tensor.cast + to_memref by to_memref + memref.cast. struct ToMemrefOfCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { auto tensorCastOperand = toMemref.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto srcTensorType = tensorCastOperand.getOperand().getType().dyn_cast(); if (!srcTensorType) return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), srcTensorType.getElementType()); Value memref = rewriter.create(toMemref.getLoc(), memrefType, tensorCastOperand.getOperand()); rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), memref); return success(); } }; /// Canonicalize bufferization.to_tensor + bufferization.to_memref to /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. struct TensorLoadToMemref : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { // Only handle cases where a cast is needed. The other case is handled by // the folder. return foldToMemrefToTensorPair(rewriter, toMemref, /*allowSameType=*/false); } }; /// Fold a load on a to_memref operation into an tensor.extract on the /// corresponding tensor. struct LoadOfToMemref : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::LoadOp load, PatternRewriter &rewriter) const override { auto toMemref = load.memref().getDefiningOp(); if (!toMemref) return failure(); rewriter.replaceOpWithNewOp(load, toMemref.tensor(), load.indices()); return success(); } }; /// Fold dim of a to_memref into the dim of the tensor. struct DimOfCastOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::DimOp dimOp, PatternRewriter &rewriter) const override { auto castOp = dimOp.source().getDefiningOp(); if (!castOp) return failure(); Value newSource = castOp.getOperand(); rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.index()); return success(); } }; } // namespace void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add( context); } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, BufferizationState &state) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. (void)foldToMemrefToTensorPair(rewriter, *this); // Note: The return value of `bufferize` indicates whether there was an error // or not. (And not whether the pattern matched or not.) return success(); } Optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc) .getOperation(); } Optional CloneOp::buildClone(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc).getResult(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"