1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 
19 using namespace mlir;
20 using namespace mlir::shape;
21 using namespace mlir::scf;
22 
23 /// Conversion patterns.
24 namespace {
25 class AnyOpConversion : public OpConversionPattern<AnyOp> {
26 public:
27   using OpConversionPattern<AnyOp>::OpConversionPattern;
28 
29   LogicalResult
30   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
31                   ConversionPatternRewriter &rewriter) const override;
32 };
33 } // namespace
34 
35 LogicalResult
36 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
37                                  ConversionPatternRewriter &rewriter) const {
38   AnyOp::Adaptor transformed(operands);
39 
40   // Replace `any` with its first operand.
41   // Any operand would be a valid substitution.
42   rewriter.replaceOp(op, {transformed.inputs().front()});
43   return success();
44 }
45 
46 namespace {
47 template <typename SrcOpTy, typename DstOpTy>
48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
49 public:
50   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
51 
52   LogicalResult
53   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
54                   ConversionPatternRewriter &rewriter) const override {
55     typename SrcOpTy::Adaptor transformed(operands);
56 
57     // For now, only error-free types are supported by this lowering.
58     if (op.getType().template isa<SizeType>())
59       return failure();
60 
61     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
62                                          transformed.rhs());
63     return success();
64   }
65 };
66 } // namespace
67 
68 namespace {
69 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
70   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
71 
72   LogicalResult
73   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
74                   ConversionPatternRewriter &rewriter) const override;
75 };
76 } // namespace
77 
78 LogicalResult BroadcastOpConverter::matchAndRewrite(
79     BroadcastOp op, ArrayRef<Value> operands,
80     ConversionPatternRewriter &rewriter) const {
81   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
82   // on shapes.
83   if (op.getType().isa<ShapeType>())
84     return failure();
85 
86   assert(!op.lhs().getType().isa<ShapeType>() &&
87          !op.rhs().getType().isa<ShapeType>());
88   auto loc = op.getLoc();
89   BroadcastOp::Adaptor transformed(operands);
90   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
91   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
92 
93   // Find smaller and greater rank and extent tensor.
94   Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
95   Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
96   Value lhsRankULE =
97       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
98   Type indexTy = rewriter.getIndexType();
99   Value lesserRank =
100       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
101   Value greaterRank =
102       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
103   auto erasedRankType =
104       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
105   Value rankErasedLhs =
106       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
107   Value rankErasedRhs =
108       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
109   Value lesserRankOperand =
110       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
111   Value greaterRankOperand =
112       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
113 
114   Value rankDiff =
115       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
116   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
117       op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
118       [&](OpBuilder &b, Location loc, ValueRange args) {
119         Value outputDimension = args[0];
120         Value isUnchallengedDimension = b.create<CmpIOp>(
121             loc, CmpIPredicate::ult, outputDimension, rankDiff);
122         Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
123             loc, greaterRankOperand, outputDimension);
124         // The initial dimensions of the greater-rank operand are unchallenged,
125         // so we can take them as-is. Otherwise, we need to do a comparison.
126         // We need an actual branch here (instead of a select) because the
127         // lesser-rank operand might be rank 0, so any tensor.extract would be
128         // invalid.
129         auto ifOp = b.create<IfOp>(
130             loc, TypeRange{indexTy}, isUnchallengedDimension,
131             [&](OpBuilder &b, Location loc) {
132               b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
133             },
134             [&](OpBuilder &b, Location loc) {
135               // The broadcasting logic is:
136               // - if one extent (here we arbitrarily choose the extent from
137               // the greater-rank operand) is equal to 1, then take the extent
138               // from the other operand
139               // - otherwise, take the extent as-is.
140               // Note that this logic remains correct in the presence of
141               // dimensions of zero extent.
142               Value lesserRankOperandDimension =
143                   b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
144               Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
145                   loc, lesserRankOperand,
146                   ValueRange{lesserRankOperandDimension});
147               Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
148                   loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
149               Value broadcastedExtent = b.create<SelectOp>(
150                   loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
151                   greaterRankOperandExtent);
152               b.create<scf::YieldOp>(loc, broadcastedExtent);
153             });
154         b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
155       });
156   return success();
157 }
158 
159 namespace {
160 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
161 public:
162   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
163 
164   LogicalResult
165   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
166                   ConversionPatternRewriter &rewriter) const override;
167 };
168 } // namespace
169 
170 LogicalResult ConstShapeOpConverter::matchAndRewrite(
171     ConstShapeOp op, ArrayRef<Value> operands,
172     ConversionPatternRewriter &rewriter) const {
173 
174   // For now, this lowering supports only extent tensors, not `shape.shape`
175   // types.
176   if (op.getType().isa<ShapeType>())
177     return failure();
178 
179   auto loc = op.getLoc();
180   SmallVector<Value, 4> extentOperands;
181   for (auto extent : op.shape()) {
182     extentOperands.push_back(
183         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
184   }
185   Type indexTy = rewriter.getIndexType();
186   Value tensor =
187       rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
188   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
189   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
190   return success();
191 }
192 
193 namespace {
194 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
195 public:
196   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
197 
198   LogicalResult
199   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 } // namespace
203 
204 LogicalResult ConstSizeOpConversion::matchAndRewrite(
205     ConstSizeOp op, ArrayRef<Value> operands,
206     ConversionPatternRewriter &rewriter) const {
207   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
208   return success();
209 }
210 
211 namespace {
212 struct IsBroadcastableOpConverter
213     : public OpConversionPattern<IsBroadcastableOp> {
214   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
215 
216   LogicalResult
217   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
218                   ConversionPatternRewriter &rewriter) const override;
219 };
220 } // namespace
221 
222 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
223     IsBroadcastableOp op, ArrayRef<Value> operands,
224     ConversionPatternRewriter &rewriter) const {
225   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
226   // on shapes.
227   IsBroadcastableOp::Adaptor transformed(operands);
228   if (transformed.lhs().getType().isa<ShapeType>() ||
229       transformed.rhs().getType().isa<ShapeType>())
230     return failure();
231 
232   auto loc = op.getLoc();
233   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
234   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
235 
236   // Find smaller and greater rank and extent tensor.
237   Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
238   Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
239   Value lhsRankULE =
240       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
241   Type indexTy = rewriter.getIndexType();
242   Value lesserRank =
243       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
244   Value greaterRank =
245       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
246   auto erasedRankType =
247       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
248   Value rankErasedLhs =
249       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
250   Value rankErasedRhs =
251       rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
252   Value lesserRankOperand =
253       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
254   Value greaterRankOperand =
255       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
256   Value rankDiff =
257       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
258   Type i1Ty = rewriter.getI1Type();
259   Value init =
260       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
261 
262   // Determine if all overlapping extents are broadcastable.
263   auto reduceResult = rewriter.create<ForOp>(
264       loc, rankDiff, greaterRank, one, ValueRange{init},
265       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
266         Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
267             loc, greaterRankOperand, ValueRange{iv});
268         Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
269             loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
270         Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
271         Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
272             loc, lesserRankOperand, ValueRange{ivShifted});
273         Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
274             loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
275         Value extentsAreEqual =
276             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
277                              lesserRankOperandExtent);
278         Value broadcastableExtents = b.create<AndOp>(
279             loc, iterArgs[0],
280             b.create<OrOp>(loc,
281                            b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
282                                           lesserRankOperandExtentIsOne),
283                            extentsAreEqual));
284         b.create<scf::YieldOp>(loc, broadcastableExtents);
285       });
286 
287   rewriter.replaceOp(op, reduceResult.results().front());
288   return success();
289 }
290 
291 namespace {
292 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
293   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
294 
295   LogicalResult
296   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
297                   ConversionPatternRewriter &rewriter) const override;
298 };
299 } // namespace
300 
301 LogicalResult GetExtentOpConverter::matchAndRewrite(
302     GetExtentOp op, ArrayRef<Value> operands,
303     ConversionPatternRewriter &rewriter) const {
304   GetExtentOp::Adaptor transformed(operands);
305 
306   // For now, only error-free types are supported by this lowering.
307   if (op.getType().isa<SizeType>())
308     return failure();
309 
310   // Derive shape extent directly from shape origin if possible. This
311   // circumvents the necessity to materialize the shape in memory.
312   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
313     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
314       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
315                                          transformed.dim());
316       return success();
317     }
318   }
319 
320   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
321                                                  transformed.shape(),
322                                                  ValueRange{transformed.dim()});
323   return success();
324 }
325 
326 namespace {
327 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
328 public:
329   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
330 
331   LogicalResult
332   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
333                   ConversionPatternRewriter &rewriter) const override;
334 };
335 } // namespace
336 
337 LogicalResult
338 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
339                                  ConversionPatternRewriter &rewriter) const {
340   // For now, this lowering supports only error-free types.
341   if (op.getType().isa<SizeType>())
342     return failure();
343 
344   shape::RankOp::Adaptor transformed(operands);
345   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
346   return success();
347 }
348 
349 namespace {
350 /// Converts `shape.reduce` to `scf.for`.
351 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
352 public:
353   using OpConversionPattern::OpConversionPattern;
354 
355   LogicalResult
356   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
357                   ConversionPatternRewriter &rewriter) const final;
358 };
359 } // namespace
360 
361 LogicalResult
362 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
363                                    ConversionPatternRewriter &rewriter) const {
364   // For now, this lowering is only defined on `tensor<?xindex>` operands.
365   if (op.shape().getType().isa<ShapeType>())
366     return failure();
367 
368   auto loc = op.getLoc();
369   shape::ReduceOp::Adaptor transformed(operands);
370 
371   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
372   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
373   Type indexTy = rewriter.getIndexType();
374   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
375 
376   auto loop = rewriter.create<scf::ForOp>(
377       loc, zero, rank, one, op.initVals(),
378       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
379         Value extent =
380             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
381 
382         SmallVector<Value, 2> mappedValues{iv, extent};
383         mappedValues.append(args.begin(), args.end());
384 
385         BlockAndValueMapping mapping;
386         Block *reduceBody = op.getBody();
387         mapping.map(reduceBody->getArguments(), mappedValues);
388         for (auto &nested : reduceBody->without_terminator())
389           b.clone(nested, mapping);
390 
391         SmallVector<Value, 2> mappedResults;
392         for (auto result : reduceBody->getTerminator()->getOperands())
393           mappedResults.push_back(mapping.lookup(result));
394         b.create<scf::YieldOp>(loc, mappedResults);
395       });
396 
397   rewriter.replaceOp(op, loop.getResults());
398   return success();
399 }
400 
401 namespace {
402 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
403 /// only defined on `tensor<?xindex>` operands. The test for equality first
404 /// compares their size and, if equal, checks every extent for equality.
405 ///
406 /// Example:
407 ///
408 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
409 ///
410 /// becomes
411 ///
412 /// %c0 = constant 0 : index
413 /// %0 = dim %arg0, %c0 : tensor<?xindex>
414 /// %1 = dim %arg1, %c0 : tensor<?xindex>
415 /// %2 = cmpi "eq", %0, %1 : index
416 /// %result = scf.if %2 -> (i1) {
417 ///   %c1 = constant 1 : index
418 ///   %true = constant true
419 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
420 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
421 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
422 ///     %7 = cmpi "eq", %5, %6 : index
423 ///     %8 = and %arg3, %7 : i1
424 ///     scf.yield %8 : i1
425 ///   }
426 ///   scf.yield %4 : i1
427 /// } else {
428 ///   %false = constant false
429 ///   scf.yield %false : i1
430 /// }
431 ///
432 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
433   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
434 
435   LogicalResult
436   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
437                   ConversionPatternRewriter &rewriter) const override;
438 };
439 } // namespace
440 
441 LogicalResult
442 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
443                                     ConversionPatternRewriter &rewriter) const {
444   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
445   // on shapes.
446   if (op.lhs().getType().isa<ShapeType>() ||
447       op.rhs().getType().isa<ShapeType>()) {
448     return failure();
449   }
450 
451   ShapeEqOp::Adaptor transformed(operands);
452   auto loc = op.getLoc();
453   Type indexTy = rewriter.getIndexType();
454   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
455   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
456   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
457   Value eqRank =
458       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
459   Type i1Ty = rewriter.getI1Type();
460   rewriter.replaceOpWithNewOp<IfOp>(
461       op, i1Ty, eqRank,
462       [&](OpBuilder &b, Location loc) {
463         Value one = b.create<ConstantIndexOp>(loc, 1);
464         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
465         auto loop = b.create<scf::ForOp>(
466             loc, zero, lhsRank, one, ValueRange{init},
467             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
468               Value conj = args[0];
469               Value lhsExtent =
470                   b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
471               Value rhsExtent =
472                   b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
473               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
474                                                 lhsExtent, rhsExtent);
475               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
476               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
477             });
478         b.create<scf::YieldOp>(loc, loop.getResults());
479       },
480       [&](OpBuilder &b, Location loc) {
481         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
482         b.create<scf::YieldOp>(loc, result);
483       });
484   return success();
485 }
486 
487 namespace {
488 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
489 public:
490   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
491 
492   LogicalResult
493   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
494                   ConversionPatternRewriter &rewriter) const override;
495 };
496 } // namespace
497 
498 LogicalResult ShapeOfOpConversion::matchAndRewrite(
499     ShapeOfOp op, ArrayRef<Value> operands,
500     ConversionPatternRewriter &rewriter) const {
501 
502   // For now, only error-free types are supported by this lowering.
503   if (op.getType().isa<ShapeType>())
504     return failure();
505 
506   // For ranked tensor arguments, lower to `tensor_from_elements`.
507   auto loc = op.getLoc();
508   ShapeOfOp::Adaptor transformed(operands);
509   Value tensor = transformed.arg();
510   Type tensorTy = tensor.getType();
511   if (tensorTy.isa<RankedTensorType>()) {
512 
513     // Build values for individual extents.
514     SmallVector<Value, 8> extentValues;
515     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
516     int64_t rank = rankedTensorTy.getRank();
517     for (int64_t i = 0; i < rank; i++) {
518       if (rankedTensorTy.isDynamicDim(i)) {
519         Value extent = rewriter.create<DimOp>(loc, tensor, i);
520         extentValues.push_back(extent);
521       } else {
522         Value extent =
523             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
524         extentValues.push_back(extent);
525       }
526     }
527 
528     // Materialize extent tensor.
529     Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
530         loc, rewriter.getIndexType(), extentValues);
531     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
532                                               op.getType());
533     return success();
534   }
535 
536   // Lower to `dynamic_tensor_from_elements` otherwise.
537   auto *ctx = rewriter.getContext();
538   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
539   rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
540       op, getExtentTensorType(ctx), ValueRange{rank},
541       [&](OpBuilder &b, Location loc, ValueRange args) {
542         Value dim = args.front();
543         Value extent = b.create<DimOp>(loc, tensor, dim);
544         b.create<mlir::YieldOp>(loc, extent);
545       });
546 
547   return success();
548 }
549 
550 namespace {
551 class ToExtentTensorOpConversion
552     : public OpConversionPattern<ToExtentTensorOp> {
553 public:
554   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
555 
556   LogicalResult
557   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
558                   ConversionPatternRewriter &rewriter) const override {
559     ToExtentTensorOpAdaptor adaptor(operands);
560 
561     if (!adaptor.input().getType().isa<RankedTensorType>())
562       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
563 
564     rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
565                                               op.getType());
566     return success();
567   }
568 };
569 } // namespace
570 
571 namespace {
572 /// Import the Shape Ops to Std Patterns.
573 #include "ShapeToStandard.cpp.inc"
574 } // namespace
575 
576 namespace {
577 /// Conversion pass.
578 class ConvertShapeToStandardPass
579     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
580 
581   void runOnOperation() override;
582 };
583 } // namespace
584 
585 void ConvertShapeToStandardPass::runOnOperation() {
586   // Setup target legality.
587   MLIRContext &ctx = getContext();
588   ConversionTarget target(ctx);
589   target
590       .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
591   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
592 
593   // Setup conversion patterns.
594   OwningRewritePatternList patterns;
595   populateShapeToStandardConversionPatterns(patterns, &ctx);
596 
597   // Apply conversion.
598   auto module = getOperation();
599   if (failed(applyPartialConversion(module, target, std::move(patterns))))
600     signalPassFailure();
601 }
602 
603 void mlir::populateShapeToStandardConversionPatterns(
604     OwningRewritePatternList &patterns, MLIRContext *ctx) {
605   // clang-format off
606   populateWithGenerated(ctx, patterns);
607   patterns.insert<
608       AnyOpConversion,
609       BinaryOpConversion<AddOp, AddIOp>,
610       BinaryOpConversion<MulOp, MulIOp>,
611       BroadcastOpConverter,
612       ConstShapeOpConverter,
613       ConstSizeOpConversion,
614       IsBroadcastableOpConverter,
615       GetExtentOpConverter,
616       RankOpConverter,
617       ReduceOpConverter,
618       ShapeEqOpConverter,
619       ShapeOfOpConversion,
620       ToExtentTensorOpConversion>(ctx);
621   // clang-format on
622 }
623 
624 std::unique_ptr<OperationPass<ModuleOp>>
625 mlir::createConvertShapeToStandardPass() {
626   return std::make_unique<ConvertShapeToStandardPass>();
627 }
628