149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1249e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1371bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1449e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1549e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1649e37000SMatthias Springer #include "mlir/IR/Operation.h" 1749e37000SMatthias Springer 1849e37000SMatthias Springer using namespace mlir; 1949e37000SMatthias Springer using namespace mlir::bufferization; 2049e37000SMatthias Springer using namespace mlir::tensor; 2149e37000SMatthias Springer 2249e37000SMatthias Springer namespace mlir { 2349e37000SMatthias Springer namespace tensor { 2449e37000SMatthias Springer namespace { 2549e37000SMatthias Springer 2649e37000SMatthias Springer struct CastOpInterface 2749e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2849e37000SMatthias Springer tensor::CastOp> { 2949e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 309597b16aSMatthias Springer const AnalysisState &state) const { 3149e37000SMatthias Springer return false; 3249e37000SMatthias Springer } 3349e37000SMatthias Springer 3449e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 359597b16aSMatthias Springer const AnalysisState &state) const { 3649e37000SMatthias Springer return false; 3749e37000SMatthias Springer } 3849e37000SMatthias Springer 399597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 409597b16aSMatthias Springer const AnalysisState &state) const { 41585a8a32SMatthias Springer return {op->getResult(0)}; 4249e37000SMatthias Springer } 4349e37000SMatthias Springer 4449e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 459597b16aSMatthias Springer const AnalysisState &state) const { 4649e37000SMatthias Springer return BufferRelation::Equivalent; 4749e37000SMatthias Springer } 4849e37000SMatthias Springer 4949e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 509597b16aSMatthias Springer BufferizationState &state) const { 5149e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5249e37000SMatthias Springer 5349e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 5449e37000SMatthias Springer FailureOr<Value> resultBuffer = 5549e37000SMatthias Springer state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 5649e37000SMatthias Springer if (failed(resultBuffer)) 5749e37000SMatthias Springer return failure(); 5849e37000SMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 5949e37000SMatthias Springer Attribute memorySpace = sourceMemRefType.getMemorySpace(); 6049e37000SMatthias Springer TensorType resultTensorType = 6149e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6249e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6349e37000SMatthias Springer 6449e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6549e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6649e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6749e37000SMatthias Springer 6849e37000SMatthias Springer // Compute the new memref type. 6926852423SMatthias Springer Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 7026852423SMatthias Springer layout, memorySpace); 7149e37000SMatthias Springer 7249e37000SMatthias Springer // Replace the op with a memref.cast. 7349e37000SMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 7449e37000SMatthias Springer resultMemRefType) && 7549e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 7649e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 7749e37000SMatthias Springer *resultBuffer); 7849e37000SMatthias Springer 7949e37000SMatthias Springer return success(); 8049e37000SMatthias Springer } 8149e37000SMatthias Springer }; 8249e37000SMatthias Springer 83e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 84e6f69161SMatthias Springer struct CollapseShapeOpInterface 85e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 86e6f69161SMatthias Springer tensor::CollapseShapeOp> { 87e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 889597b16aSMatthias Springer const AnalysisState &state) const { 89e6f69161SMatthias Springer return false; 90e6f69161SMatthias Springer } 91e6f69161SMatthias Springer 92e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 939597b16aSMatthias Springer const AnalysisState &state) const { 94e6f69161SMatthias Springer return false; 95e6f69161SMatthias Springer } 96e6f69161SMatthias Springer 979597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 989597b16aSMatthias Springer const AnalysisState &state) const { 99e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 100e6f69161SMatthias Springer return {op->getOpResult(0)}; 101e6f69161SMatthias Springer return {}; 102e6f69161SMatthias Springer } 103e6f69161SMatthias Springer 104e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1059597b16aSMatthias Springer const AnalysisState &state) const { 106e6f69161SMatthias Springer return BufferRelation::Equivalent; 107e6f69161SMatthias Springer } 108e6f69161SMatthias Springer 109e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1109597b16aSMatthias Springer BufferizationState &state) const { 111e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 11251df6238SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 113d7a9bf91SMatthias Springer OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; 114d7a9bf91SMatthias Springer auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>(); 11551df6238SMatthias Springer 11651df6238SMatthias Springer if (tensorResultType.getRank() == 0) { 11751df6238SMatthias Springer // 0-d collapses must go through a different op builder. 118d7a9bf91SMatthias Springer Value buffer = *state.getBuffer(rewriter, srcOperand); 11973c0333dSMatthias Springer MemRefType resultType; 12073c0333dSMatthias Springer 12173c0333dSMatthias Springer if (bufferType.getLayout().isIdentity()) { 12273c0333dSMatthias Springer // Standard layout: result type has no offset. 12351df6238SMatthias Springer MemRefLayoutAttrInterface layout; 12473c0333dSMatthias Springer resultType = MemRefType::get({}, tensorResultType.getElementType(), 12551df6238SMatthias Springer layout, bufferType.getMemorySpace()); 12673c0333dSMatthias Springer } else { 12773c0333dSMatthias Springer // Source memref has a layout map: result type has the same offset as 12873c0333dSMatthias Springer // the source type. 12973c0333dSMatthias Springer SmallVector<int64_t> strides; 13073c0333dSMatthias Springer int64_t offset; 13173c0333dSMatthias Springer if (failed(getStridesAndOffset(bufferType, strides, offset))) 13273c0333dSMatthias Springer return failure(); 13373c0333dSMatthias Springer AffineMap resultLayout = 13473c0333dSMatthias Springer makeStridedLinearLayoutMap({}, offset, op->getContext()); 13573c0333dSMatthias Springer resultType = 13673c0333dSMatthias Springer MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 13773c0333dSMatthias Springer bufferType.getMemorySpaceAsInt()); 13873c0333dSMatthias Springer } 13973c0333dSMatthias Springer 140e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 141e6f69161SMatthias Springer rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 142e6f69161SMatthias Springer return success(); 143e6f69161SMatthias Springer } 14451df6238SMatthias Springer 145d7a9bf91SMatthias Springer // If the dims are not collapsible (due to an incompatible source layout 146d7a9bf91SMatthias Springer // map), force an out-of-place bufferization, i.e., a buffer copy. This 147d7a9bf91SMatthias Springer // newly allocated buffer will have no layout map and thus be collapsible. 148a74e5a89SAdrian Kuegel bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 149d7a9bf91SMatthias Springer bufferType, collapseShapeOp.getReassociationIndices()); 150d7a9bf91SMatthias Springer Optional<BufferizationState::ForceInPlacability> overrideInPlace = 151d7a9bf91SMatthias Springer canBeCollapsed 152d7a9bf91SMatthias Springer ? None 153d7a9bf91SMatthias Springer : Optional<BufferizationState::ForceInPlacability>( 154d7a9bf91SMatthias Springer BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); 155d7a9bf91SMatthias Springer Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace); 156d7a9bf91SMatthias Springer 15751df6238SMatthias Springer // Result type is inferred by the builder. 15851df6238SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 15951df6238SMatthias Springer rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 16051df6238SMatthias Springer return success(); 16151df6238SMatthias Springer } 162e6f69161SMatthias Springer }; 163e6f69161SMatthias Springer 16449e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 16549e37000SMatthias Springer struct DimOpInterface 16649e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 16749e37000SMatthias Springer tensor::DimOp> { 16849e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1699597b16aSMatthias Springer const AnalysisState &state) const { 17049e37000SMatthias Springer return true; 17149e37000SMatthias Springer } 17249e37000SMatthias Springer 17349e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1749597b16aSMatthias Springer const AnalysisState &state) const { 17549e37000SMatthias Springer return false; 17649e37000SMatthias Springer } 17749e37000SMatthias Springer 1789597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1799597b16aSMatthias Springer const AnalysisState &state) const { 180585a8a32SMatthias Springer return {}; 18149e37000SMatthias Springer } 18249e37000SMatthias Springer 18349e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1849597b16aSMatthias Springer BufferizationState &state) const { 18549e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 18649e37000SMatthias Springer Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 18749e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 18849e37000SMatthias Springer return success(); 18949e37000SMatthias Springer } 19049e37000SMatthias Springer }; 19149e37000SMatthias Springer 192e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 193e6f69161SMatthias Springer struct ExpandShapeOpInterface 194e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 195e6f69161SMatthias Springer tensor::ExpandShapeOp> { 196e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1979597b16aSMatthias Springer const AnalysisState &state) const { 198e6f69161SMatthias Springer return false; 199e6f69161SMatthias Springer } 200e6f69161SMatthias Springer 201e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2029597b16aSMatthias Springer const AnalysisState &state) const { 203e6f69161SMatthias Springer return false; 204e6f69161SMatthias Springer } 205e6f69161SMatthias Springer 2069597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2079597b16aSMatthias Springer const AnalysisState &state) const { 208e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 209e6f69161SMatthias Springer return {op->getOpResult(0)}; 210e6f69161SMatthias Springer return {}; 211e6f69161SMatthias Springer } 212e6f69161SMatthias Springer 213e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2149597b16aSMatthias Springer const AnalysisState &state) const { 215e6f69161SMatthias Springer return BufferRelation::Equivalent; 216e6f69161SMatthias Springer } 217e6f69161SMatthias Springer 218e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 2199597b16aSMatthias Springer BufferizationState &state) const { 220e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 22151df6238SMatthias Springer auto tensorResultType = expandShapeOp.getResultType(); 222e6f69161SMatthias Springer Value buffer = 223e6f69161SMatthias Springer *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 22451df6238SMatthias Springer 22551df6238SMatthias Springer // Memref result type is inferred by the builder based on reassociation 22651df6238SMatthias Springer // indices and result shape. 227e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 22851df6238SMatthias Springer rewriter, op, tensorResultType.getShape(), buffer, 22951df6238SMatthias Springer expandShapeOp.getReassociationIndices()); 230e6f69161SMatthias Springer return success(); 231e6f69161SMatthias Springer } 232e6f69161SMatthias Springer }; 233e6f69161SMatthias Springer 23449e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 23549e37000SMatthias Springer struct ExtractSliceOpInterface 23649e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 23749e37000SMatthias Springer tensor::ExtractSliceOp> { 23849e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2399597b16aSMatthias Springer const AnalysisState &state) const { 24049e37000SMatthias Springer return false; 24149e37000SMatthias Springer } 24249e37000SMatthias Springer 24349e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2449597b16aSMatthias Springer const AnalysisState &state) const { 24549e37000SMatthias Springer return false; 24649e37000SMatthias Springer } 24749e37000SMatthias Springer 2489597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2499597b16aSMatthias Springer const AnalysisState &state) const { 250585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*source*/) 251585a8a32SMatthias Springer return {op->getOpResult(0)}; 252585a8a32SMatthias Springer return {}; 25349e37000SMatthias Springer } 25449e37000SMatthias Springer 25549e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2569597b16aSMatthias Springer const AnalysisState &state) const { 25749e37000SMatthias Springer return BufferRelation::None; 25849e37000SMatthias Springer } 25949e37000SMatthias Springer 26049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 2619597b16aSMatthias Springer BufferizationState &state) const { 26249e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 26349e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 264d7a9bf91SMatthias Springer 265d7a9bf91SMatthias Springer // Even if this op was decided to bufferize out-of-place, do not insert the 266d7a9bf91SMatthias Springer // buffer copy yet. This is done later in this function. 26749e37000SMatthias Springer Value srcMemref = 26849e37000SMatthias Springer *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 269d7a9bf91SMatthias Springer BufferizationState::ForceInPlacability::FORCE_INPLACE); 27049e37000SMatthias Springer auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 27149e37000SMatthias Springer auto dstTensorType = 27249e37000SMatthias Springer extractSliceOp.result().getType().cast<RankedTensorType>(); 27349e37000SMatthias Springer 27449e37000SMatthias Springer // If not inplaceable, alloc. 2759597b16aSMatthias Springer bool inplace = 2769597b16aSMatthias Springer state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 27749e37000SMatthias Springer Value alloc; 27849e37000SMatthias Springer if (!inplace) { 27949e37000SMatthias Springer FailureOr<Value> allocOrFailure = 28005e0495fSMatthias Springer state.createAlloc(rewriter, loc, extractSliceOp.result()); 28149e37000SMatthias Springer if (failed(allocOrFailure)) 28249e37000SMatthias Springer return failure(); 28349e37000SMatthias Springer alloc = *allocOrFailure; 28449e37000SMatthias Springer } 28549e37000SMatthias Springer 28649e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 28749e37000SMatthias Springer // rank-reducing case. 28849e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 28949e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 29049e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 29149e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 29249e37000SMatthias Springer srcMemref, mixedOffsets, mixedSizes, mixedStrides, 29349e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 29449e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 29549e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 29649e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 29749e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 29849e37000SMatthias Springer }); 29949e37000SMatthias Springer // Bufferize to subview. 30049e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 30149e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 30249e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 30349e37000SMatthias Springer .cast<MemRefType>(); 30449e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 30549e37000SMatthias Springer loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 30649e37000SMatthias Springer mixedStrides); 30749e37000SMatthias Springer 30849e37000SMatthias Springer // If not inplaceable, copy. 30949e37000SMatthias Springer if (!inplace) { 31049e37000SMatthias Springer // Do not copy if the copied data is never read. 3119597b16aSMatthias Springer if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 31249e37000SMatthias Springer if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 31349e37000SMatthias Springer alloc, state.getOptions()))) 31449e37000SMatthias Springer return failure(); 31549e37000SMatthias Springer subView = alloc; 31649e37000SMatthias Springer } 31749e37000SMatthias Springer 31849e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 31949e37000SMatthias Springer return success(); 32049e37000SMatthias Springer } 32149e37000SMatthias Springer }; 32249e37000SMatthias Springer 32349e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 32449e37000SMatthias Springer struct ExtractOpInterface 32549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 32649e37000SMatthias Springer tensor::ExtractOp> { 32749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 3289597b16aSMatthias Springer const AnalysisState &state) const { 32949e37000SMatthias Springer return true; 33049e37000SMatthias Springer } 33149e37000SMatthias Springer 33249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3339597b16aSMatthias Springer const AnalysisState &state) const { 33449e37000SMatthias Springer return false; 33549e37000SMatthias Springer } 33649e37000SMatthias Springer 3379597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 3389597b16aSMatthias Springer const AnalysisState &state) const { 339585a8a32SMatthias Springer return {}; 34049e37000SMatthias Springer } 34149e37000SMatthias Springer 34249e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 3439597b16aSMatthias Springer BufferizationState &state) const { 34449e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 34549e37000SMatthias Springer Value srcMemref = 34649e37000SMatthias Springer *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 34749e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 34849e37000SMatthias Springer extractOp.indices()); 34949e37000SMatthias Springer return success(); 35049e37000SMatthias Springer } 35149e37000SMatthias Springer }; 35249e37000SMatthias Springer 353d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 354d581c94dSMatthias Springer // iterating over op.elements(). 355d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 356d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 357d581c94dSMatthias Springer ArrayRef<Value> constants, 358d581c94dSMatthias Springer OperandRange::iterator &elementIt, 359d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 360d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 361d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 362d581c94dSMatthias Springer indices.back() = constants[i]; 363d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 364d581c94dSMatthias Springer ++elementIt; 365d581c94dSMatthias Springer } 366d581c94dSMatthias Springer return; 367d581c94dSMatthias Springer } 368d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 369d581c94dSMatthias Springer indices[dim] = constants[i]; 370d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 371d581c94dSMatthias Springer indices); 372d581c94dSMatthias Springer } 373d581c94dSMatthias Springer } 374d581c94dSMatthias Springer 375d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 376d581c94dSMatthias Springer struct FromElementsOpInterface 377d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 378d581c94dSMatthias Springer tensor::FromElementsOp> { 379d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 3809597b16aSMatthias Springer BufferizationState &state) const { 381d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 382d581c94dSMatthias Springer 383d581c94dSMatthias Springer // Allocate a buffer for the result. 384d581c94dSMatthias Springer Location loc = op->getLoc(); 385d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 386d581c94dSMatthias Springer auto shape = tensorType.getShape(); 387d581c94dSMatthias Springer FailureOr<Value> maybeBuffer = 3889e24f0f4SMatthias Springer state.createAlloc(rewriter, loc, fromElementsOp.result()); 389d581c94dSMatthias Springer if (failed(maybeBuffer)) 390d581c94dSMatthias Springer return failure(); 391d581c94dSMatthias Springer Value buffer = *maybeBuffer; 392d581c94dSMatthias Springer 393d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 394d581c94dSMatthias Springer if (fromElementsOp.elements().empty()) { 395d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 396d581c94dSMatthias Springer return success(); 397d581c94dSMatthias Springer } 398d581c94dSMatthias Springer 399d581c94dSMatthias Springer // Case: tensor<elem_type>. 400d581c94dSMatthias Springer if (shape.empty()) { 401d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 402d581c94dSMatthias Springer buffer); 403d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 404d581c94dSMatthias Springer return success(); 405d581c94dSMatthias Springer } 406d581c94dSMatthias Springer 407d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 408d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end()); 409d581c94dSMatthias Springer SmallVector<Value, 2> constants; 410d581c94dSMatthias Springer constants.reserve(maxDim); 411d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 412d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 413d581c94dSMatthias Springer 414d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 415d581c94dSMatthias Springer auto elementIt = fromElementsOp.elements().begin(); 416d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 417d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 418d581c94dSMatthias Springer indices); 419d581c94dSMatthias Springer 420d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 421d581c94dSMatthias Springer return success(); 422d581c94dSMatthias Springer } 423d581c94dSMatthias Springer }; 424d581c94dSMatthias Springer 42571bbb78bSMatthias Springer /// Bufferization of tensor.generate. 42671bbb78bSMatthias Springer struct GenerateOpInterface 42771bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 42871bbb78bSMatthias Springer tensor::GenerateOp> { 42971bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4309597b16aSMatthias Springer BufferizationState &state) const { 43171bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 43271bbb78bSMatthias Springer 43371bbb78bSMatthias Springer // Allocate memory. 43471bbb78bSMatthias Springer Location loc = op->getLoc(); 43571bbb78bSMatthias Springer MemRefType memrefType = 43671bbb78bSMatthias Springer getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 4379e24f0f4SMatthias Springer FailureOr<Value> maybeResult = 4389e24f0f4SMatthias Springer state.createAlloc(rewriter, loc, generateOp.result()); 43971bbb78bSMatthias Springer if (failed(maybeResult)) 44071bbb78bSMatthias Springer return failure(); 44171bbb78bSMatthias Springer Value result = *maybeResult; 44271bbb78bSMatthias Springer 44371bbb78bSMatthias Springer // Collect loop bounds. 44471bbb78bSMatthias Springer int64_t rank = memrefType.getRank(); 44571bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 44671bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 44771bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero); 44871bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one); 44971bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds; 45071bbb78bSMatthias Springer int nextDynamicIndex = 0; 45171bbb78bSMatthias Springer for (int i = 0; i < rank; i++) { 45271bbb78bSMatthias Springer Value upperBound = memrefType.isDynamicDim(i) 45371bbb78bSMatthias Springer ? generateOp.dynamicExtents()[nextDynamicIndex++] 45471bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>( 45571bbb78bSMatthias Springer loc, memrefType.getDimSize(i)); 45671bbb78bSMatthias Springer upperBounds.push_back(upperBound); 45771bbb78bSMatthias Springer } 45871bbb78bSMatthias Springer 45971bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into 46071bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move" 46171bbb78bSMatthias Springer // this op's body into the scf.parallel's body. 46271bbb78bSMatthias Springer auto parallel = 46371bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 46471bbb78bSMatthias Springer Block *parallelBody = parallel.getBody(); 46571bbb78bSMatthias Springer rewriter.mergeBlockBefore(generateOp.getBody(), 46671bbb78bSMatthias Springer parallelBody->getTerminator(), 46771bbb78bSMatthias Springer parallelBody->getArguments()); 46871bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder 46971bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry 47071bbb78bSMatthias Springer // about creating that. 47171bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 47271bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield); 47371bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 47471bbb78bSMatthias Springer elementYield, elementYield->getOperands()[0], result, 47571bbb78bSMatthias Springer parallelBody->getArguments()); 47671bbb78bSMatthias Springer 47771bbb78bSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, result); 47871bbb78bSMatthias Springer return success(); 47971bbb78bSMatthias Springer } 48071bbb78bSMatthias Springer }; 48171bbb78bSMatthias Springer 48249e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 48349e37000SMatthias Springer struct InsertOpInterface 48449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 48549e37000SMatthias Springer tensor::InsertOp> { 48649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4879597b16aSMatthias Springer const AnalysisState &state) const { 48849e37000SMatthias Springer return true; 48949e37000SMatthias Springer } 49049e37000SMatthias Springer 49149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4929597b16aSMatthias Springer const AnalysisState &state) const { 49349e37000SMatthias Springer return true; 49449e37000SMatthias Springer } 49549e37000SMatthias Springer 4969597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 4979597b16aSMatthias Springer const AnalysisState &state) const { 49849e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 49949e37000SMatthias Springer "expected dest OpOperand"); 500585a8a32SMatthias Springer return {op->getOpResult(0)}; 50149e37000SMatthias Springer } 50249e37000SMatthias Springer 50349e37000SMatthias Springer SmallVector<OpOperand *> 50449e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 5059597b16aSMatthias Springer const AnalysisState &state) const { 50649e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 50749e37000SMatthias Springer } 50849e37000SMatthias Springer 50949e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 5109597b16aSMatthias Springer BufferizationState &state) const { 51149e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 51249e37000SMatthias Springer FailureOr<Value> destMemref = 51349e37000SMatthias Springer state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 51449e37000SMatthias Springer if (failed(destMemref)) 51549e37000SMatthias Springer return failure(); 51649e37000SMatthias Springer rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 51749e37000SMatthias Springer *destMemref, insertOp.indices()); 51849e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 51949e37000SMatthias Springer return success(); 52049e37000SMatthias Springer } 52149e37000SMatthias Springer 52249e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5239597b16aSMatthias Springer const AnalysisState &state) const { 52449e37000SMatthias Springer return BufferRelation::Equivalent; 52549e37000SMatthias Springer } 52649e37000SMatthias Springer }; 52749e37000SMatthias Springer 52849e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 52949e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 53049e37000SMatthias Springer /// 53149e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 53249e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 53349e37000SMatthias Springer /// exposed as interfaces on the proper types. 5349597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state, 53549e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 53649e37000SMatthias Springer if (!st || !sti) 53749e37000SMatthias Springer return false; 53849e37000SMatthias Springer if (sti != sti && 53949e37000SMatthias Springer !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 54049e37000SMatthias Springer return false; 54149e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 54249e37000SMatthias Springer return false; 54349e37000SMatthias Springer return true; 54449e37000SMatthias Springer } 54549e37000SMatthias Springer 54649e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 54749e37000SMatthias Springer /// the given InsertSliceOp. 5489597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 5499597b16aSMatthias Springer InsertSliceOp insertOp) { 55049e37000SMatthias Springer auto condition = [&](Value val) { 55149e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 55249e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 55349e37000SMatthias Springer return true; 55449e37000SMatthias Springer return false; 55549e37000SMatthias Springer }; 55649e37000SMatthias Springer 55749e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 55849e37000SMatthias Springer condition); 55949e37000SMatthias Springer } 56049e37000SMatthias Springer 56149e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 56249e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 56349e37000SMatthias Springer struct InsertSliceOpInterface 56449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 56549e37000SMatthias Springer tensor::InsertSliceOp> { 56649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 5679597b16aSMatthias Springer const AnalysisState &state) const { 56849e37000SMatthias Springer return true; 56949e37000SMatthias Springer } 57049e37000SMatthias Springer 57149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 5729597b16aSMatthias Springer const AnalysisState &state) const { 57349e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 57449e37000SMatthias Springer } 57549e37000SMatthias Springer 5769597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 5779597b16aSMatthias Springer const AnalysisState &state) const { 578585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(1) /*dest*/) 579585a8a32SMatthias Springer return {op->getResult(0)}; 580585a8a32SMatthias Springer return {}; 58149e37000SMatthias Springer } 58249e37000SMatthias Springer 58349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5849597b16aSMatthias Springer const AnalysisState &state) const { 58549e37000SMatthias Springer return BufferRelation::Equivalent; 58649e37000SMatthias Springer } 58749e37000SMatthias Springer 58849e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 58949e37000SMatthias Springer OpOperand *uConflictingWrite, 5909597b16aSMatthias Springer const AnalysisState &state) const { 59149e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 59249e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 59349e37000SMatthias Springer 59449e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 59549e37000SMatthias Springer // uRead is an InsertSliceOp... 59649e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 59749e37000SMatthias Springer // As an example, consider the following IR. 59849e37000SMatthias Springer // 59949e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 60049e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 60149e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 60249e37000SMatthias Springer // {inplace= [true] } 60349e37000SMatthias Springer 60449e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 60549e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 60649e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 60749e37000SMatthias Springer insertSliceOp)) 60849e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 60949e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 61049e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 61149e37000SMatthias Springer // being read by uRead, this is not a conflict. 61249e37000SMatthias Springer // 61349e37000SMatthias Springer // In the above example: 61449e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 61549e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 61649e37000SMatthias Springer // 61749e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 61849e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 61949e37000SMatthias Springer // exactly the one that is *not* read via %t. 62049e37000SMatthias Springer return true; 62149e37000SMatthias Springer 62249e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 62349e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 62449e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 62549e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 62649e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 62749e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 62849e37000SMatthias Springer // InsertSliceOp is writing. 62949e37000SMatthias Springer // 63049e37000SMatthias Springer // In the above example: 63149e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 63249e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 63349e37000SMatthias Springer return true; 63449e37000SMatthias Springer } 63549e37000SMatthias Springer 63649e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 63749e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 63849e37000SMatthias Springer // As an example, consider the following IR. 63949e37000SMatthias Springer // 64049e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 64149e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 64249e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 64349e37000SMatthias Springer // {inplace= [true] } 64449e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 64549e37000SMatthias Springer // 64649e37000SMatthias Springer // In the above example: 64749e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 64849e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 64949e37000SMatthias Springer // lastWrite = %1 65049e37000SMatthias Springer // 65149e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 65249e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 65349e37000SMatthias Springer // is no memory write here.) 65449e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 65549e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 65649e37000SMatthias Springer insertSliceOp.source()) && 65749e37000SMatthias Springer hasMatchingExtractSliceOp(state, insertSliceOp.source(), 65849e37000SMatthias Springer insertSliceOp)) 65949e37000SMatthias Springer return true; 66049e37000SMatthias Springer 66149e37000SMatthias Springer return false; 66249e37000SMatthias Springer } 66349e37000SMatthias Springer 66449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 6659597b16aSMatthias Springer BufferizationState &state) const { 66649e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 66749e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 66849e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 66949e37000SMatthias Springer // catastrophically bad scheduling decision. 67049e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 67149e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 67249e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 67349e37000SMatthias Springer 67449e37000SMatthias Springer // When bufferizing out-of-place, `getResultBuffer` allocates. 67549e37000SMatthias Springer FailureOr<Value> dstMemref = 67649e37000SMatthias Springer state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 67749e37000SMatthias Springer if (failed(dstMemref)) 67849e37000SMatthias Springer return failure(); 67949e37000SMatthias Springer 68049e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 68149e37000SMatthias Springer // rank-reducing case. 68249e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 68349e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 68449e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 68549e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 68649e37000SMatthias Springer *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 68749e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 68849e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 68949e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 69049e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 69149e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 69249e37000SMatthias Springer }); 69349e37000SMatthias Springer // Take a subview of the dst. 69449e37000SMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 69549e37000SMatthias Springer auto subviewMemRefType = 69649e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 69749e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 69849e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 69949e37000SMatthias Springer .cast<MemRefType>(); 70049e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 70149e37000SMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 70249e37000SMatthias Springer mixedStrides); 70349e37000SMatthias Springer 70449e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 70549e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 70649e37000SMatthias Springer Value srcMemref = 70749e37000SMatthias Springer *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 70849e37000SMatthias Springer if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 70949e37000SMatthias Springer state.getOptions()))) 71049e37000SMatthias Springer return failure(); 71149e37000SMatthias Springer 71249e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 71349e37000SMatthias Springer return success(); 71449e37000SMatthias Springer } 71549e37000SMatthias Springer }; 71649e37000SMatthias Springer 717fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 718fc08d1c2SMatthias Springer struct RankOpInterface 719fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 720fc08d1c2SMatthias Springer tensor::RankOp> { 721fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 7229597b16aSMatthias Springer const AnalysisState &state) const { 723fc08d1c2SMatthias Springer return true; 724fc08d1c2SMatthias Springer } 725fc08d1c2SMatthias Springer 726fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7279597b16aSMatthias Springer const AnalysisState &state) const { 728fc08d1c2SMatthias Springer return false; 729fc08d1c2SMatthias Springer } 730fc08d1c2SMatthias Springer 7319597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 7329597b16aSMatthias Springer const AnalysisState &state) const { 733585a8a32SMatthias Springer return {}; 734fc08d1c2SMatthias Springer } 735fc08d1c2SMatthias Springer 736fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 7379597b16aSMatthias Springer BufferizationState &state) const { 738fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 739fc08d1c2SMatthias Springer Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 740fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 741fc08d1c2SMatthias Springer v); 742fc08d1c2SMatthias Springer return success(); 743fc08d1c2SMatthias Springer } 744fc08d1c2SMatthias Springer }; 745fc08d1c2SMatthias Springer 746*e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape. 747*e287d647SAshay Rane struct ReshapeOpInterface 748*e287d647SAshay Rane : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 749*e287d647SAshay Rane tensor::ReshapeOp> { 750*e287d647SAshay Rane bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 751*e287d647SAshay Rane const AnalysisState &state) const { 752*e287d647SAshay Rane if (&opOperand == &op->getOpOperand(1) /* shape */) 753*e287d647SAshay Rane return true; 754*e287d647SAshay Rane return false; 755*e287d647SAshay Rane } 756*e287d647SAshay Rane 757*e287d647SAshay Rane bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 758*e287d647SAshay Rane const AnalysisState &state) const { 759*e287d647SAshay Rane return false; 760*e287d647SAshay Rane } 761*e287d647SAshay Rane 762*e287d647SAshay Rane SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 763*e287d647SAshay Rane const AnalysisState &state) const { 764*e287d647SAshay Rane return {op->getOpResult(0)}; 765*e287d647SAshay Rane } 766*e287d647SAshay Rane 767*e287d647SAshay Rane BufferRelation bufferRelation(Operation *op, OpResult opResult, 768*e287d647SAshay Rane const AnalysisState &state) const { 769*e287d647SAshay Rane return BufferRelation::Equivalent; 770*e287d647SAshay Rane } 771*e287d647SAshay Rane 772*e287d647SAshay Rane LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 773*e287d647SAshay Rane BufferizationState &state) const { 774*e287d647SAshay Rane auto reshapeOp = cast<tensor::ReshapeOp>(op); 775*e287d647SAshay Rane auto &srcOperand = reshapeOp->getOpOperand(0); 776*e287d647SAshay Rane auto srcBuffer = state.getBuffer(rewriter, srcOperand); 777*e287d647SAshay Rane if (failed(srcBuffer)) 778*e287d647SAshay Rane return failure(); 779*e287d647SAshay Rane 780*e287d647SAshay Rane auto &shapeOperand = reshapeOp->getOpOperand(1); 781*e287d647SAshay Rane auto shapeBuffer = state.getBuffer(rewriter, shapeOperand); 782*e287d647SAshay Rane if (failed(shapeBuffer)) 783*e287d647SAshay Rane return failure(); 784*e287d647SAshay Rane 785*e287d647SAshay Rane auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 786*e287d647SAshay Rane auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions()); 787*e287d647SAshay Rane 788*e287d647SAshay Rane replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 789*e287d647SAshay Rane rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 790*e287d647SAshay Rane return success(); 791*e287d647SAshay Rane } 792*e287d647SAshay Rane }; 793*e287d647SAshay Rane 79449e37000SMatthias Springer } // namespace 79549e37000SMatthias Springer } // namespace tensor 79649e37000SMatthias Springer } // namespace mlir 79749e37000SMatthias Springer 79849e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 79949e37000SMatthias Springer DialectRegistry ®istry) { 80077eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 80177eee579SRiver Riddle CastOp::attachInterface<CastOpInterface>(*ctx); 80277eee579SRiver Riddle CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 80377eee579SRiver Riddle DimOp::attachInterface<DimOpInterface>(*ctx); 80477eee579SRiver Riddle ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 80577eee579SRiver Riddle ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 80677eee579SRiver Riddle ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 80777eee579SRiver Riddle FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 80877eee579SRiver Riddle GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 80977eee579SRiver Riddle InsertOp::attachInterface<InsertOpInterface>(*ctx); 81077eee579SRiver Riddle InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 81177eee579SRiver Riddle RankOp::attachInterface<RankOpInterface>(*ctx); 812*e287d647SAshay Rane ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 81377eee579SRiver Riddle }); 81449e37000SMatthias Springer } 815