1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 15 using namespace mlir; 16 17 namespace { 18 19 /// Bufferize arith.index_cast. 20 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> { 21 using OpConversionPattern::OpConversionPattern; 22 23 LogicalResult 24 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 25 ConversionPatternRewriter &rewriter) const override { 26 auto tensorType = op.getType().cast<RankedTensorType>(); 27 rewriter.replaceOpWithNewOp<arith::IndexCastOp>( 28 op, adaptor.getIn(), 29 MemRefType::get(tensorType.getShape(), tensorType.getElementType())); 30 return success(); 31 } 32 }; 33 34 /// Pass to bufferize Arithmetic ops. 35 struct ArithmeticBufferizePass 36 : public ArithmeticBufferizeBase<ArithmeticBufferizePass> { 37 void runOnFunction() override { 38 BufferizeTypeConverter typeConverter; 39 RewritePatternSet patterns(&getContext()); 40 ConversionTarget target(getContext()); 41 42 target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>(); 43 44 arith::populateArithmeticBufferizePatterns(typeConverter, patterns); 45 46 target.addDynamicallyLegalOp<arith::IndexCastOp>( 47 [&](arith::IndexCastOp op) { 48 return typeConverter.isLegal(op.getType()); 49 }); 50 51 if (failed( 52 applyPartialConversion(getFunction(), target, std::move(patterns)))) 53 signalPassFailure(); 54 } 55 }; 56 57 } // end anonymous namespace 58 59 void mlir::arith::populateArithmeticBufferizePatterns( 60 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 61 patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext()); 62 } 63 64 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() { 65 return std::make_unique<ArithmeticBufferizePass>(); 66 } 67