1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 using namespace mlir::scf;
24 
25 /// Conversion patterns.
26 namespace {
27 class AnyOpConversion : public OpConversionPattern<AnyOp> {
28 public:
29   using OpConversionPattern<AnyOp>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override;
34 };
35 } // namespace
36 
37 LogicalResult
38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
39                                  ConversionPatternRewriter &rewriter) const {
40   AnyOp::Adaptor transformed(operands);
41 
42   // Replace `any` with its first operand.
43   // Any operand would be a valid substitution.
44   rewriter.replaceOp(op, {transformed.inputs().front()});
45   return success();
46 }
47 
48 namespace {
49 template <typename SrcOpTy, typename DstOpTy>
50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
51 public:
52   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
53 
54   LogicalResult
55   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
56                   ConversionPatternRewriter &rewriter) const override {
57     typename SrcOpTy::Adaptor transformed(operands);
58 
59     // For now, only error-free types are supported by this lowering.
60     if (op.getType().template isa<SizeType>())
61       return failure();
62 
63     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
64                                          transformed.rhs());
65     return success();
66   }
67 };
68 } // namespace
69 
70 namespace {
71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
72   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
73 
74   LogicalResult
75   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
76                   ConversionPatternRewriter &rewriter) const override;
77 };
78 
79 // Get the resulting extent in a given dimension. This is computed with any
80 // number of extent tensors and shifted offsets into them.
81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
82                         ValueRange rankDiffs, Value outputDimension) {
83   Value one = lb.create<ConstantIndexOp>(1);
84   Value broadcastedDim = one;
85   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
86     Value shape = std::get<0>(tup);
87     Value rankDiff = std::get<1>(tup);
88     Value outOfBounds =
89         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
90     Type indexTy = lb.getIndexType();
91     broadcastedDim =
92         lb.create<IfOp>(
93               TypeRange{indexTy}, outOfBounds,
94               [&](OpBuilder &b, Location loc) {
95                 b.create<scf::YieldOp>(loc, broadcastedDim);
96               },
97               [&](OpBuilder &b, Location loc) {
98                 // The broadcasting logic is:
99                 // - if one extent (here we arbitrarily choose the
100                 // extent from the greater-rank operand) is equal to 1,
101                 // then take the extent from the other operand
102                 // - otherwise, take the extent as-is.
103                 // Note that this logic remains correct in the presence
104                 // of dimensions of zero extent.
105                 Value lesserRankOperandDimension =
106                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
107                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
108                     loc, shape, ValueRange{lesserRankOperandDimension});
109 
110                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
111                                                   lesserRankOperandExtent, one);
112                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
113                                                lesserRankOperandExtent);
114                 b.create<scf::YieldOp>(loc, dim);
115               })
116             .getResult(0);
117   }
118   return broadcastedDim;
119 }
120 } // namespace
121 
122 LogicalResult BroadcastOpConverter::matchAndRewrite(
123     BroadcastOp op, ArrayRef<Value> operands,
124     ConversionPatternRewriter &rewriter) const {
125   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
126   // on shapes.
127   if (op.getType().isa<ShapeType>())
128     return failure();
129 
130   auto loc = op.getLoc();
131   ImplicitLocOpBuilder lb(loc, rewriter);
132   BroadcastOp::Adaptor transformed(operands);
133 
134   Value zero = lb.create<ConstantIndexOp>(0);
135   Type indexTy = lb.getIndexType();
136 
137   // Save all the ranks for bounds checking. Because this is a tensor
138   // representing the shape extents, the rank is the extent of the only
139   // dimension in the tensor.
140   SmallVector<Value> ranks, rankDiffs;
141   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
142                        return lb.create<tensor::DimOp>(v, zero);
143                      }));
144 
145   // Find the maximum rank
146   Value maxRank = ranks.front();
147   for (Value v : llvm::drop_begin(ranks, 1)) {
148     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
149     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
150   }
151 
152   // Calculate the difference of ranks and the maximum rank for later offsets.
153   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154                        return lb.create<SubIOp>(indexTy, maxRank, v);
155                      }));
156 
157   Value replacement = lb.create<tensor::GenerateOp>(
158       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
159       [&](OpBuilder &b, Location loc, ValueRange args) {
160         Value broadcastedDim =
161             getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
162                               transformed.shapes(), rankDiffs, args[0]);
163 
164         b.create<tensor::YieldOp>(loc, broadcastedDim);
165       });
166   if (replacement.getType() != op.getType())
167     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
168   rewriter.replaceOp(op, replacement);
169   return success();
170 }
171 
172 namespace {
173 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
174 public:
175   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 } // namespace
182 
183 LogicalResult ConstShapeOpConverter::matchAndRewrite(
184     ConstShapeOp op, ArrayRef<Value> operands,
185     ConversionPatternRewriter &rewriter) const {
186 
187   // For now, this lowering supports only extent tensors, not `shape.shape`
188   // types.
189   if (op.getType().isa<ShapeType>())
190     return failure();
191 
192   auto loc = op.getLoc();
193   SmallVector<Value, 4> extentOperands;
194   for (auto extent : op.shape()) {
195     extentOperands.push_back(
196         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
197   }
198   Type indexTy = rewriter.getIndexType();
199   Value tensor =
200       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
201   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
202   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
203   return success();
204 }
205 
206 namespace {
207 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
208 public:
209   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
210 
211   LogicalResult
212   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 } // namespace
216 
217 LogicalResult ConstSizeOpConversion::matchAndRewrite(
218     ConstSizeOp op, ArrayRef<Value> operands,
219     ConversionPatternRewriter &rewriter) const {
220   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
221   return success();
222 }
223 
224 namespace {
225 struct IsBroadcastableOpConverter
226     : public OpConversionPattern<IsBroadcastableOp> {
227   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
228 
229   LogicalResult
230   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
231                   ConversionPatternRewriter &rewriter) const override;
232 };
233 } // namespace
234 
235 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
236     IsBroadcastableOp op, ArrayRef<Value> operands,
237     ConversionPatternRewriter &rewriter) const {
238   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
239   // on shapes.
240   IsBroadcastableOp::Adaptor transformed(operands);
241   if (!llvm::all_of(op.shapes(),
242                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
243     return failure();
244 
245   auto loc = op.getLoc();
246   ImplicitLocOpBuilder lb(loc, rewriter);
247   Value zero = lb.create<ConstantIndexOp>(0);
248   Value one = lb.create<ConstantIndexOp>(1);
249   Type indexTy = lb.getIndexType();
250 
251   // Save all the ranks for bounds checking. Because this is a tensor
252   // representing the shape extents, the rank is the extent of the only
253   // dimension in the tensor.
254   SmallVector<Value> ranks, rankDiffs;
255   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
256                        return lb.create<tensor::DimOp>(v, zero);
257                      }));
258 
259   // Find the maximum rank
260   Value maxRank = ranks.front();
261   for (Value v : llvm::drop_begin(ranks, 1)) {
262     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
263     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
264   }
265 
266   // Calculate the difference of ranks and the maximum rank for later offsets.
267   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268                        return lb.create<SubIOp>(indexTy, maxRank, v);
269                      }));
270 
271   Type i1Ty = rewriter.getI1Type();
272   Value trueVal =
273       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274 
275   auto reduceResult = lb.create<ForOp>(
276       loc, zero, maxRank, one, ValueRange{trueVal},
277       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278         // Find a non-1 dim, if it exists. Note that the first part of this
279         // could reuse the Broadcast lowering entirely, but we redo the work
280         // here to make optimizations easier between the two loops.
281         Value broadcastedDim = getBroadcastedDim(
282             ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
283 
284         Value broadcastable = iterArgs[0];
285         for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
286           Value shape, rankDiff;
287           std::tie(shape, rankDiff) = tup;
288           Value outOfBounds =
289               b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
290           broadcastable =
291               b.create<IfOp>(
292                    loc, TypeRange{i1Ty}, outOfBounds,
293                    [&](OpBuilder &b, Location loc) {
294                      // Non existent dimensions are always broadcastable
295                      b.create<scf::YieldOp>(loc, broadcastable);
296                    },
297                    [&](OpBuilder &b, Location loc) {
298                      // Every value needs to be either 1, or the same non-1
299                      // value to be broadcastable in this dim.
300                      Value operandDimension =
301                          b.create<SubIOp>(loc, indexTy, iv, rankDiff);
302                      Value dimensionExtent = b.create<tensor::ExtractOp>(
303                          loc, shape, ValueRange{operandDimension});
304 
305                      Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
306                                                        dimensionExtent, one);
307                      Value equalBroadcasted =
308                          b.create<CmpIOp>(loc, CmpIPredicate::eq,
309                                           dimensionExtent, broadcastedDim);
310                      Value result = b.create<AndOp>(
311                          loc, broadcastable,
312                          b.create<OrOp>(loc, equalOne, equalBroadcasted));
313                      b.create<scf::YieldOp>(loc, result);
314                    })
315                   .getResult(0);
316         }
317 
318         b.create<scf::YieldOp>(loc, broadcastable);
319       });
320 
321   rewriter.replaceOp(op, reduceResult.results().front());
322   return success();
323 }
324 
325 namespace {
326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
327   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
331                   ConversionPatternRewriter &rewriter) const override;
332 };
333 } // namespace
334 
335 LogicalResult GetExtentOpConverter::matchAndRewrite(
336     GetExtentOp op, ArrayRef<Value> operands,
337     ConversionPatternRewriter &rewriter) const {
338   GetExtentOp::Adaptor transformed(operands);
339 
340   // For now, only error-free types are supported by this lowering.
341   if (op.getType().isa<SizeType>())
342     return failure();
343 
344   // Derive shape extent directly from shape origin if possible. This
345   // circumvents the necessity to materialize the shape in memory.
346   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
347     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
348       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(),
349                                                  transformed.dim());
350       return success();
351     }
352   }
353 
354   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
355                                                  transformed.shape(),
356                                                  ValueRange{transformed.dim()});
357   return success();
358 }
359 
360 namespace {
361 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
362 public:
363   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
364 
365   LogicalResult
366   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
367                   ConversionPatternRewriter &rewriter) const override;
368 };
369 } // namespace
370 
371 LogicalResult
372 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
373                                  ConversionPatternRewriter &rewriter) const {
374   // For now, this lowering supports only error-free types.
375   if (op.getType().isa<SizeType>())
376     return failure();
377 
378   shape::RankOp::Adaptor transformed(operands);
379   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, transformed.shape(), 0);
380   return success();
381 }
382 
383 namespace {
384 /// Converts `shape.reduce` to `scf.for`.
385 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
386 public:
387   using OpConversionPattern::OpConversionPattern;
388 
389   LogicalResult
390   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const final;
392 };
393 } // namespace
394 
395 LogicalResult
396 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
397                                    ConversionPatternRewriter &rewriter) const {
398   // For now, this lowering is only defined on `tensor<?xindex>` operands.
399   if (op.shape().getType().isa<ShapeType>())
400     return failure();
401 
402   auto loc = op.getLoc();
403   shape::ReduceOp::Adaptor transformed(operands);
404 
405   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
406   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
407   Type indexTy = rewriter.getIndexType();
408   Value rank =
409       rewriter.create<tensor::DimOp>(loc, indexTy, transformed.shape(), zero);
410 
411   auto loop = rewriter.create<scf::ForOp>(
412       loc, zero, rank, one, op.initVals(),
413       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
414         Value extent =
415             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
416 
417         SmallVector<Value, 2> mappedValues{iv, extent};
418         mappedValues.append(args.begin(), args.end());
419 
420         BlockAndValueMapping mapping;
421         Block *reduceBody = op.getBody();
422         mapping.map(reduceBody->getArguments(), mappedValues);
423         for (auto &nested : reduceBody->without_terminator())
424           b.clone(nested, mapping);
425 
426         SmallVector<Value, 2> mappedResults;
427         for (auto result : reduceBody->getTerminator()->getOperands())
428           mappedResults.push_back(mapping.lookup(result));
429         b.create<scf::YieldOp>(loc, mappedResults);
430       });
431 
432   rewriter.replaceOp(op, loop.getResults());
433   return success();
434 }
435 
436 namespace {
437 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
438 /// only defined on `tensor<?xindex>` operands. The test for equality first
439 /// compares their size and, if equal, checks every extent for equality.
440 ///
441 /// Example:
442 ///
443 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
444 ///
445 /// becomes
446 ///
447 /// %c0 = constant 0 : index
448 /// %0 = dim %arg0, %c0 : tensor<?xindex>
449 /// %1 = dim %arg1, %c0 : tensor<?xindex>
450 /// %2 = cmpi "eq", %0, %1 : index
451 /// %result = scf.if %2 -> (i1) {
452 ///   %c1 = constant 1 : index
453 ///   %true = constant true
454 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
455 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
456 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
457 ///     %7 = cmpi "eq", %5, %6 : index
458 ///     %8 = and %arg3, %7 : i1
459 ///     scf.yield %8 : i1
460 ///   }
461 ///   scf.yield %4 : i1
462 /// } else {
463 ///   %false = constant false
464 ///   scf.yield %false : i1
465 /// }
466 ///
467 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
468   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
469 
470   LogicalResult
471   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
472                   ConversionPatternRewriter &rewriter) const override;
473 };
474 } // namespace
475 
476 LogicalResult
477 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
478                                     ConversionPatternRewriter &rewriter) const {
479   if (!llvm::all_of(op.shapes(),
480                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
481     return failure();
482 
483   Type i1Ty = rewriter.getI1Type();
484   if (op.shapes().size() <= 1) {
485     rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
486                                             rewriter.getBoolAttr(true));
487     return success();
488   }
489 
490   ShapeEqOp::Adaptor transformed(operands);
491   auto loc = op.getLoc();
492   Type indexTy = rewriter.getIndexType();
493   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
494   Value firstShape = transformed.shapes().front();
495   Value firstRank =
496       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
497   Value result = nullptr;
498   // Generate a linear sequence of compares, all with firstShape as lhs.
499   for (Value shape : transformed.shapes().drop_front(1)) {
500     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
501     Value eqRank =
502         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
503     auto same = rewriter.create<IfOp>(
504         loc, i1Ty, eqRank,
505         [&](OpBuilder &b, Location loc) {
506           Value one = b.create<ConstantIndexOp>(loc, 1);
507           Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
508           auto loop = b.create<scf::ForOp>(
509               loc, zero, firstRank, one, ValueRange{init},
510               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
511                 Value conj = args[0];
512                 Value lhsExtent =
513                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
514                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
515                 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
516                                                   lhsExtent, rhsExtent);
517                 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
518                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
519               });
520           b.create<scf::YieldOp>(loc, loop.getResults());
521         },
522         [&](OpBuilder &b, Location loc) {
523           Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
524           b.create<scf::YieldOp>(loc, result);
525         });
526     result = !result ? same.getResult(0)
527                      : rewriter.create<AndOp>(loc, result, same.getResult(0));
528   }
529   rewriter.replaceOp(op, result);
530   return success();
531 }
532 
533 namespace {
534 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
535 public:
536   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
537 
538   LogicalResult
539   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
540                   ConversionPatternRewriter &rewriter) const override;
541 };
542 } // namespace
543 
544 LogicalResult ShapeOfOpConversion::matchAndRewrite(
545     ShapeOfOp op, ArrayRef<Value> operands,
546     ConversionPatternRewriter &rewriter) const {
547 
548   // For now, only error-free types are supported by this lowering.
549   if (op.getType().isa<ShapeType>())
550     return failure();
551 
552   // For ranked tensor arguments, lower to `tensor.from_elements`.
553   auto loc = op.getLoc();
554   ShapeOfOp::Adaptor transformed(operands);
555   Value tensor = transformed.arg();
556   Type tensorTy = tensor.getType();
557   if (tensorTy.isa<RankedTensorType>()) {
558 
559     // Build values for individual extents.
560     SmallVector<Value, 8> extentValues;
561     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
562     int64_t rank = rankedTensorTy.getRank();
563     for (int64_t i = 0; i < rank; i++) {
564       if (rankedTensorTy.isDynamicDim(i)) {
565         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
566         extentValues.push_back(extent);
567       } else {
568         Value extent =
569             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
570         extentValues.push_back(extent);
571       }
572     }
573 
574     // Materialize extent tensor.
575     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
576         loc, rewriter.getIndexType(), extentValues);
577     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
578                                                 staticExtentTensor);
579     return success();
580   }
581 
582   // Lower to `tensor.generate` otherwise.
583   auto *ctx = rewriter.getContext();
584   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
585   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
586       op, getExtentTensorType(ctx), ValueRange{rank},
587       [&](OpBuilder &b, Location loc, ValueRange args) {
588         Value dim = args.front();
589         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
590         b.create<tensor::YieldOp>(loc, extent);
591       });
592 
593   return success();
594 }
595 
596 namespace {
597 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
598 public:
599   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
600 
601   LogicalResult
602   matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
603                   ConversionPatternRewriter &rewriter) const override;
604 };
605 } // namespace
606 
607 LogicalResult SplitAtOpConversion::matchAndRewrite(
608     SplitAtOp op, ArrayRef<Value> operands,
609     ConversionPatternRewriter &rewriter) const {
610   // Error conditions are not implemented, only lower if all operands and
611   // results are extent tensors.
612   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
613                    [](Value v) { return v.getType().isa<ShapeType>(); }))
614     return failure();
615 
616   SplitAtOp::Adaptor transformed(op);
617   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
618   Value zero = b.create<ConstantIndexOp>(0);
619   Value rank = b.create<tensor::DimOp>(transformed.operand(), zero);
620 
621   // index < 0 ? index + rank : index
622   Value originalIndex = transformed.index();
623   Value add = b.create<AddIOp>(originalIndex, rank);
624   Value indexIsNegative =
625       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
626   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
627 
628   Value one = b.create<ConstantIndexOp>(1);
629   Value head =
630       b.create<tensor::ExtractSliceOp>(transformed.operand(), zero, index, one);
631   Value tailSize = b.create<SubIOp>(rank, index);
632   Value tail = b.create<tensor::ExtractSliceOp>(transformed.operand(), index,
633                                                 tailSize, one);
634   rewriter.replaceOp(op, {head, tail});
635   return success();
636 }
637 
638 namespace {
639 class ToExtentTensorOpConversion
640     : public OpConversionPattern<ToExtentTensorOp> {
641 public:
642   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
643 
644   LogicalResult
645   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     ToExtentTensorOpAdaptor adaptor(operands);
648 
649     if (!adaptor.input().getType().isa<RankedTensorType>())
650       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
651 
652     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
653                                                 adaptor.input());
654     return success();
655   }
656 };
657 } // namespace
658 
659 namespace {
660 /// Import the Shape Ops to Std Patterns.
661 #include "ShapeToStandard.cpp.inc"
662 } // namespace
663 
664 namespace {
665 /// Conversion pass.
666 class ConvertShapeToStandardPass
667     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
668 
669   void runOnOperation() override;
670 };
671 } // namespace
672 
673 void ConvertShapeToStandardPass::runOnOperation() {
674   // Setup target legality.
675   MLIRContext &ctx = getContext();
676   ConversionTarget target(ctx);
677   target
678       .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
679   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
680 
681   // Setup conversion patterns.
682   RewritePatternSet patterns(&ctx);
683   populateShapeToStandardConversionPatterns(patterns);
684 
685   // Apply conversion.
686   auto module = getOperation();
687   if (failed(applyPartialConversion(module, target, std::move(patterns))))
688     signalPassFailure();
689 }
690 
691 void mlir::populateShapeToStandardConversionPatterns(
692     RewritePatternSet &patterns) {
693   // clang-format off
694   populateWithGenerated(patterns);
695   patterns.add<
696       AnyOpConversion,
697       BinaryOpConversion<AddOp, AddIOp>,
698       BinaryOpConversion<MulOp, MulIOp>,
699       BroadcastOpConverter,
700       ConstShapeOpConverter,
701       ConstSizeOpConversion,
702       IsBroadcastableOpConverter,
703       GetExtentOpConverter,
704       RankOpConverter,
705       ReduceOpConverter,
706       ShapeEqOpConverter,
707       ShapeOfOpConversion,
708       SplitAtOpConversion,
709       ToExtentTensorOpConversion>(patterns.getContext());
710   // clang-format on
711 }
712 
713 std::unique_ptr<OperationPass<ModuleOp>>
714 mlir::createConvertShapeToStandardPass() {
715   return std::make_unique<ConvertShapeToStandardPass>();
716 }
717