1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 
14 using namespace mlir;
15 
16 namespace {
17 
18 /// Bufferize arith.index_cast.
19 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
20   using OpConversionPattern::OpConversionPattern;
21 
22   LogicalResult
23   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
24                   ConversionPatternRewriter &rewriter) const override {
25     auto tensorType = op.getType().cast<RankedTensorType>();
26     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
27         op, adaptor.getIn(),
28         MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
29     return success();
30   }
31 };
32 
33 /// Pass to bufferize Arithmetic ops.
34 struct ArithmeticBufferizePass
35     : public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
36   void runOnFunction() override {
37     BufferizeTypeConverter typeConverter;
38     RewritePatternSet patterns(&getContext());
39     ConversionTarget target(getContext());
40 
41     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
42 
43     arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
44 
45     target.addDynamicallyLegalOp<arith::IndexCastOp>(
46         [&](arith::IndexCastOp op) {
47           return typeConverter.isLegal(op.getType());
48         });
49 
50     if (failed(
51             applyPartialConversion(getFunction(), target, std::move(patterns))))
52       signalPassFailure();
53   }
54 };
55 
56 } // end anonymous namespace
57 
58 void mlir::arith::populateArithmeticBufferizePatterns(
59     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
60   patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
61 }
62 
63 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
64   return std::make_unique<ArithmeticBufferizePass>();
65 }
66