1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 namespace {
21 
22 /// Conversion patterns.
23 template <typename SrcOpTy, typename DstOpTy>
24 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
25 public:
26   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override {
31     typename SrcOpTy::OperandAdaptor adaptor(operands);
32     rewriter.replaceOpWithNewOp<DstOpTy>(op.getOperation(), adaptor.lhs(),
33                                          adaptor.rhs());
34     return success();
35   }
36 };
37 
38 class FromExtentTensorOpConversion
39     : public OpConversionPattern<shape::FromExtentTensorOp> {
40 public:
41   using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
45                   ConversionPatternRewriter &rewriter) const override {
46     shape::FromExtentTensorOpOperandAdaptor transformed(operands);
47     rewriter.replaceOp(op.getOperation(), transformed.input());
48     return success();
49   }
50 };
51 
52 class IndexToSizeOpConversion
53     : public OpConversionPattern<shape::IndexToSizeOp> {
54 public:
55   using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
56 
57   LogicalResult
58   matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
59                   ConversionPatternRewriter &rewriter) const override {
60     shape::IndexToSizeOpOperandAdaptor transformed(operands);
61     rewriter.replaceOp(op.getOperation(), transformed.arg());
62     return success();
63   }
64 };
65 
66 class SizeToIndexOpConversion
67     : public OpConversionPattern<shape::SizeToIndexOp> {
68 public:
69   using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
73                   ConversionPatternRewriter &rewriter) const override {
74     shape::SizeToIndexOpOperandAdaptor transformed(operands);
75     rewriter.replaceOp(op.getOperation(), transformed.arg());
76     return success();
77   }
78 };
79 
80 class ToExtentTensorOpConversion
81     : public OpConversionPattern<shape::ToExtentTensorOp> {
82 public:
83   using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
87                   ConversionPatternRewriter &rewriter) const override {
88     shape::ToExtentTensorOpOperandAdaptor transformed(operands);
89     rewriter.replaceOp(op.getOperation(), transformed.input());
90     return success();
91   }
92 };
93 
94 /// Type conversions.
95 class ShapeTypeConverter : public TypeConverter {
96 public:
97   using TypeConverter::convertType;
98 
99   ShapeTypeConverter(MLIRContext *ctx) {
100     // Add default pass-through conversion.
101     addConversion([&](Type type) { return type; });
102 
103     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
104     addConversion([ctx](shape::ShapeType type) {
105       return RankedTensorType::get({ShapedType::kDynamicSize},
106                                    IndexType::get(ctx));
107     });
108   }
109 };
110 
111 /// Conversion pass.
112 class ConvertShapeToStandardPass
113     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
114 
115   void runOnOperation() override {
116 
117     // Setup type conversion.
118     MLIRContext &ctx = getContext();
119     ShapeTypeConverter typeConverter(&ctx);
120 
121     // Setup target legality.
122     ConversionTarget target(ctx);
123     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
124     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
125     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
126       return typeConverter.isSignatureLegal(op.getType());
127     });
128 
129     // Setup conversion patterns.
130     OwningRewritePatternList patterns;
131     populateShapeToStandardConversionPatterns(patterns, &ctx);
132     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
133 
134     // Apply conversion.
135     auto module = getOperation();
136     if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
137       signalPassFailure();
138   }
139 };
140 
141 } // namespace
142 
143 void mlir::populateShapeToStandardConversionPatterns(
144     OwningRewritePatternList &patterns, MLIRContext *ctx) {
145   // clang-format off
146   patterns.insert<
147       BinaryOpConversion<AddOp, AddIOp>,
148       BinaryOpConversion<MulOp, MulIOp>,
149       FromExtentTensorOpConversion,
150       IndexToSizeOpConversion,
151       SizeToIndexOpConversion,
152       ToExtentTensorOpConversion>(ctx);
153   // clang-format on
154 }
155 
156 std::unique_ptr<OperationPass<ModuleOp>>
157 mlir::createConvertShapeToStandardPass() {
158   return std::make_unique<ConvertShapeToStandardPass>();
159 }
160