149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1149e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12b3ebe3beSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1349e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1471bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1549e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1649e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1749e37000SMatthias Springer #include "mlir/IR/Operation.h" 1849e37000SMatthias Springer 1949e37000SMatthias Springer using namespace mlir; 2049e37000SMatthias Springer using namespace mlir::bufferization; 2149e37000SMatthias Springer using namespace mlir::tensor; 2249e37000SMatthias Springer 2349e37000SMatthias Springer namespace mlir { 2449e37000SMatthias Springer namespace tensor { 2549e37000SMatthias Springer namespace { 2649e37000SMatthias Springer 2749e37000SMatthias Springer struct CastOpInterface 2849e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2949e37000SMatthias Springer tensor::CastOp> { 3049e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 319597b16aSMatthias Springer const AnalysisState &state) const { 3249e37000SMatthias Springer return false; 3349e37000SMatthias Springer } 3449e37000SMatthias Springer 3549e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 369597b16aSMatthias Springer const AnalysisState &state) const { 3749e37000SMatthias Springer return false; 3849e37000SMatthias Springer } 3949e37000SMatthias Springer 409597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 419597b16aSMatthias Springer const AnalysisState &state) const { 42585a8a32SMatthias Springer return {op->getResult(0)}; 4349e37000SMatthias Springer } 4449e37000SMatthias Springer 4549e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 469597b16aSMatthias Springer const AnalysisState &state) const { 4749e37000SMatthias Springer return BufferRelation::Equivalent; 4849e37000SMatthias Springer } 4949e37000SMatthias Springer 5049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51b55d55ecSMatthias Springer const BufferizationOptions &options) const { 5249e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5349e37000SMatthias Springer 5449e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 55b55d55ecSMatthias Springer Value resultBuffer = getBuffer(rewriter, castOp.source(), options); 56b3ebe3beSMatthias Springer auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>(); 5749e37000SMatthias Springer Attribute memorySpace = sourceMemRefType.getMemorySpace(); 5849e37000SMatthias Springer TensorType resultTensorType = 5949e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6049e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6149e37000SMatthias Springer 6249e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6349e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6449e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6549e37000SMatthias Springer 6649e37000SMatthias Springer // Compute the new memref type. 67b55d55ecSMatthias Springer Type resultMemRefType = 68b55d55ecSMatthias Springer getMemRefType(resultTensorType, options, layout, memorySpace); 6949e37000SMatthias Springer 7049e37000SMatthias Springer // Replace the op with a memref.cast. 71b3ebe3beSMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), 7249e37000SMatthias Springer resultMemRefType) && 7349e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 7449e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 75b3ebe3beSMatthias Springer resultBuffer); 7649e37000SMatthias Springer 7749e37000SMatthias Springer return success(); 7849e37000SMatthias Springer } 7949e37000SMatthias Springer }; 8049e37000SMatthias Springer 81e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 82e6f69161SMatthias Springer struct CollapseShapeOpInterface 83e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 84e6f69161SMatthias Springer tensor::CollapseShapeOp> { 85e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 869597b16aSMatthias Springer const AnalysisState &state) const { 87e6f69161SMatthias Springer return false; 88e6f69161SMatthias Springer } 89e6f69161SMatthias Springer 90e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 919597b16aSMatthias Springer const AnalysisState &state) const { 92e6f69161SMatthias Springer return false; 93e6f69161SMatthias Springer } 94e6f69161SMatthias Springer 959597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 969597b16aSMatthias Springer const AnalysisState &state) const { 97e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 98e6f69161SMatthias Springer return {op->getOpResult(0)}; 99e6f69161SMatthias Springer return {}; 100e6f69161SMatthias Springer } 101e6f69161SMatthias Springer 102e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1039597b16aSMatthias Springer const AnalysisState &state) const { 104e6f69161SMatthias Springer return BufferRelation::Equivalent; 105e6f69161SMatthias Springer } 106e6f69161SMatthias Springer 107e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 108b55d55ecSMatthias Springer const BufferizationOptions &options) const { 109e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 11051df6238SMatthias Springer RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 111b55d55ecSMatthias Springer Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options); 112b3ebe3beSMatthias Springer auto bufferType = buffer.getType().cast<MemRefType>(); 11351df6238SMatthias Springer 11451df6238SMatthias Springer if (tensorResultType.getRank() == 0) { 11551df6238SMatthias Springer // 0-d collapses must go through a different op builder. 11673c0333dSMatthias Springer MemRefType resultType; 11773c0333dSMatthias Springer 11873c0333dSMatthias Springer if (bufferType.getLayout().isIdentity()) { 11973c0333dSMatthias Springer // Standard layout: result type has no offset. 12051df6238SMatthias Springer MemRefLayoutAttrInterface layout; 12173c0333dSMatthias Springer resultType = MemRefType::get({}, tensorResultType.getElementType(), 12251df6238SMatthias Springer layout, bufferType.getMemorySpace()); 12373c0333dSMatthias Springer } else { 12473c0333dSMatthias Springer // Source memref has a layout map: result type has the same offset as 12573c0333dSMatthias Springer // the source type. 12673c0333dSMatthias Springer SmallVector<int64_t> strides; 12773c0333dSMatthias Springer int64_t offset; 12873c0333dSMatthias Springer if (failed(getStridesAndOffset(bufferType, strides, offset))) 12973c0333dSMatthias Springer return failure(); 13073c0333dSMatthias Springer AffineMap resultLayout = 13173c0333dSMatthias Springer makeStridedLinearLayoutMap({}, offset, op->getContext()); 13273c0333dSMatthias Springer resultType = 13373c0333dSMatthias Springer MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 13473c0333dSMatthias Springer bufferType.getMemorySpaceAsInt()); 13573c0333dSMatthias Springer } 13673c0333dSMatthias Springer 137e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 138b3ebe3beSMatthias Springer rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 139e6f69161SMatthias Springer return success(); 140e6f69161SMatthias Springer } 14151df6238SMatthias Springer 142d7a9bf91SMatthias Springer // If the dims are not collapsible (due to an incompatible source layout 143d7a9bf91SMatthias Springer // map), force an out-of-place bufferization, i.e., a buffer copy. This 144d7a9bf91SMatthias Springer // newly allocated buffer will have no layout map and thus be collapsible. 145a74e5a89SAdrian Kuegel bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 146d7a9bf91SMatthias Springer bufferType, collapseShapeOp.getReassociationIndices()); 147b3ebe3beSMatthias Springer if (!canBeCollapsed) { 148b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 149b55d55ecSMatthias Springer AnalysisState analysisState(options); 150b3ebe3beSMatthias Springer Value tensorAlloc = allocateTensorForShapedValue( 151b3ebe3beSMatthias Springer rewriter, op->getLoc(), collapseShapeOp.src(), 152b3ebe3beSMatthias Springer analysisState.isTensorYielded(collapseShapeOp.result())); 153b3ebe3beSMatthias Springer auto memrefType = 154b3ebe3beSMatthias Springer MemRefType::get(collapseShapeOp.getSrcType().getShape(), 155b3ebe3beSMatthias Springer collapseShapeOp.getSrcType().getElementType(), 156b3ebe3beSMatthias Springer AffineMap(), bufferType.getMemorySpaceAsInt()); 157b3ebe3beSMatthias Springer buffer = rewriter.create<bufferization::ToMemrefOp>( 158b3ebe3beSMatthias Springer op->getLoc(), memrefType, tensorAlloc); 159b3ebe3beSMatthias Springer } 160d7a9bf91SMatthias Springer 16151df6238SMatthias Springer // Result type is inferred by the builder. 16251df6238SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 163b3ebe3beSMatthias Springer rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 16451df6238SMatthias Springer return success(); 16551df6238SMatthias Springer } 166e6f69161SMatthias Springer }; 167e6f69161SMatthias Springer 16849e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 16949e37000SMatthias Springer struct DimOpInterface 17049e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 17149e37000SMatthias Springer tensor::DimOp> { 17249e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1739597b16aSMatthias Springer const AnalysisState &state) const { 17449e37000SMatthias Springer return true; 17549e37000SMatthias Springer } 17649e37000SMatthias Springer 17749e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1789597b16aSMatthias Springer const AnalysisState &state) const { 17949e37000SMatthias Springer return false; 18049e37000SMatthias Springer } 18149e37000SMatthias Springer 1829597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1839597b16aSMatthias Springer const AnalysisState &state) const { 184585a8a32SMatthias Springer return {}; 18549e37000SMatthias Springer } 18649e37000SMatthias Springer 18749e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 188b55d55ecSMatthias Springer const BufferizationOptions &options) const { 18949e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 190b55d55ecSMatthias Springer auto v = getBuffer(rewriter, dimOp.source(), options); 191b3ebe3beSMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 19249e37000SMatthias Springer return success(); 19349e37000SMatthias Springer } 19449e37000SMatthias Springer }; 19549e37000SMatthias Springer 196e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 197e6f69161SMatthias Springer struct ExpandShapeOpInterface 198e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 199e6f69161SMatthias Springer tensor::ExpandShapeOp> { 200e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2019597b16aSMatthias Springer const AnalysisState &state) const { 202e6f69161SMatthias Springer return false; 203e6f69161SMatthias Springer } 204e6f69161SMatthias Springer 205e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2069597b16aSMatthias Springer const AnalysisState &state) const { 207e6f69161SMatthias Springer return false; 208e6f69161SMatthias Springer } 209e6f69161SMatthias Springer 2109597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2119597b16aSMatthias Springer const AnalysisState &state) const { 212e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 213e6f69161SMatthias Springer return {op->getOpResult(0)}; 214e6f69161SMatthias Springer return {}; 215e6f69161SMatthias Springer } 216e6f69161SMatthias Springer 217e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2189597b16aSMatthias Springer const AnalysisState &state) const { 219e6f69161SMatthias Springer return BufferRelation::Equivalent; 220e6f69161SMatthias Springer } 221e6f69161SMatthias Springer 222e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 223b55d55ecSMatthias Springer const BufferizationOptions &options) const { 224e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 22551df6238SMatthias Springer auto tensorResultType = expandShapeOp.getResultType(); 226b55d55ecSMatthias Springer auto buffer = getBuffer(rewriter, expandShapeOp.src(), options); 22751df6238SMatthias Springer 22851df6238SMatthias Springer // Memref result type is inferred by the builder based on reassociation 22951df6238SMatthias Springer // indices and result shape. 230e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 231b3ebe3beSMatthias Springer rewriter, op, tensorResultType.getShape(), buffer, 23251df6238SMatthias Springer expandShapeOp.getReassociationIndices()); 233e6f69161SMatthias Springer return success(); 234e6f69161SMatthias Springer } 235e6f69161SMatthias Springer }; 236e6f69161SMatthias Springer 23749e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 23849e37000SMatthias Springer struct ExtractSliceOpInterface 23949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 24049e37000SMatthias Springer tensor::ExtractSliceOp> { 24149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2429597b16aSMatthias Springer const AnalysisState &state) const { 24349e37000SMatthias Springer return false; 24449e37000SMatthias Springer } 24549e37000SMatthias Springer 24649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2479597b16aSMatthias Springer const AnalysisState &state) const { 24849e37000SMatthias Springer return false; 24949e37000SMatthias Springer } 25049e37000SMatthias Springer 2519597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2529597b16aSMatthias Springer const AnalysisState &state) const { 253585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*source*/) 254585a8a32SMatthias Springer return {op->getOpResult(0)}; 255585a8a32SMatthias Springer return {}; 25649e37000SMatthias Springer } 25749e37000SMatthias Springer 25849e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2599597b16aSMatthias Springer const AnalysisState &state) const { 26049e37000SMatthias Springer return BufferRelation::None; 26149e37000SMatthias Springer } 26249e37000SMatthias Springer 26349e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 264b55d55ecSMatthias Springer const BufferizationOptions &options) const { 26549e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 26649e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 267d7a9bf91SMatthias Springer 268d7a9bf91SMatthias Springer // Even if this op was decided to bufferize out-of-place, do not insert the 269d7a9bf91SMatthias Springer // buffer copy yet. This is done later in this function. 270b55d55ecSMatthias Springer auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options); 271b3ebe3beSMatthias Springer auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 27249e37000SMatthias Springer auto dstTensorType = 27349e37000SMatthias Springer extractSliceOp.result().getType().cast<RankedTensorType>(); 27449e37000SMatthias Springer 27549e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 27649e37000SMatthias Springer // rank-reducing case. 27749e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 27849e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 27949e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 28049e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 281b3ebe3beSMatthias Springer srcMemref, mixedOffsets, mixedSizes, mixedStrides, 28249e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 28349e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 28449e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 28549e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 28649e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 28749e37000SMatthias Springer }); 28849e37000SMatthias Springer // Bufferize to subview. 28949e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 29049e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 29149e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 29249e37000SMatthias Springer .cast<MemRefType>(); 29349e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 294b3ebe3beSMatthias Springer loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 29549e37000SMatthias Springer mixedStrides); 29649e37000SMatthias Springer 29749e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 29849e37000SMatthias Springer return success(); 29949e37000SMatthias Springer } 30049e37000SMatthias Springer }; 30149e37000SMatthias Springer 30249e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 30349e37000SMatthias Springer struct ExtractOpInterface 30449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 30549e37000SMatthias Springer tensor::ExtractOp> { 30649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 3079597b16aSMatthias Springer const AnalysisState &state) const { 30849e37000SMatthias Springer return true; 30949e37000SMatthias Springer } 31049e37000SMatthias Springer 31149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3129597b16aSMatthias Springer const AnalysisState &state) const { 31349e37000SMatthias Springer return false; 31449e37000SMatthias Springer } 31549e37000SMatthias Springer 3169597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 3179597b16aSMatthias Springer const AnalysisState &state) const { 318585a8a32SMatthias Springer return {}; 31949e37000SMatthias Springer } 32049e37000SMatthias Springer 32149e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 322b55d55ecSMatthias Springer const BufferizationOptions &options) const { 32349e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 324b55d55ecSMatthias Springer Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options); 325b3ebe3beSMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 32649e37000SMatthias Springer extractOp.indices()); 32749e37000SMatthias Springer return success(); 32849e37000SMatthias Springer } 32949e37000SMatthias Springer }; 33049e37000SMatthias Springer 331d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 332d581c94dSMatthias Springer // iterating over op.elements(). 333d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 334d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 335d581c94dSMatthias Springer ArrayRef<Value> constants, 336d581c94dSMatthias Springer OperandRange::iterator &elementIt, 337d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 338d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 339d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 340d581c94dSMatthias Springer indices.back() = constants[i]; 341d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 342d581c94dSMatthias Springer ++elementIt; 343d581c94dSMatthias Springer } 344d581c94dSMatthias Springer return; 345d581c94dSMatthias Springer } 346d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 347d581c94dSMatthias Springer indices[dim] = constants[i]; 348d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 349d581c94dSMatthias Springer indices); 350d581c94dSMatthias Springer } 351d581c94dSMatthias Springer } 352d581c94dSMatthias Springer 353d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 354d581c94dSMatthias Springer struct FromElementsOpInterface 355d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 356d581c94dSMatthias Springer tensor::FromElementsOp> { 357d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 358b55d55ecSMatthias Springer const BufferizationOptions &options) const { 359d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 360d581c94dSMatthias Springer 361d581c94dSMatthias Springer // Allocate a buffer for the result. 362d581c94dSMatthias Springer Location loc = op->getLoc(); 363d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 364d581c94dSMatthias Springer auto shape = tensorType.getShape(); 365b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 366b55d55ecSMatthias Springer AnalysisState analysisState(options); 367b3ebe3beSMatthias Springer Value tensorAlloc = allocateTensorForShapedValue( 368b3ebe3beSMatthias Springer rewriter, loc, fromElementsOp.result(), 369b3ebe3beSMatthias Springer analysisState.isTensorYielded(fromElementsOp.result()), 370b3ebe3beSMatthias Springer /*copy=*/false); 371b3ebe3beSMatthias Springer auto memrefType = 372b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 373b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>( 374b3ebe3beSMatthias Springer op->getLoc(), memrefType, tensorAlloc); 375d581c94dSMatthias Springer 376d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 377d581c94dSMatthias Springer if (fromElementsOp.elements().empty()) { 378d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 379d581c94dSMatthias Springer return success(); 380d581c94dSMatthias Springer } 381d581c94dSMatthias Springer 382d581c94dSMatthias Springer // Case: tensor<elem_type>. 383d581c94dSMatthias Springer if (shape.empty()) { 384d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 385d581c94dSMatthias Springer buffer); 386d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 387d581c94dSMatthias Springer return success(); 388d581c94dSMatthias Springer } 389d581c94dSMatthias Springer 390d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 391d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end()); 392d581c94dSMatthias Springer SmallVector<Value, 2> constants; 393d581c94dSMatthias Springer constants.reserve(maxDim); 394d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 395d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 396d581c94dSMatthias Springer 397d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 398d581c94dSMatthias Springer auto elementIt = fromElementsOp.elements().begin(); 399d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 400d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 401d581c94dSMatthias Springer indices); 402d581c94dSMatthias Springer 403d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 404d581c94dSMatthias Springer return success(); 405d581c94dSMatthias Springer } 406d581c94dSMatthias Springer }; 407d581c94dSMatthias Springer 40871bbb78bSMatthias Springer /// Bufferization of tensor.generate. 40971bbb78bSMatthias Springer struct GenerateOpInterface 41071bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 41171bbb78bSMatthias Springer tensor::GenerateOp> { 41271bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 413b55d55ecSMatthias Springer const BufferizationOptions &options) const { 41471bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 415b3ebe3beSMatthias Springer auto tensorType = generateOp.getType().cast<RankedTensorType>(); 41671bbb78bSMatthias Springer // Allocate memory. 41771bbb78bSMatthias Springer Location loc = op->getLoc(); 418b3ebe3beSMatthias Springer // TODO: Create alloc_tensor ops during TensorCopyInsertion. 419b55d55ecSMatthias Springer AnalysisState analysisState(options); 420b3ebe3beSMatthias Springer Value tensorAlloc = allocateTensorForShapedValue( 421b3ebe3beSMatthias Springer rewriter, loc, generateOp.result(), 422b3ebe3beSMatthias Springer analysisState.isTensorYielded(generateOp.result()), 423b3ebe3beSMatthias Springer /*copy=*/false); 424b3ebe3beSMatthias Springer auto memrefType = 425b3ebe3beSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 426b3ebe3beSMatthias Springer Value buffer = rewriter.create<bufferization::ToMemrefOp>( 427b3ebe3beSMatthias Springer op->getLoc(), memrefType, tensorAlloc); 42871bbb78bSMatthias Springer 42971bbb78bSMatthias Springer // Collect loop bounds. 43071bbb78bSMatthias Springer int64_t rank = memrefType.getRank(); 43171bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 43271bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 43371bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero); 43471bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one); 43571bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds; 43671bbb78bSMatthias Springer int nextDynamicIndex = 0; 43771bbb78bSMatthias Springer for (int i = 0; i < rank; i++) { 43871bbb78bSMatthias Springer Value upperBound = memrefType.isDynamicDim(i) 43971bbb78bSMatthias Springer ? generateOp.dynamicExtents()[nextDynamicIndex++] 44071bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>( 44171bbb78bSMatthias Springer loc, memrefType.getDimSize(i)); 44271bbb78bSMatthias Springer upperBounds.push_back(upperBound); 44371bbb78bSMatthias Springer } 44471bbb78bSMatthias Springer 44571bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into 44671bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move" 44771bbb78bSMatthias Springer // this op's body into the scf.parallel's body. 44871bbb78bSMatthias Springer auto parallel = 44971bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 45071bbb78bSMatthias Springer Block *parallelBody = parallel.getBody(); 451*eca86cb2SJacques Pienaar rewriter.mergeBlockBefore(&generateOp.getBody().front(), 45271bbb78bSMatthias Springer parallelBody->getTerminator(), 45371bbb78bSMatthias Springer parallelBody->getArguments()); 45471bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder 45571bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry 45671bbb78bSMatthias Springer // about creating that. 45771bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 45871bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield); 45971bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 460b3ebe3beSMatthias Springer elementYield, elementYield->getOperands()[0], buffer, 46171bbb78bSMatthias Springer parallelBody->getArguments()); 46271bbb78bSMatthias Springer 463b3ebe3beSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 46471bbb78bSMatthias Springer return success(); 46571bbb78bSMatthias Springer } 46671bbb78bSMatthias Springer }; 46771bbb78bSMatthias Springer 46849e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 46949e37000SMatthias Springer struct InsertOpInterface 47049e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 47149e37000SMatthias Springer tensor::InsertOp> { 47249e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4739597b16aSMatthias Springer const AnalysisState &state) const { 47449e37000SMatthias Springer return true; 47549e37000SMatthias Springer } 47649e37000SMatthias Springer 47749e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4789597b16aSMatthias Springer const AnalysisState &state) const { 47949e37000SMatthias Springer return true; 48049e37000SMatthias Springer } 48149e37000SMatthias Springer 4829597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 4839597b16aSMatthias Springer const AnalysisState &state) const { 48449e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 48549e37000SMatthias Springer "expected dest OpOperand"); 486585a8a32SMatthias Springer return {op->getOpResult(0)}; 48749e37000SMatthias Springer } 48849e37000SMatthias Springer 48949e37000SMatthias Springer SmallVector<OpOperand *> 49049e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 4919597b16aSMatthias Springer const AnalysisState &state) const { 49249e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 49349e37000SMatthias Springer } 49449e37000SMatthias Springer 49549e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 496b55d55ecSMatthias Springer const BufferizationOptions &options) const { 49749e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 498b55d55ecSMatthias Springer Value destMemref = getBuffer(rewriter, insertOp.dest(), options); 49949e37000SMatthias Springer rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 500b3ebe3beSMatthias Springer destMemref, insertOp.indices()); 501b3ebe3beSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, destMemref); 50249e37000SMatthias Springer return success(); 50349e37000SMatthias Springer } 50449e37000SMatthias Springer 50549e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5069597b16aSMatthias Springer const AnalysisState &state) const { 50749e37000SMatthias Springer return BufferRelation::Equivalent; 50849e37000SMatthias Springer } 50949e37000SMatthias Springer }; 51049e37000SMatthias Springer 51149e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 51249e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 51349e37000SMatthias Springer /// 51449e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 51549e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 51649e37000SMatthias Springer /// exposed as interfaces on the proper types. 5179597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state, 51849e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 51949e37000SMatthias Springer if (!st || !sti) 52049e37000SMatthias Springer return false; 52149e37000SMatthias Springer if (sti != sti && 52249e37000SMatthias Springer !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 52349e37000SMatthias Springer return false; 52449e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 52549e37000SMatthias Springer return false; 52649e37000SMatthias Springer return true; 52749e37000SMatthias Springer } 52849e37000SMatthias Springer 52949e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 53049e37000SMatthias Springer /// the given InsertSliceOp. 5319597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 5329597b16aSMatthias Springer InsertSliceOp insertOp) { 53349e37000SMatthias Springer auto condition = [&](Value val) { 53449e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 53549e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 53649e37000SMatthias Springer return true; 53749e37000SMatthias Springer return false; 53849e37000SMatthias Springer }; 53949e37000SMatthias Springer 54049e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 54149e37000SMatthias Springer condition); 54249e37000SMatthias Springer } 54349e37000SMatthias Springer 54449e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 54549e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 54649e37000SMatthias Springer struct InsertSliceOpInterface 54749e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 54849e37000SMatthias Springer tensor::InsertSliceOp> { 54949e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 5509597b16aSMatthias Springer const AnalysisState &state) const { 55149e37000SMatthias Springer return true; 55249e37000SMatthias Springer } 55349e37000SMatthias Springer 55449e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 5559597b16aSMatthias Springer const AnalysisState &state) const { 55649e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 55749e37000SMatthias Springer } 55849e37000SMatthias Springer 5599597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 5609597b16aSMatthias Springer const AnalysisState &state) const { 561585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(1) /*dest*/) 562585a8a32SMatthias Springer return {op->getResult(0)}; 563585a8a32SMatthias Springer return {}; 56449e37000SMatthias Springer } 56549e37000SMatthias Springer 56649e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5679597b16aSMatthias Springer const AnalysisState &state) const { 56849e37000SMatthias Springer return BufferRelation::Equivalent; 56949e37000SMatthias Springer } 57049e37000SMatthias Springer 57149e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 57249e37000SMatthias Springer OpOperand *uConflictingWrite, 5739597b16aSMatthias Springer const AnalysisState &state) const { 57449e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 57549e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 57649e37000SMatthias Springer 57749e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 57849e37000SMatthias Springer // uRead is an InsertSliceOp... 57949e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 58049e37000SMatthias Springer // As an example, consider the following IR. 58149e37000SMatthias Springer // 58249e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 58349e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 58449e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 58549e37000SMatthias Springer // {inplace= [true] } 58649e37000SMatthias Springer 58749e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 58849e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 58949e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 59049e37000SMatthias Springer insertSliceOp)) 59149e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 59249e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 59349e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 59449e37000SMatthias Springer // being read by uRead, this is not a conflict. 59549e37000SMatthias Springer // 59649e37000SMatthias Springer // In the above example: 59749e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 59849e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 59949e37000SMatthias Springer // 60049e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 60149e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 60249e37000SMatthias Springer // exactly the one that is *not* read via %t. 60349e37000SMatthias Springer return true; 60449e37000SMatthias Springer 60549e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 60649e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 60749e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 60849e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 60949e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 61049e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 61149e37000SMatthias Springer // InsertSliceOp is writing. 61249e37000SMatthias Springer // 61349e37000SMatthias Springer // In the above example: 61449e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 61549e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 61649e37000SMatthias Springer return true; 61749e37000SMatthias Springer } 61849e37000SMatthias Springer 61949e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 62049e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 62149e37000SMatthias Springer // As an example, consider the following IR. 62249e37000SMatthias Springer // 62349e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 62449e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 62549e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 62649e37000SMatthias Springer // {inplace= [true] } 62749e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 62849e37000SMatthias Springer // 62949e37000SMatthias Springer // In the above example: 63049e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 63149e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 63249e37000SMatthias Springer // lastWrite = %1 63349e37000SMatthias Springer // 63449e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 63549e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 63649e37000SMatthias Springer // is no memory write here.) 63749e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 63849e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 63949e37000SMatthias Springer insertSliceOp.source()) && 64049e37000SMatthias Springer hasMatchingExtractSliceOp(state, insertSliceOp.source(), 64149e37000SMatthias Springer insertSliceOp)) 64249e37000SMatthias Springer return true; 64349e37000SMatthias Springer 64449e37000SMatthias Springer return false; 64549e37000SMatthias Springer } 64649e37000SMatthias Springer 64749e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 648b55d55ecSMatthias Springer const BufferizationOptions &options) const { 64949e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 65049e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 65149e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 65249e37000SMatthias Springer // catastrophically bad scheduling decision. 65349e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 65449e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 65549e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 656b55d55ecSMatthias Springer Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options); 65749e37000SMatthias Springer 65849e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 65949e37000SMatthias Springer // rank-reducing case. 66049e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 66149e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 66249e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 66349e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 664b3ebe3beSMatthias Springer dstMemref, mixedOffsets, mixedSizes, mixedStrides, 66549e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 66649e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 66749e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 66849e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 66949e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 67049e37000SMatthias Springer }); 67149e37000SMatthias Springer // Take a subview of the dst. 672b3ebe3beSMatthias Springer auto dstMemrefType = dstMemref.getType().cast<MemRefType>(); 67349e37000SMatthias Springer auto subviewMemRefType = 67449e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 67549e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 67649e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 67749e37000SMatthias Springer .cast<MemRefType>(); 67849e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 679b3ebe3beSMatthias Springer loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes, 68049e37000SMatthias Springer mixedStrides); 68149e37000SMatthias Springer 68249e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 68349e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 684b55d55ecSMatthias Springer auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options); 685b55d55ecSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView))) 68649e37000SMatthias Springer return failure(); 68749e37000SMatthias Springer 688b3ebe3beSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, dstMemref); 68949e37000SMatthias Springer return success(); 69049e37000SMatthias Springer } 69149e37000SMatthias Springer }; 69249e37000SMatthias Springer 693fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 694fc08d1c2SMatthias Springer struct RankOpInterface 695fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 696fc08d1c2SMatthias Springer tensor::RankOp> { 697fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 6989597b16aSMatthias Springer const AnalysisState &state) const { 699fc08d1c2SMatthias Springer return true; 700fc08d1c2SMatthias Springer } 701fc08d1c2SMatthias Springer 702fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 7039597b16aSMatthias Springer const AnalysisState &state) const { 704fc08d1c2SMatthias Springer return false; 705fc08d1c2SMatthias Springer } 706fc08d1c2SMatthias Springer 7079597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 7089597b16aSMatthias Springer const AnalysisState &state) const { 709585a8a32SMatthias Springer return {}; 710fc08d1c2SMatthias Springer } 711fc08d1c2SMatthias Springer 712fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 713b55d55ecSMatthias Springer const BufferizationOptions &options) const { 714fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 715b55d55ecSMatthias Springer auto v = getBuffer(rewriter, rankOp.tensor(), options); 716fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 717b3ebe3beSMatthias Springer v); 718fc08d1c2SMatthias Springer return success(); 719fc08d1c2SMatthias Springer } 720fc08d1c2SMatthias Springer }; 721fc08d1c2SMatthias Springer 722e287d647SAshay Rane /// Bufferization of tensor.reshape. Replace with memref.reshape. 723e287d647SAshay Rane struct ReshapeOpInterface 724e287d647SAshay Rane : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 725e287d647SAshay Rane tensor::ReshapeOp> { 726e287d647SAshay Rane bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 727e287d647SAshay Rane const AnalysisState &state) const { 728e287d647SAshay Rane if (&opOperand == &op->getOpOperand(1) /* shape */) 729e287d647SAshay Rane return true; 730e287d647SAshay Rane return false; 731e287d647SAshay Rane } 732e287d647SAshay Rane 733e287d647SAshay Rane bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 734e287d647SAshay Rane const AnalysisState &state) const { 735e287d647SAshay Rane return false; 736e287d647SAshay Rane } 737e287d647SAshay Rane 738e287d647SAshay Rane SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 739e287d647SAshay Rane const AnalysisState &state) const { 740e287d647SAshay Rane return {op->getOpResult(0)}; 741e287d647SAshay Rane } 742e287d647SAshay Rane 743e287d647SAshay Rane BufferRelation bufferRelation(Operation *op, OpResult opResult, 744e287d647SAshay Rane const AnalysisState &state) const { 745e287d647SAshay Rane return BufferRelation::Equivalent; 746e287d647SAshay Rane } 747e287d647SAshay Rane 748e287d647SAshay Rane LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 749b55d55ecSMatthias Springer const BufferizationOptions &options) const { 750e287d647SAshay Rane auto reshapeOp = cast<tensor::ReshapeOp>(op); 751b55d55ecSMatthias Springer Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options); 752b55d55ecSMatthias Springer Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options); 753e287d647SAshay Rane auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 754b55d55ecSMatthias Springer auto resultMemRefType = getMemRefType(resultTensorType, options); 755e287d647SAshay Rane replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 756b3ebe3beSMatthias Springer rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); 757e287d647SAshay Rane return success(); 758e287d647SAshay Rane } 759e287d647SAshay Rane }; 760e287d647SAshay Rane 76149e37000SMatthias Springer } // namespace 76249e37000SMatthias Springer } // namespace tensor 76349e37000SMatthias Springer } // namespace mlir 76449e37000SMatthias Springer 76549e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 76649e37000SMatthias Springer DialectRegistry ®istry) { 76777eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 76877eee579SRiver Riddle CastOp::attachInterface<CastOpInterface>(*ctx); 76977eee579SRiver Riddle CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 77077eee579SRiver Riddle DimOp::attachInterface<DimOpInterface>(*ctx); 77177eee579SRiver Riddle ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 77277eee579SRiver Riddle ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 77377eee579SRiver Riddle ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 77477eee579SRiver Riddle FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 77577eee579SRiver Riddle GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 77677eee579SRiver Riddle InsertOp::attachInterface<InsertOpInterface>(*ctx); 77777eee579SRiver Riddle InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 77877eee579SRiver Riddle RankOp::attachInterface<RankOpInterface>(*ctx); 779e287d647SAshay Rane ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 78077eee579SRiver Riddle }); 78149e37000SMatthias Springer } 782