1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 /// Conversion patterns.
21 namespace {
22 class AnyOpConversion : public OpConversionPattern<AnyOp> {
23 public:
24   using OpConversionPattern<AnyOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override;
29 };
30 } // namespace
31 
32 LogicalResult
33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                                  ConversionPatternRewriter &rewriter) const {
35   AnyOp::Adaptor transformed(operands);
36 
37   // Replace `any` with its first operand.
38   // Any operand would be a valid substitution.
39   rewriter.replaceOp(op, {transformed.inputs().front()});
40   return success();
41 }
42 
43 namespace {
44 template <typename SrcOpTy, typename DstOpTy>
45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
46 public:
47   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     typename SrcOpTy::Adaptor transformed(operands);
53 
54     // For now, only error-free types are supported by this lowering.
55     if (op.getType().template isa<SizeType>())
56       return failure();
57 
58     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
59                                          transformed.rhs());
60     return success();
61   }
62 };
63 } // namespace
64 
65 namespace {
66 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
67 public:
68   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
72                   ConversionPatternRewriter &rewriter) const override {
73     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
74     return success();
75   }
76 };
77 } // namespace
78 
79 namespace {
80 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
81 public:
82   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
83 
84   LogicalResult
85   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
86                   ConversionPatternRewriter &rewriter) const override;
87 };
88 } // namespace
89 
90 LogicalResult ShapeOfOpConversion::matchAndRewrite(
91     ShapeOfOp op, ArrayRef<Value> operands,
92     ConversionPatternRewriter &rewriter) const {
93 
94   // For now, only error-free types are supported by this lowering.
95   if (op.getType().isa<ShapeType>())
96     return failure();
97 
98   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
99   // found in the corresponding pass.
100   ShapeOfOp::Adaptor transformed(operands);
101   Value tensorVal = transformed.arg();
102   Type tensorTy = tensorVal.getType();
103   if (tensorTy.isa<UnrankedTensorType>())
104     return failure();
105 
106   // Build values for individual dimensions.
107   SmallVector<Value, 8> dimValues;
108   RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
109   int64_t rank = rankedTensorTy.getRank();
110   auto loc = op.getLoc();
111   for (int64_t i = 0; i < rank; i++) {
112     if (rankedTensorTy.isDynamicDim(i)) {
113       Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
114       dimValues.push_back(dimVal);
115     } else {
116       int64_t dim = rankedTensorTy.getDimSize(i);
117       Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
118       dimValues.push_back(dimVal);
119     }
120   }
121 
122   // Materialize extent tensor.
123   Value staticExtentTensor =
124       rewriter.create<TensorFromElementsOp>(loc, dimValues);
125   rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
126                                             op.getType());
127   return success();
128 }
129 
130 namespace {
131 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
132 public:
133   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
134 
135   LogicalResult
136   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
137                   ConversionPatternRewriter &rewriter) const override;
138 };
139 } // namespace
140 
141 LogicalResult ConstShapeOpConverter::matchAndRewrite(
142     ConstShapeOp op, ArrayRef<Value> operands,
143     ConversionPatternRewriter &rewriter) const {
144 
145   // For now, this lowering supports only extent tensors, not `shape.shape`
146   // types.
147   if (op.getType().isa<ShapeType>())
148     return failure();
149 
150   auto loc = op.getLoc();
151   SmallVector<Value, 4> extentOperands;
152   for (auto extent : op.shape()) {
153     extentOperands.push_back(
154         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
155   }
156   Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
157   Type indexTy = rewriter.getIndexType();
158   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
159   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
160   return success();
161 }
162 
163 namespace {
164 class ToExtentTensorOpConversion
165     : public OpConversionPattern<ToExtentTensorOp> {
166 public:
167   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
168 
169   LogicalResult
170   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
171                   ConversionPatternRewriter &rewriter) const override {
172     ToExtentTensorOpAdaptor adaptor(operands);
173 
174     if (!adaptor.input().getType().isa<RankedTensorType>())
175       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
176 
177     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
178                                               op.getType());
179     return success();
180   }
181 };
182 } // namespace
183 
184 namespace {
185 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
186   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
187 
188   LogicalResult
189   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
190                   ConversionPatternRewriter &rewriter) const override;
191 };
192 } // namespace
193 
194 LogicalResult GetExtentOpConverter::matchAndRewrite(
195     GetExtentOp op, ArrayRef<Value> operands,
196     ConversionPatternRewriter &rewriter) const {
197   GetExtentOp::Adaptor transformed(operands);
198 
199   // For now, only error-free types are supported by this lowering.
200   if (op.getType().isa<SizeType>())
201     return failure();
202 
203   // Derive shape extent directly from shape origin if possible. This
204   // circumvents the necessity to materialize the shape in memory.
205   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
206     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
207       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
208                                          transformed.dim());
209       return success();
210     }
211   }
212 
213   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
214                                                 transformed.shape(),
215                                                 ValueRange{transformed.dim()});
216   return success();
217 }
218 
219 namespace {
220 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
221 public:
222   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
223 
224   LogicalResult
225   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
226                   ConversionPatternRewriter &rewriter) const override;
227 };
228 } // namespace
229 
230 LogicalResult
231 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
232                                  ConversionPatternRewriter &rewriter) const {
233   // For now, this lowering supports only error-free types.
234   if (op.getType().isa<SizeType>())
235     return failure();
236 
237   shape::RankOp::Adaptor transformed(operands);
238   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
239   return success();
240 }
241 
242 namespace {
243 /// Conversion pass.
244 class ConvertShapeToStandardPass
245     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
246 
247   void runOnOperation() override;
248 };
249 } // namespace
250 
251 void ConvertShapeToStandardPass::runOnOperation() {
252   // Setup target legality.
253   MLIRContext &ctx = getContext();
254   ConversionTarget target(ctx);
255   target.addLegalDialect<StandardOpsDialect>();
256   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
257 
258   // Setup conversion patterns.
259   OwningRewritePatternList patterns;
260   populateShapeToStandardConversionPatterns(patterns, &ctx);
261 
262   // Apply conversion.
263   auto module = getOperation();
264   if (failed(applyPartialConversion(module, target, patterns)))
265     signalPassFailure();
266 }
267 
268 void mlir::populateShapeToStandardConversionPatterns(
269     OwningRewritePatternList &patterns, MLIRContext *ctx) {
270   // clang-format off
271   patterns.insert<
272       AnyOpConversion,
273       BinaryOpConversion<AddOp, AddIOp>,
274       ConstShapeOpConverter,
275       BinaryOpConversion<MulOp, MulIOp>,
276       ConstSizeOpConversion,
277       GetExtentOpConverter,
278       RankOpConverter,
279       ShapeOfOpConversion,
280       ToExtentTensorOpConversion>(ctx);
281   // clang-format on
282 }
283 
284 std::unique_ptr<OperationPass<ModuleOp>>
285 mlir::createConvertShapeToStandardPass() {
286   return std::make_unique<ConvertShapeToStandardPass>();
287 }
288