1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 /// Conversion patterns.
22 class FromExtentTensorOpConversion
23     : public OpConversionPattern<shape::FromExtentTensorOp> {
24 public:
25   using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
26 
27   LogicalResult
28   matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
29                   ConversionPatternRewriter &rewriter) const override {
30     shape::FromExtentTensorOpOperandAdaptor transformed(operands);
31     rewriter.replaceOp(op.getOperation(), transformed.input());
32     return success();
33   }
34 };
35 
36 class IndexToSizeOpConversion
37     : public OpConversionPattern<shape::IndexToSizeOp> {
38 public:
39   using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
40 
41   LogicalResult
42   matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
43                   ConversionPatternRewriter &rewriter) const override {
44     shape::IndexToSizeOpOperandAdaptor transformed(operands);
45     rewriter.replaceOp(op.getOperation(), transformed.arg());
46     return success();
47   }
48 };
49 
50 class SizeToIndexOpConversion
51     : public OpConversionPattern<shape::SizeToIndexOp> {
52 public:
53   using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
57                   ConversionPatternRewriter &rewriter) const override {
58     shape::SizeToIndexOpOperandAdaptor transformed(operands);
59     rewriter.replaceOp(op.getOperation(), transformed.arg());
60     return success();
61   }
62 };
63 
64 class ToExtentTensorOpConversion
65     : public OpConversionPattern<shape::ToExtentTensorOp> {
66 public:
67   using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
68 
69   LogicalResult
70   matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
71                   ConversionPatternRewriter &rewriter) const override {
72     shape::ToExtentTensorOpOperandAdaptor transformed(operands);
73     rewriter.replaceOp(op.getOperation(), transformed.input());
74     return success();
75   }
76 };
77 
78 /// Type conversions.
79 class ShapeTypeConverter : public TypeConverter {
80 public:
81   using TypeConverter::convertType;
82 
83   ShapeTypeConverter(MLIRContext *ctx) {
84     // Add default pass-through conversion.
85     addConversion([&](Type type) { return type; });
86 
87     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
88     addConversion([ctx](shape::ShapeType type) {
89       return RankedTensorType::get({ShapedType::kDynamicSize},
90                                    IndexType::get(ctx));
91     });
92   }
93 };
94 
95 /// Conversion pass.
96 class ConvertShapeToStandardPass
97     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
98 
99   void runOnOperation() override {
100 
101     // Setup type conversion.
102     MLIRContext &ctx = getContext();
103     ShapeTypeConverter typeConverter(&ctx);
104 
105     // Setup target legality.
106     ConversionTarget target(ctx);
107     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
108     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
109     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
110       return typeConverter.isSignatureLegal(op.getType());
111     });
112 
113     // Setup conversion patterns.
114     OwningRewritePatternList patterns;
115     populateShapeToStandardConversionPatterns(patterns, &ctx);
116     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
117 
118     // Apply conversion.
119     auto module = getOperation();
120     if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
121       signalPassFailure();
122   }
123 };
124 
125 } // namespace
126 
127 void mlir::populateShapeToStandardConversionPatterns(
128     OwningRewritePatternList &patterns, MLIRContext *ctx) {
129   // clang-format off
130   patterns.insert<
131       FromExtentTensorOpConversion,
132       IndexToSizeOpConversion,
133       SizeToIndexOpConversion,
134       ToExtentTensorOpConversion>(ctx);
135   // clang-format on
136 }
137 
138 std::unique_ptr<OperationPass<ModuleOp>>
139 mlir::createConvertShapeToStandardPass() {
140   return std::make_unique<ConvertShapeToStandardPass>();
141 }
142