1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 /// Conversion patterns.
21 namespace {
22 class AnyOpConversion : public OpConversionPattern<AnyOp> {
23 public:
24   using OpConversionPattern<AnyOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override;
29 };
30 } // namespace
31 
32 LogicalResult
33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                                  ConversionPatternRewriter &rewriter) const {
35   AnyOp::Adaptor transformed(operands);
36 
37   // Replace `any` with its first operand.
38   // Any operand would be a valid substitution.
39   rewriter.replaceOp(op, {transformed.inputs().front()});
40   return success();
41 }
42 
43 namespace {
44 template <typename SrcOpTy, typename DstOpTy>
45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
46 public:
47   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     typename SrcOpTy::Adaptor transformed(operands);
53 
54     // For now, only error-free types are supported by this lowering.
55     if (op.getType().template isa<SizeType>())
56       return failure();
57 
58     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
59                                          transformed.rhs());
60     return success();
61   }
62 };
63 } // namespace
64 
65 namespace {
66 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
67 public:
68   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
72                   ConversionPatternRewriter &rewriter) const override {
73 
74     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
75     return success();
76   }
77 };
78 } // namespace
79 
80 namespace {
81 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
82 public:
83   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
84 
85   LogicalResult
86   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
87                   ConversionPatternRewriter &rewriter) const override;
88 };
89 } // namespace
90 
91 LogicalResult ShapeOfOpConversion::matchAndRewrite(
92     ShapeOfOp op, ArrayRef<Value> operands,
93     ConversionPatternRewriter &rewriter) const {
94 
95   // For now, only error-free types are supported by this lowering.
96   if (op.getType().isa<ShapeType>())
97     return failure();
98 
99   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
100   // found in the corresponding pass.
101   ShapeOfOp::Adaptor transformed(operands);
102   Value tensorVal = transformed.arg();
103   Type tensorTy = tensorVal.getType();
104   if (tensorTy.isa<UnrankedTensorType>())
105     return failure();
106 
107   // Build values for individual dimensions.
108   SmallVector<Value, 8> dimValues;
109   RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
110   int64_t rank = rankedTensorTy.getRank();
111   auto loc = op.getLoc();
112   for (int64_t i = 0; i < rank; i++) {
113     if (rankedTensorTy.isDynamicDim(i)) {
114       Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
115       dimValues.push_back(dimVal);
116     } else {
117       int64_t dim = rankedTensorTy.getDimSize(i);
118       Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
119       dimValues.push_back(dimVal);
120     }
121   }
122 
123   // Materialize extent tensor.
124   Value staticExtentTensor =
125       rewriter.create<TensorFromElementsOp>(loc, dimValues);
126   rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
127                                             op.getType());
128   return success();
129 }
130 
131 namespace {
132 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
133 public:
134   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
135 
136   LogicalResult
137   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
138                   ConversionPatternRewriter &rewriter) const override;
139 };
140 } // namespace
141 
142 LogicalResult ConstShapeOpConverter::matchAndRewrite(
143     ConstShapeOp op, ArrayRef<Value> operands,
144     ConversionPatternRewriter &rewriter) const {
145 
146   // For now, this lowering supports only extent tensors, not `shape.shape`
147   // types.
148   if (op.getType().isa<ShapeType>())
149     return failure();
150 
151   auto loc = op.getLoc();
152   SmallVector<Value, 4> extentOperands;
153   for (auto extent : op.shape()) {
154     extentOperands.push_back(
155         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
156   }
157   Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
158   Type indexTy = rewriter.getIndexType();
159   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
160   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
161   return success();
162 }
163 
164 namespace {
165 class ToExtentTensorOpConversion
166     : public OpConversionPattern<ToExtentTensorOp> {
167 public:
168   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
169 
170   LogicalResult
171   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
172                   ConversionPatternRewriter &rewriter) const override {
173     ToExtentTensorOpAdaptor adaptor(operands);
174 
175     if (!adaptor.input().getType().isa<RankedTensorType>())
176       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
177 
178     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
179                                               op.getType());
180     return success();
181   }
182 };
183 } // namespace
184 
185 namespace {
186 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
187   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const override;
192 };
193 } // namespace
194 
195 LogicalResult GetExtentOpConverter::matchAndRewrite(
196     GetExtentOp op, ArrayRef<Value> operands,
197     ConversionPatternRewriter &rewriter) const {
198   GetExtentOp::Adaptor transformed(operands);
199 
200   // For now, only error-free types are supported by this lowering.
201   if (op.getType().isa<SizeType>())
202     return failure();
203 
204   // Derive shape extent directly from shape origin if possible. This
205   // circumvents the necessity to materialize the shape in memory.
206   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
207     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
208       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
209                                          transformed.dim());
210       return success();
211     }
212   }
213 
214   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
215                                                 transformed.shape(),
216                                                 ValueRange{transformed.dim()});
217   return success();
218 }
219 
220 namespace {
221 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
222 public:
223   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
224 
225   LogicalResult
226   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
227                   ConversionPatternRewriter &rewriter) const override;
228 };
229 } // namespace
230 
231 LogicalResult
232 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
233                                  ConversionPatternRewriter &rewriter) const {
234   // For now, this lowering supports only error-free types.
235   if (op.getType().isa<SizeType>())
236     return failure();
237 
238   shape::RankOp::Adaptor transformed(operands);
239   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
240   return success();
241 }
242 
243 namespace {
244 /// Conversion pass.
245 class ConvertShapeToStandardPass
246     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
247 
248   void runOnOperation() override;
249 };
250 } // namespace
251 
252 void ConvertShapeToStandardPass::runOnOperation() {
253   // Setup target legality.
254   MLIRContext &ctx = getContext();
255   ConversionTarget target(ctx);
256   target.addLegalDialect<StandardOpsDialect>();
257   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
258 
259   // Setup conversion patterns.
260   OwningRewritePatternList patterns;
261   populateShapeToStandardConversionPatterns(patterns, &ctx);
262 
263   // Apply conversion.
264   auto module = getOperation();
265   if (failed(applyPartialConversion(module, target, patterns)))
266     signalPassFailure();
267 }
268 
269 void mlir::populateShapeToStandardConversionPatterns(
270     OwningRewritePatternList &patterns, MLIRContext *ctx) {
271   // clang-format off
272   patterns.insert<
273       AnyOpConversion,
274       BinaryOpConversion<AddOp, AddIOp>,
275       ConstShapeOpConverter,
276       BinaryOpConversion<MulOp, MulIOp>,
277       ConstSizeOpConversion,
278       GetExtentOpConverter,
279       RankOpConverter,
280       ShapeOfOpConversion,
281       ToExtentTensorOpConversion>(ctx);
282   // clang-format on
283 }
284 
285 std::unique_ptr<OperationPass<ModuleOp>>
286 mlir::createConvertShapeToStandardPass() {
287   return std::make_unique<ConvertShapeToStandardPass>();
288 }
289