1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 namespace {
21 
22 /// Generated conversion patterns.
23 #include "ShapeToStandardPatterns.inc"
24 
25 /// Conversion patterns.
26 template <typename SrcOpTy, typename DstOpTy>
27 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
28 public:
29   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override {
34     typename SrcOpTy::Adaptor adaptor(operands);
35     rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
36                                          adaptor.rhs());
37     return success();
38   }
39 };
40 
41 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
42 public:
43   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
44 
45   LogicalResult
46   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
47                   ConversionPatternRewriter &rewriter) const override {
48     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
49                                                  op.value().getSExtValue());
50     return success();
51   }
52 };
53 
54 /// Type conversions.
55 class ShapeTypeConverter : public TypeConverter {
56 public:
57   using TypeConverter::convertType;
58 
59   ShapeTypeConverter(MLIRContext *ctx) {
60     // Add default pass-through conversion.
61     addConversion([&](Type type) { return type; });
62 
63     addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
64     addConversion([ctx](ShapeType type) {
65       return RankedTensorType::get({ShapedType::kDynamicSize},
66                                    IndexType::get(ctx));
67     });
68   }
69 };
70 
71 /// Conversion pass.
72 class ConvertShapeToStandardPass
73     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
74 
75   void runOnOperation() override {
76     // Setup type conversion.
77     MLIRContext &ctx = getContext();
78     ShapeTypeConverter typeConverter(&ctx);
79 
80     // Setup target legality.
81     ConversionTarget target(ctx);
82     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
83     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
84     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
85       return typeConverter.isSignatureLegal(op.getType());
86     });
87 
88     // Setup conversion patterns.
89     OwningRewritePatternList patterns;
90     populateShapeToStandardConversionPatterns(patterns, &ctx);
91     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
92 
93     // Apply conversion.
94     auto module = getOperation();
95     if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
96       signalPassFailure();
97   }
98 };
99 
100 } // namespace
101 
102 void mlir::populateShapeToStandardConversionPatterns(
103     OwningRewritePatternList &patterns, MLIRContext *ctx) {
104   populateWithGenerated(ctx, &patterns);
105   // clang-format off
106   patterns.insert<
107       BinaryOpConversion<AddOp, AddIOp>,
108       BinaryOpConversion<MulOp, MulIOp>,
109       ConstSizeOpConverter>(ctx);
110   // clang-format on
111 }
112 
113 std::unique_ptr<OperationPass<ModuleOp>>
114 mlir::createConvertShapeToStandardPass() {
115   return std::make_unique<ConvertShapeToStandardPass>();
116 }
117