149e37000SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 249e37000SMatthias Springer // 349e37000SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 449e37000SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 549e37000SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 649e37000SMatthias Springer // 749e37000SMatthias Springer //===----------------------------------------------------------------------===// 849e37000SMatthias Springer 949e37000SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 1049e37000SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1149e37000SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 1271bbb78bSMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 1349e37000SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1449e37000SMatthias Springer #include "mlir/IR/Dialect.h" 1549e37000SMatthias Springer #include "mlir/IR/Operation.h" 1649e37000SMatthias Springer 1749e37000SMatthias Springer using namespace mlir; 1849e37000SMatthias Springer using namespace mlir::bufferization; 1949e37000SMatthias Springer using namespace mlir::tensor; 2049e37000SMatthias Springer 2149e37000SMatthias Springer namespace mlir { 2249e37000SMatthias Springer namespace tensor { 2349e37000SMatthias Springer namespace { 2449e37000SMatthias Springer 2549e37000SMatthias Springer struct CastOpInterface 2649e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<CastOpInterface, 2749e37000SMatthias Springer tensor::CastOp> { 2849e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 299597b16aSMatthias Springer const AnalysisState &state) const { 3049e37000SMatthias Springer return false; 3149e37000SMatthias Springer } 3249e37000SMatthias Springer 3349e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 349597b16aSMatthias Springer const AnalysisState &state) const { 3549e37000SMatthias Springer return false; 3649e37000SMatthias Springer } 3749e37000SMatthias Springer 389597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 399597b16aSMatthias Springer const AnalysisState &state) const { 40585a8a32SMatthias Springer return {op->getResult(0)}; 4149e37000SMatthias Springer } 4249e37000SMatthias Springer 4349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 449597b16aSMatthias Springer const AnalysisState &state) const { 4549e37000SMatthias Springer return BufferRelation::Equivalent; 4649e37000SMatthias Springer } 4749e37000SMatthias Springer 4849e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 499597b16aSMatthias Springer BufferizationState &state) const { 5049e37000SMatthias Springer auto castOp = cast<tensor::CastOp>(op); 5149e37000SMatthias Springer 5249e37000SMatthias Springer // The result buffer still has the old (pre-cast) type. 5349e37000SMatthias Springer FailureOr<Value> resultBuffer = 5449e37000SMatthias Springer state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 5549e37000SMatthias Springer if (failed(resultBuffer)) 5649e37000SMatthias Springer return failure(); 5749e37000SMatthias Springer auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 5849e37000SMatthias Springer Attribute memorySpace = sourceMemRefType.getMemorySpace(); 5949e37000SMatthias Springer TensorType resultTensorType = 6049e37000SMatthias Springer castOp.getResult().getType().cast<TensorType>(); 6149e37000SMatthias Springer MemRefLayoutAttrInterface layout; 6249e37000SMatthias Springer 6349e37000SMatthias Springer if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 6449e37000SMatthias Springer if (resultTensorType.isa<RankedTensorType>()) 6549e37000SMatthias Springer layout = rankedMemRefType.getLayout(); 6649e37000SMatthias Springer 6749e37000SMatthias Springer // Compute the new memref type. 6826852423SMatthias Springer Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 6926852423SMatthias Springer layout, memorySpace); 7049e37000SMatthias Springer 7149e37000SMatthias Springer // Replace the op with a memref.cast. 7249e37000SMatthias Springer assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 7349e37000SMatthias Springer resultMemRefType) && 7449e37000SMatthias Springer "CallOp::bufferize: cast incompatible"); 7549e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 7649e37000SMatthias Springer *resultBuffer); 7749e37000SMatthias Springer 7849e37000SMatthias Springer return success(); 7949e37000SMatthias Springer } 8049e37000SMatthias Springer }; 8149e37000SMatthias Springer 82e6f69161SMatthias Springer /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83e6f69161SMatthias Springer struct CollapseShapeOpInterface 84e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85e6f69161SMatthias Springer tensor::CollapseShapeOp> { 86e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 879597b16aSMatthias Springer const AnalysisState &state) const { 88e6f69161SMatthias Springer return false; 89e6f69161SMatthias Springer } 90e6f69161SMatthias Springer 91e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 929597b16aSMatthias Springer const AnalysisState &state) const { 93e6f69161SMatthias Springer return false; 94e6f69161SMatthias Springer } 95e6f69161SMatthias Springer 969597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 979597b16aSMatthias Springer const AnalysisState &state) const { 98e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 99e6f69161SMatthias Springer return {op->getOpResult(0)}; 100e6f69161SMatthias Springer return {}; 101e6f69161SMatthias Springer } 102e6f69161SMatthias Springer 103e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1049597b16aSMatthias Springer const AnalysisState &state) const { 105e6f69161SMatthias Springer return BufferRelation::Equivalent; 106e6f69161SMatthias Springer } 107e6f69161SMatthias Springer 108e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1099597b16aSMatthias Springer BufferizationState &state) const { 110e6f69161SMatthias Springer auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111e6f69161SMatthias Springer Value buffer = 112e6f69161SMatthias Springer *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 113e6f69161SMatthias Springer Type resultType = 114e6f69161SMatthias Springer getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); 115e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 116e6f69161SMatthias Springer rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 117e6f69161SMatthias Springer return success(); 118e6f69161SMatthias Springer } 119e6f69161SMatthias Springer }; 120e6f69161SMatthias Springer 12149e37000SMatthias Springer /// Bufferization of tensor.dim. Replace with memref.dim. 12249e37000SMatthias Springer struct DimOpInterface 12349e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<DimOpInterface, 12449e37000SMatthias Springer tensor::DimOp> { 12549e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1269597b16aSMatthias Springer const AnalysisState &state) const { 12749e37000SMatthias Springer return true; 12849e37000SMatthias Springer } 12949e37000SMatthias Springer 13049e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1319597b16aSMatthias Springer const AnalysisState &state) const { 13249e37000SMatthias Springer return false; 13349e37000SMatthias Springer } 13449e37000SMatthias Springer 1359597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1369597b16aSMatthias Springer const AnalysisState &state) const { 137585a8a32SMatthias Springer return {}; 13849e37000SMatthias Springer } 13949e37000SMatthias Springer 14049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1419597b16aSMatthias Springer BufferizationState &state) const { 14249e37000SMatthias Springer auto dimOp = cast<tensor::DimOp>(op); 14349e37000SMatthias Springer Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 14449e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 14549e37000SMatthias Springer return success(); 14649e37000SMatthias Springer } 14749e37000SMatthias Springer }; 14849e37000SMatthias Springer 149e6f69161SMatthias Springer /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 150e6f69161SMatthias Springer struct ExpandShapeOpInterface 151e6f69161SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 152e6f69161SMatthias Springer tensor::ExpandShapeOp> { 153e6f69161SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1549597b16aSMatthias Springer const AnalysisState &state) const { 155e6f69161SMatthias Springer return false; 156e6f69161SMatthias Springer } 157e6f69161SMatthias Springer 158e6f69161SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1599597b16aSMatthias Springer const AnalysisState &state) const { 160e6f69161SMatthias Springer return false; 161e6f69161SMatthias Springer } 162e6f69161SMatthias Springer 1639597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1649597b16aSMatthias Springer const AnalysisState &state) const { 165e6f69161SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*src*/) 166e6f69161SMatthias Springer return {op->getOpResult(0)}; 167e6f69161SMatthias Springer return {}; 168e6f69161SMatthias Springer } 169e6f69161SMatthias Springer 170e6f69161SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 1719597b16aSMatthias Springer const AnalysisState &state) const { 172e6f69161SMatthias Springer return BufferRelation::Equivalent; 173e6f69161SMatthias Springer } 174e6f69161SMatthias Springer 175e6f69161SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1769597b16aSMatthias Springer BufferizationState &state) const { 177e6f69161SMatthias Springer auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 178e6f69161SMatthias Springer Value buffer = 179e6f69161SMatthias Springer *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 180e6f69161SMatthias Springer Type resultType = 181e6f69161SMatthias Springer getMemRefType(expandShapeOp.getResultType(), state.getOptions()); 182e6f69161SMatthias Springer replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 183e6f69161SMatthias Springer rewriter, op, resultType, buffer, expandShapeOp.reassociation()); 184e6f69161SMatthias Springer return success(); 185e6f69161SMatthias Springer } 186e6f69161SMatthias Springer }; 187e6f69161SMatthias Springer 18849e37000SMatthias Springer /// Bufferization of tensor.extract_slice. Replace with memref.subview. 18949e37000SMatthias Springer struct ExtractSliceOpInterface 19049e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 19149e37000SMatthias Springer tensor::ExtractSliceOp> { 19249e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1939597b16aSMatthias Springer const AnalysisState &state) const { 19449e37000SMatthias Springer return false; 19549e37000SMatthias Springer } 19649e37000SMatthias Springer 19749e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1989597b16aSMatthias Springer const AnalysisState &state) const { 19949e37000SMatthias Springer return false; 20049e37000SMatthias Springer } 20149e37000SMatthias Springer 2029597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2039597b16aSMatthias Springer const AnalysisState &state) const { 204585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(0) /*source*/) 205585a8a32SMatthias Springer return {op->getOpResult(0)}; 206585a8a32SMatthias Springer return {}; 20749e37000SMatthias Springer } 20849e37000SMatthias Springer 20949e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 2109597b16aSMatthias Springer const AnalysisState &state) const { 21149e37000SMatthias Springer return BufferRelation::None; 21249e37000SMatthias Springer } 21349e37000SMatthias Springer 21449e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 2159597b16aSMatthias Springer BufferizationState &state) const { 21649e37000SMatthias Springer auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 21749e37000SMatthias Springer Location loc = extractSliceOp.getLoc(); 21849e37000SMatthias Springer Value srcMemref = 21949e37000SMatthias Springer *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 22049e37000SMatthias Springer /*forceInPlace=*/true); 22149e37000SMatthias Springer auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 22249e37000SMatthias Springer auto dstTensorType = 22349e37000SMatthias Springer extractSliceOp.result().getType().cast<RankedTensorType>(); 22449e37000SMatthias Springer 22549e37000SMatthias Springer // If not inplaceable, alloc. 2269597b16aSMatthias Springer bool inplace = 2279597b16aSMatthias Springer state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 22849e37000SMatthias Springer Value alloc; 22949e37000SMatthias Springer if (!inplace) { 23049e37000SMatthias Springer FailureOr<Value> allocOrFailure = 23105e0495fSMatthias Springer state.createAlloc(rewriter, loc, extractSliceOp.result()); 23249e37000SMatthias Springer if (failed(allocOrFailure)) 23349e37000SMatthias Springer return failure(); 23449e37000SMatthias Springer alloc = *allocOrFailure; 23549e37000SMatthias Springer } 23649e37000SMatthias Springer 23749e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 23849e37000SMatthias Springer // rank-reducing case. 23949e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 24049e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 24149e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 24249e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 24349e37000SMatthias Springer srcMemref, mixedOffsets, mixedSizes, mixedStrides, 24449e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 24549e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 24649e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 24749e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 24849e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 24949e37000SMatthias Springer }); 25049e37000SMatthias Springer // Bufferize to subview. 25149e37000SMatthias Springer auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 25249e37000SMatthias Springer dstTensorType.getRank(), srcMemrefType, 25349e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 25449e37000SMatthias Springer .cast<MemRefType>(); 25549e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 25649e37000SMatthias Springer loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 25749e37000SMatthias Springer mixedStrides); 25849e37000SMatthias Springer 25949e37000SMatthias Springer // If not inplaceable, copy. 26049e37000SMatthias Springer if (!inplace) { 26149e37000SMatthias Springer // Do not copy if the copied data is never read. 2629597b16aSMatthias Springer if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 26349e37000SMatthias Springer if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 26449e37000SMatthias Springer alloc, state.getOptions()))) 26549e37000SMatthias Springer return failure(); 26649e37000SMatthias Springer subView = alloc; 26749e37000SMatthias Springer } 26849e37000SMatthias Springer 26949e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, subView); 27049e37000SMatthias Springer return success(); 27149e37000SMatthias Springer } 27249e37000SMatthias Springer }; 27349e37000SMatthias Springer 27449e37000SMatthias Springer /// Bufferization of tensor.extract. Replace with memref.load. 27549e37000SMatthias Springer struct ExtractOpInterface 27649e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 27749e37000SMatthias Springer tensor::ExtractOp> { 27849e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 2799597b16aSMatthias Springer const AnalysisState &state) const { 28049e37000SMatthias Springer return true; 28149e37000SMatthias Springer } 28249e37000SMatthias Springer 28349e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 2849597b16aSMatthias Springer const AnalysisState &state) const { 28549e37000SMatthias Springer return false; 28649e37000SMatthias Springer } 28749e37000SMatthias Springer 2889597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 2899597b16aSMatthias Springer const AnalysisState &state) const { 290585a8a32SMatthias Springer return {}; 29149e37000SMatthias Springer } 29249e37000SMatthias Springer 29349e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 2949597b16aSMatthias Springer BufferizationState &state) const { 29549e37000SMatthias Springer auto extractOp = cast<tensor::ExtractOp>(op); 29649e37000SMatthias Springer Value srcMemref = 29749e37000SMatthias Springer *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 29849e37000SMatthias Springer replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 29949e37000SMatthias Springer extractOp.indices()); 30049e37000SMatthias Springer return success(); 30149e37000SMatthias Springer } 30249e37000SMatthias Springer }; 30349e37000SMatthias Springer 304d581c94dSMatthias Springer // Implements backtracking to traverse indices of the output buffer while 305d581c94dSMatthias Springer // iterating over op.elements(). 306d581c94dSMatthias Springer static void createStores(RewriterBase &rewriter, Location loc, int dim, 307d581c94dSMatthias Springer Value buffer, ArrayRef<int64_t> shape, 308d581c94dSMatthias Springer ArrayRef<Value> constants, 309d581c94dSMatthias Springer OperandRange::iterator &elementIt, 310d581c94dSMatthias Springer SmallVectorImpl<Value> &indices) { 311d581c94dSMatthias Springer if (dim == static_cast<int>(shape.size()) - 1) { 312d581c94dSMatthias Springer for (int i = 0; i < shape.back(); ++i) { 313d581c94dSMatthias Springer indices.back() = constants[i]; 314d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 315d581c94dSMatthias Springer ++elementIt; 316d581c94dSMatthias Springer } 317d581c94dSMatthias Springer return; 318d581c94dSMatthias Springer } 319d581c94dSMatthias Springer for (int i = 0; i < shape[dim]; ++i) { 320d581c94dSMatthias Springer indices[dim] = constants[i]; 321d581c94dSMatthias Springer createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 322d581c94dSMatthias Springer indices); 323d581c94dSMatthias Springer } 324d581c94dSMatthias Springer } 325d581c94dSMatthias Springer 326d581c94dSMatthias Springer /// Bufferization of tensor.from_elements. 327d581c94dSMatthias Springer struct FromElementsOpInterface 328d581c94dSMatthias Springer : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 329d581c94dSMatthias Springer tensor::FromElementsOp> { 330d581c94dSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 3319597b16aSMatthias Springer BufferizationState &state) const { 332d581c94dSMatthias Springer auto fromElementsOp = cast<tensor::FromElementsOp>(op); 333d581c94dSMatthias Springer 334d581c94dSMatthias Springer // Allocate a buffer for the result. 335d581c94dSMatthias Springer Location loc = op->getLoc(); 336d581c94dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 337d581c94dSMatthias Springer auto shape = tensorType.getShape(); 338d581c94dSMatthias Springer FailureOr<Value> maybeBuffer = 339*9e24f0f4SMatthias Springer state.createAlloc(rewriter, loc, fromElementsOp.result()); 340d581c94dSMatthias Springer if (failed(maybeBuffer)) 341d581c94dSMatthias Springer return failure(); 342d581c94dSMatthias Springer Value buffer = *maybeBuffer; 343d581c94dSMatthias Springer 344d581c94dSMatthias Springer // Case: tensor<0xelem_type>. 345d581c94dSMatthias Springer if (fromElementsOp.elements().empty()) { 346d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 347d581c94dSMatthias Springer return success(); 348d581c94dSMatthias Springer } 349d581c94dSMatthias Springer 350d581c94dSMatthias Springer // Case: tensor<elem_type>. 351d581c94dSMatthias Springer if (shape.empty()) { 352d581c94dSMatthias Springer rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 353d581c94dSMatthias Springer buffer); 354d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 355d581c94dSMatthias Springer return success(); 356d581c94dSMatthias Springer } 357d581c94dSMatthias Springer 358d581c94dSMatthias Springer // Create constants for the range of possible indices [0, max{shape_i}). 359d581c94dSMatthias Springer auto maxDim = *std::max_element(shape.begin(), shape.end()); 360d581c94dSMatthias Springer SmallVector<Value, 2> constants; 361d581c94dSMatthias Springer constants.reserve(maxDim); 362d581c94dSMatthias Springer for (int i = 0; i < maxDim; ++i) 363d581c94dSMatthias Springer constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 364d581c94dSMatthias Springer 365d581c94dSMatthias Springer // Traverse all `elements` and create `memref.store` ops. 366d581c94dSMatthias Springer auto elementIt = fromElementsOp.elements().begin(); 367d581c94dSMatthias Springer SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 368d581c94dSMatthias Springer createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 369d581c94dSMatthias Springer indices); 370d581c94dSMatthias Springer 371d581c94dSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, buffer); 372d581c94dSMatthias Springer return success(); 373d581c94dSMatthias Springer } 374d581c94dSMatthias Springer }; 375d581c94dSMatthias Springer 37671bbb78bSMatthias Springer /// Bufferization of tensor.generate. 37771bbb78bSMatthias Springer struct GenerateOpInterface 37871bbb78bSMatthias Springer : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 37971bbb78bSMatthias Springer tensor::GenerateOp> { 38071bbb78bSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 3819597b16aSMatthias Springer BufferizationState &state) const { 38271bbb78bSMatthias Springer auto generateOp = cast<tensor::GenerateOp>(op); 38371bbb78bSMatthias Springer 38471bbb78bSMatthias Springer // Allocate memory. 38571bbb78bSMatthias Springer Location loc = op->getLoc(); 38671bbb78bSMatthias Springer MemRefType memrefType = 38771bbb78bSMatthias Springer getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 388*9e24f0f4SMatthias Springer FailureOr<Value> maybeResult = 389*9e24f0f4SMatthias Springer state.createAlloc(rewriter, loc, generateOp.result()); 39071bbb78bSMatthias Springer if (failed(maybeResult)) 39171bbb78bSMatthias Springer return failure(); 39271bbb78bSMatthias Springer Value result = *maybeResult; 39371bbb78bSMatthias Springer 39471bbb78bSMatthias Springer // Collect loop bounds. 39571bbb78bSMatthias Springer int64_t rank = memrefType.getRank(); 39671bbb78bSMatthias Springer Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 39771bbb78bSMatthias Springer Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 39871bbb78bSMatthias Springer SmallVector<Value, 4> lowerBounds(rank, zero); 39971bbb78bSMatthias Springer SmallVector<Value, 4> steps(rank, one); 40071bbb78bSMatthias Springer SmallVector<Value, 4> upperBounds; 40171bbb78bSMatthias Springer int nextDynamicIndex = 0; 40271bbb78bSMatthias Springer for (int i = 0; i < rank; i++) { 40371bbb78bSMatthias Springer Value upperBound = memrefType.isDynamicDim(i) 40471bbb78bSMatthias Springer ? generateOp.dynamicExtents()[nextDynamicIndex++] 40571bbb78bSMatthias Springer : rewriter.create<arith::ConstantIndexOp>( 40671bbb78bSMatthias Springer loc, memrefType.getDimSize(i)); 40771bbb78bSMatthias Springer upperBounds.push_back(upperBound); 40871bbb78bSMatthias Springer } 40971bbb78bSMatthias Springer 41071bbb78bSMatthias Springer // Generate tensor elements with a parallel loop that stores into 41171bbb78bSMatthias Springer // each element of the resulting memref. We use mergeBlockBefore to "move" 41271bbb78bSMatthias Springer // this op's body into the scf.parallel's body. 41371bbb78bSMatthias Springer auto parallel = 41471bbb78bSMatthias Springer rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 41571bbb78bSMatthias Springer Block *parallelBody = parallel.getBody(); 41671bbb78bSMatthias Springer rewriter.mergeBlockBefore(generateOp.getBody(), 41771bbb78bSMatthias Springer parallelBody->getTerminator(), 41871bbb78bSMatthias Springer parallelBody->getArguments()); 41971bbb78bSMatthias Springer // Replace the inlined yield op with a store op. The scf.parallel's builder 42071bbb78bSMatthias Springer // already populated an scf.yield at the end, so we don't need to worry 42171bbb78bSMatthias Springer // about creating that. 42271bbb78bSMatthias Springer Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 42371bbb78bSMatthias Springer rewriter.setInsertionPointAfter(elementYield); 42471bbb78bSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>( 42571bbb78bSMatthias Springer elementYield, elementYield->getOperands()[0], result, 42671bbb78bSMatthias Springer parallelBody->getArguments()); 42771bbb78bSMatthias Springer 42871bbb78bSMatthias Springer replaceOpWithBufferizedValues(rewriter, op, result); 42971bbb78bSMatthias Springer return success(); 43071bbb78bSMatthias Springer } 43171bbb78bSMatthias Springer }; 43271bbb78bSMatthias Springer 43349e37000SMatthias Springer /// Bufferization of tensor.insert. Replace with memref.store. 43449e37000SMatthias Springer struct InsertOpInterface 43549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 43649e37000SMatthias Springer tensor::InsertOp> { 43749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 4389597b16aSMatthias Springer const AnalysisState &state) const { 43949e37000SMatthias Springer return true; 44049e37000SMatthias Springer } 44149e37000SMatthias Springer 44249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 4439597b16aSMatthias Springer const AnalysisState &state) const { 44449e37000SMatthias Springer return true; 44549e37000SMatthias Springer } 44649e37000SMatthias Springer 4479597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 4489597b16aSMatthias Springer const AnalysisState &state) const { 44949e37000SMatthias Springer assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 45049e37000SMatthias Springer "expected dest OpOperand"); 451585a8a32SMatthias Springer return {op->getOpResult(0)}; 45249e37000SMatthias Springer } 45349e37000SMatthias Springer 45449e37000SMatthias Springer SmallVector<OpOperand *> 45549e37000SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 4569597b16aSMatthias Springer const AnalysisState &state) const { 45749e37000SMatthias Springer return {&op->getOpOperand(1) /*dest*/}; 45849e37000SMatthias Springer } 45949e37000SMatthias Springer 46049e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 4619597b16aSMatthias Springer BufferizationState &state) const { 46249e37000SMatthias Springer auto insertOp = cast<tensor::InsertOp>(op); 46349e37000SMatthias Springer FailureOr<Value> destMemref = 46449e37000SMatthias Springer state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 46549e37000SMatthias Springer if (failed(destMemref)) 46649e37000SMatthias Springer return failure(); 46749e37000SMatthias Springer rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 46849e37000SMatthias Springer *destMemref, insertOp.indices()); 46949e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *destMemref); 47049e37000SMatthias Springer return success(); 47149e37000SMatthias Springer } 47249e37000SMatthias Springer 47349e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 4749597b16aSMatthias Springer const AnalysisState &state) const { 47549e37000SMatthias Springer return BufferRelation::Equivalent; 47649e37000SMatthias Springer } 47749e37000SMatthias Springer }; 47849e37000SMatthias Springer 47949e37000SMatthias Springer /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 48049e37000SMatthias Springer /// equivalent operand / result and same offset/sizes/strides specification). 48149e37000SMatthias Springer /// 48249e37000SMatthias Springer /// This is one particular type of relationship between ops on tensors that 48349e37000SMatthias Springer /// reduce to an equivalence on buffers. This should be generalized and 48449e37000SMatthias Springer /// exposed as interfaces on the proper types. 4859597b16aSMatthias Springer static bool areEquivalentExtractSliceOps(const AnalysisState &state, 48649e37000SMatthias Springer ExtractSliceOp st, InsertSliceOp sti) { 48749e37000SMatthias Springer if (!st || !sti) 48849e37000SMatthias Springer return false; 48949e37000SMatthias Springer if (sti != sti && 49049e37000SMatthias Springer !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 49149e37000SMatthias Springer return false; 49249e37000SMatthias Springer if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 49349e37000SMatthias Springer return false; 49449e37000SMatthias Springer return true; 49549e37000SMatthias Springer } 49649e37000SMatthias Springer 49749e37000SMatthias Springer /// Return true if `value` is originating from an ExtractSliceOp that matches 49849e37000SMatthias Springer /// the given InsertSliceOp. 4999597b16aSMatthias Springer static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 5009597b16aSMatthias Springer InsertSliceOp insertOp) { 50149e37000SMatthias Springer auto condition = [&](Value val) { 50249e37000SMatthias Springer if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 50349e37000SMatthias Springer if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 50449e37000SMatthias Springer return true; 50549e37000SMatthias Springer return false; 50649e37000SMatthias Springer }; 50749e37000SMatthias Springer 50849e37000SMatthias Springer return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 50949e37000SMatthias Springer condition); 51049e37000SMatthias Springer } 51149e37000SMatthias Springer 51249e37000SMatthias Springer /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 51349e37000SMatthias Springer /// certain circumstances, this op can also be a no-op. 51449e37000SMatthias Springer struct InsertSliceOpInterface 51549e37000SMatthias Springer : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 51649e37000SMatthias Springer tensor::InsertSliceOp> { 51749e37000SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 5189597b16aSMatthias Springer const AnalysisState &state) const { 51949e37000SMatthias Springer return true; 52049e37000SMatthias Springer } 52149e37000SMatthias Springer 52249e37000SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 5239597b16aSMatthias Springer const AnalysisState &state) const { 52449e37000SMatthias Springer return &opOperand == &op->getOpOperand(1) /*dest*/; 52549e37000SMatthias Springer } 52649e37000SMatthias Springer 5279597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 5289597b16aSMatthias Springer const AnalysisState &state) const { 529585a8a32SMatthias Springer if (&opOperand == &op->getOpOperand(1) /*dest*/) 530585a8a32SMatthias Springer return {op->getResult(0)}; 531585a8a32SMatthias Springer return {}; 53249e37000SMatthias Springer } 53349e37000SMatthias Springer 53449e37000SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 5359597b16aSMatthias Springer const AnalysisState &state) const { 53649e37000SMatthias Springer return BufferRelation::Equivalent; 53749e37000SMatthias Springer } 53849e37000SMatthias Springer 53949e37000SMatthias Springer bool isNotConflicting(Operation *op, OpOperand *uRead, 54049e37000SMatthias Springer OpOperand *uConflictingWrite, 5419597b16aSMatthias Springer const AnalysisState &state) const { 54249e37000SMatthias Springer Operation *readingOp = uRead->getOwner(); 54349e37000SMatthias Springer Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 54449e37000SMatthias Springer 54549e37000SMatthias Springer // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 54649e37000SMatthias Springer // uRead is an InsertSliceOp... 54749e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 54849e37000SMatthias Springer // As an example, consider the following IR. 54949e37000SMatthias Springer // 55049e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 55149e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 55249e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 55349e37000SMatthias Springer // {inplace= [true] } 55449e37000SMatthias Springer 55549e37000SMatthias Springer // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 55649e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 55749e37000SMatthias Springer hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 55849e37000SMatthias Springer insertSliceOp)) 55949e37000SMatthias Springer // Case 1: The main insight is that InsertSliceOp reads only part of 56049e37000SMatthias Springer // the destination tensor. The overwritten area is not read. If 56149e37000SMatthias Springer // uConflictingWrite writes into exactly the memory location that is 56249e37000SMatthias Springer // being read by uRead, this is not a conflict. 56349e37000SMatthias Springer // 56449e37000SMatthias Springer // In the above example: 56549e37000SMatthias Springer // uRead = OpOperand 1 (%t) of tensor.insert_slice 56649e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 56749e37000SMatthias Springer // 56849e37000SMatthias Springer // The read of %t does not conflict with the write of the FillOp 56949e37000SMatthias Springer // (same aliases!) because the area that the FillOp operates on is 57049e37000SMatthias Springer // exactly the one that is *not* read via %t. 57149e37000SMatthias Springer return true; 57249e37000SMatthias Springer 57349e37000SMatthias Springer if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 57449e37000SMatthias Springer uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 57549e37000SMatthias Springer hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 57649e37000SMatthias Springer // Case 2: The read of the source tensor and the write to the dest 57749e37000SMatthias Springer // tensor via an InsertSliceOp is not a conflict if the read is 57849e37000SMatthias Springer // reading exactly that part of an equivalent tensor that the 57949e37000SMatthias Springer // InsertSliceOp is writing. 58049e37000SMatthias Springer // 58149e37000SMatthias Springer // In the above example: 58249e37000SMatthias Springer // uRead = OpOperand 0 (%1) of tensor.insert_slice 58349e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 58449e37000SMatthias Springer return true; 58549e37000SMatthias Springer } 58649e37000SMatthias Springer 58749e37000SMatthias Springer // If uConflictingWrite is an InsertSliceOp... 58849e37000SMatthias Springer if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 58949e37000SMatthias Springer // As an example, consider the following IR. 59049e37000SMatthias Springer // 59149e37000SMatthias Springer // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 59249e37000SMatthias Springer // %1 = linalg.fill %cst, %0 {inplace= [true] } 59349e37000SMatthias Springer // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 59449e37000SMatthias Springer // {inplace= [true] } 59549e37000SMatthias Springer // %3 = vector.transfer_read %1, %cst 59649e37000SMatthias Springer // 59749e37000SMatthias Springer // In the above example: 59849e37000SMatthias Springer // uRead = OpOperand 0 (%1) of vector.transfer_read 59949e37000SMatthias Springer // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 60049e37000SMatthias Springer // lastWrite = %1 60149e37000SMatthias Springer // 60249e37000SMatthias Springer // This is not a conflict because the InsertSliceOp overwrites the 60349e37000SMatthias Springer // memory segment of %1 with the exact same data. (Effectively, there 60449e37000SMatthias Springer // is no memory write here.) 60549e37000SMatthias Springer if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 60649e37000SMatthias Springer state.areEquivalentBufferizedValues(uRead->get(), 60749e37000SMatthias Springer insertSliceOp.source()) && 60849e37000SMatthias Springer hasMatchingExtractSliceOp(state, insertSliceOp.source(), 60949e37000SMatthias Springer insertSliceOp)) 61049e37000SMatthias Springer return true; 61149e37000SMatthias Springer 61249e37000SMatthias Springer return false; 61349e37000SMatthias Springer } 61449e37000SMatthias Springer 61549e37000SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 6169597b16aSMatthias Springer BufferizationState &state) const { 61749e37000SMatthias Springer // insert_slice ops arise from tiling and bufferizing them out-of-place is 61849e37000SMatthias Springer // generally a deal breaker. When used with loops, this ends up cloning the 61949e37000SMatthias Springer // whole tensor on every single iteration and is a symptom of a 62049e37000SMatthias Springer // catastrophically bad scheduling decision. 62149e37000SMatthias Springer // TODO: be very loud about it or even consider failing the pass. 62249e37000SMatthias Springer auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 62349e37000SMatthias Springer Location loc = insertSliceOp.getLoc(); 62449e37000SMatthias Springer 62549e37000SMatthias Springer // When bufferizing out-of-place, `getResultBuffer` allocates. 62649e37000SMatthias Springer FailureOr<Value> dstMemref = 62749e37000SMatthias Springer state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 62849e37000SMatthias Springer if (failed(dstMemref)) 62949e37000SMatthias Springer return failure(); 63049e37000SMatthias Springer 63149e37000SMatthias Springer // Expand offsets, sizes and strides to the full rank to handle the 63249e37000SMatthias Springer // rank-reducing case. 63349e37000SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 63449e37000SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 63549e37000SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 63649e37000SMatthias Springer OffsetSizeAndStrideOpInterface::expandToRank( 63749e37000SMatthias Springer *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 63849e37000SMatthias Springer [&](Value target, int64_t dim) -> OpFoldResult { 63949e37000SMatthias Springer auto shapedType = target.getType().cast<ShapedType>(); 64049e37000SMatthias Springer if (shapedType.isDynamicDim(dim)) 64149e37000SMatthias Springer return rewriter.create<memref::DimOp>(loc, target, dim).result(); 64249e37000SMatthias Springer return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 64349e37000SMatthias Springer }); 64449e37000SMatthias Springer // Take a subview of the dst. 64549e37000SMatthias Springer auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 64649e37000SMatthias Springer auto subviewMemRefType = 64749e37000SMatthias Springer memref::SubViewOp::inferRankReducedResultType( 64849e37000SMatthias Springer insertSliceOp.getSourceType().getRank(), dstMemrefType, 64949e37000SMatthias Springer mixedOffsets, mixedSizes, mixedStrides) 65049e37000SMatthias Springer .cast<MemRefType>(); 65149e37000SMatthias Springer Value subView = rewriter.create<memref::SubViewOp>( 65249e37000SMatthias Springer loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 65349e37000SMatthias Springer mixedStrides); 65449e37000SMatthias Springer 65549e37000SMatthias Springer // Copy tensor. If this tensor.insert_slice has a matching 65649e37000SMatthias Springer // tensor.extract_slice, the copy operation will eventually fold away. 65749e37000SMatthias Springer Value srcMemref = 65849e37000SMatthias Springer *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 65949e37000SMatthias Springer if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 66049e37000SMatthias Springer state.getOptions()))) 66149e37000SMatthias Springer return failure(); 66249e37000SMatthias Springer 66349e37000SMatthias Springer replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 66449e37000SMatthias Springer return success(); 66549e37000SMatthias Springer } 66649e37000SMatthias Springer }; 66749e37000SMatthias Springer 668fc08d1c2SMatthias Springer /// Bufferization of tensor.rank. Replace with memref.rank. 669fc08d1c2SMatthias Springer struct RankOpInterface 670fc08d1c2SMatthias Springer : public BufferizableOpInterface::ExternalModel<RankOpInterface, 671fc08d1c2SMatthias Springer tensor::RankOp> { 672fc08d1c2SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 6739597b16aSMatthias Springer const AnalysisState &state) const { 674fc08d1c2SMatthias Springer return true; 675fc08d1c2SMatthias Springer } 676fc08d1c2SMatthias Springer 677fc08d1c2SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 6789597b16aSMatthias Springer const AnalysisState &state) const { 679fc08d1c2SMatthias Springer return false; 680fc08d1c2SMatthias Springer } 681fc08d1c2SMatthias Springer 6829597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 6839597b16aSMatthias Springer const AnalysisState &state) const { 684585a8a32SMatthias Springer return {}; 685fc08d1c2SMatthias Springer } 686fc08d1c2SMatthias Springer 687fc08d1c2SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 6889597b16aSMatthias Springer BufferizationState &state) const { 689fc08d1c2SMatthias Springer auto rankOp = cast<tensor::RankOp>(op); 690fc08d1c2SMatthias Springer Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 691fc08d1c2SMatthias Springer replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 692fc08d1c2SMatthias Springer v); 693fc08d1c2SMatthias Springer return success(); 694fc08d1c2SMatthias Springer } 695fc08d1c2SMatthias Springer }; 696fc08d1c2SMatthias Springer 69749e37000SMatthias Springer } // namespace 69849e37000SMatthias Springer } // namespace tensor 69949e37000SMatthias Springer } // namespace mlir 70049e37000SMatthias Springer 70149e37000SMatthias Springer void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 70249e37000SMatthias Springer DialectRegistry ®istry) { 70349e37000SMatthias Springer registry.addOpInterface<CastOp, CastOpInterface>(); 704e6f69161SMatthias Springer registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>(); 70549e37000SMatthias Springer registry.addOpInterface<DimOp, DimOpInterface>(); 706e6f69161SMatthias Springer registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>(); 70749e37000SMatthias Springer registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 70849e37000SMatthias Springer registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 709d581c94dSMatthias Springer registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 71071bbb78bSMatthias Springer registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 71149e37000SMatthias Springer registry.addOpInterface<InsertOp, InsertOpInterface>(); 71249e37000SMatthias Springer registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 713fc08d1c2SMatthias Springer registry.addOpInterface<RankOp, RankOpInterface>(); 71449e37000SMatthias Springer } 715