1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::shape;
19 
20 /// Conversion patterns.
21 namespace {
22 class AnyOpConversion : public OpConversionPattern<AnyOp> {
23 public:
24   using OpConversionPattern<AnyOp>::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
28                   ConversionPatternRewriter &rewriter) const override;
29 };
30 } // namespace
31 
32 LogicalResult
33 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                                  ConversionPatternRewriter &rewriter) const {
35   AnyOp::Adaptor transformed(operands);
36 
37   // Replace `any` with its first operand.
38   // Any operand would be a valid substitution.
39   rewriter.replaceOp(op, {transformed.inputs().front()});
40   return success();
41 }
42 
43 namespace {
44 template <typename SrcOpTy, typename DstOpTy>
45 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
46 public:
47   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     typename SrcOpTy::Adaptor adaptor(operands);
53     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
54     return success();
55   }
56 };
57 } // namespace
58 
59 namespace {
60 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
61 public:
62   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
66                   ConversionPatternRewriter &rewriter) const override;
67 };
68 } // namespace
69 
70 LogicalResult ShapeOfOpConversion::matchAndRewrite(
71     ShapeOfOp op, ArrayRef<Value> operands,
72     ConversionPatternRewriter &rewriter) const {
73   ShapeOfOp::Adaptor transformed(operands);
74   auto loc = op.getLoc();
75   auto tensorVal = transformed.arg();
76   auto tensorTy = tensorVal.getType();
77 
78   // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
79   // found in the corresponding pass.
80   if (tensorTy.isa<UnrankedTensorType>())
81     return failure();
82 
83   // Build values for individual dimensions.
84   SmallVector<Value, 8> dimValues;
85   auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
86   int64_t rank = rankedTensorTy.getRank();
87   for (int64_t i = 0; i < rank; i++) {
88     if (rankedTensorTy.isDynamicDim(i)) {
89       auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
90       dimValues.push_back(dimVal);
91     } else {
92       int64_t dim = rankedTensorTy.getDimSize(i);
93       auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
94       dimValues.push_back(dimVal);
95     }
96   }
97 
98   // Materialize extent tensor.
99   Value staticExtentTensor =
100       rewriter.create<TensorFromElementsOp>(loc, dimValues);
101   rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
102                                             op.getType());
103   return success();
104 }
105 
106 namespace {
107 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
108 public:
109   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
110 
111   LogicalResult
112   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
113                   ConversionPatternRewriter &rewriter) const override;
114 };
115 } // namespace
116 
117 LogicalResult ConstShapeOpConverter::matchAndRewrite(
118     ConstShapeOp op, ArrayRef<Value> operands,
119     ConversionPatternRewriter &rewriter) const {
120 
121   // For now, this lowering supports only extent tensors, not `shape.shape`
122   // types.
123   if (op.getType().isa<ShapeType>())
124     return failure();
125 
126   auto loc = op.getLoc();
127   SmallVector<Value, 4> extentOperands;
128   for (auto extent : op.shape()) {
129     extentOperands.push_back(
130         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
131   }
132   Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
133   Type indexTy = rewriter.getIndexType();
134   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
135   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
136   return success();
137 }
138 
139 namespace {
140 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
141   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
142 
143   LogicalResult
144   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
145                   ConversionPatternRewriter &rewriter) const override;
146 };
147 } // namespace
148 
149 LogicalResult GetExtentOpConverter::matchAndRewrite(
150     GetExtentOp op, ArrayRef<Value> operands,
151     ConversionPatternRewriter &rewriter) const {
152   GetExtentOp::Adaptor transformed(operands);
153 
154   // Derive shape extent directly from shape origin if possible.
155   // This circumvents the necessity to materialize the shape in memory.
156   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
157     rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
158     return success();
159   }
160 
161   rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
162                                                 transformed.shape(),
163                                                 ValueRange{transformed.dim()});
164   return success();
165 }
166 
167 namespace {
168 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
169 public:
170   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
171 
172   LogicalResult
173   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
174                   ConversionPatternRewriter &rewriter) const override;
175 };
176 } // namespace
177 
178 LogicalResult
179 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
180                                  ConversionPatternRewriter &rewriter) const {
181   shape::RankOp::Adaptor transformed(operands);
182   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
183   return success();
184 }
185 
186 namespace {
187 /// Type conversions.
188 class ShapeTypeConverter : public TypeConverter {
189 public:
190   using TypeConverter::convertType;
191 
192   ShapeTypeConverter(MLIRContext *ctx) {
193     // Add default pass-through conversion.
194     addConversion([&](Type type) { return type; });
195 
196     addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
197     addConversion([ctx](ShapeType type) {
198       return RankedTensorType::get({ShapedType::kDynamicSize},
199                                    IndexType::get(ctx));
200     });
201   }
202 };
203 } // namespace
204 
205 namespace {
206 /// Conversion pass.
207 class ConvertShapeToStandardPass
208     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
209 
210   void runOnOperation() override;
211 };
212 } // namespace
213 
214 void ConvertShapeToStandardPass::runOnOperation() {
215   // Setup type conversion.
216   MLIRContext &ctx = getContext();
217   ShapeTypeConverter typeConverter(&ctx);
218 
219   // Setup target legality.
220   ConversionTarget target(ctx);
221   target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
222   target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
223   target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
224     return typeConverter.isSignatureLegal(op.getType()) &&
225            typeConverter.isLegal(&op.getBody());
226   });
227 
228   // Setup conversion patterns.
229   OwningRewritePatternList patterns;
230   populateShapeToStandardConversionPatterns(patterns, &ctx);
231   populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
232 
233   // Apply conversion.
234   auto module = getOperation();
235   if (failed(applyFullConversion(module, target, patterns)))
236     signalPassFailure();
237 }
238 
239 void mlir::populateShapeToStandardConversionPatterns(
240     OwningRewritePatternList &patterns, MLIRContext *ctx) {
241   // clang-format off
242   patterns.insert<
243       AnyOpConversion,
244       BinaryOpConversion<AddOp, AddIOp>,
245       ConstShapeOpConverter,
246       BinaryOpConversion<MulOp, MulIOp>,
247       GetExtentOpConverter,
248       RankOpConverter,
249       ShapeOfOpConversion>(ctx);
250   // clang-format on
251 }
252 
253 std::unique_ptr<OperationPass<ModuleOp>>
254 mlir::createConvertShapeToStandardPass() {
255   return std::make_unique<ConvertShapeToStandardPass>();
256 }
257