1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 namespace mlir {
18 namespace {
19 
20 /// Conversion patterns.
21 class SizeToIndexOpConversion
22     : public OpConversionPattern<shape::SizeToIndexOp> {
23 public:
24   using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override {
29     shape::SizeToIndexOpOperandAdaptor transformed(operands);
30     rewriter.replaceOp(op.getOperation(), transformed.arg());
31     return success();
32   }
33 };
34 
35 class IndexToSizeOpConversion
36     : public OpConversionPattern<shape::IndexToSizeOp> {
37 public:
38   using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
39 
40   LogicalResult
41   matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
42                   ConversionPatternRewriter &rewriter) const override {
43     shape::IndexToSizeOpOperandAdaptor transformed(operands);
44     rewriter.replaceOp(op.getOperation(), transformed.arg());
45     return success();
46   }
47 };
48 
49 /// Type conversions.
50 class ShapeTypeConverter : public TypeConverter {
51 public:
52   using TypeConverter::convertType;
53 
54   ShapeTypeConverter(MLIRContext *ctx) {
55     // Add default pass-through conversion.
56     addConversion([&](Type type) { return type; });
57     addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
58   }
59 };
60 
61 /// Conversion pass.
62 class ConvertShapeToStandardPass
63     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
64 
65   void runOnOperation() override {
66 
67     // Setup type conversion.
68     MLIRContext &ctx = getContext();
69     ShapeTypeConverter typeConverter(&ctx);
70 
71     // Setup target legality.
72     ConversionTarget target(ctx);
73     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
74     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
75     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
76       return typeConverter.isSignatureLegal(op.getType());
77     });
78 
79     // Setup conversion patterns.
80     OwningRewritePatternList patterns;
81     populateShapeToStandardConversionPatterns(patterns, &ctx);
82     populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
83 
84     // Apply conversion.
85     auto module = getOperation();
86     if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
87       signalPassFailure();
88   }
89 };
90 
91 } // namespace
92 
93 void populateShapeToStandardConversionPatterns(
94     OwningRewritePatternList &patterns, MLIRContext *ctx) {
95   // clang-format off
96   patterns.insert<
97       IndexToSizeOpConversion,
98       SizeToIndexOpConversion>(ctx);
99   // clang-format on
100 }
101 
102 std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass() {
103   return std::make_unique<ConvertShapeToStandardPass>();
104 }
105 
106 } // namespace mlir
107