149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1271bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1349e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1449e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1549e37000SMatthias Springer #include "mlir/IR/Operation.h" 1649e37000SMatthias Springer 1749e37000SMatthias Springer using namespace mlir; 1849e37000SMatthias Springer using namespace mlir::bufferization; 1949e37000SMatthias Springer using namespace mlir::tensor; 2049e37000SMatthias Springer 2149e37000SMatthias Springer namespace mlir { 2249e37000SMatthias Springer namespace tensor { 2349e37000SMatthias Springer namespace { 2449e37000SMatthias Springer 2549e37000SMatthias Springer struct CastOpInterface 2649e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2749e37000SMatthias Springer tensor::CastOp> { 2849e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2949e37000SMatthias Springer const BufferizationState &state) const { 3049e37000SMatthias Springer return false; 3149e37000SMatthias Springer } 3249e37000SMatthias Springer 3349e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 3449e37000SMatthias Springer const BufferizationState &state) const { 3549e37000SMatthias Springer return false; 3649e37000SMatthias Springer } 3749e37000SMatthias Springer 3849e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 3949e37000SMatthias Springer const BufferizationState &state) const { 4049e37000SMatthias Springer return op->getResult(0); 4149e37000SMatthias Springer } 4249e37000SMatthias Springer 4349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 4449e37000SMatthias Springer const BufferizationState &state) const { 4549e37000SMatthias Springer return BufferRelation::Equivalent; 4649e37000SMatthias Springer } 4749e37000SMatthias Springer 4849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4949e37000SMatthias Springer const BufferizationState &state) const { 5049e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5149e37000SMatthias Springer 5249e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 5349e37000SMatthias Springer FailureOr<Value> resultBuffer = 5449e37000SMatthias Springer state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 5549e37000SMatthias Springer if (failed(resultBuffer)) 5649e37000SMatthias Springer return failure(); 5749e37000SMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 5849e37000SMatthias Springer Attribute memorySpace = sourceMemRefType.getMemorySpace(); 5949e37000SMatthias Springer TensorType resultTensorType = 6049e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6149e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6249e37000SMatthias Springer 6349e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6449e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6549e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6649e37000SMatthias Springer 6749e37000SMatthias Springer // Compute the new memref type. 6849e37000SMatthias Springer Type resultMemRefType; 6949e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) { 7049e37000SMatthias Springer resultMemRefType = 7149e37000SMatthias Springer getContiguousMemRefType(resultTensorType, layout, memorySpace); 7249e37000SMatthias Springer } else { 7349e37000SMatthias Springer resultMemRefType = 7449e37000SMatthias Springer getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); 7549e37000SMatthias Springer } 7649e37000SMatthias Springer 7749e37000SMatthias Springer // Replace the op with a memref.cast. 7849e37000SMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 7949e37000SMatthias Springer resultMemRefType) && 8049e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 8149e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 8249e37000SMatthias Springer *resultBuffer); 8349e37000SMatthias Springer 8449e37000SMatthias Springer return success(); 8549e37000SMatthias Springer } 8649e37000SMatthias Springer }; 8749e37000SMatthias Springer 8849e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 8949e37000SMatthias Springer struct DimOpInterface 9049e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 9149e37000SMatthias Springer tensor::DimOp> { 9249e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 9349e37000SMatthias Springer const BufferizationState &state) const { 9449e37000SMatthias Springer return true; 9549e37000SMatthias Springer } 9649e37000SMatthias Springer 9749e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 9849e37000SMatthias Springer const BufferizationState &state) const { 9949e37000SMatthias Springer return false; 10049e37000SMatthias Springer } 10149e37000SMatthias Springer 10249e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 10349e37000SMatthias Springer const BufferizationState &state) const { 10449e37000SMatthias Springer return OpResult(); 10549e37000SMatthias Springer } 10649e37000SMatthias Springer 10749e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 10849e37000SMatthias Springer const BufferizationState &state) const { 10949e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 11049e37000SMatthias Springer Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 11149e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 11249e37000SMatthias Springer return success(); 11349e37000SMatthias Springer } 11449e37000SMatthias Springer }; 11549e37000SMatthias Springer 11649e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 11749e37000SMatthias Springer struct ExtractSliceOpInterface 11849e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 11949e37000SMatthias Springer tensor::ExtractSliceOp> { 12049e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 12149e37000SMatthias Springer const BufferizationState &state) const { 12249e37000SMatthias Springer return false; 12349e37000SMatthias Springer } 12449e37000SMatthias Springer 12549e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 12649e37000SMatthias Springer const BufferizationState &state) const { 12749e37000SMatthias Springer return false; 12849e37000SMatthias Springer } 12949e37000SMatthias Springer 13049e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 13149e37000SMatthias Springer const BufferizationState &state) const { 13249e37000SMatthias Springer return &opOperand == &op->getOpOperand(0) /*source*/ 13349e37000SMatthias Springer ? op->getResult(0) 13449e37000SMatthias Springer : OpResult(); 13549e37000SMatthias Springer } 13649e37000SMatthias Springer 13749e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 13849e37000SMatthias Springer const BufferizationState &state) const { 13949e37000SMatthias Springer return BufferRelation::None; 14049e37000SMatthias Springer } 14149e37000SMatthias Springer 14249e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 14349e37000SMatthias Springer const BufferizationState &state) const { 14449e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 14549e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 14649e37000SMatthias Springer Value srcMemref = 14749e37000SMatthias Springer *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 14849e37000SMatthias Springer /*forceInPlace=*/true); 14949e37000SMatthias Springer auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 15049e37000SMatthias Springer auto dstTensorType = 15149e37000SMatthias Springer extractSliceOp.result().getType().cast<RankedTensorType>(); 15249e37000SMatthias Springer 15349e37000SMatthias Springer // If not inplaceable, alloc. 15449e37000SMatthias Springer bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 15549e37000SMatthias Springer Value alloc; 15649e37000SMatthias Springer if (!inplace) { 15749e37000SMatthias Springer FailureOr<Value> allocOrFailure = 15849e37000SMatthias Springer createAlloc(rewriter, loc, extractSliceOp.result(), 15949e37000SMatthias Springer state.getOptions().createDeallocs, state.getOptions()); 16049e37000SMatthias Springer if (failed(allocOrFailure)) 16149e37000SMatthias Springer return failure(); 16249e37000SMatthias Springer alloc = *allocOrFailure; 16349e37000SMatthias Springer } 16449e37000SMatthias Springer 16549e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 16649e37000SMatthias Springer // rank-reducing case. 16749e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 16849e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 16949e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 17049e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 17149e37000SMatthias Springer srcMemref, mixedOffsets, mixedSizes, mixedStrides, 17249e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 17349e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 17449e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 17549e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 17649e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 17749e37000SMatthias Springer }); 17849e37000SMatthias Springer // Bufferize to subview. 17949e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 18049e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 18149e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 18249e37000SMatthias Springer .cast<MemRefType>(); 18349e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 18449e37000SMatthias Springer loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 18549e37000SMatthias Springer mixedStrides); 18649e37000SMatthias Springer 18749e37000SMatthias Springer // If not inplaceable, copy. 18849e37000SMatthias Springer if (!inplace) { 18949e37000SMatthias Springer // Do not copy if the copied data is never read. 19049e37000SMatthias Springer if (state.isValueRead(extractSliceOp.result())) 19149e37000SMatthias Springer if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 19249e37000SMatthias Springer alloc, state.getOptions()))) 19349e37000SMatthias Springer return failure(); 19449e37000SMatthias Springer subView = alloc; 19549e37000SMatthias Springer } 19649e37000SMatthias Springer 19749e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 19849e37000SMatthias Springer return success(); 19949e37000SMatthias Springer } 20049e37000SMatthias Springer }; 20149e37000SMatthias Springer 20249e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 20349e37000SMatthias Springer struct ExtractOpInterface 20449e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 20549e37000SMatthias Springer tensor::ExtractOp> { 20649e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 20749e37000SMatthias Springer const BufferizationState &state) const { 20849e37000SMatthias Springer return true; 20949e37000SMatthias Springer } 21049e37000SMatthias Springer 21149e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 21249e37000SMatthias Springer const BufferizationState &state) const { 21349e37000SMatthias Springer return false; 21449e37000SMatthias Springer } 21549e37000SMatthias Springer 21649e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 21749e37000SMatthias Springer const BufferizationState &state) const { 21849e37000SMatthias Springer return OpResult(); 21949e37000SMatthias Springer } 22049e37000SMatthias Springer 22149e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 22249e37000SMatthias Springer const BufferizationState &state) const { 22349e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 22449e37000SMatthias Springer Value srcMemref = 22549e37000SMatthias Springer *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 22649e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 22749e37000SMatthias Springer extractOp.indices()); 22849e37000SMatthias Springer return success(); 22949e37000SMatthias Springer } 23049e37000SMatthias Springer }; 23149e37000SMatthias Springer 232*d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 233*d581c94dSMatthias Springer // iterating over op.elements(). 234*d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 235*d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 236*d581c94dSMatthias Springer ArrayRef<Value> constants, 237*d581c94dSMatthias Springer OperandRange::iterator &elementIt, 238*d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 239*d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 240*d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 241*d581c94dSMatthias Springer indices.back() = constants[i]; 242*d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 243*d581c94dSMatthias Springer ++elementIt; 244*d581c94dSMatthias Springer } 245*d581c94dSMatthias Springer return; 246*d581c94dSMatthias Springer } 247*d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 248*d581c94dSMatthias Springer indices[dim] = constants[i]; 249*d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 250*d581c94dSMatthias Springer indices); 251*d581c94dSMatthias Springer } 252*d581c94dSMatthias Springer } 253*d581c94dSMatthias Springer 254*d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 255*d581c94dSMatthias Springer struct FromElementsOpInterface 256*d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 257*d581c94dSMatthias Springer tensor::FromElementsOp> { 258*d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 259*d581c94dSMatthias Springer const BufferizationState &state) const { 260*d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 261*d581c94dSMatthias Springer 262*d581c94dSMatthias Springer // Allocate a buffer for the result. 263*d581c94dSMatthias Springer Location loc = op->getLoc(); 264*d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 265*d581c94dSMatthias Springer auto shape = tensorType.getShape(); 266*d581c94dSMatthias Springer MemRefType resultType = 267*d581c94dSMatthias Springer MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 268*d581c94dSMatthias Springer FailureOr<Value> maybeBuffer = 269*d581c94dSMatthias Springer createAlloc(rewriter, loc, resultType, {}, 270*d581c94dSMatthias Springer /*deallocMemref=*/state.getOptions().createDeallocs, 271*d581c94dSMatthias Springer state.getOptions()); 272*d581c94dSMatthias Springer if (failed(maybeBuffer)) 273*d581c94dSMatthias Springer return failure(); 274*d581c94dSMatthias Springer Value buffer = *maybeBuffer; 275*d581c94dSMatthias Springer 276*d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 277*d581c94dSMatthias Springer if (fromElementsOp.elements().empty()) { 278*d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 279*d581c94dSMatthias Springer return success(); 280*d581c94dSMatthias Springer } 281*d581c94dSMatthias Springer 282*d581c94dSMatthias Springer // Case: tensor<elem_type>. 283*d581c94dSMatthias Springer if (shape.empty()) { 284*d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 285*d581c94dSMatthias Springer buffer); 286*d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 287*d581c94dSMatthias Springer return success(); 288*d581c94dSMatthias Springer } 289*d581c94dSMatthias Springer 290*d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 291*d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end()); 292*d581c94dSMatthias Springer SmallVector<Value, 2> constants; 293*d581c94dSMatthias Springer constants.reserve(maxDim); 294*d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 295*d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 296*d581c94dSMatthias Springer 297*d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 298*d581c94dSMatthias Springer auto elementIt = fromElementsOp.elements().begin(); 299*d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 300*d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 301*d581c94dSMatthias Springer indices); 302*d581c94dSMatthias Springer 303*d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 304*d581c94dSMatthias Springer return success(); 305*d581c94dSMatthias Springer } 306*d581c94dSMatthias Springer }; 307*d581c94dSMatthias Springer 30871bbb78bSMatthias Springer /// Bufferization of tensor.generate. 30971bbb78bSMatthias Springer struct GenerateOpInterface 31071bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 31171bbb78bSMatthias Springer tensor::GenerateOp> { 31271bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 31371bbb78bSMatthias Springer const BufferizationState &state) const { 31471bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 31571bbb78bSMatthias Springer 31671bbb78bSMatthias Springer // Allocate memory. 31771bbb78bSMatthias Springer Location loc = op->getLoc(); 31871bbb78bSMatthias Springer MemRefType memrefType = 31971bbb78bSMatthias Springer getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 32071bbb78bSMatthias Springer FailureOr<Value> maybeResult = 32171bbb78bSMatthias Springer createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), 32271bbb78bSMatthias Springer /*deallocMemref=*/state.getOptions().createDeallocs, 32371bbb78bSMatthias Springer state.getOptions()); 32471bbb78bSMatthias Springer if (failed(maybeResult)) 32571bbb78bSMatthias Springer return failure(); 32671bbb78bSMatthias Springer Value result = *maybeResult; 32771bbb78bSMatthias Springer 32871bbb78bSMatthias Springer // Collect loop bounds. 32971bbb78bSMatthias Springer int64_t rank = memrefType.getRank(); 33071bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 33171bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 33271bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero); 33371bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one); 33471bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds; 33571bbb78bSMatthias Springer int nextDynamicIndex = 0; 33671bbb78bSMatthias Springer for (int i = 0; i < rank; i++) { 33771bbb78bSMatthias Springer Value upperBound = memrefType.isDynamicDim(i) 33871bbb78bSMatthias Springer ? generateOp.dynamicExtents()[nextDynamicIndex++] 33971bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>( 34071bbb78bSMatthias Springer loc, memrefType.getDimSize(i)); 34171bbb78bSMatthias Springer upperBounds.push_back(upperBound); 34271bbb78bSMatthias Springer } 34371bbb78bSMatthias Springer 34471bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into 34571bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move" 34671bbb78bSMatthias Springer // this op's body into the scf.parallel's body. 34771bbb78bSMatthias Springer auto parallel = 34871bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 34971bbb78bSMatthias Springer Block *parallelBody = parallel.getBody(); 35071bbb78bSMatthias Springer rewriter.mergeBlockBefore(generateOp.getBody(), 35171bbb78bSMatthias Springer parallelBody->getTerminator(), 35271bbb78bSMatthias Springer parallelBody->getArguments()); 35371bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder 35471bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry 35571bbb78bSMatthias Springer // about creating that. 35671bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 35771bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield); 35871bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 35971bbb78bSMatthias Springer elementYield, elementYield->getOperands()[0], result, 36071bbb78bSMatthias Springer parallelBody->getArguments()); 36171bbb78bSMatthias Springer 36271bbb78bSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, result); 36371bbb78bSMatthias Springer return success(); 36471bbb78bSMatthias Springer } 36571bbb78bSMatthias Springer }; 36671bbb78bSMatthias Springer 36749e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 36849e37000SMatthias Springer struct InsertOpInterface 36949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 37049e37000SMatthias Springer tensor::InsertOp> { 37149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 37249e37000SMatthias Springer const BufferizationState &state) const { 37349e37000SMatthias Springer return true; 37449e37000SMatthias Springer } 37549e37000SMatthias Springer 37649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 37749e37000SMatthias Springer const BufferizationState &state) const { 37849e37000SMatthias Springer return true; 37949e37000SMatthias Springer } 38049e37000SMatthias Springer 38149e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 38249e37000SMatthias Springer const BufferizationState &state) const { 38349e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 38449e37000SMatthias Springer "expected dest OpOperand"); 38549e37000SMatthias Springer return op->getOpResult(0); 38649e37000SMatthias Springer } 38749e37000SMatthias Springer 38849e37000SMatthias Springer SmallVector<OpOperand *> 38949e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 39049e37000SMatthias Springer const BufferizationState &state) const { 39149e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 39249e37000SMatthias Springer } 39349e37000SMatthias Springer 39449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 39549e37000SMatthias Springer const BufferizationState &state) const { 39649e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 39749e37000SMatthias Springer FailureOr<Value> destMemref = 39849e37000SMatthias Springer state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 39949e37000SMatthias Springer if (failed(destMemref)) 40049e37000SMatthias Springer return failure(); 40149e37000SMatthias Springer rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 40249e37000SMatthias Springer *destMemref, insertOp.indices()); 40349e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 40449e37000SMatthias Springer return success(); 40549e37000SMatthias Springer } 40649e37000SMatthias Springer 40749e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 40849e37000SMatthias Springer const BufferizationState &state) const { 40949e37000SMatthias Springer return BufferRelation::Equivalent; 41049e37000SMatthias Springer } 41149e37000SMatthias Springer }; 41249e37000SMatthias Springer 41349e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 41449e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 41549e37000SMatthias Springer /// 41649e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 41749e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 41849e37000SMatthias Springer /// exposed as interfaces on the proper types. 41949e37000SMatthias Springer static bool areEquivalentExtractSliceOps(const BufferizationState &state, 42049e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 42149e37000SMatthias Springer if (!st || !sti) 42249e37000SMatthias Springer return false; 42349e37000SMatthias Springer if (sti != sti && 42449e37000SMatthias Springer !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 42549e37000SMatthias Springer return false; 42649e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 42749e37000SMatthias Springer return false; 42849e37000SMatthias Springer return true; 42949e37000SMatthias Springer } 43049e37000SMatthias Springer 43149e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 43249e37000SMatthias Springer /// the given InsertSliceOp. 43349e37000SMatthias Springer static bool hasMatchingExtractSliceOp(const BufferizationState &state, 43449e37000SMatthias Springer Value value, InsertSliceOp insertOp) { 43549e37000SMatthias Springer auto condition = [&](Value val) { 43649e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 43749e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 43849e37000SMatthias Springer return true; 43949e37000SMatthias Springer return false; 44049e37000SMatthias Springer }; 44149e37000SMatthias Springer 44249e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 44349e37000SMatthias Springer condition); 44449e37000SMatthias Springer } 44549e37000SMatthias Springer 44649e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 44749e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 44849e37000SMatthias Springer struct InsertSliceOpInterface 44949e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 45049e37000SMatthias Springer tensor::InsertSliceOp> { 45149e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 45249e37000SMatthias Springer const BufferizationState &state) const { 45349e37000SMatthias Springer return true; 45449e37000SMatthias Springer } 45549e37000SMatthias Springer 45649e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 45749e37000SMatthias Springer const BufferizationState &state) const { 45849e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 45949e37000SMatthias Springer } 46049e37000SMatthias Springer 46149e37000SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 46249e37000SMatthias Springer const BufferizationState &state) const { 46349e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/ 46449e37000SMatthias Springer ? op->getResult(0) 46549e37000SMatthias Springer : OpResult(); 46649e37000SMatthias Springer } 46749e37000SMatthias Springer 46849e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 46949e37000SMatthias Springer const BufferizationState &state) const { 47049e37000SMatthias Springer return BufferRelation::Equivalent; 47149e37000SMatthias Springer } 47249e37000SMatthias Springer 47349e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 47449e37000SMatthias Springer OpOperand *uConflictingWrite, 47549e37000SMatthias Springer const BufferizationState &state) const { 47649e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 47749e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 47849e37000SMatthias Springer 47949e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 48049e37000SMatthias Springer // uRead is an InsertSliceOp... 48149e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 48249e37000SMatthias Springer // As an example, consider the following IR. 48349e37000SMatthias Springer // 48449e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 48549e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 48649e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 48749e37000SMatthias Springer // {inplace= [true] } 48849e37000SMatthias Springer 48949e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 49049e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 49149e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 49249e37000SMatthias Springer insertSliceOp)) 49349e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 49449e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 49549e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 49649e37000SMatthias Springer // being read by uRead, this is not a conflict. 49749e37000SMatthias Springer // 49849e37000SMatthias Springer // In the above example: 49949e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 50049e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 50149e37000SMatthias Springer // 50249e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 50349e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 50449e37000SMatthias Springer // exactly the one that is *not* read via %t. 50549e37000SMatthias Springer return true; 50649e37000SMatthias Springer 50749e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 50849e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 50949e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 51049e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 51149e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 51249e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 51349e37000SMatthias Springer // InsertSliceOp is writing. 51449e37000SMatthias Springer // 51549e37000SMatthias Springer // In the above example: 51649e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 51749e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 51849e37000SMatthias Springer return true; 51949e37000SMatthias Springer } 52049e37000SMatthias Springer 52149e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 52249e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 52349e37000SMatthias Springer // As an example, consider the following IR. 52449e37000SMatthias Springer // 52549e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 52649e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 52749e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 52849e37000SMatthias Springer // {inplace= [true] } 52949e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 53049e37000SMatthias Springer // 53149e37000SMatthias Springer // In the above example: 53249e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 53349e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 53449e37000SMatthias Springer // lastWrite = %1 53549e37000SMatthias Springer // 53649e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 53749e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 53849e37000SMatthias Springer // is no memory write here.) 53949e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 54049e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 54149e37000SMatthias Springer insertSliceOp.source()) && 54249e37000SMatthias Springer hasMatchingExtractSliceOp(state, insertSliceOp.source(), 54349e37000SMatthias Springer insertSliceOp)) 54449e37000SMatthias Springer return true; 54549e37000SMatthias Springer 54649e37000SMatthias Springer return false; 54749e37000SMatthias Springer } 54849e37000SMatthias Springer 54949e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 55049e37000SMatthias Springer const BufferizationState &state) const { 55149e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 55249e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 55349e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 55449e37000SMatthias Springer // catastrophically bad scheduling decision. 55549e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 55649e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 55749e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 55849e37000SMatthias Springer 55949e37000SMatthias Springer // When bufferizing out-of-place, `getResultBuffer` allocates. 56049e37000SMatthias Springer FailureOr<Value> dstMemref = 56149e37000SMatthias Springer state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 56249e37000SMatthias Springer if (failed(dstMemref)) 56349e37000SMatthias Springer return failure(); 56449e37000SMatthias Springer 56549e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 56649e37000SMatthias Springer // rank-reducing case. 56749e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 56849e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 56949e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 57049e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 57149e37000SMatthias Springer *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 57249e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 57349e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 57449e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 57549e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 57649e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 57749e37000SMatthias Springer }); 57849e37000SMatthias Springer // Take a subview of the dst. 57949e37000SMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 58049e37000SMatthias Springer auto subviewMemRefType = 58149e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 58249e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 58349e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 58449e37000SMatthias Springer .cast<MemRefType>(); 58549e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 58649e37000SMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 58749e37000SMatthias Springer mixedStrides); 58849e37000SMatthias Springer 58949e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 59049e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 59149e37000SMatthias Springer Value srcMemref = 59249e37000SMatthias Springer *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 59349e37000SMatthias Springer if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 59449e37000SMatthias Springer state.getOptions()))) 59549e37000SMatthias Springer return failure(); 59649e37000SMatthias Springer 59749e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 59849e37000SMatthias Springer return success(); 59949e37000SMatthias Springer } 60049e37000SMatthias Springer }; 60149e37000SMatthias Springer 602fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 603fc08d1c2SMatthias Springer struct RankOpInterface 604fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 605fc08d1c2SMatthias Springer tensor::RankOp> { 606fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 607fc08d1c2SMatthias Springer const BufferizationState &state) const { 608fc08d1c2SMatthias Springer return true; 609fc08d1c2SMatthias Springer } 610fc08d1c2SMatthias Springer 611fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 612fc08d1c2SMatthias Springer const BufferizationState &state) const { 613fc08d1c2SMatthias Springer return false; 614fc08d1c2SMatthias Springer } 615fc08d1c2SMatthias Springer 616fc08d1c2SMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 617fc08d1c2SMatthias Springer const BufferizationState &state) const { 618fc08d1c2SMatthias Springer return OpResult(); 619fc08d1c2SMatthias Springer } 620fc08d1c2SMatthias Springer 621fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 622fc08d1c2SMatthias Springer const BufferizationState &state) const { 623fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 624fc08d1c2SMatthias Springer Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 625fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 626fc08d1c2SMatthias Springer v); 627fc08d1c2SMatthias Springer return success(); 628fc08d1c2SMatthias Springer } 629fc08d1c2SMatthias Springer }; 630fc08d1c2SMatthias Springer 63149e37000SMatthias Springer } // namespace 63249e37000SMatthias Springer } // namespace tensor 63349e37000SMatthias Springer } // namespace mlir 63449e37000SMatthias Springer 63549e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 63649e37000SMatthias Springer DialectRegistry ®istry) { 63749e37000SMatthias Springer registry.addOpInterface<CastOp, CastOpInterface>(); 63849e37000SMatthias Springer registry.addOpInterface<DimOp, DimOpInterface>(); 63949e37000SMatthias Springer registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 64049e37000SMatthias Springer registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 641*d581c94dSMatthias Springer registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 64271bbb78bSMatthias Springer registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 64349e37000SMatthias Springer registry.addOpInterface<InsertOp, InsertOpInterface>(); 64449e37000SMatthias Springer registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 645fc08d1c2SMatthias Springer registry.addOpInterface<RankOp, RankOpInterface>(); 64649e37000SMatthias Springer } 647