149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1349e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 1549e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1649e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1749e37000SMatthias Springer #include "mlir/IR/Operation.h" 1849e37000SMatthias Springer 1949e37000SMatthias Springer using namespace mlir; 2049e37000SMatthias Springer using namespace mlir::bufferization; 2149e37000SMatthias Springer using namespace mlir::tensor; 2249e37000SMatthias Springer 2349e37000SMatthias Springer namespace mlir { 2449e37000SMatthias Springer namespace tensor { 2549e37000SMatthias Springer namespace { 2649e37000SMatthias Springer 2749e37000SMatthias Springer struct CastOpInterface 2849e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2949e37000SMatthias Springer tensor::CastOp> { 3049e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 319597b16aSMatthias Springer const AnalysisState &state) const { 3249e37000SMatthias Springer return false; 3349e37000SMatthias Springer } 3449e37000SMatthias Springer 3549e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 369597b16aSMatthias Springer const AnalysisState &state) const { 3749e37000SMatthias Springer return false; 3849e37000SMatthias Springer } 3949e37000SMatthias Springer 409597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 419597b16aSMatthias Springer const AnalysisState &state) const { 42585a8a32SMatthias Springer return {op->getResult(0)}; 4349e37000SMatthias Springer } 4449e37000SMatthias Springer 4549e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 469597b16aSMatthias Springer const AnalysisState &state) const { 4749e37000SMatthias Springer return BufferRelation::Equivalent; 4849e37000SMatthias Springer } 4949e37000SMatthias Springer 5049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51b55d55ecSMatthias Springer const BufferizationOptions &options) const { 5249e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5349e37000SMatthias Springer 5449e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 555d50f51cSMatthias Springer FailureOr<Value> resultBuffer = 565d50f51cSMatthias Springer getBuffer(rewriter, castOp.getSource(), options); 575d50f51cSMatthias Springer if (failed(resultBuffer)) 585d50f51cSMatthias Springer return failure(); 595d50f51cSMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 6049e37000SMatthias Springer TensorType resultTensorType = 6149e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6249e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6349e37000SMatthias Springer 6449e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6549e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6649e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6749e37000SMatthias Springer 6849e37000SMatthias Springer // Compute the new memref type. 69b55d55ecSMatthias Springer Type resultMemRefType = 70b06614e2SMatthias Springer getMemRefType(resultTensorType, options, layout, 71b06614e2SMatthias Springer sourceMemRefType.getMemorySpaceAsInt()); 7249e37000SMatthias Springer 7349e37000SMatthias Springer // Replace the op with a memref.cast. 745d50f51cSMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 7549e37000SMatthias Springer resultMemRefType) && 7649e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 7749e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 785d50f51cSMatthias Springer *resultBuffer); 7949e37000SMatthias Springer 8049e37000SMatthias Springer return success(); 8149e37000SMatthias Springer } 8249e37000SMatthias Springer }; 8349e37000SMatthias Springer 84e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 85e6f69161SMatthias Springer struct CollapseShapeOpInterface 86e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 87e6f69161SMatthias Springer tensor::CollapseShapeOp> { 88e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 899597b16aSMatthias Springer const AnalysisState &state) const { 90e6f69161SMatthias Springer return false; 91e6f69161SMatthias Springer } 92e6f69161SMatthias Springer 93e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 949597b16aSMatthias Springer const AnalysisState &state) const { 95e6f69161SMatthias Springer return false; 96e6f69161SMatthias Springer } 97e6f69161SMatthias Springer 989597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 999597b16aSMatthias Springer const AnalysisState &state) const { 100e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 101e6f69161SMatthias Springer return {op->getOpResult(0)}; 102e6f69161SMatthias Springer return {}; 103e6f69161SMatthias Springer } 104e6f69161SMatthias Springer 105e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1069597b16aSMatthias Springer const AnalysisState &state) const { 107e6f69161SMatthias Springer return BufferRelation::Equivalent; 108e6f69161SMatthias Springer } 109e6f69161SMatthias Springer 110e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 111b55d55ecSMatthias Springer const BufferizationOptions &options) const { 112e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 11351df6238SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 1145d50f51cSMatthias Springer FailureOr<Value> maybeBuffer = 1155d50f51cSMatthias Springer getBuffer(rewriter, collapseShapeOp.getSrc(), options); 1165d50f51cSMatthias Springer if (failed(maybeBuffer)) 1175d50f51cSMatthias Springer return failure(); 1185d50f51cSMatthias Springer Value buffer = *maybeBuffer; 119b3ebe3beSMatthias Springer auto bufferType = buffer.getType().cast<MemRefType>(); 12051df6238SMatthias Springer 12151df6238SMatthias Springer if (tensorResultType.getRank() == 0) { 12251df6238SMatthias Springer // 0-d collapses must go through a different op builder. 12373c0333dSMatthias Springer MemRefType resultType; 12473c0333dSMatthias Springer 12573c0333dSMatthias Springer if (bufferType.getLayout().isIdentity()) { 12673c0333dSMatthias Springer // Standard layout: result type has no offset. 12751df6238SMatthias Springer MemRefLayoutAttrInterface layout; 12873c0333dSMatthias Springer resultType = MemRefType::get({}, tensorResultType.getElementType(), 12951df6238SMatthias Springer layout, bufferType.getMemorySpace()); 13073c0333dSMatthias Springer } else { 13173c0333dSMatthias Springer // Source memref has a layout map: result type has the same offset as 13273c0333dSMatthias Springer // the source type. 13373c0333dSMatthias Springer SmallVector<int64_t> strides; 13473c0333dSMatthias Springer int64_t offset; 13573c0333dSMatthias Springer if (failed(getStridesAndOffset(bufferType, strides, offset))) 13673c0333dSMatthias Springer return failure(); 13773c0333dSMatthias Springer AffineMap resultLayout = 13873c0333dSMatthias Springer makeStridedLinearLayoutMap({}, offset, op->getContext()); 13973c0333dSMatthias Springer resultType = 14073c0333dSMatthias Springer MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 14173c0333dSMatthias Springer bufferType.getMemorySpaceAsInt()); 14273c0333dSMatthias Springer } 14373c0333dSMatthias Springer 144e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 1458df54a6aSJacques Pienaar rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 146e6f69161SMatthias Springer return success(); 147e6f69161SMatthias Springer } 14851df6238SMatthias Springer 149d7a9bf91SMatthias Springer // If the dims are not collapsible (due to an incompatible source layout 150d7a9bf91SMatthias Springer // map), force an out-of-place bufferization, i.e., a buffer copy. This 151d7a9bf91SMatthias Springer // newly allocated buffer will have no layout map and thus be collapsible. 152a74e5a89SAdrian Kuegel bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 153d7a9bf91SMatthias Springer bufferType, collapseShapeOp.getReassociationIndices()); 154b3ebe3beSMatthias Springer if (!canBeCollapsed) { 155b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 156b55d55ecSMatthias Springer AnalysisState analysisState(options); 15745b995cdSMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 1588df54a6aSJacques Pienaar rewriter, op->getLoc(), collapseShapeOp.getSrc(), 15945b995cdSMatthias Springer analysisState.isTensorYielded(collapseShapeOp.getResult()), options); 16045b995cdSMatthias Springer if (failed(tensorAlloc)) 16145b995cdSMatthias Springer return failure(); 162b3ebe3beSMatthias Springer auto memrefType = 163b3ebe3beSMatthias Springer MemRefType::get(collapseShapeOp.getSrcType().getShape(), 164b3ebe3beSMatthias Springer collapseShapeOp.getSrcType().getElementType(), 165b3ebe3beSMatthias Springer AffineMap(), bufferType.getMemorySpaceAsInt()); 166b3ebe3beSMatthias Springer buffer = rewriter.create<bufferization::ToMemrefOp>( 16745b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc); 168b3ebe3beSMatthias Springer } 169d7a9bf91SMatthias Springer 17051df6238SMatthias Springer // Result type is inferred by the builder. 17151df6238SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 172b3ebe3beSMatthias Springer rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 17351df6238SMatthias Springer return success(); 17451df6238SMatthias Springer } 175e6f69161SMatthias Springer }; 176e6f69161SMatthias Springer 17749e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 17849e37000SMatthias Springer struct DimOpInterface 17949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 18049e37000SMatthias Springer tensor::DimOp> { 18149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1829597b16aSMatthias Springer const AnalysisState &state) const { 18349e37000SMatthias Springer return true; 18449e37000SMatthias Springer } 18549e37000SMatthias Springer 18649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1879597b16aSMatthias Springer const AnalysisState &state) const { 18849e37000SMatthias Springer return false; 18949e37000SMatthias Springer } 19049e37000SMatthias Springer 1919597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1929597b16aSMatthias Springer const AnalysisState &state) const { 193585a8a32SMatthias Springer return {}; 19449e37000SMatthias Springer } 19549e37000SMatthias Springer 19649e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 197b55d55ecSMatthias Springer const BufferizationOptions &options) const { 19849e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 1995d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 2005d50f51cSMatthias Springer if (failed(v)) 2015d50f51cSMatthias Springer return failure(); 2025d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 2035d50f51cSMatthias Springer dimOp.index()); 20449e37000SMatthias Springer return success(); 20549e37000SMatthias Springer } 20649e37000SMatthias Springer }; 20749e37000SMatthias Springer 208e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 209e6f69161SMatthias Springer struct ExpandShapeOpInterface 210e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 211e6f69161SMatthias Springer tensor::ExpandShapeOp> { 212e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2139597b16aSMatthias Springer const AnalysisState &state) const { 214e6f69161SMatthias Springer return false; 215e6f69161SMatthias Springer } 216e6f69161SMatthias Springer 217e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2189597b16aSMatthias Springer const AnalysisState &state) const { 219e6f69161SMatthias Springer return false; 220e6f69161SMatthias Springer } 221e6f69161SMatthias Springer 2229597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2239597b16aSMatthias Springer const AnalysisState &state) const { 224e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 225e6f69161SMatthias Springer return {op->getOpResult(0)}; 226e6f69161SMatthias Springer return {}; 227e6f69161SMatthias Springer } 228e6f69161SMatthias Springer 229e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2309597b16aSMatthias Springer const AnalysisState &state) const { 231e6f69161SMatthias Springer return BufferRelation::Equivalent; 232e6f69161SMatthias Springer } 233e6f69161SMatthias Springer 234e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 235b55d55ecSMatthias Springer const BufferizationOptions &options) const { 236e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 23751df6238SMatthias Springer auto tensorResultType = expandShapeOp.getResultType(); 2385d50f51cSMatthias Springer FailureOr<Value> buffer = 2395d50f51cSMatthias Springer getBuffer(rewriter, expandShapeOp.getSrc(), options); 2405d50f51cSMatthias Springer if (failed(buffer)) 2415d50f51cSMatthias Springer return failure(); 24251df6238SMatthias Springer 24351df6238SMatthias Springer // Memref result type is inferred by the builder based on reassociation 24451df6238SMatthias Springer // indices and result shape. 245e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 2465d50f51cSMatthias Springer rewriter, op, tensorResultType.getShape(), *buffer, 24751df6238SMatthias Springer expandShapeOp.getReassociationIndices()); 248e6f69161SMatthias Springer return success(); 249e6f69161SMatthias Springer } 250e6f69161SMatthias Springer }; 251e6f69161SMatthias Springer 25249e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 25349e37000SMatthias Springer struct ExtractSliceOpInterface 25449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 25549e37000SMatthias Springer tensor::ExtractSliceOp> { 25649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2579597b16aSMatthias Springer const AnalysisState &state) const { 25849e37000SMatthias Springer return false; 25949e37000SMatthias Springer } 26049e37000SMatthias Springer 26149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2629597b16aSMatthias Springer const AnalysisState &state) const { 26349e37000SMatthias Springer return false; 26449e37000SMatthias Springer } 26549e37000SMatthias Springer 2669597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2679597b16aSMatthias Springer const AnalysisState &state) const { 268585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*source*/) 269585a8a32SMatthias Springer return {op->getOpResult(0)}; 270585a8a32SMatthias Springer return {}; 27149e37000SMatthias Springer } 27249e37000SMatthias Springer 27349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2749597b16aSMatthias Springer const AnalysisState &state) const { 27549e37000SMatthias Springer return BufferRelation::None; 27649e37000SMatthias Springer } 27749e37000SMatthias Springer 27849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 279b55d55ecSMatthias Springer const BufferizationOptions &options) const { 28049e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 28149e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 282d7a9bf91SMatthias Springer 283d7a9bf91SMatthias Springer // Even if this op was decided to bufferize out-of-place, do not insert the 284d7a9bf91SMatthias Springer // buffer copy yet. This is done later in this function. 2855d50f51cSMatthias Springer FailureOr<Value> srcMemref = 2865d50f51cSMatthias Springer getBuffer(rewriter, extractSliceOp.getSource(), options); 2875d50f51cSMatthias Springer if (failed(srcMemref)) 2885d50f51cSMatthias Springer return failure(); 2895d50f51cSMatthias Springer auto srcMemrefType = srcMemref->getType().cast<MemRefType>(); 29049e37000SMatthias Springer auto dstTensorType = 2918df54a6aSJacques Pienaar extractSliceOp.getResult().getType().cast<RankedTensorType>(); 29249e37000SMatthias Springer 29349e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 29449e37000SMatthias Springer // rank-reducing case. 29549e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 29649e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 29749e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 29849e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 2995d50f51cSMatthias Springer *srcMemref, mixedOffsets, mixedSizes, mixedStrides, 30049e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 30149e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 30249e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 30349e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 30449e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 30549e37000SMatthias Springer }); 30649e37000SMatthias Springer // Bufferize to subview. 30749e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 30849e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 30949e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 31049e37000SMatthias Springer .cast<MemRefType>(); 31149e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 3125d50f51cSMatthias Springer loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, 31349e37000SMatthias Springer mixedStrides); 31449e37000SMatthias Springer 31549e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 31649e37000SMatthias Springer return success(); 31749e37000SMatthias Springer } 31849e37000SMatthias Springer }; 31949e37000SMatthias Springer 32049e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 32149e37000SMatthias Springer struct ExtractOpInterface 32249e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 32349e37000SMatthias Springer tensor::ExtractOp> { 32449e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 3259597b16aSMatthias Springer const AnalysisState &state) const { 32649e37000SMatthias Springer return true; 32749e37000SMatthias Springer } 32849e37000SMatthias Springer 32949e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3309597b16aSMatthias Springer const AnalysisState &state) const { 33149e37000SMatthias Springer return false; 33249e37000SMatthias Springer } 33349e37000SMatthias Springer 3349597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 3359597b16aSMatthias Springer const AnalysisState &state) const { 336585a8a32SMatthias Springer return {}; 33749e37000SMatthias Springer } 33849e37000SMatthias Springer 33949e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 340b55d55ecSMatthias Springer const BufferizationOptions &options) const { 34149e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 3425d50f51cSMatthias Springer FailureOr<Value> srcMemref = 3435d50f51cSMatthias Springer getBuffer(rewriter, extractOp.getTensor(), options); 3445d50f51cSMatthias Springer if (failed(srcMemref)) 3455d50f51cSMatthias Springer return failure(); 3465d50f51cSMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 3475d50f51cSMatthias Springer extractOp.indices()); 34849e37000SMatthias Springer return success(); 34949e37000SMatthias Springer } 35049e37000SMatthias Springer }; 35149e37000SMatthias Springer 352d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 353d581c94dSMatthias Springer // iterating over op.elements(). 354d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 355d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 356d581c94dSMatthias Springer ArrayRef<Value> constants, 357d581c94dSMatthias Springer OperandRange::iterator &elementIt, 358d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 359d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 360d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 361d581c94dSMatthias Springer indices.back() = constants[i]; 362d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 363d581c94dSMatthias Springer ++elementIt; 364d581c94dSMatthias Springer } 365d581c94dSMatthias Springer return; 366d581c94dSMatthias Springer } 367d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 368d581c94dSMatthias Springer indices[dim] = constants[i]; 369d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 370d581c94dSMatthias Springer indices); 371d581c94dSMatthias Springer } 372d581c94dSMatthias Springer } 373d581c94dSMatthias Springer 374d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 375d581c94dSMatthias Springer struct FromElementsOpInterface 376d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 377d581c94dSMatthias Springer tensor::FromElementsOp> { 378d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 379b55d55ecSMatthias Springer const BufferizationOptions &options) const { 380d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 381d581c94dSMatthias Springer 382c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op. 383c0b0b6a0SMatthias Springer if (options.defaultMemorySpace != static_cast<unsigned>(0)) 384c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet"); 385c0b0b6a0SMatthias Springer 386d581c94dSMatthias Springer // Allocate a buffer for the result. 387d581c94dSMatthias Springer Location loc = op->getLoc(); 388d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 389d581c94dSMatthias Springer auto shape = tensorType.getShape(); 390b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 391b55d55ecSMatthias Springer AnalysisState analysisState(options); 39245b995cdSMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 3938df54a6aSJacques Pienaar rewriter, loc, fromElementsOp.getResult(), 39445b995cdSMatthias Springer analysisState.isTensorYielded(fromElementsOp.getResult()), options, 395b3ebe3beSMatthias Springer /*copy=*/false); 39645b995cdSMatthias Springer if (failed(tensorAlloc)) 39745b995cdSMatthias Springer return failure(); 398b3ebe3beSMatthias Springer auto memrefType = 399b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 400b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>( 40145b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc); 402d581c94dSMatthias Springer 403d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 4048df54a6aSJacques Pienaar if (fromElementsOp.getElements().empty()) { 405d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 406d581c94dSMatthias Springer return success(); 407d581c94dSMatthias Springer } 408d581c94dSMatthias Springer 409d581c94dSMatthias Springer // Case: tensor<elem_type>. 410d581c94dSMatthias Springer if (shape.empty()) { 4118df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>( 4128df54a6aSJacques Pienaar loc, fromElementsOp.getElements().front(), buffer); 413d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 414d581c94dSMatthias Springer return success(); 415d581c94dSMatthias Springer } 416d581c94dSMatthias Springer 417d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 418d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end()); 419d581c94dSMatthias Springer SmallVector<Value, 2> constants; 420d581c94dSMatthias Springer constants.reserve(maxDim); 421d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 422d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 423d581c94dSMatthias Springer 424d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 4258df54a6aSJacques Pienaar auto elementIt = fromElementsOp.getElements().begin(); 426d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 427d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 428d581c94dSMatthias Springer indices); 429d581c94dSMatthias Springer 430d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 431d581c94dSMatthias Springer return success(); 432d581c94dSMatthias Springer } 433d581c94dSMatthias Springer }; 434d581c94dSMatthias Springer 43571bbb78bSMatthias Springer /// Bufferization of tensor.generate. 43671bbb78bSMatthias Springer struct GenerateOpInterface 43771bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 43871bbb78bSMatthias Springer tensor::GenerateOp> { 43971bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 440b55d55ecSMatthias Springer const BufferizationOptions &options) const { 44171bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 442c0b0b6a0SMatthias Springer 443c0b0b6a0SMatthias Springer // TODO: Implement memory space for this op. 444c0b0b6a0SMatthias Springer if (options.defaultMemorySpace != static_cast<unsigned>(0)) 445c0b0b6a0SMatthias Springer return op->emitError("memory space not implemented yet"); 446c0b0b6a0SMatthias Springer 447b3ebe3beSMatthias Springer auto tensorType = generateOp.getType().cast<RankedTensorType>(); 44871bbb78bSMatthias Springer // Allocate memory. 44971bbb78bSMatthias Springer Location loc = op->getLoc(); 450b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 451b55d55ecSMatthias Springer AnalysisState analysisState(options); 45245b995cdSMatthias Springer FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 4538df54a6aSJacques Pienaar rewriter, loc, generateOp.getResult(), 45445b995cdSMatthias Springer analysisState.isTensorYielded(generateOp.getResult()), options, 455b3ebe3beSMatthias Springer /*copy=*/false); 45645b995cdSMatthias Springer if (failed(tensorAlloc)) 45745b995cdSMatthias Springer return failure(); 458b3ebe3beSMatthias Springer auto memrefType = 459b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 460b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>( 46145b995cdSMatthias Springer op->getLoc(), memrefType, *tensorAlloc); 46271bbb78bSMatthias Springer 46371bbb78bSMatthias Springer // Collect loop bounds. 46471bbb78bSMatthias Springer int64_t rank = memrefType.getRank(); 46571bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 46671bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 46771bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero); 46871bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one); 46971bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds; 47071bbb78bSMatthias Springer int nextDynamicIndex = 0; 47171bbb78bSMatthias Springer for (int i = 0; i < rank; i++) { 4728df54a6aSJacques Pienaar Value upperBound = 4738df54a6aSJacques Pienaar memrefType.isDynamicDim(i) 4748df54a6aSJacques Pienaar ? generateOp.getDynamicExtents()[nextDynamicIndex++] 47571bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>( 47671bbb78bSMatthias Springer loc, memrefType.getDimSize(i)); 47771bbb78bSMatthias Springer upperBounds.push_back(upperBound); 47871bbb78bSMatthias Springer } 47971bbb78bSMatthias Springer 48071bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into 48171bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move" 48271bbb78bSMatthias Springer // this op's body into the scf.parallel's body. 48371bbb78bSMatthias Springer auto parallel = 48471bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 48571bbb78bSMatthias Springer Block *parallelBody = parallel.getBody(); 486eca86cb2SJacques Pienaar rewriter.mergeBlockBefore(&generateOp.getBody().front(), 48771bbb78bSMatthias Springer parallelBody->getTerminator(), 48871bbb78bSMatthias Springer parallelBody->getArguments()); 48971bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder 49071bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry 49171bbb78bSMatthias Springer // about creating that. 49271bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 49371bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield); 49471bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 495b3ebe3beSMatthias Springer elementYield, elementYield->getOperands()[0], buffer, 49671bbb78bSMatthias Springer parallelBody->getArguments()); 49771bbb78bSMatthias Springer 498b3ebe3beSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 49971bbb78bSMatthias Springer return success(); 50071bbb78bSMatthias Springer } 50171bbb78bSMatthias Springer }; 50271bbb78bSMatthias Springer 50349e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 50449e37000SMatthias Springer struct InsertOpInterface 50549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 50649e37000SMatthias Springer tensor::InsertOp> { 50749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 5089597b16aSMatthias Springer const AnalysisState &state) const { 50949e37000SMatthias Springer return true; 51049e37000SMatthias Springer } 51149e37000SMatthias Springer 51249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 5139597b16aSMatthias Springer const AnalysisState &state) const { 51449e37000SMatthias Springer return true; 51549e37000SMatthias Springer } 51649e37000SMatthias Springer 5179597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 5189597b16aSMatthias Springer const AnalysisState &state) const { 51949e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 52049e37000SMatthias Springer "expected dest OpOperand"); 521585a8a32SMatthias Springer return {op->getOpResult(0)}; 52249e37000SMatthias Springer } 52349e37000SMatthias Springer 52449e37000SMatthias Springer SmallVector<OpOperand *> 52549e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 5269597b16aSMatthias Springer const AnalysisState &state) const { 52749e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 52849e37000SMatthias Springer } 52949e37000SMatthias Springer 53049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 531b55d55ecSMatthias Springer const BufferizationOptions &options) const { 53249e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 5335d50f51cSMatthias Springer FailureOr<Value> destMemref = 5345d50f51cSMatthias Springer getBuffer(rewriter, insertOp.getDest(), options); 5355d50f51cSMatthias Springer if (failed(destMemref)) 5365d50f51cSMatthias Springer return failure(); 5378df54a6aSJacques Pienaar rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 5385d50f51cSMatthias Springer *destMemref, insertOp.getIndices()); 5395d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 54049e37000SMatthias Springer return success(); 54149e37000SMatthias Springer } 54249e37000SMatthias Springer 54349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5449597b16aSMatthias Springer const AnalysisState &state) const { 54549e37000SMatthias Springer return BufferRelation::Equivalent; 54649e37000SMatthias Springer } 54749e37000SMatthias Springer }; 54849e37000SMatthias Springer 54949e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 55049e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 55149e37000SMatthias Springer /// 55249e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 55349e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 55449e37000SMatthias Springer /// exposed as interfaces on the proper types. 5559597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state, 55649e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 55749e37000SMatthias Springer if (!st || !sti) 55849e37000SMatthias Springer return false; 55949e37000SMatthias Springer if (sti != sti && 5608df54a6aSJacques Pienaar !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 56149e37000SMatthias Springer return false; 56249e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 56349e37000SMatthias Springer return false; 56449e37000SMatthias Springer return true; 56549e37000SMatthias Springer } 56649e37000SMatthias Springer 56749e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 56849e37000SMatthias Springer /// the given InsertSliceOp. 5699597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 5709597b16aSMatthias Springer InsertSliceOp insertOp) { 57149e37000SMatthias Springer auto condition = [&](Value val) { 57249e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 57349e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 57449e37000SMatthias Springer return true; 57549e37000SMatthias Springer return false; 57649e37000SMatthias Springer }; 57749e37000SMatthias Springer 57849e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 57949e37000SMatthias Springer condition); 58049e37000SMatthias Springer } 58149e37000SMatthias Springer 58249e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 58349e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 58449e37000SMatthias Springer struct InsertSliceOpInterface 58549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 58649e37000SMatthias Springer tensor::InsertSliceOp> { 58749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 5889597b16aSMatthias Springer const AnalysisState &state) const { 58949e37000SMatthias Springer return true; 59049e37000SMatthias Springer } 59149e37000SMatthias Springer 59249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 5939597b16aSMatthias Springer const AnalysisState &state) const { 59449e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 59549e37000SMatthias Springer } 59649e37000SMatthias Springer 5979597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 5989597b16aSMatthias Springer const AnalysisState &state) const { 599585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(1) /*dest*/) 600585a8a32SMatthias Springer return {op->getResult(0)}; 601585a8a32SMatthias Springer return {}; 60249e37000SMatthias Springer } 60349e37000SMatthias Springer 60449e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 6059597b16aSMatthias Springer const AnalysisState &state) const { 60649e37000SMatthias Springer return BufferRelation::Equivalent; 60749e37000SMatthias Springer } 60849e37000SMatthias Springer 60949e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 61049e37000SMatthias Springer OpOperand *uConflictingWrite, 6119597b16aSMatthias Springer const AnalysisState &state) const { 61249e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 61349e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 61449e37000SMatthias Springer 61549e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 61649e37000SMatthias Springer // uRead is an InsertSliceOp... 61749e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 61849e37000SMatthias Springer // As an example, consider the following IR. 61949e37000SMatthias Springer // 62049e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 62149e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 62249e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 62349e37000SMatthias Springer // {inplace= [true] } 62449e37000SMatthias Springer 62549e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 62649e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 62749e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 62849e37000SMatthias Springer insertSliceOp)) 62949e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 63049e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 63149e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 63249e37000SMatthias Springer // being read by uRead, this is not a conflict. 63349e37000SMatthias Springer // 63449e37000SMatthias Springer // In the above example: 63549e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 63649e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 63749e37000SMatthias Springer // 63849e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 63949e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 64049e37000SMatthias Springer // exactly the one that is *not* read via %t. 64149e37000SMatthias Springer return true; 64249e37000SMatthias Springer 64349e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 64449e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 64549e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 64649e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 64749e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 64849e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 64949e37000SMatthias Springer // InsertSliceOp is writing. 65049e37000SMatthias Springer // 65149e37000SMatthias Springer // In the above example: 65249e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 65349e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 65449e37000SMatthias Springer return true; 65549e37000SMatthias Springer } 65649e37000SMatthias Springer 65749e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 65849e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 65949e37000SMatthias Springer // As an example, consider the following IR. 66049e37000SMatthias Springer // 66149e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 66249e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 66349e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 66449e37000SMatthias Springer // {inplace= [true] } 66549e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 66649e37000SMatthias Springer // 66749e37000SMatthias Springer // In the above example: 66849e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 66949e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 67049e37000SMatthias Springer // lastWrite = %1 67149e37000SMatthias Springer // 67249e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 67349e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 67449e37000SMatthias Springer // is no memory write here.) 67549e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 67649e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 6778df54a6aSJacques Pienaar insertSliceOp.getSource()) && 6788df54a6aSJacques Pienaar hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 67949e37000SMatthias Springer insertSliceOp)) 68049e37000SMatthias Springer return true; 68149e37000SMatthias Springer 68249e37000SMatthias Springer return false; 68349e37000SMatthias Springer } 68449e37000SMatthias Springer 68549e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 686b55d55ecSMatthias Springer const BufferizationOptions &options) const { 68749e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 68849e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 68949e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 69049e37000SMatthias Springer // catastrophically bad scheduling decision. 69149e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 69249e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 69349e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 6945d50f51cSMatthias Springer FailureOr<Value> dstMemref = 6955d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getDest(), options); 6965d50f51cSMatthias Springer if (failed(dstMemref)) 6975d50f51cSMatthias Springer return failure(); 69849e37000SMatthias Springer 69949e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 70049e37000SMatthias Springer // rank-reducing case. 70149e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 70249e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 70349e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 70449e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 7055d50f51cSMatthias Springer *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 70649e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 70749e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 70849e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 70949e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 71049e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 71149e37000SMatthias Springer }); 71249e37000SMatthias Springer // Take a subview of the dst. 7135d50f51cSMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 71449e37000SMatthias Springer auto subviewMemRefType = 71549e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 71649e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 71749e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 71849e37000SMatthias Springer .cast<MemRefType>(); 71949e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 7205d50f51cSMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 72149e37000SMatthias Springer mixedStrides); 72249e37000SMatthias Springer 72349e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 72449e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 7255d50f51cSMatthias Springer FailureOr<Value> srcMemref = 7265d50f51cSMatthias Springer getBuffer(rewriter, insertSliceOp.getSource(), options); 7275d50f51cSMatthias Springer if (failed(srcMemref)) 7285d50f51cSMatthias Springer return failure(); 7295d50f51cSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 73049e37000SMatthias Springer return failure(); 73149e37000SMatthias Springer 7325d50f51cSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 73349e37000SMatthias Springer return success(); 73449e37000SMatthias Springer } 73549e37000SMatthias Springer }; 73649e37000SMatthias Springer 737fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 738fc08d1c2SMatthias Springer struct RankOpInterface 739fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 740fc08d1c2SMatthias Springer tensor::RankOp> { 741fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 7429597b16aSMatthias Springer const AnalysisState &state) const { 743fc08d1c2SMatthias Springer return true; 744fc08d1c2SMatthias Springer } 745fc08d1c2SMatthias Springer 746fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7479597b16aSMatthias Springer const AnalysisState &state) const { 748fc08d1c2SMatthias Springer return false; 749fc08d1c2SMatthias Springer } 750fc08d1c2SMatthias Springer 7519597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 7529597b16aSMatthias Springer const AnalysisState &state) const { 753585a8a32SMatthias Springer return {}; 754fc08d1c2SMatthias Springer } 755fc08d1c2SMatthias Springer 756fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 757b55d55ecSMatthias Springer const BufferizationOptions &options) const { 758fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 7595d50f51cSMatthias Springer FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 7605d50f51cSMatthias Springer if (failed(v)) 7615d50f51cSMatthias Springer return failure(); 762fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 7635d50f51cSMatthias Springer *v); 764fc08d1c2SMatthias Springer return success(); 765fc08d1c2SMatthias Springer } 766fc08d1c2SMatthias Springer }; 767fc08d1c2SMatthias Springer 768e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape. 769e287d647SAshay Rane struct ReshapeOpInterface 770e287d647SAshay Rane : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 771e287d647SAshay Rane tensor::ReshapeOp> { 772e287d647SAshay Rane bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 773e287d647SAshay Rane const AnalysisState &state) const { 774e287d647SAshay Rane if (&opOperand == &op->getOpOperand(1) /* shape */) 775e287d647SAshay Rane return true; 776e287d647SAshay Rane return false; 777e287d647SAshay Rane } 778e287d647SAshay Rane 779e287d647SAshay Rane bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 780e287d647SAshay Rane const AnalysisState &state) const { 781e287d647SAshay Rane return false; 782e287d647SAshay Rane } 783e287d647SAshay Rane 784e287d647SAshay Rane SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 785e287d647SAshay Rane const AnalysisState &state) const { 786e287d647SAshay Rane return {op->getOpResult(0)}; 787e287d647SAshay Rane } 788e287d647SAshay Rane 789e287d647SAshay Rane BufferRelation bufferRelation(Operation *op, OpResult opResult, 790e287d647SAshay Rane const AnalysisState &state) const { 791e287d647SAshay Rane return BufferRelation::Equivalent; 792e287d647SAshay Rane } 793e287d647SAshay Rane 794e287d647SAshay Rane LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 795b55d55ecSMatthias Springer const BufferizationOptions &options) const { 796e287d647SAshay Rane auto reshapeOp = cast<tensor::ReshapeOp>(op); 7975d50f51cSMatthias Springer FailureOr<Value> srcBuffer = 7985d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getSource(), options); 7995d50f51cSMatthias Springer FailureOr<Value> shapeBuffer = 8005d50f51cSMatthias Springer getBuffer(rewriter, reshapeOp.getShape(), options); 8015d50f51cSMatthias Springer if (failed(srcBuffer) || failed(shapeBuffer)) 8025d50f51cSMatthias Springer return failure(); 803e287d647SAshay Rane auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 804c0b0b6a0SMatthias Springer auto resultMemRefType = getMemRefType( 805c0b0b6a0SMatthias Springer resultTensorType, options, /*layout=*/{}, 806c0b0b6a0SMatthias Springer srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt()); 807e287d647SAshay Rane replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 8085d50f51cSMatthias Springer rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 809e287d647SAshay Rane return success(); 810e287d647SAshay Rane } 811e287d647SAshay Rane }; 812e287d647SAshay Rane 813*7fbf55c9SNicolas Vasilache /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 814*7fbf55c9SNicolas Vasilache /// equivalent operand / result and same offset/sizes/strides specification). 815*7fbf55c9SNicolas Vasilache static bool areEquivalentExtractSliceOps(const AnalysisState &state, 816*7fbf55c9SNicolas Vasilache ExtractSliceOp st, 817*7fbf55c9SNicolas Vasilache ParallelInsertSliceOp sti) { 818*7fbf55c9SNicolas Vasilache if (!st || !sti) 819*7fbf55c9SNicolas Vasilache return false; 820*7fbf55c9SNicolas Vasilache if (st != sti && 821*7fbf55c9SNicolas Vasilache !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 822*7fbf55c9SNicolas Vasilache return false; 823*7fbf55c9SNicolas Vasilache if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 824*7fbf55c9SNicolas Vasilache return false; 825*7fbf55c9SNicolas Vasilache return true; 826*7fbf55c9SNicolas Vasilache } 827*7fbf55c9SNicolas Vasilache 828*7fbf55c9SNicolas Vasilache /// Return true if `value` is originating from an ExtractSliceOp that matches 829*7fbf55c9SNicolas Vasilache /// the given InsertSliceOp. 830*7fbf55c9SNicolas Vasilache static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 831*7fbf55c9SNicolas Vasilache ParallelInsertSliceOp insertOp) { 832*7fbf55c9SNicolas Vasilache auto condition = [&](Value val) { 833*7fbf55c9SNicolas Vasilache if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 834*7fbf55c9SNicolas Vasilache if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 835*7fbf55c9SNicolas Vasilache return true; 836*7fbf55c9SNicolas Vasilache return false; 837*7fbf55c9SNicolas Vasilache }; 838*7fbf55c9SNicolas Vasilache 839*7fbf55c9SNicolas Vasilache return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 840*7fbf55c9SNicolas Vasilache condition); 841*7fbf55c9SNicolas Vasilache } 842*7fbf55c9SNicolas Vasilache 843*7fbf55c9SNicolas Vasilache /// Analysis of ParallelInsertSliceOp. 844*7fbf55c9SNicolas Vasilache struct ParallelInsertSliceOpInterface 845*7fbf55c9SNicolas Vasilache : public BufferizableOpInterface::ExternalModel< 846*7fbf55c9SNicolas Vasilache ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 847*7fbf55c9SNicolas Vasilache SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 848*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 849*7fbf55c9SNicolas Vasilache if (&opOperand != &op->getOpOperand(1) /*dest*/) 850*7fbf55c9SNicolas Vasilache return {}; 851*7fbf55c9SNicolas Vasilache 852*7fbf55c9SNicolas Vasilache // ParallelInsertSliceOp itself has no results, query its tied op results. 853*7fbf55c9SNicolas Vasilache auto insertOp = cast<ParallelInsertSliceOp>(op); 854*7fbf55c9SNicolas Vasilache return {insertOp.getTiedOpResult()}; 855*7fbf55c9SNicolas Vasilache } 856*7fbf55c9SNicolas Vasilache 857*7fbf55c9SNicolas Vasilache bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 858*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 859*7fbf55c9SNicolas Vasilache return true; 860*7fbf55c9SNicolas Vasilache } 861*7fbf55c9SNicolas Vasilache 862*7fbf55c9SNicolas Vasilache bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 863*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 864*7fbf55c9SNicolas Vasilache return &opOperand == &op->getOpOperand(1) /*dest*/; 865*7fbf55c9SNicolas Vasilache } 866*7fbf55c9SNicolas Vasilache 867*7fbf55c9SNicolas Vasilache BufferRelation bufferRelation(Operation *op, OpResult opResult, 868*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 869*7fbf55c9SNicolas Vasilache return BufferRelation::Equivalent; 870*7fbf55c9SNicolas Vasilache } 871*7fbf55c9SNicolas Vasilache 872*7fbf55c9SNicolas Vasilache LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 873*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 874*7fbf55c9SNicolas Vasilache // This interface method is overridden because we want to set a custom 875*7fbf55c9SNicolas Vasilache // insertion point for tensor copies. They should be inserted right before 876*7fbf55c9SNicolas Vasilache // the ForeachThreadOp. E.g.: 877*7fbf55c9SNicolas Vasilache // 878*7fbf55c9SNicolas Vasilache // %r0, %r1 = foreach_thead ... { 879*7fbf55c9SNicolas Vasilache // ... 880*7fbf55c9SNicolas Vasilache // perform_concurrently { 881*7fbf55c9SNicolas Vasilache // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} 882*7fbf55c9SNicolas Vasilache // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} 883*7fbf55c9SNicolas Vasilache // } 884*7fbf55c9SNicolas Vasilache // } 885*7fbf55c9SNicolas Vasilache // 886*7fbf55c9SNicolas Vasilache // After TensorCopyInsertion: 887*7fbf55c9SNicolas Vasilache // 888*7fbf55c9SNicolas Vasilache // %copy = bufferization.alloc_tensor() copy(%d) 889*7fbf55c9SNicolas Vasilache // %r0, %r1 = foreach_thead ... { 890*7fbf55c9SNicolas Vasilache // ... 891*7fbf55c9SNicolas Vasilache // perform_concurrently { 892*7fbf55c9SNicolas Vasilache // parallel_insert_slice %a into %b ... 893*7fbf55c9SNicolas Vasilache // parallel_insert_slice %c into %copy ... 894*7fbf55c9SNicolas Vasilache // } 895*7fbf55c9SNicolas Vasilache // } 896*7fbf55c9SNicolas Vasilache 897*7fbf55c9SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 898*7fbf55c9SNicolas Vasilache auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 899*7fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent = 900*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getParallelCombiningParent(); 901*7fbf55c9SNicolas Vasilache Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 902*7fbf55c9SNicolas Vasilache 903*7fbf55c9SNicolas Vasilache // Nothing to do if the destination tensor is inplace. 904*7fbf55c9SNicolas Vasilache assert(state.isInPlace(op->getOpOperand(0) /*src*/) && 905*7fbf55c9SNicolas Vasilache "source is always in-place"); 906*7fbf55c9SNicolas Vasilache if (state.isInPlace(op->getOpOperand(1) /*dest*/)) 907*7fbf55c9SNicolas Vasilache return success(); 908*7fbf55c9SNicolas Vasilache 909*7fbf55c9SNicolas Vasilache // Find corresponding OpResult. 910*7fbf55c9SNicolas Vasilache OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); 911*7fbf55c9SNicolas Vasilache 912*7fbf55c9SNicolas Vasilache // Insert tensor allocation right before the ForeachThreadOp. 913*7fbf55c9SNicolas Vasilache rewriter.setInsertionPoint(parallelIteratingOp); 914*7fbf55c9SNicolas Vasilache bool isYielded = state.isTensorYielded(opResult); 915*7fbf55c9SNicolas Vasilache FailureOr<Value> alloc = allocateTensorForShapedValue( 916*7fbf55c9SNicolas Vasilache rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), 917*7fbf55c9SNicolas Vasilache /*escape=*/isYielded, state.getOptions()); 918*7fbf55c9SNicolas Vasilache if (failed(alloc)) 919*7fbf55c9SNicolas Vasilache return failure(); 920*7fbf55c9SNicolas Vasilache 921*7fbf55c9SNicolas Vasilache // Update destination operand. 922*7fbf55c9SNicolas Vasilache rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { 923*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getDestMutable().assign(*alloc); 924*7fbf55c9SNicolas Vasilache }); 925*7fbf55c9SNicolas Vasilache 926*7fbf55c9SNicolas Vasilache return success(); 927*7fbf55c9SNicolas Vasilache } 928*7fbf55c9SNicolas Vasilache 929*7fbf55c9SNicolas Vasilache LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 930*7fbf55c9SNicolas Vasilache const BufferizationOptions &options) const { 931*7fbf55c9SNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 932*7fbf55c9SNicolas Vasilache auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 933*7fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent = 934*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getParallelCombiningParent(); 935*7fbf55c9SNicolas Vasilache Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 936*7fbf55c9SNicolas Vasilache 937*7fbf55c9SNicolas Vasilache // Get destination buffer. 938*7fbf55c9SNicolas Vasilache FailureOr<Value> destBuffer = 939*7fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); 940*7fbf55c9SNicolas Vasilache if (failed(destBuffer)) 941*7fbf55c9SNicolas Vasilache return failure(); 942*7fbf55c9SNicolas Vasilache 943*7fbf55c9SNicolas Vasilache // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. 944*7fbf55c9SNicolas Vasilache rewriter.setInsertionPoint(parallelCombiningParent); 945*7fbf55c9SNicolas Vasilache FailureOr<Value> srcBuffer = 946*7fbf55c9SNicolas Vasilache getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); 947*7fbf55c9SNicolas Vasilache if (failed(srcBuffer)) 948*7fbf55c9SNicolas Vasilache return failure(); 949*7fbf55c9SNicolas Vasilache Value subview = rewriter.create<memref::SubViewOp>( 950*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getLoc(), *destBuffer, 951*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedOffsets(), 952*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedSizes(), 953*7fbf55c9SNicolas Vasilache parallelInsertSliceOp.getMixedStrides()); 954*7fbf55c9SNicolas Vasilache // This memcpy will fold away if everything bufferizes in-place. 955*7fbf55c9SNicolas Vasilache if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), 956*7fbf55c9SNicolas Vasilache *srcBuffer, subview))) 957*7fbf55c9SNicolas Vasilache return failure(); 958*7fbf55c9SNicolas Vasilache 959*7fbf55c9SNicolas Vasilache // Replace all uses of parallelIteratingOp (just the corresponding result). 960*7fbf55c9SNicolas Vasilache rewriter.setInsertionPointAfter(parallelIteratingOp); 961*7fbf55c9SNicolas Vasilache Value toTensorOp = 962*7fbf55c9SNicolas Vasilache rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer); 963*7fbf55c9SNicolas Vasilache // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. 964*7fbf55c9SNicolas Vasilache SmallVector<OpOperand *> resultUses = llvm::to_vector( 965*7fbf55c9SNicolas Vasilache llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), 966*7fbf55c9SNicolas Vasilache [](OpOperand &use) { return &use; })); 967*7fbf55c9SNicolas Vasilache for (OpOperand *use : resultUses) { 968*7fbf55c9SNicolas Vasilache rewriter.updateRootInPlace(use->getOwner(), 969*7fbf55c9SNicolas Vasilache [&]() { use->set(toTensorOp); }); 970*7fbf55c9SNicolas Vasilache } 971*7fbf55c9SNicolas Vasilache rewriter.eraseOp(op); 972*7fbf55c9SNicolas Vasilache return success(); 973*7fbf55c9SNicolas Vasilache } 974*7fbf55c9SNicolas Vasilache 975*7fbf55c9SNicolas Vasilache // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 976*7fbf55c9SNicolas Vasilache // the code. 977*7fbf55c9SNicolas Vasilache bool isNotConflicting(Operation *op, OpOperand *uRead, 978*7fbf55c9SNicolas Vasilache OpOperand *uConflictingWrite, 979*7fbf55c9SNicolas Vasilache const AnalysisState &state) const { 980*7fbf55c9SNicolas Vasilache Operation *readingOp = uRead->getOwner(); 981*7fbf55c9SNicolas Vasilache Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 982*7fbf55c9SNicolas Vasilache 983*7fbf55c9SNicolas Vasilache // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 984*7fbf55c9SNicolas Vasilache // uRead is an InsertSliceOp... 985*7fbf55c9SNicolas Vasilache if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 986*7fbf55c9SNicolas Vasilache // As an example, consider the following IR. 987*7fbf55c9SNicolas Vasilache // 988*7fbf55c9SNicolas Vasilache // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 989*7fbf55c9SNicolas Vasilache // %1 = linalg.fill %cst, %0 {inplace= [true] } 990*7fbf55c9SNicolas Vasilache // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 991*7fbf55c9SNicolas Vasilache // {inplace= [true] } 992*7fbf55c9SNicolas Vasilache 993*7fbf55c9SNicolas Vasilache // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 994*7fbf55c9SNicolas Vasilache if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 995*7fbf55c9SNicolas Vasilache hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 996*7fbf55c9SNicolas Vasilache insertSliceOp)) 997*7fbf55c9SNicolas Vasilache // Case 1: The main insight is that InsertSliceOp reads only part of 998*7fbf55c9SNicolas Vasilache // the destination tensor. The overwritten area is not read. If 999*7fbf55c9SNicolas Vasilache // uConflictingWrite writes into exactly the memory location that is 1000*7fbf55c9SNicolas Vasilache // being read by uRead, this is not a conflict. 1001*7fbf55c9SNicolas Vasilache // 1002*7fbf55c9SNicolas Vasilache // In the above example: 1003*7fbf55c9SNicolas Vasilache // uRead = OpOperand 1 (%t) of tensor.insert_slice 1004*7fbf55c9SNicolas Vasilache // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 1005*7fbf55c9SNicolas Vasilache // 1006*7fbf55c9SNicolas Vasilache // The read of %t does not conflict with the write of the FillOp 1007*7fbf55c9SNicolas Vasilache // (same aliases!) because the area that the FillOp operates on is 1008*7fbf55c9SNicolas Vasilache // exactly the one that is *not* read via %t. 1009*7fbf55c9SNicolas Vasilache return true; 1010*7fbf55c9SNicolas Vasilache 1011*7fbf55c9SNicolas Vasilache if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1012*7fbf55c9SNicolas Vasilache uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1013*7fbf55c9SNicolas Vasilache hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1014*7fbf55c9SNicolas Vasilache // Case 2: The read of the source tensor and the write to the dest 1015*7fbf55c9SNicolas Vasilache // tensor via an InsertSliceOp is not a conflict if the read is 1016*7fbf55c9SNicolas Vasilache // reading exactly that part of an equivalent tensor that the 1017*7fbf55c9SNicolas Vasilache // InsertSliceOp is writing. 1018*7fbf55c9SNicolas Vasilache // 1019*7fbf55c9SNicolas Vasilache // In the above example: 1020*7fbf55c9SNicolas Vasilache // uRead = OpOperand 0 (%1) of tensor.insert_slice 1021*7fbf55c9SNicolas Vasilache // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1022*7fbf55c9SNicolas Vasilache return true; 1023*7fbf55c9SNicolas Vasilache } 1024*7fbf55c9SNicolas Vasilache 1025*7fbf55c9SNicolas Vasilache // If uConflictingWrite is an InsertSliceOp... 1026*7fbf55c9SNicolas Vasilache if (auto insertSliceOp = 1027*7fbf55c9SNicolas Vasilache dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1028*7fbf55c9SNicolas Vasilache // As an example, consider the following IR. 1029*7fbf55c9SNicolas Vasilache // 1030*7fbf55c9SNicolas Vasilache // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1031*7fbf55c9SNicolas Vasilache // %1 = linalg.fill %cst, %0 {inplace= [true] } 1032*7fbf55c9SNicolas Vasilache // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1033*7fbf55c9SNicolas Vasilache // {inplace= [true] } 1034*7fbf55c9SNicolas Vasilache // %3 = vector.transfer_read %1, %cst 1035*7fbf55c9SNicolas Vasilache // 1036*7fbf55c9SNicolas Vasilache // In the above example: 1037*7fbf55c9SNicolas Vasilache // uRead = OpOperand 0 (%1) of vector.transfer_read 1038*7fbf55c9SNicolas Vasilache // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1039*7fbf55c9SNicolas Vasilache // lastWrite = %1 1040*7fbf55c9SNicolas Vasilache // 1041*7fbf55c9SNicolas Vasilache // This is not a conflict because the InsertSliceOp overwrites the 1042*7fbf55c9SNicolas Vasilache // memory segment of %1 with the exact same data. (Effectively, there 1043*7fbf55c9SNicolas Vasilache // is no memory write here.) 1044*7fbf55c9SNicolas Vasilache if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1045*7fbf55c9SNicolas Vasilache state.areEquivalentBufferizedValues(uRead->get(), 1046*7fbf55c9SNicolas Vasilache insertSliceOp.getSource()) && 1047*7fbf55c9SNicolas Vasilache hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1048*7fbf55c9SNicolas Vasilache insertSliceOp)) 1049*7fbf55c9SNicolas Vasilache return true; 1050*7fbf55c9SNicolas Vasilache 1051*7fbf55c9SNicolas Vasilache return false; 1052*7fbf55c9SNicolas Vasilache } 1053*7fbf55c9SNicolas Vasilache }; 1054*7fbf55c9SNicolas Vasilache 105549e37000SMatthias Springer } // namespace 105649e37000SMatthias Springer } // namespace tensor 105749e37000SMatthias Springer } // namespace mlir 105849e37000SMatthias Springer 105949e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 106049e37000SMatthias Springer DialectRegistry ®istry) { 106177eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 106277eee579SRiver Riddle CastOp::attachInterface<CastOpInterface>(*ctx); 106377eee579SRiver Riddle CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 106477eee579SRiver Riddle DimOp::attachInterface<DimOpInterface>(*ctx); 106577eee579SRiver Riddle ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 106677eee579SRiver Riddle ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 106777eee579SRiver Riddle ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 106877eee579SRiver Riddle FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 106977eee579SRiver Riddle GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 107077eee579SRiver Riddle InsertOp::attachInterface<InsertOpInterface>(*ctx); 107177eee579SRiver Riddle InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 1072*7fbf55c9SNicolas Vasilache ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1073*7fbf55c9SNicolas Vasilache *ctx); 107477eee579SRiver Riddle RankOp::attachInterface<RankOpInterface>(*ctx); 1075e287d647SAshay Rane ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 107677eee579SRiver Riddle }); 107749e37000SMatthias Springer } 1078