1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 16 using namespace mlir; 17 18 namespace { 19 20 /// Bufferize arith.index_cast. 21 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> { 22 using OpConversionPattern::OpConversionPattern; 23 24 LogicalResult 25 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 26 ConversionPatternRewriter &rewriter) const override { 27 auto tensorType = op.getType().cast<RankedTensorType>(); 28 rewriter.replaceOpWithNewOp<arith::IndexCastOp>( 29 op, adaptor.getIn(), 30 MemRefType::get(tensorType.getShape(), tensorType.getElementType())); 31 return success(); 32 } 33 }; 34 35 /// Pass to bufferize Arithmetic ops. 36 struct ArithmeticBufferizePass 37 : public ArithmeticBufferizeBase<ArithmeticBufferizePass> { 38 void runOnFunction() override { 39 bufferization::BufferizeTypeConverter typeConverter; 40 RewritePatternSet patterns(&getContext()); 41 ConversionTarget target(getContext()); 42 43 target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>(); 44 45 arith::populateArithmeticBufferizePatterns(typeConverter, patterns); 46 47 target.addDynamicallyLegalOp<arith::IndexCastOp>( 48 [&](arith::IndexCastOp op) { 49 return typeConverter.isLegal(op.getType()); 50 }); 51 52 if (failed( 53 applyPartialConversion(getFunction(), target, std::move(patterns)))) 54 signalPassFailure(); 55 } 56 }; 57 58 } // namespace 59 60 void mlir::arith::populateArithmeticBufferizePatterns( 61 bufferization::BufferizeTypeConverter &typeConverter, 62 RewritePatternSet &patterns) { 63 patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext()); 64 } 65 66 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() { 67 return std::make_unique<ArithmeticBufferizePass>(); 68 } 69