1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 using namespace mlir::scf;
24 
25 /// Conversion patterns.
26 namespace {
27 class AnyOpConversion : public OpConversionPattern<AnyOp> {
28 public:
29   using OpConversionPattern<AnyOp>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override;
34 };
35 } // namespace
36 
37 LogicalResult
38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
39                                  ConversionPatternRewriter &rewriter) const {
40   AnyOp::Adaptor transformed(operands);
41 
42   // Replace `any` with its first operand.
43   // Any operand would be a valid substitution.
44   rewriter.replaceOp(op, {transformed.inputs().front()});
45   return success();
46 }
47 
48 namespace {
49 template <typename SrcOpTy, typename DstOpTy>
50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
51 public:
52   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
53 
54   LogicalResult
55   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
56                   ConversionPatternRewriter &rewriter) const override {
57     typename SrcOpTy::Adaptor transformed(operands);
58 
59     // For now, only error-free types are supported by this lowering.
60     if (op.getType().template isa<SizeType>())
61       return failure();
62 
63     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
64                                          transformed.rhs());
65     return success();
66   }
67 };
68 } // namespace
69 
70 namespace {
71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
72   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
73 
74   LogicalResult
75   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
76                   ConversionPatternRewriter &rewriter) const override;
77 };
78 
79 // Get the resulting extent in a given dimension. This is computed with any
80 // number of extent tensors and shifted offsets into them.
81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
82                         ValueRange rankDiffs, Value outputDimension) {
83   Value one = lb.create<ConstantIndexOp>(1);
84   Value broadcastedDim = one;
85   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
86     Value shape = std::get<0>(tup);
87     Value rankDiff = std::get<1>(tup);
88     Value outOfBounds =
89         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
90     Type indexTy = lb.getIndexType();
91     broadcastedDim =
92         lb.create<IfOp>(
93               TypeRange{indexTy}, outOfBounds,
94               [&](OpBuilder &b, Location loc) {
95                 b.create<scf::YieldOp>(loc, broadcastedDim);
96               },
97               [&](OpBuilder &b, Location loc) {
98                 // The broadcasting logic is:
99                 // - if one extent (here we arbitrarily choose the
100                 // extent from the greater-rank operand) is equal to 1,
101                 // then take the extent from the other operand
102                 // - otherwise, take the extent as-is.
103                 // Note that this logic remains correct in the presence
104                 // of dimensions of zero extent.
105                 Value lesserRankOperandDimension =
106                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
107                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
108                     loc, shape, ValueRange{lesserRankOperandDimension});
109 
110                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
111                                                   lesserRankOperandExtent, one);
112                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
113                                                lesserRankOperandExtent);
114                 b.create<scf::YieldOp>(loc, dim);
115               })
116             .getResult(0);
117   }
118   return broadcastedDim;
119 }
120 } // namespace
121 
122 LogicalResult BroadcastOpConverter::matchAndRewrite(
123     BroadcastOp op, ArrayRef<Value> operands,
124     ConversionPatternRewriter &rewriter) const {
125   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
126   // on shapes.
127   if (op.getType().isa<ShapeType>())
128     return failure();
129 
130   auto loc = op.getLoc();
131   ImplicitLocOpBuilder lb(loc, rewriter);
132   BroadcastOp::Adaptor transformed(operands);
133 
134   Value zero = lb.create<ConstantIndexOp>(0);
135   Type indexTy = lb.getIndexType();
136 
137   // Save all the ranks for bounds checking. Because this is a tensor
138   // representing the shape extents, the rank is the extent of the only
139   // dimension in the tensor.
140   SmallVector<Value> ranks, rankDiffs;
141   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
142                        return lb.create<DimOp>(v, zero);
143                      }));
144 
145   // Find the maximum rank
146   Value maxRank = ranks.front();
147   for (Value v : llvm::drop_begin(ranks, 1)) {
148     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
149     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
150   }
151 
152   // Calculate the difference of ranks and the maximum rank for later offsets.
153   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154                        return lb.create<SubIOp>(indexTy, maxRank, v);
155                      }));
156 
157   rewriter.replaceOp(
158       op, lb.create<tensor::GenerateOp>(
159                 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160                 [&](OpBuilder &b, Location loc, ValueRange args) {
161                   Value broadcastedDim = getBroadcastedDim(
162                       ImplicitLocOpBuilder(loc, b), transformed.shapes(),
163                       rankDiffs, args[0]);
164 
165                   b.create<tensor::YieldOp>(loc, broadcastedDim);
166                 })
167               ->getResults());
168   return success();
169 }
170 
171 namespace {
172 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
173 public:
174   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
175 
176   LogicalResult
177   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
178                   ConversionPatternRewriter &rewriter) const override;
179 };
180 } // namespace
181 
182 LogicalResult ConstShapeOpConverter::matchAndRewrite(
183     ConstShapeOp op, ArrayRef<Value> operands,
184     ConversionPatternRewriter &rewriter) const {
185 
186   // For now, this lowering supports only extent tensors, not `shape.shape`
187   // types.
188   if (op.getType().isa<ShapeType>())
189     return failure();
190 
191   auto loc = op.getLoc();
192   SmallVector<Value, 4> extentOperands;
193   for (auto extent : op.shape()) {
194     extentOperands.push_back(
195         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
196   }
197   Type indexTy = rewriter.getIndexType();
198   Value tensor =
199       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
200   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
201   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
202   return success();
203 }
204 
205 namespace {
206 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
207 public:
208   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
212                   ConversionPatternRewriter &rewriter) const override;
213 };
214 } // namespace
215 
216 LogicalResult ConstSizeOpConversion::matchAndRewrite(
217     ConstSizeOp op, ArrayRef<Value> operands,
218     ConversionPatternRewriter &rewriter) const {
219   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
220   return success();
221 }
222 
223 namespace {
224 struct IsBroadcastableOpConverter
225     : public OpConversionPattern<IsBroadcastableOp> {
226   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227 
228   LogicalResult
229   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
230                   ConversionPatternRewriter &rewriter) const override;
231 };
232 } // namespace
233 
234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235     IsBroadcastableOp op, ArrayRef<Value> operands,
236     ConversionPatternRewriter &rewriter) const {
237   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238   // on shapes.
239   IsBroadcastableOp::Adaptor transformed(operands);
240   if (transformed.lhs().getType().isa<ShapeType>() ||
241       transformed.rhs().getType().isa<ShapeType>())
242     return failure();
243 
244   auto loc = op.getLoc();
245   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
246   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
247 
248   // Find smaller and greater rank and extent tensor.
249   Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
250   Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
251   Value lhsRankULE =
252       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
253   Type indexTy = rewriter.getIndexType();
254   Value lesserRank =
255       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
256   Value greaterRank =
257       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
258   auto erasedRankType =
259       RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
260   Value rankErasedLhs =
261       rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
262   Value rankErasedRhs =
263       rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
264   Value lesserRankOperand =
265       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
266   Value greaterRankOperand =
267       rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
268   Value rankDiff =
269       rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
270   Type i1Ty = rewriter.getI1Type();
271   Value init =
272       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273 
274   // Determine if all overlapping extents are broadcastable.
275   auto reduceResult = rewriter.create<ForOp>(
276       loc, rankDiff, greaterRank, one, ValueRange{init},
277       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278         Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
279             loc, greaterRankOperand, ValueRange{iv});
280         Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
281             loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
282         Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
283         Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
284             loc, lesserRankOperand, ValueRange{ivShifted});
285         Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
286             loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
287         Value extentsAreEqual =
288             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
289                              lesserRankOperandExtent);
290         Value broadcastableExtents = b.create<AndOp>(
291             loc, iterArgs[0],
292             b.create<OrOp>(loc,
293                            b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
294                                           lesserRankOperandExtentIsOne),
295                            extentsAreEqual));
296         b.create<scf::YieldOp>(loc, broadcastableExtents);
297       });
298 
299   rewriter.replaceOp(op, reduceResult.results().front());
300   return success();
301 }
302 
303 namespace {
304 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
305   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
306 
307   LogicalResult
308   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
309                   ConversionPatternRewriter &rewriter) const override;
310 };
311 } // namespace
312 
313 LogicalResult GetExtentOpConverter::matchAndRewrite(
314     GetExtentOp op, ArrayRef<Value> operands,
315     ConversionPatternRewriter &rewriter) const {
316   GetExtentOp::Adaptor transformed(operands);
317 
318   // For now, only error-free types are supported by this lowering.
319   if (op.getType().isa<SizeType>())
320     return failure();
321 
322   // Derive shape extent directly from shape origin if possible. This
323   // circumvents the necessity to materialize the shape in memory.
324   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
325     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
326       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
327                                          transformed.dim());
328       return success();
329     }
330   }
331 
332   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
333                                                  transformed.shape(),
334                                                  ValueRange{transformed.dim()});
335   return success();
336 }
337 
338 namespace {
339 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
340 public:
341   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
342 
343   LogicalResult
344   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
345                   ConversionPatternRewriter &rewriter) const override;
346 };
347 } // namespace
348 
349 LogicalResult
350 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
351                                  ConversionPatternRewriter &rewriter) const {
352   // For now, this lowering supports only error-free types.
353   if (op.getType().isa<SizeType>())
354     return failure();
355 
356   shape::RankOp::Adaptor transformed(operands);
357   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
358   return success();
359 }
360 
361 namespace {
362 /// Converts `shape.reduce` to `scf.for`.
363 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
364 public:
365   using OpConversionPattern::OpConversionPattern;
366 
367   LogicalResult
368   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
369                   ConversionPatternRewriter &rewriter) const final;
370 };
371 } // namespace
372 
373 LogicalResult
374 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
375                                    ConversionPatternRewriter &rewriter) const {
376   // For now, this lowering is only defined on `tensor<?xindex>` operands.
377   if (op.shape().getType().isa<ShapeType>())
378     return failure();
379 
380   auto loc = op.getLoc();
381   shape::ReduceOp::Adaptor transformed(operands);
382 
383   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
384   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
385   Type indexTy = rewriter.getIndexType();
386   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
387 
388   auto loop = rewriter.create<scf::ForOp>(
389       loc, zero, rank, one, op.initVals(),
390       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
391         Value extent =
392             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
393 
394         SmallVector<Value, 2> mappedValues{iv, extent};
395         mappedValues.append(args.begin(), args.end());
396 
397         BlockAndValueMapping mapping;
398         Block *reduceBody = op.getBody();
399         mapping.map(reduceBody->getArguments(), mappedValues);
400         for (auto &nested : reduceBody->without_terminator())
401           b.clone(nested, mapping);
402 
403         SmallVector<Value, 2> mappedResults;
404         for (auto result : reduceBody->getTerminator()->getOperands())
405           mappedResults.push_back(mapping.lookup(result));
406         b.create<scf::YieldOp>(loc, mappedResults);
407       });
408 
409   rewriter.replaceOp(op, loop.getResults());
410   return success();
411 }
412 
413 namespace {
414 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
415 /// only defined on `tensor<?xindex>` operands. The test for equality first
416 /// compares their size and, if equal, checks every extent for equality.
417 ///
418 /// Example:
419 ///
420 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
421 ///
422 /// becomes
423 ///
424 /// %c0 = constant 0 : index
425 /// %0 = dim %arg0, %c0 : tensor<?xindex>
426 /// %1 = dim %arg1, %c0 : tensor<?xindex>
427 /// %2 = cmpi "eq", %0, %1 : index
428 /// %result = scf.if %2 -> (i1) {
429 ///   %c1 = constant 1 : index
430 ///   %true = constant true
431 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
432 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
433 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
434 ///     %7 = cmpi "eq", %5, %6 : index
435 ///     %8 = and %arg3, %7 : i1
436 ///     scf.yield %8 : i1
437 ///   }
438 ///   scf.yield %4 : i1
439 /// } else {
440 ///   %false = constant false
441 ///   scf.yield %false : i1
442 /// }
443 ///
444 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
445   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
446 
447   LogicalResult
448   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
449                   ConversionPatternRewriter &rewriter) const override;
450 };
451 } // namespace
452 
453 LogicalResult
454 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
455                                     ConversionPatternRewriter &rewriter) const {
456   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
457   // on shapes.
458   if (op.lhs().getType().isa<ShapeType>() ||
459       op.rhs().getType().isa<ShapeType>()) {
460     return failure();
461   }
462 
463   ShapeEqOp::Adaptor transformed(operands);
464   auto loc = op.getLoc();
465   Type indexTy = rewriter.getIndexType();
466   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
467   Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
468   Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
469   Value eqRank =
470       rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
471   Type i1Ty = rewriter.getI1Type();
472   rewriter.replaceOpWithNewOp<IfOp>(
473       op, i1Ty, eqRank,
474       [&](OpBuilder &b, Location loc) {
475         Value one = b.create<ConstantIndexOp>(loc, 1);
476         Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
477         auto loop = b.create<scf::ForOp>(
478             loc, zero, lhsRank, one, ValueRange{init},
479             [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
480               Value conj = args[0];
481               Value lhsExtent =
482                   b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
483               Value rhsExtent =
484                   b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
485               Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
486                                                 lhsExtent, rhsExtent);
487               Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
488               b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
489             });
490         b.create<scf::YieldOp>(loc, loop.getResults());
491       },
492       [&](OpBuilder &b, Location loc) {
493         Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
494         b.create<scf::YieldOp>(loc, result);
495       });
496   return success();
497 }
498 
499 namespace {
500 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
501 public:
502   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
503 
504   LogicalResult
505   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
506                   ConversionPatternRewriter &rewriter) const override;
507 };
508 } // namespace
509 
510 LogicalResult ShapeOfOpConversion::matchAndRewrite(
511     ShapeOfOp op, ArrayRef<Value> operands,
512     ConversionPatternRewriter &rewriter) const {
513 
514   // For now, only error-free types are supported by this lowering.
515   if (op.getType().isa<ShapeType>())
516     return failure();
517 
518   // For ranked tensor arguments, lower to `tensor.from_elements`.
519   auto loc = op.getLoc();
520   ShapeOfOp::Adaptor transformed(operands);
521   Value tensor = transformed.arg();
522   Type tensorTy = tensor.getType();
523   if (tensorTy.isa<RankedTensorType>()) {
524 
525     // Build values for individual extents.
526     SmallVector<Value, 8> extentValues;
527     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
528     int64_t rank = rankedTensorTy.getRank();
529     for (int64_t i = 0; i < rank; i++) {
530       if (rankedTensorTy.isDynamicDim(i)) {
531         Value extent = rewriter.create<DimOp>(loc, tensor, i);
532         extentValues.push_back(extent);
533       } else {
534         Value extent =
535             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
536         extentValues.push_back(extent);
537       }
538     }
539 
540     // Materialize extent tensor.
541     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
542         loc, rewriter.getIndexType(), extentValues);
543     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
544                                                 staticExtentTensor);
545     return success();
546   }
547 
548   // Lower to `tensor.generate` otherwise.
549   auto *ctx = rewriter.getContext();
550   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
551   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
552       op, getExtentTensorType(ctx), ValueRange{rank},
553       [&](OpBuilder &b, Location loc, ValueRange args) {
554         Value dim = args.front();
555         Value extent = b.create<DimOp>(loc, tensor, dim);
556         b.create<tensor::YieldOp>(loc, extent);
557       });
558 
559   return success();
560 }
561 
562 namespace {
563 class ToExtentTensorOpConversion
564     : public OpConversionPattern<ToExtentTensorOp> {
565 public:
566   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
567 
568   LogicalResult
569   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
570                   ConversionPatternRewriter &rewriter) const override {
571     ToExtentTensorOpAdaptor adaptor(operands);
572 
573     if (!adaptor.input().getType().isa<RankedTensorType>())
574       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
575 
576     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
577                                                 adaptor.input());
578     return success();
579   }
580 };
581 } // namespace
582 
583 namespace {
584 /// Import the Shape Ops to Std Patterns.
585 #include "ShapeToStandard.cpp.inc"
586 } // namespace
587 
588 namespace {
589 /// Conversion pass.
590 class ConvertShapeToStandardPass
591     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
592 
593   void runOnOperation() override;
594 };
595 } // namespace
596 
597 void ConvertShapeToStandardPass::runOnOperation() {
598   // Setup target legality.
599   MLIRContext &ctx = getContext();
600   ConversionTarget target(ctx);
601   target
602       .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
603   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
604 
605   // Setup conversion patterns.
606   OwningRewritePatternList patterns;
607   populateShapeToStandardConversionPatterns(patterns, &ctx);
608 
609   // Apply conversion.
610   auto module = getOperation();
611   if (failed(applyPartialConversion(module, target, std::move(patterns))))
612     signalPassFailure();
613 }
614 
615 void mlir::populateShapeToStandardConversionPatterns(
616     OwningRewritePatternList &patterns, MLIRContext *ctx) {
617   // clang-format off
618   populateWithGenerated(ctx, patterns);
619   patterns.insert<
620       AnyOpConversion,
621       BinaryOpConversion<AddOp, AddIOp>,
622       BinaryOpConversion<MulOp, MulIOp>,
623       BroadcastOpConverter,
624       ConstShapeOpConverter,
625       ConstSizeOpConversion,
626       IsBroadcastableOpConverter,
627       GetExtentOpConverter,
628       RankOpConverter,
629       ReduceOpConverter,
630       ShapeEqOpConverter,
631       ShapeOfOpConversion,
632       ToExtentTensorOpConversion>(ctx);
633   // clang-format on
634 }
635 
636 std::unique_ptr<OperationPass<ModuleOp>>
637 mlir::createConvertShapeToStandardPass() {
638   return std::make_unique<ConvertShapeToStandardPass>();
639 }
640