1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 namespace {
21 
22 /// Generated conversion patterns.
23 #include "ShapeToStandardPatterns.inc"
24 
25 /// Conversion patterns.
26 template <typename SrcOpTy, typename DstOpTy>
27 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
28 public:
29   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override {
34     typename SrcOpTy::Adaptor adaptor(operands);
35     rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
36                                          adaptor.rhs());
37     return success();
38   }
39 };
40 
41 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
42 public:
43   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
44 
45   LogicalResult
46   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
47                   ConversionPatternRewriter &rewriter) const override {
48     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
49                                                  op.value().getSExtValue());
50     return success();
51   }
52 };
53 
54 /// Type conversions.
55 class ShapeTypeConverter : public TypeConverter {
56 public:
57   using TypeConverter::convertType;
58 
59   ShapeTypeConverter(MLIRContext *ctx) {
60     // Add default pass-through conversion.
61     addConversion([&](Type type) { return type; });
62 
63     addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
64     addConversion([ctx](ShapeType type) {
65       return RankedTensorType::get({ShapedType::kDynamicSize},
66                                    IndexType::get(ctx));
67     });
68   }
69 };
70 
71 /// Conversion pass.
72 class ConvertShapeToStandardPass
73     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
74 
75   void runOnOperation() override {
76     // Setup type conversion.
77     MLIRContext &ctx = getContext();
78     ShapeTypeConverter typeConverter(&ctx);
79 
80     // Setup target legality.
81     ConversionTarget target(ctx);
82     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
83     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
84     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
85       return typeConverter.isSignatureLegal(op.getType()) &&
86              typeConverter.isLegal(&op.getBody());
87     });
88 
89     // Setup conversion patterns.
90     OwningRewritePatternList patterns;
91     populateShapeToStandardConversionPatterns(patterns, &ctx);
92     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
93 
94     // Apply conversion.
95     auto module = getOperation();
96     if (failed(applyFullConversion(module, target, patterns)))
97       signalPassFailure();
98   }
99 };
100 
101 } // namespace
102 
103 void mlir::populateShapeToStandardConversionPatterns(
104     OwningRewritePatternList &patterns, MLIRContext *ctx) {
105   populateWithGenerated(ctx, &patterns);
106   // clang-format off
107   patterns.insert<
108       BinaryOpConversion<AddOp, AddIOp>,
109       BinaryOpConversion<MulOp, MulIOp>,
110       ConstSizeOpConverter>(ctx);
111   // clang-format on
112 }
113 
114 std::unique_ptr<OperationPass<ModuleOp>>
115 mlir::createConvertShapeToStandardPass() {
116   return std::make_unique<ConvertShapeToStandardPass>();
117 }
118