1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 
19 /// Bufferize arith.index_cast.
20 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
21   using OpConversionPattern::OpConversionPattern;
22 
23   LogicalResult
24   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
25                   ConversionPatternRewriter &rewriter) const override {
26     auto tensorType = op.getType().cast<RankedTensorType>();
27     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
28         op, adaptor.getIn(),
29         MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
30     return success();
31   }
32 };
33 
34 /// Pass to bufferize Arithmetic ops.
35 struct ArithmeticBufferizePass
36     : public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
37   void runOnFunction() override {
38     BufferizeTypeConverter typeConverter;
39     RewritePatternSet patterns(&getContext());
40     ConversionTarget target(getContext());
41 
42     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
43 
44     arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
45 
46     target.addDynamicallyLegalOp<arith::IndexCastOp>(
47         [&](arith::IndexCastOp op) {
48           return typeConverter.isLegal(op.getType());
49         });
50 
51     if (failed(
52             applyPartialConversion(getFunction(), target, std::move(patterns))))
53       signalPassFailure();
54   }
55 };
56 
57 } // end anonymous namespace
58 
59 void mlir::arith::populateArithmeticBufferizePatterns(
60     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
61   patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
62 }
63 
64 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
65   return std::make_unique<ArithmeticBufferizePass>();
66 }
67