1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 14 using namespace mlir; 15 16 namespace { 17 18 /// Bufferize arith.index_cast. 19 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> { 20 using OpConversionPattern::OpConversionPattern; 21 22 LogicalResult 23 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 24 ConversionPatternRewriter &rewriter) const override { 25 auto tensorType = op.getType().cast<RankedTensorType>(); 26 rewriter.replaceOpWithNewOp<arith::IndexCastOp>( 27 op, adaptor.getIn(), 28 MemRefType::get(tensorType.getShape(), tensorType.getElementType())); 29 return success(); 30 } 31 }; 32 33 /// Pass to bufferize Arithmetic ops. 34 struct ArithmeticBufferizePass 35 : public ArithmeticBufferizeBase<ArithmeticBufferizePass> { 36 void runOnFunction() override { 37 BufferizeTypeConverter typeConverter; 38 RewritePatternSet patterns(&getContext()); 39 ConversionTarget target(getContext()); 40 41 target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>(); 42 43 arith::populateArithmeticBufferizePatterns(typeConverter, patterns); 44 45 target.addDynamicallyLegalOp<arith::IndexCastOp>( 46 [&](arith::IndexCastOp op) { 47 return typeConverter.isLegal(op.getType()); 48 }); 49 50 if (failed( 51 applyPartialConversion(getFunction(), target, std::move(patterns)))) 52 signalPassFailure(); 53 } 54 }; 55 56 } // end anonymous namespace 57 58 void mlir::arith::populateArithmeticBufferizePatterns( 59 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 60 patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext()); 61 } 62 63 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() { 64 return std::make_unique<ArithmeticBufferizePass>(); 65 } 66