1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 namespace { 23 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 24 public: 25 using OpConversionPattern::OpConversionPattern; 26 LogicalResult 27 matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands, 28 ConversionPatternRewriter &rewriter) const override { 29 tensor::ExtractOp::Adaptor adaptor(operands); 30 rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.tensor(), 31 adaptor.indices()); 32 return success(); 33 } 34 }; 35 } // namespace 36 37 void mlir::populateTensorBufferizePatterns( 38 MLIRContext *context, BufferizeTypeConverter &typeConverter, 39 OwningRewritePatternList &patterns) { 40 patterns.insert<BufferizeExtractOp>(typeConverter, context); 41 } 42 43 namespace { 44 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 45 void runOnFunction() override { 46 auto *context = &getContext(); 47 BufferizeTypeConverter typeConverter; 48 OwningRewritePatternList patterns; 49 ConversionTarget target(*context); 50 51 populateTensorBufferizePatterns(context, typeConverter, patterns); 52 target.addIllegalOp<tensor::ExtractOp>(); 53 target.addLegalDialect<StandardOpsDialect>(); 54 55 if (failed( 56 applyPartialConversion(getFunction(), target, std::move(patterns)))) 57 signalPassFailure(); 58 } 59 }; 60 } // namespace 61 62 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 63 return std::make_unique<TensorBufferizePass>(); 64 } 65