1 //====----- Bufferize.cpp - Bufferization of shape ops  ---------*- C++-*--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/Dialect/Shape/Transforms/Passes.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::shape;
20 
21 namespace {
22 // Propagate tensor to memref conversions through shape.assuming ops.
23 class TypeConversionAssumingOpConverter
24     : public BufferizeOpConversionPattern<shape::AssumingOp> {
25 public:
26   using BufferizeOpConversionPattern<
27       shape::AssumingOp>::BufferizeOpConversionPattern;
28 
29   LogicalResult
30   matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef<Value> operands,
31                   ConversionPatternRewriter &rewriter) const final {
32     SmallVector<Type, 2> newResultTypes;
33     newResultTypes.reserve(assumingOp.getNumResults());
34     for (auto result : assumingOp.getResults()) {
35       auto originalType = result.getType();
36       Type convertedType = converter.convertType(originalType);
37       newResultTypes.push_back(convertedType);
38     }
39 
40     auto newAssumingOp = rewriter.create<shape::AssumingOp>(
41         assumingOp.getLoc(), newResultTypes, assumingOp.witness());
42 
43     rewriter.replaceOp(assumingOp, newAssumingOp.getResults());
44     rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(),
45                                 newAssumingOp.doRegion().end());
46 
47     return success();
48   }
49 };
50 
51 struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
52   void runOnFunction() override {
53     MLIRContext &ctx = getContext();
54 
55     OwningRewritePatternList patterns;
56     BufferizeTypeConverter converter;
57     populateShapeTypeConversionPatterns(&ctx, converter, patterns);
58 
59     ConversionTarget target(getContext());
60     auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
61 
62     target.addDynamicallyLegalOp<AssumingOp>([&](shape::AssumingOp op) {
63       return std::all_of(op.result_type_begin(), op.result_type_end(),
64                          isMemRefType);
65     });
66 
67     if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
68       signalPassFailure();
69   }
70 };
71 
72 } // namespace
73 
74 /// Populates `patterns` with the conversion patterns of tensor->memref.
75 //
76 // TODO: Change this to work generally with any type conversions.
77 void mlir::populateShapeTypeConversionPatterns(
78     MLIRContext *context, BufferizeTypeConverter &converter,
79     OwningRewritePatternList &patterns) {
80   patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // ShapeBufferizePass construction
85 //===----------------------------------------------------------------------===//
86 
87 std::unique_ptr<FunctionPass> mlir::createShapeBufferizePass() {
88   return std::make_unique<ShapeBufferizePass>();
89 }
90