1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 namespace {
21 
22 /// Conversion patterns.
23 template <typename SrcOpTy, typename DstOpTy>
24 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
25 public:
26   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override {
31     typename SrcOpTy::Adaptor adaptor(operands);
32     rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
33                                          adaptor.rhs());
34     return success();
35   }
36 };
37 
38 class FromExtentTensorOpConversion
39     : public OpConversionPattern<FromExtentTensorOp> {
40 public:
41   using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
45                   ConversionPatternRewriter &rewriter) const override {
46     FromExtentTensorOp::Adaptor transformed(operands);
47     rewriter.replaceOp(op.getOperation(), transformed.input());
48     return success();
49   }
50 };
51 
52 class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
53 public:
54   using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
58                   ConversionPatternRewriter &rewriter) const override {
59     IndexToSizeOp::Adaptor transformed(operands);
60     rewriter.replaceOp(op.getOperation(), transformed.arg());
61     return success();
62   }
63 };
64 
65 class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
66 public:
67   using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
68 
69   LogicalResult
70   matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
71                   ConversionPatternRewriter &rewriter) const override {
72     SizeToIndexOp::Adaptor transformed(operands);
73     rewriter.replaceOp(op.getOperation(), transformed.arg());
74     return success();
75   }
76 };
77 
78 class ToExtentTensorOpConversion
79     : public OpConversionPattern<ToExtentTensorOp> {
80 public:
81   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
82 
83   LogicalResult
84   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
85                   ConversionPatternRewriter &rewriter) const override {
86     ToExtentTensorOp::Adaptor transformed(operands);
87     rewriter.replaceOp(op.getOperation(), transformed.input());
88     return success();
89   }
90 };
91 
92 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
93 public:
94   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
95 
96   LogicalResult
97   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
98                   ConversionPatternRewriter &rewriter) const override {
99     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
100                                                  op.value().getSExtValue());
101     return success();
102   }
103 };
104 
105 /// Type conversions.
106 class ShapeTypeConverter : public TypeConverter {
107 public:
108   using TypeConverter::convertType;
109 
110   ShapeTypeConverter(MLIRContext *ctx) {
111     // Add default pass-through conversion.
112     addConversion([&](Type type) { return type; });
113 
114     addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
115     addConversion([ctx](ShapeType type) {
116       return RankedTensorType::get({ShapedType::kDynamicSize},
117                                    IndexType::get(ctx));
118     });
119   }
120 };
121 
122 /// Conversion pass.
123 class ConvertShapeToStandardPass
124     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
125   void runOnOperation() override {
126     // Setup type conversion.
127     MLIRContext &ctx = getContext();
128     ShapeTypeConverter typeConverter(&ctx);
129 
130     // Setup target legality.
131     ConversionTarget target(ctx);
132     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
133     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
134     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
135       return typeConverter.isSignatureLegal(op.getType());
136     });
137 
138     // Setup conversion patterns.
139     OwningRewritePatternList patterns;
140     populateShapeToStandardConversionPatterns(patterns, &ctx);
141     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
142 
143     // Apply conversion.
144     auto module = getOperation();
145     if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
146       signalPassFailure();
147   }
148 };
149 
150 } // namespace
151 
152 void mlir::populateShapeToStandardConversionPatterns(
153     OwningRewritePatternList &patterns, MLIRContext *ctx) {
154   // clang-format off
155   patterns.insert<
156       BinaryOpConversion<AddOp, AddIOp>,
157       BinaryOpConversion<MulOp, MulIOp>,
158       ConstSizeOpConverter,
159       FromExtentTensorOpConversion,
160       IndexToSizeOpConversion,
161       SizeToIndexOpConversion,
162       ToExtentTensorOpConversion>(ctx);
163   // clang-format on
164 }
165 
166 std::unique_ptr<OperationPass<ModuleOp>>
167 mlir::createConvertShapeToStandardPass() {
168   return std::make_unique<ConvertShapeToStandardPass>();
169 }
170