1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
14 #include "mlir/IR/Operation.h"
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 
19 //===----------------------------------------------------------------------===//
20 // BufferizeTypeConverter
21 //===----------------------------------------------------------------------===//
22 
23 static Value materializeToTensor(OpBuilder &builder, TensorType type,
24                                  ValueRange inputs, Location loc) {
25   assert(inputs.size() == 1);
26   assert(inputs[0].getType().isa<BaseMemRefType>());
27   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
28 }
29 
30 /// Registers conversions into BufferizeTypeConverter
31 BufferizeTypeConverter::BufferizeTypeConverter() {
32   // Keep all types unchanged.
33   addConversion([](Type type) { return type; });
34   // Convert RankedTensorType to MemRefType.
35   addConversion([](RankedTensorType type) -> Type {
36     return MemRefType::get(type.getShape(), type.getElementType());
37   });
38   // Convert UnrankedTensorType to UnrankedMemRefType.
39   addConversion([](UnrankedTensorType type) -> Type {
40     return UnrankedMemRefType::get(type.getElementType(), 0);
41   });
42   addArgumentMaterialization(materializeToTensor);
43   addSourceMaterialization(materializeToTensor);
44   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
45                               ValueRange inputs, Location loc) -> Value {
46     assert(inputs.size() == 1);
47     assert(inputs[0].getType().isa<TensorType>());
48     return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
49   });
50 }
51 
52 void mlir::bufferization::populateBufferizeMaterializationLegality(
53     ConversionTarget &target) {
54   target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
55 }
56 
57 namespace {
58 // In a finalizing bufferize conversion, we know that all tensors have been
59 // converted to memrefs, thus, this op becomes an identity.
60 class BufferizeToTensorOp
61     : public OpConversionPattern<bufferization::ToTensorOp> {
62 public:
63   using OpConversionPattern::OpConversionPattern;
64   LogicalResult
65   matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override {
67     rewriter.replaceOp(op, adaptor.memref());
68     return success();
69   }
70 };
71 } // namespace
72 
73 namespace {
74 // In a finalizing bufferize conversion, we know that all tensors have been
75 // converted to memrefs, thus, this op becomes an identity.
76 class BufferizeToMemrefOp
77     : public OpConversionPattern<bufferization::ToMemrefOp> {
78 public:
79   using OpConversionPattern::OpConversionPattern;
80   LogicalResult
81   matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
82                   ConversionPatternRewriter &rewriter) const override {
83     rewriter.replaceOp(op, adaptor.tensor());
84     return success();
85   }
86 };
87 } // namespace
88 
89 void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
90     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
91   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
92                                                          patterns.getContext());
93 }
94 
95 namespace {
96 struct FinalizingBufferizePass
97     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
98   using FinalizingBufferizeBase<
99       FinalizingBufferizePass>::FinalizingBufferizeBase;
100 
101   void runOnFunction() override {
102     auto func = getFunction();
103     auto *context = &getContext();
104 
105     BufferizeTypeConverter typeConverter;
106     RewritePatternSet patterns(context);
107     ConversionTarget target(*context);
108 
109     populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
110 
111     // If all result types are legal, and all block arguments are legal (ensured
112     // by func conversion above), then all types in the program are legal.
113     //
114     // We also check that the operand types are legal to avoid creating invalid
115     // IR. For example, this prevents
116     // populateEliminateBufferizeMaterializationsPatterns from updating the
117     // types of the operands to a return op without updating the enclosing
118     // function.
119     target.markUnknownOpDynamicallyLegal(
120         [&](Operation *op) { return typeConverter.isLegal(op); });
121 
122     if (failed(applyFullConversion(func, target, std::move(patterns))))
123       signalPassFailure();
124   }
125 };
126 } // namespace
127 
128 std::unique_ptr<FunctionPass>
129 mlir::bufferization::createFinalizingBufferizePass() {
130   return std::make_unique<FinalizingBufferizePass>();
131 }
132