1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 /// Conversion patterns.
21 namespace {
22 class AnyOpConversion : public OpConversionPattern<AnyOp> {
23 public:
24   using OpConversionPattern<AnyOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override;
29 };
30 } // namespace
31 
32 LogicalResult
33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                                  ConversionPatternRewriter &rewriter) const {
35   AnyOp::Adaptor transformed(operands);
36 
37   // Replace `any` with its first operand.
38   // Any operand would be a valid substitution.
39   rewriter.replaceOp(op, {transformed.inputs().front()});
40   return success();
41 }
42 
43 namespace {
44 template <typename SrcOpTy, typename DstOpTy>
45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
46 public:
47   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     typename SrcOpTy::Adaptor adaptor(operands);
53     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
54     return success();
55   }
56 };
57 } // namespace
58 
59 namespace {
60 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
61 public:
62   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
66                   ConversionPatternRewriter &rewriter) const override;
67 };
68 } // namespace
69 
70 LogicalResult ShapeOfOpConversion::matchAndRewrite(
71     ShapeOfOp op, ArrayRef<Value> operands,
72     ConversionPatternRewriter &rewriter) const {
73   ShapeOfOp::Adaptor transformed(operands);
74   auto loc = op.getLoc();
75   auto tensorVal = transformed.arg();
76   auto tensorTy = tensorVal.getType();
77 
78   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
79   // found in the corresponding pass.
80   if (tensorTy.isa<UnrankedTensorType>())
81     return failure();
82 
83   // Build values for individual dimensions.
84   SmallVector<Value, 8> dimValues;
85   auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
86   int64_t rank = rankedTensorTy.getRank();
87   for (int64_t i = 0; i < rank; i++) {
88     if (rankedTensorTy.isDynamicDim(i)) {
89       auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
90       dimValues.push_back(dimVal);
91     } else {
92       int64_t dim = rankedTensorTy.getDimSize(i);
93       auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
94       dimValues.push_back(dimVal);
95     }
96   }
97 
98   // Materialize extent tensor.
99   Value staticExtentTensor =
100       rewriter.create<TensorFromElementsOp>(loc, dimValues);
101   rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
102                                             op.getType());
103   return success();
104 }
105 
106 namespace {
107 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
108   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
109 
110   LogicalResult
111   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
112                   ConversionPatternRewriter &rewriter) const override;
113 };
114 } // namespace
115 
116 LogicalResult GetExtentOpConverter::matchAndRewrite(
117     GetExtentOp op, ArrayRef<Value> operands,
118     ConversionPatternRewriter &rewriter) const {
119   GetExtentOp::Adaptor transformed(operands);
120 
121   // Derive shape extent directly from shape origin if possible.
122   // This circumvents the necessity to materialize the shape in memory.
123   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
124     rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
125     return success();
126   }
127 
128   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
129                                                 transformed.shape(),
130                                                 ValueRange{transformed.dim()});
131   return success();
132 }
133 
134 namespace {
135 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
136 public:
137   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
138 
139   LogicalResult
140   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
141                   ConversionPatternRewriter &rewriter) const override;
142 };
143 } // namespace
144 
145 LogicalResult
146 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
147                                  ConversionPatternRewriter &rewriter) const {
148   shape::RankOp::Adaptor transformed(operands);
149   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
150   return success();
151 }
152 
153 namespace {
154 /// Type conversions.
155 class ShapeTypeConverter : public TypeConverter {
156 public:
157   using TypeConverter::convertType;
158 
159   ShapeTypeConverter(MLIRContext *ctx) {
160     // Add default pass-through conversion.
161     addConversion([&](Type type) { return type; });
162 
163     addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
164     addConversion([ctx](ShapeType type) {
165       return RankedTensorType::get({ShapedType::kDynamicSize},
166                                    IndexType::get(ctx));
167     });
168   }
169 };
170 } // namespace
171 
172 namespace {
173 /// Conversion pass.
174 class ConvertShapeToStandardPass
175     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
176 
177   void runOnOperation() override;
178 };
179 } // namespace
180 
181 void ConvertShapeToStandardPass::runOnOperation() {
182   // Setup type conversion.
183   MLIRContext &ctx = getContext();
184   ShapeTypeConverter typeConverter(&ctx);
185 
186   // Setup target legality.
187   ConversionTarget target(ctx);
188   target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
189   target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
190   target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
191     return typeConverter.isSignatureLegal(op.getType()) &&
192            typeConverter.isLegal(&op.getBody());
193   });
194 
195   // Setup conversion patterns.
196   OwningRewritePatternList patterns;
197   populateShapeToStandardConversionPatterns(patterns, &ctx);
198   populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
199 
200   // Apply conversion.
201   auto module = getOperation();
202   if (failed(applyFullConversion(module, target, patterns)))
203     signalPassFailure();
204 }
205 
206 void mlir::populateShapeToStandardConversionPatterns(
207     OwningRewritePatternList &patterns, MLIRContext *ctx) {
208   // clang-format off
209   patterns.insert<
210       AnyOpConversion,
211       BinaryOpConversion<AddOp, AddIOp>,
212       BinaryOpConversion<MulOp, MulIOp>,
213       GetExtentOpConverter,
214       RankOpConverter,
215       ShapeOfOpConversion>(ctx);
216   // clang-format on
217 }
218 
219 std::unique_ptr<OperationPass<ModuleOp>>
220 mlir::createConvertShapeToStandardPass() {
221   return std::make_unique<ConvertShapeToStandardPass>();
222 }
223