149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1249e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1349e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1449e37000SMatthias Springer #include "mlir/IR/Operation.h" 1549e37000SMatthias Springer 1649e37000SMatthias Springer using namespace mlir; 1749e37000SMatthias Springer using namespace mlir::bufferization; 1849e37000SMatthias Springer using namespace mlir::tensor; 1949e37000SMatthias Springer 2049e37000SMatthias Springer namespace mlir { 2149e37000SMatthias Springer namespace tensor { 2249e37000SMatthias Springer namespace { 2349e37000SMatthias Springer 2449e37000SMatthias Springer struct CastOpInterface 2549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2649e37000SMatthias Springer tensor::CastOp> { 2749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2849e37000SMatthias Springer const BufferizationState &state) const { 2949e37000SMatthias Springer return false; 3049e37000SMatthias Springer } 3149e37000SMatthias Springer 3249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3349e37000SMatthias Springer const BufferizationState &state) const { 3449e37000SMatthias Springer return false; 3549e37000SMatthias Springer } 3649e37000SMatthias Springer 3749e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 3849e37000SMatthias Springer const BufferizationState &state) const { 3949e37000SMatthias Springer return op->getResult(0); 4049e37000SMatthias Springer } 4149e37000SMatthias Springer 4249e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 4349e37000SMatthias Springer const BufferizationState &state) const { 4449e37000SMatthias Springer return BufferRelation::Equivalent; 4549e37000SMatthias Springer } 4649e37000SMatthias Springer 4749e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4849e37000SMatthias Springer const BufferizationState &state) const { 4949e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5049e37000SMatthias Springer 5149e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 5249e37000SMatthias Springer FailureOr<Value> resultBuffer = 5349e37000SMatthias Springer state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 5449e37000SMatthias Springer if (failed(resultBuffer)) 5549e37000SMatthias Springer return failure(); 5649e37000SMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 5749e37000SMatthias Springer Attribute memorySpace = sourceMemRefType.getMemorySpace(); 5849e37000SMatthias Springer TensorType resultTensorType = 5949e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6049e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6149e37000SMatthias Springer 6249e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6349e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6449e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6549e37000SMatthias Springer 6649e37000SMatthias Springer // Compute the new memref type. 6749e37000SMatthias Springer Type resultMemRefType; 6849e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) { 6949e37000SMatthias Springer resultMemRefType = 7049e37000SMatthias Springer getContiguousMemRefType(resultTensorType, layout, memorySpace); 7149e37000SMatthias Springer } else { 7249e37000SMatthias Springer resultMemRefType = 7349e37000SMatthias Springer getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); 7449e37000SMatthias Springer } 7549e37000SMatthias Springer 7649e37000SMatthias Springer // Replace the op with a memref.cast. 7749e37000SMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 7849e37000SMatthias Springer resultMemRefType) && 7949e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 8049e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 8149e37000SMatthias Springer *resultBuffer); 8249e37000SMatthias Springer 8349e37000SMatthias Springer return success(); 8449e37000SMatthias Springer } 8549e37000SMatthias Springer }; 8649e37000SMatthias Springer 8749e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 8849e37000SMatthias Springer struct DimOpInterface 8949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 9049e37000SMatthias Springer tensor::DimOp> { 9149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 9249e37000SMatthias Springer const BufferizationState &state) const { 9349e37000SMatthias Springer return true; 9449e37000SMatthias Springer } 9549e37000SMatthias Springer 9649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 9749e37000SMatthias Springer const BufferizationState &state) const { 9849e37000SMatthias Springer return false; 9949e37000SMatthias Springer } 10049e37000SMatthias Springer 10149e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 10249e37000SMatthias Springer const BufferizationState &state) const { 10349e37000SMatthias Springer return OpResult(); 10449e37000SMatthias Springer } 10549e37000SMatthias Springer 10649e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 10749e37000SMatthias Springer const BufferizationState &state) const { 10849e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 10949e37000SMatthias Springer Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 11049e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 11149e37000SMatthias Springer return success(); 11249e37000SMatthias Springer } 11349e37000SMatthias Springer }; 11449e37000SMatthias Springer 11549e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 11649e37000SMatthias Springer struct ExtractSliceOpInterface 11749e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 11849e37000SMatthias Springer tensor::ExtractSliceOp> { 11949e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 12049e37000SMatthias Springer const BufferizationState &state) const { 12149e37000SMatthias Springer return false; 12249e37000SMatthias Springer } 12349e37000SMatthias Springer 12449e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 12549e37000SMatthias Springer const BufferizationState &state) const { 12649e37000SMatthias Springer return false; 12749e37000SMatthias Springer } 12849e37000SMatthias Springer 12949e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 13049e37000SMatthias Springer const BufferizationState &state) const { 13149e37000SMatthias Springer return &opOperand == &op->getOpOperand(0) /*source*/ 13249e37000SMatthias Springer ? op->getResult(0) 13349e37000SMatthias Springer : OpResult(); 13449e37000SMatthias Springer } 13549e37000SMatthias Springer 13649e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 13749e37000SMatthias Springer const BufferizationState &state) const { 13849e37000SMatthias Springer return BufferRelation::None; 13949e37000SMatthias Springer } 14049e37000SMatthias Springer 14149e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 14249e37000SMatthias Springer const BufferizationState &state) const { 14349e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 14449e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 14549e37000SMatthias Springer Value srcMemref = 14649e37000SMatthias Springer *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 14749e37000SMatthias Springer /*forceInPlace=*/true); 14849e37000SMatthias Springer auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 14949e37000SMatthias Springer auto dstTensorType = 15049e37000SMatthias Springer extractSliceOp.result().getType().cast<RankedTensorType>(); 15149e37000SMatthias Springer 15249e37000SMatthias Springer // If not inplaceable, alloc. 15349e37000SMatthias Springer bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 15449e37000SMatthias Springer Value alloc; 15549e37000SMatthias Springer if (!inplace) { 15649e37000SMatthias Springer FailureOr<Value> allocOrFailure = 15749e37000SMatthias Springer createAlloc(rewriter, loc, extractSliceOp.result(), 15849e37000SMatthias Springer state.getOptions().createDeallocs, state.getOptions()); 15949e37000SMatthias Springer if (failed(allocOrFailure)) 16049e37000SMatthias Springer return failure(); 16149e37000SMatthias Springer alloc = *allocOrFailure; 16249e37000SMatthias Springer } 16349e37000SMatthias Springer 16449e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 16549e37000SMatthias Springer // rank-reducing case. 16649e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 16749e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 16849e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 16949e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 17049e37000SMatthias Springer srcMemref, mixedOffsets, mixedSizes, mixedStrides, 17149e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 17249e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 17349e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 17449e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 17549e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 17649e37000SMatthias Springer }); 17749e37000SMatthias Springer // Bufferize to subview. 17849e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 17949e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 18049e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 18149e37000SMatthias Springer .cast<MemRefType>(); 18249e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 18349e37000SMatthias Springer loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 18449e37000SMatthias Springer mixedStrides); 18549e37000SMatthias Springer 18649e37000SMatthias Springer // If not inplaceable, copy. 18749e37000SMatthias Springer if (!inplace) { 18849e37000SMatthias Springer // Do not copy if the copied data is never read. 18949e37000SMatthias Springer if (state.isValueRead(extractSliceOp.result())) 19049e37000SMatthias Springer if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 19149e37000SMatthias Springer alloc, state.getOptions()))) 19249e37000SMatthias Springer return failure(); 19349e37000SMatthias Springer subView = alloc; 19449e37000SMatthias Springer } 19549e37000SMatthias Springer 19649e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 19749e37000SMatthias Springer return success(); 19849e37000SMatthias Springer } 19949e37000SMatthias Springer }; 20049e37000SMatthias Springer 20149e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 20249e37000SMatthias Springer struct ExtractOpInterface 20349e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 20449e37000SMatthias Springer tensor::ExtractOp> { 20549e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 20649e37000SMatthias Springer const BufferizationState &state) const { 20749e37000SMatthias Springer return true; 20849e37000SMatthias Springer } 20949e37000SMatthias Springer 21049e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 21149e37000SMatthias Springer const BufferizationState &state) const { 21249e37000SMatthias Springer return false; 21349e37000SMatthias Springer } 21449e37000SMatthias Springer 21549e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 21649e37000SMatthias Springer const BufferizationState &state) const { 21749e37000SMatthias Springer return OpResult(); 21849e37000SMatthias Springer } 21949e37000SMatthias Springer 22049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 22149e37000SMatthias Springer const BufferizationState &state) const { 22249e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 22349e37000SMatthias Springer Value srcMemref = 22449e37000SMatthias Springer *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 22549e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 22649e37000SMatthias Springer extractOp.indices()); 22749e37000SMatthias Springer return success(); 22849e37000SMatthias Springer } 22949e37000SMatthias Springer }; 23049e37000SMatthias Springer 23149e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 23249e37000SMatthias Springer struct InsertOpInterface 23349e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 23449e37000SMatthias Springer tensor::InsertOp> { 23549e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 23649e37000SMatthias Springer const BufferizationState &state) const { 23749e37000SMatthias Springer return true; 23849e37000SMatthias Springer } 23949e37000SMatthias Springer 24049e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 24149e37000SMatthias Springer const BufferizationState &state) const { 24249e37000SMatthias Springer return true; 24349e37000SMatthias Springer } 24449e37000SMatthias Springer 24549e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 24649e37000SMatthias Springer const BufferizationState &state) const { 24749e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 24849e37000SMatthias Springer "expected dest OpOperand"); 24949e37000SMatthias Springer return op->getOpResult(0); 25049e37000SMatthias Springer } 25149e37000SMatthias Springer 25249e37000SMatthias Springer SmallVector<OpOperand *> 25349e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 25449e37000SMatthias Springer const BufferizationState &state) const { 25549e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 25649e37000SMatthias Springer } 25749e37000SMatthias Springer 25849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 25949e37000SMatthias Springer const BufferizationState &state) const { 26049e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 26149e37000SMatthias Springer FailureOr<Value> destMemref = 26249e37000SMatthias Springer state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 26349e37000SMatthias Springer if (failed(destMemref)) 26449e37000SMatthias Springer return failure(); 26549e37000SMatthias Springer rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 26649e37000SMatthias Springer *destMemref, insertOp.indices()); 26749e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 26849e37000SMatthias Springer return success(); 26949e37000SMatthias Springer } 27049e37000SMatthias Springer 27149e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 27249e37000SMatthias Springer const BufferizationState &state) const { 27349e37000SMatthias Springer return BufferRelation::Equivalent; 27449e37000SMatthias Springer } 27549e37000SMatthias Springer }; 27649e37000SMatthias Springer 27749e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 27849e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 27949e37000SMatthias Springer /// 28049e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 28149e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 28249e37000SMatthias Springer /// exposed as interfaces on the proper types. 28349e37000SMatthias Springer static bool areEquivalentExtractSliceOps(const BufferizationState &state, 28449e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 28549e37000SMatthias Springer if (!st || !sti) 28649e37000SMatthias Springer return false; 28749e37000SMatthias Springer if (sti != sti && 28849e37000SMatthias Springer !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 28949e37000SMatthias Springer return false; 29049e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 29149e37000SMatthias Springer return false; 29249e37000SMatthias Springer return true; 29349e37000SMatthias Springer } 29449e37000SMatthias Springer 29549e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 29649e37000SMatthias Springer /// the given InsertSliceOp. 29749e37000SMatthias Springer static bool hasMatchingExtractSliceOp(const BufferizationState &state, 29849e37000SMatthias Springer Value value, InsertSliceOp insertOp) { 29949e37000SMatthias Springer auto condition = [&](Value val) { 30049e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 30149e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 30249e37000SMatthias Springer return true; 30349e37000SMatthias Springer return false; 30449e37000SMatthias Springer }; 30549e37000SMatthias Springer 30649e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 30749e37000SMatthias Springer condition); 30849e37000SMatthias Springer } 30949e37000SMatthias Springer 31049e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 31149e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 31249e37000SMatthias Springer struct InsertSliceOpInterface 31349e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 31449e37000SMatthias Springer tensor::InsertSliceOp> { 31549e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31649e37000SMatthias Springer const BufferizationState &state) const { 31749e37000SMatthias Springer return true; 31849e37000SMatthias Springer } 31949e37000SMatthias Springer 32049e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 32149e37000SMatthias Springer const BufferizationState &state) const { 32249e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 32349e37000SMatthias Springer } 32449e37000SMatthias Springer 32549e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 32649e37000SMatthias Springer const BufferizationState &state) const { 32749e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/ 32849e37000SMatthias Springer ? op->getResult(0) 32949e37000SMatthias Springer : OpResult(); 33049e37000SMatthias Springer } 33149e37000SMatthias Springer 33249e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 33349e37000SMatthias Springer const BufferizationState &state) const { 33449e37000SMatthias Springer return BufferRelation::Equivalent; 33549e37000SMatthias Springer } 33649e37000SMatthias Springer 33749e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 33849e37000SMatthias Springer OpOperand *uConflictingWrite, 33949e37000SMatthias Springer const BufferizationState &state) const { 34049e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 34149e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 34249e37000SMatthias Springer 34349e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 34449e37000SMatthias Springer // uRead is an InsertSliceOp... 34549e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 34649e37000SMatthias Springer // As an example, consider the following IR. 34749e37000SMatthias Springer // 34849e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 34949e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 35049e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 35149e37000SMatthias Springer // {inplace= [true] } 35249e37000SMatthias Springer 35349e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 35449e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 35549e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 35649e37000SMatthias Springer insertSliceOp)) 35749e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 35849e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 35949e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 36049e37000SMatthias Springer // being read by uRead, this is not a conflict. 36149e37000SMatthias Springer // 36249e37000SMatthias Springer // In the above example: 36349e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 36449e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 36549e37000SMatthias Springer // 36649e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 36749e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 36849e37000SMatthias Springer // exactly the one that is *not* read via %t. 36949e37000SMatthias Springer return true; 37049e37000SMatthias Springer 37149e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 37249e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 37349e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 37449e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 37549e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 37649e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 37749e37000SMatthias Springer // InsertSliceOp is writing. 37849e37000SMatthias Springer // 37949e37000SMatthias Springer // In the above example: 38049e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 38149e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 38249e37000SMatthias Springer return true; 38349e37000SMatthias Springer } 38449e37000SMatthias Springer 38549e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 38649e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 38749e37000SMatthias Springer // As an example, consider the following IR. 38849e37000SMatthias Springer // 38949e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 39049e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 39149e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 39249e37000SMatthias Springer // {inplace= [true] } 39349e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 39449e37000SMatthias Springer // 39549e37000SMatthias Springer // In the above example: 39649e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 39749e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 39849e37000SMatthias Springer // lastWrite = %1 39949e37000SMatthias Springer // 40049e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 40149e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 40249e37000SMatthias Springer // is no memory write here.) 40349e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 40449e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 40549e37000SMatthias Springer insertSliceOp.source()) && 40649e37000SMatthias Springer hasMatchingExtractSliceOp(state, insertSliceOp.source(), 40749e37000SMatthias Springer insertSliceOp)) 40849e37000SMatthias Springer return true; 40949e37000SMatthias Springer 41049e37000SMatthias Springer return false; 41149e37000SMatthias Springer } 41249e37000SMatthias Springer 41349e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 41449e37000SMatthias Springer const BufferizationState &state) const { 41549e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 41649e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 41749e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 41849e37000SMatthias Springer // catastrophically bad scheduling decision. 41949e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 42049e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 42149e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 42249e37000SMatthias Springer 42349e37000SMatthias Springer // When bufferizing out-of-place, `getResultBuffer` allocates. 42449e37000SMatthias Springer FailureOr<Value> dstMemref = 42549e37000SMatthias Springer state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 42649e37000SMatthias Springer if (failed(dstMemref)) 42749e37000SMatthias Springer return failure(); 42849e37000SMatthias Springer 42949e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 43049e37000SMatthias Springer // rank-reducing case. 43149e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 43249e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 43349e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 43449e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 43549e37000SMatthias Springer *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 43649e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 43749e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 43849e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 43949e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 44049e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 44149e37000SMatthias Springer }); 44249e37000SMatthias Springer // Take a subview of the dst. 44349e37000SMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 44449e37000SMatthias Springer auto subviewMemRefType = 44549e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 44649e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 44749e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 44849e37000SMatthias Springer .cast<MemRefType>(); 44949e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 45049e37000SMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 45149e37000SMatthias Springer mixedStrides); 45249e37000SMatthias Springer 45349e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 45449e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 45549e37000SMatthias Springer Value srcMemref = 45649e37000SMatthias Springer *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 45749e37000SMatthias Springer if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 45849e37000SMatthias Springer state.getOptions()))) 45949e37000SMatthias Springer return failure(); 46049e37000SMatthias Springer 46149e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 46249e37000SMatthias Springer return success(); 46349e37000SMatthias Springer } 46449e37000SMatthias Springer }; 46549e37000SMatthias Springer 466*fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 467*fc08d1c2SMatthias Springer struct RankOpInterface 468*fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 469*fc08d1c2SMatthias Springer tensor::RankOp> { 470*fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 471*fc08d1c2SMatthias Springer const BufferizationState &state) const { 472*fc08d1c2SMatthias Springer return true; 473*fc08d1c2SMatthias Springer } 474*fc08d1c2SMatthias Springer 475*fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 476*fc08d1c2SMatthias Springer const BufferizationState &state) const { 477*fc08d1c2SMatthias Springer return false; 478*fc08d1c2SMatthias Springer } 479*fc08d1c2SMatthias Springer 480*fc08d1c2SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 481*fc08d1c2SMatthias Springer const BufferizationState &state) const { 482*fc08d1c2SMatthias Springer return OpResult(); 483*fc08d1c2SMatthias Springer } 484*fc08d1c2SMatthias Springer 485*fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 486*fc08d1c2SMatthias Springer const BufferizationState &state) const { 487*fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 488*fc08d1c2SMatthias Springer Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 489*fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 490*fc08d1c2SMatthias Springer v); 491*fc08d1c2SMatthias Springer return success(); 492*fc08d1c2SMatthias Springer } 493*fc08d1c2SMatthias Springer }; 494*fc08d1c2SMatthias Springer 49549e37000SMatthias Springer } // namespace 49649e37000SMatthias Springer } // namespace tensor 49749e37000SMatthias Springer } // namespace mlir 49849e37000SMatthias Springer 49949e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 50049e37000SMatthias Springer DialectRegistry ®istry) { 50149e37000SMatthias Springer registry.addOpInterface<CastOp, CastOpInterface>(); 50249e37000SMatthias Springer registry.addOpInterface<DimOp, DimOpInterface>(); 50349e37000SMatthias Springer registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 50449e37000SMatthias Springer registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 50549e37000SMatthias Springer registry.addOpInterface<InsertOp, InsertOpInterface>(); 50649e37000SMatthias Springer registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 507*fc08d1c2SMatthias Springer registry.addOpInterface<RankOp, RankOpInterface>(); 50849e37000SMatthias Springer } 509