1 //===- Bufferize.cpp - Bufferization for Arithmetic ops ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 /// Bufferize arith.index_cast.
21 struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
22   using OpConversionPattern::OpConversionPattern;
23 
24   LogicalResult
25   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
26                   ConversionPatternRewriter &rewriter) const override {
27     auto tensorType = op.getType().cast<RankedTensorType>();
28     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
29         op, adaptor.getIn(),
30         MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
31     return success();
32   }
33 };
34 
35 /// Pass to bufferize Arithmetic ops.
36 struct ArithmeticBufferizePass
37     : public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
38   void runOnFunction() override {
39     bufferization::BufferizeTypeConverter typeConverter;
40     RewritePatternSet patterns(&getContext());
41     ConversionTarget target(getContext());
42 
43     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
44 
45     arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
46 
47     target.addDynamicallyLegalOp<arith::IndexCastOp>(
48         [&](arith::IndexCastOp op) {
49           return typeConverter.isLegal(op.getType());
50         });
51 
52     if (failed(
53             applyPartialConversion(getFunction(), target, std::move(patterns))))
54       signalPassFailure();
55   }
56 };
57 
58 } // end anonymous namespace
59 
60 void mlir::arith::populateArithmeticBufferizePatterns(
61     bufferization::BufferizeTypeConverter &typeConverter,
62     RewritePatternSet &patterns) {
63   patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
64 }
65 
66 std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
67   return std::make_unique<ArithmeticBufferizePass>();
68 }
69