1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 /// Conversion patterns.
21 namespace {
22 class AnyOpConversion : public OpConversionPattern<AnyOp> {
23 public:
24   using OpConversionPattern<AnyOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override;
29 };
30 } // namespace
31 
32 LogicalResult
33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                                  ConversionPatternRewriter &rewriter) const {
35   AnyOp::Adaptor transformed(operands);
36 
37   // Replace `any` with its first operand.
38   // Any operand would be a valid substitution.
39   rewriter.replaceOp(op, {transformed.inputs().front()});
40   return success();
41 }
42 
43 namespace {
44 template <typename SrcOpTy, typename DstOpTy>
45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
46 public:
47   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     typename SrcOpTy::Adaptor adaptor(operands);
53     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
54     return success();
55   }
56 };
57 } // namespace
58 
59 namespace {
60 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
61 public:
62   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
66                   ConversionPatternRewriter &rewriter) const override {
67 
68     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
69     return success();
70   }
71 };
72 } // namespace
73 
74 namespace {
75 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
76 public:
77   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &rewriter) const override;
82 };
83 } // namespace
84 
85 LogicalResult ShapeOfOpConversion::matchAndRewrite(
86     ShapeOfOp op, ArrayRef<Value> operands,
87     ConversionPatternRewriter &rewriter) const {
88   ShapeOfOp::Adaptor transformed(operands);
89   auto loc = op.getLoc();
90   auto tensorVal = transformed.arg();
91   auto tensorTy = tensorVal.getType();
92 
93   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
94   // found in the corresponding pass.
95   if (tensorTy.isa<UnrankedTensorType>())
96     return failure();
97 
98   // Build values for individual dimensions.
99   SmallVector<Value, 8> dimValues;
100   auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
101   int64_t rank = rankedTensorTy.getRank();
102   for (int64_t i = 0; i < rank; i++) {
103     if (rankedTensorTy.isDynamicDim(i)) {
104       auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
105       dimValues.push_back(dimVal);
106     } else {
107       int64_t dim = rankedTensorTy.getDimSize(i);
108       auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
109       dimValues.push_back(dimVal);
110     }
111   }
112 
113   // Materialize extent tensor.
114   Value staticExtentTensor =
115       rewriter.create<TensorFromElementsOp>(loc, dimValues);
116   rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
117                                             op.getType());
118   return success();
119 }
120 
121 namespace {
122 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
123 public:
124   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
128                   ConversionPatternRewriter &rewriter) const override;
129 };
130 } // namespace
131 
132 LogicalResult ConstShapeOpConverter::matchAndRewrite(
133     ConstShapeOp op, ArrayRef<Value> operands,
134     ConversionPatternRewriter &rewriter) const {
135 
136   // For now, this lowering supports only extent tensors, not `shape.shape`
137   // types.
138   if (op.getType().isa<ShapeType>())
139     return failure();
140 
141   auto loc = op.getLoc();
142   SmallVector<Value, 4> extentOperands;
143   for (auto extent : op.shape()) {
144     extentOperands.push_back(
145         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
146   }
147   Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
148   Type indexTy = rewriter.getIndexType();
149   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
150   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
151   return success();
152 }
153 
154 namespace {
155 class ToExtentTensorOpConversion
156     : public OpConversionPattern<ToExtentTensorOp> {
157 public:
158   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
159 
160   LogicalResult
161   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
162                   ConversionPatternRewriter &rewriter) const override {
163     ToExtentTensorOpAdaptor adaptor(operands);
164 
165     if (!adaptor.input().getType().isa<RankedTensorType>())
166       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
167 
168     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
169                                               op.getType());
170     return success();
171   }
172 };
173 } // namespace
174 
175 namespace {
176 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
177   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
178 
179   LogicalResult
180   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
181                   ConversionPatternRewriter &rewriter) const override;
182 };
183 } // namespace
184 
185 LogicalResult GetExtentOpConverter::matchAndRewrite(
186     GetExtentOp op, ArrayRef<Value> operands,
187     ConversionPatternRewriter &rewriter) const {
188   GetExtentOp::Adaptor transformed(operands);
189 
190   // Derive shape extent directly from shape origin if possible.
191   // This circumvents the necessity to materialize the shape in memory.
192   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
193     rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
194     return success();
195   }
196 
197   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
198                                                 transformed.shape(),
199                                                 ValueRange{transformed.dim()});
200   return success();
201 }
202 
203 namespace {
204 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
205 public:
206   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
207 
208   LogicalResult
209   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
210                   ConversionPatternRewriter &rewriter) const override;
211 };
212 } // namespace
213 
214 LogicalResult
215 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
216                                  ConversionPatternRewriter &rewriter) const {
217   shape::RankOp::Adaptor transformed(operands);
218   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
219   return success();
220 }
221 
222 namespace {
223 /// Conversion pass.
224 class ConvertShapeToStandardPass
225     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
226 
227   void runOnOperation() override;
228 };
229 } // namespace
230 
231 void ConvertShapeToStandardPass::runOnOperation() {
232   // Setup target legality.
233   MLIRContext &ctx = getContext();
234   ConversionTarget target(ctx);
235   target.addLegalDialect<StandardOpsDialect>();
236   target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
237 
238   // Setup conversion patterns.
239   OwningRewritePatternList patterns;
240   populateShapeToStandardConversionPatterns(patterns, &ctx);
241 
242   // Apply conversion.
243   auto module = getOperation();
244   if (failed(applyFullConversion(module, target, patterns)))
245     signalPassFailure();
246 }
247 
248 void mlir::populateShapeToStandardConversionPatterns(
249     OwningRewritePatternList &patterns, MLIRContext *ctx) {
250   // clang-format off
251   patterns.insert<
252       AnyOpConversion,
253       BinaryOpConversion<AddOp, AddIOp>,
254       ConstShapeOpConverter,
255       BinaryOpConversion<MulOp, MulIOp>,
256       ConstSizeOpConversion,
257       GetExtentOpConverter,
258       RankOpConverter,
259       ShapeOfOpConversion,
260       ToExtentTensorOpConversion>(ctx);
261   // clang-format on
262 }
263 
264 std::unique_ptr<OperationPass<ModuleOp>>
265 mlir::createConvertShapeToStandardPass() {
266   return std::make_unique<ConvertShapeToStandardPass>();
267 }
268