1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 using namespace mlir::scf;
25 
26 /// Conversion patterns.
27 namespace {
28 class AnyOpConversion : public OpConversionPattern<AnyOp> {
29 public:
30   using OpConversionPattern<AnyOp>::OpConversionPattern;
31 
32   LogicalResult
33   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                   ConversionPatternRewriter &rewriter) const override;
35 };
36 } // namespace
37 
38 LogicalResult
39 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
40                                  ConversionPatternRewriter &rewriter) const {
41   AnyOp::Adaptor transformed(operands);
42 
43   // Replace `any` with its first operand.
44   // Any operand would be a valid substitution.
45   rewriter.replaceOp(op, {transformed.inputs().front()});
46   return success();
47 }
48 
49 namespace {
50 template <typename SrcOpTy, typename DstOpTy>
51 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
52 public:
53   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
57                   ConversionPatternRewriter &rewriter) const override {
58     typename SrcOpTy::Adaptor transformed(operands);
59 
60     // For now, only error-free types are supported by this lowering.
61     if (op.getType().template isa<SizeType>())
62       return failure();
63 
64     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
65                                          transformed.rhs());
66     return success();
67   }
68 };
69 } // namespace
70 
71 namespace {
72 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
73   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
74 
75   LogicalResult
76   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
77                   ConversionPatternRewriter &rewriter) const override;
78 };
79 
80 // Get the resulting extent in a given dimension. This is computed with any
81 // number of extent tensors and shifted offsets into them.
82 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
83                         ValueRange rankDiffs, Value outputDimension) {
84   Value one = lb.create<ConstantIndexOp>(1);
85   Value broadcastedDim = one;
86   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
87     Value shape = std::get<0>(tup);
88     Value rankDiff = std::get<1>(tup);
89     Value outOfBounds =
90         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
91     Type indexTy = lb.getIndexType();
92     broadcastedDim =
93         lb.create<IfOp>(
94               TypeRange{indexTy}, outOfBounds,
95               [&](OpBuilder &b, Location loc) {
96                 b.create<scf::YieldOp>(loc, broadcastedDim);
97               },
98               [&](OpBuilder &b, Location loc) {
99                 // The broadcasting logic is:
100                 // - if one extent (here we arbitrarily choose the
101                 // extent from the greater-rank operand) is equal to 1,
102                 // then take the extent from the other operand
103                 // - otherwise, take the extent as-is.
104                 // Note that this logic remains correct in the presence
105                 // of dimensions of zero extent.
106                 Value lesserRankOperandDimension =
107                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
108                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
109                     loc, shape, ValueRange{lesserRankOperandDimension});
110 
111                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
112                                                   lesserRankOperandExtent, one);
113                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
114                                                lesserRankOperandExtent);
115                 b.create<scf::YieldOp>(loc, dim);
116               })
117             .getResult(0);
118   }
119   return broadcastedDim;
120 }
121 } // namespace
122 
123 LogicalResult BroadcastOpConverter::matchAndRewrite(
124     BroadcastOp op, ArrayRef<Value> operands,
125     ConversionPatternRewriter &rewriter) const {
126   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
127   // on shapes.
128   if (op.getType().isa<ShapeType>())
129     return failure();
130 
131   auto loc = op.getLoc();
132   ImplicitLocOpBuilder lb(loc, rewriter);
133   BroadcastOp::Adaptor transformed(operands);
134 
135   Value zero = lb.create<ConstantIndexOp>(0);
136   Type indexTy = lb.getIndexType();
137 
138   // Save all the ranks for bounds checking. Because this is a tensor
139   // representing the shape extents, the rank is the extent of the only
140   // dimension in the tensor.
141   SmallVector<Value> ranks, rankDiffs;
142   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
143                        return lb.create<memref::DimOp>(v, zero);
144                      }));
145 
146   // Find the maximum rank
147   Value maxRank = ranks.front();
148   for (Value v : llvm::drop_begin(ranks, 1)) {
149     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
150     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
151   }
152 
153   // Calculate the difference of ranks and the maximum rank for later offsets.
154   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155                        return lb.create<SubIOp>(indexTy, maxRank, v);
156                      }));
157 
158   rewriter.replaceOp(
159       op, lb.create<tensor::GenerateOp>(
160                 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
161                 [&](OpBuilder &b, Location loc, ValueRange args) {
162                   Value broadcastedDim = getBroadcastedDim(
163                       ImplicitLocOpBuilder(loc, b), transformed.shapes(),
164                       rankDiffs, args[0]);
165 
166                   b.create<tensor::YieldOp>(loc, broadcastedDim);
167                 })
168               ->getResults());
169   return success();
170 }
171 
172 namespace {
173 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
174 public:
175   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 } // namespace
182 
183 LogicalResult ConstShapeOpConverter::matchAndRewrite(
184     ConstShapeOp op, ArrayRef<Value> operands,
185     ConversionPatternRewriter &rewriter) const {
186 
187   // For now, this lowering supports only extent tensors, not `shape.shape`
188   // types.
189   if (op.getType().isa<ShapeType>())
190     return failure();
191 
192   auto loc = op.getLoc();
193   SmallVector<Value, 4> extentOperands;
194   for (auto extent : op.shape()) {
195     extentOperands.push_back(
196         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
197   }
198   Type indexTy = rewriter.getIndexType();
199   Value tensor =
200       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
201   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
202   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
203   return success();
204 }
205 
206 namespace {
207 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
208 public:
209   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
210 
211   LogicalResult
212   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 } // namespace
216 
217 LogicalResult ConstSizeOpConversion::matchAndRewrite(
218     ConstSizeOp op, ArrayRef<Value> operands,
219     ConversionPatternRewriter &rewriter) const {
220   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
221   return success();
222 }
223 
224 namespace {
225 struct IsBroadcastableOpConverter
226     : public OpConversionPattern<IsBroadcastableOp> {
227   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
228 
229   LogicalResult
230   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
231                   ConversionPatternRewriter &rewriter) const override;
232 };
233 } // namespace
234 
235 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
236     IsBroadcastableOp op, ArrayRef<Value> operands,
237     ConversionPatternRewriter &rewriter) const {
238   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
239   // on shapes.
240   IsBroadcastableOp::Adaptor transformed(operands);
241   if (!llvm::all_of(op.shapes(),
242                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
243     return failure();
244 
245   auto loc = op.getLoc();
246   ImplicitLocOpBuilder lb(loc, rewriter);
247   Value zero = lb.create<ConstantIndexOp>(0);
248   Value one = lb.create<ConstantIndexOp>(1);
249   Type indexTy = lb.getIndexType();
250 
251   // Save all the ranks for bounds checking. Because this is a tensor
252   // representing the shape extents, the rank is the extent of the only
253   // dimension in the tensor.
254   SmallVector<Value> ranks, rankDiffs;
255   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
256                        return lb.create<memref::DimOp>(v, zero);
257                      }));
258 
259   // Find the maximum rank
260   Value maxRank = ranks.front();
261   for (Value v : llvm::drop_begin(ranks, 1)) {
262     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
263     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
264   }
265 
266   // Calculate the difference of ranks and the maximum rank for later offsets.
267   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268                        return lb.create<SubIOp>(indexTy, maxRank, v);
269                      }));
270 
271   Type i1Ty = rewriter.getI1Type();
272   Value trueVal =
273       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274 
275   auto reduceResult = lb.create<ForOp>(
276       loc, zero, maxRank, one, ValueRange{trueVal},
277       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278         // Find a non-1 dim, if it exists. Note that the first part of this
279         // could reuse the Broadcast lowering entirely, but we redo the work
280         // here to make optimizations easier between the two loops.
281         Value broadcastedDim = getBroadcastedDim(
282             ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
283 
284         Value broadcastable = iterArgs[0];
285         for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
286           Value shape, rankDiff;
287           std::tie(shape, rankDiff) = tup;
288           Value outOfBounds =
289               b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
290           broadcastable =
291               b.create<IfOp>(
292                    loc, TypeRange{i1Ty}, outOfBounds,
293                    [&](OpBuilder &b, Location loc) {
294                      // Non existent dimensions are always broadcastable
295                      b.create<scf::YieldOp>(loc, broadcastable);
296                    },
297                    [&](OpBuilder &b, Location loc) {
298                      // Every value needs to be either 1, or the same non-1
299                      // value to be broadcastable in this dim.
300                      Value operandDimension =
301                          b.create<SubIOp>(loc, indexTy, iv, rankDiff);
302                      Value dimensionExtent = b.create<tensor::ExtractOp>(
303                          loc, shape, ValueRange{operandDimension});
304 
305                      Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
306                                                        dimensionExtent, one);
307                      Value equalBroadcasted =
308                          b.create<CmpIOp>(loc, CmpIPredicate::eq,
309                                           dimensionExtent, broadcastedDim);
310                      Value result = b.create<AndOp>(
311                          loc, broadcastable,
312                          b.create<OrOp>(loc, equalOne, equalBroadcasted));
313                      b.create<scf::YieldOp>(loc, result);
314                    })
315                   .getResult(0);
316         }
317 
318         b.create<scf::YieldOp>(loc, broadcastable);
319       });
320 
321   rewriter.replaceOp(op, reduceResult.results().front());
322   return success();
323 }
324 
325 namespace {
326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
327   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
331                   ConversionPatternRewriter &rewriter) const override;
332 };
333 } // namespace
334 
335 LogicalResult GetExtentOpConverter::matchAndRewrite(
336     GetExtentOp op, ArrayRef<Value> operands,
337     ConversionPatternRewriter &rewriter) const {
338   GetExtentOp::Adaptor transformed(operands);
339 
340   // For now, only error-free types are supported by this lowering.
341   if (op.getType().isa<SizeType>())
342     return failure();
343 
344   // Derive shape extent directly from shape origin if possible. This
345   // circumvents the necessity to materialize the shape in memory.
346   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
347     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
348       rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(),
349                                                  transformed.dim());
350       return success();
351     }
352   }
353 
354   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
355                                                  transformed.shape(),
356                                                  ValueRange{transformed.dim()});
357   return success();
358 }
359 
360 namespace {
361 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
362 public:
363   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
364 
365   LogicalResult
366   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
367                   ConversionPatternRewriter &rewriter) const override;
368 };
369 } // namespace
370 
371 LogicalResult
372 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
373                                  ConversionPatternRewriter &rewriter) const {
374   // For now, this lowering supports only error-free types.
375   if (op.getType().isa<SizeType>())
376     return failure();
377 
378   shape::RankOp::Adaptor transformed(operands);
379   rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0);
380   return success();
381 }
382 
383 namespace {
384 /// Converts `shape.reduce` to `scf.for`.
385 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
386 public:
387   using OpConversionPattern::OpConversionPattern;
388 
389   LogicalResult
390   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const final;
392 };
393 } // namespace
394 
395 LogicalResult
396 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
397                                    ConversionPatternRewriter &rewriter) const {
398   // For now, this lowering is only defined on `tensor<?xindex>` operands.
399   if (op.shape().getType().isa<ShapeType>())
400     return failure();
401 
402   auto loc = op.getLoc();
403   shape::ReduceOp::Adaptor transformed(operands);
404 
405   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
406   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
407   Type indexTy = rewriter.getIndexType();
408   Value rank =
409       rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero);
410 
411   auto loop = rewriter.create<scf::ForOp>(
412       loc, zero, rank, one, op.initVals(),
413       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
414         Value extent =
415             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
416 
417         SmallVector<Value, 2> mappedValues{iv, extent};
418         mappedValues.append(args.begin(), args.end());
419 
420         BlockAndValueMapping mapping;
421         Block *reduceBody = op.getBody();
422         mapping.map(reduceBody->getArguments(), mappedValues);
423         for (auto &nested : reduceBody->without_terminator())
424           b.clone(nested, mapping);
425 
426         SmallVector<Value, 2> mappedResults;
427         for (auto result : reduceBody->getTerminator()->getOperands())
428           mappedResults.push_back(mapping.lookup(result));
429         b.create<scf::YieldOp>(loc, mappedResults);
430       });
431 
432   rewriter.replaceOp(op, loop.getResults());
433   return success();
434 }
435 
436 namespace {
437 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
438 /// only defined on `tensor<?xindex>` operands. The test for equality first
439 /// compares their size and, if equal, checks every extent for equality.
440 ///
441 /// Example:
442 ///
443 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
444 ///
445 /// becomes
446 ///
447 /// %c0 = constant 0 : index
448 /// %0 = dim %arg0, %c0 : tensor<?xindex>
449 /// %1 = dim %arg1, %c0 : tensor<?xindex>
450 /// %2 = cmpi "eq", %0, %1 : index
451 /// %result = scf.if %2 -> (i1) {
452 ///   %c1 = constant 1 : index
453 ///   %true = constant true
454 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
455 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
456 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
457 ///     %7 = cmpi "eq", %5, %6 : index
458 ///     %8 = and %arg3, %7 : i1
459 ///     scf.yield %8 : i1
460 ///   }
461 ///   scf.yield %4 : i1
462 /// } else {
463 ///   %false = constant false
464 ///   scf.yield %false : i1
465 /// }
466 ///
467 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
468   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
469 
470   LogicalResult
471   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
472                   ConversionPatternRewriter &rewriter) const override;
473 };
474 } // namespace
475 
476 LogicalResult
477 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
478                                     ConversionPatternRewriter &rewriter) const {
479   if (!llvm::all_of(op.shapes(),
480                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
481     return failure();
482 
483   Type i1Ty = rewriter.getI1Type();
484   if (op.shapes().size() <= 1) {
485     rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
486                                             rewriter.getBoolAttr(true));
487     return success();
488   }
489 
490   ShapeEqOp::Adaptor transformed(operands);
491   auto loc = op.getLoc();
492   Type indexTy = rewriter.getIndexType();
493   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
494   Value firstShape = transformed.shapes().front();
495   Value firstRank =
496       rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero);
497   Value result = nullptr;
498   // Generate a linear sequence of compares, all with firstShape as lhs.
499   for (Value shape : transformed.shapes().drop_front(1)) {
500     Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero);
501     Value eqRank =
502         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
503     auto same = rewriter.create<IfOp>(
504         loc, i1Ty, eqRank,
505         [&](OpBuilder &b, Location loc) {
506           Value one = b.create<ConstantIndexOp>(loc, 1);
507           Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
508           auto loop = b.create<scf::ForOp>(
509               loc, zero, firstRank, one, ValueRange{init},
510               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
511                 Value conj = args[0];
512                 Value lhsExtent =
513                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
514                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
515                 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
516                                                   lhsExtent, rhsExtent);
517                 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
518                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
519               });
520           b.create<scf::YieldOp>(loc, loop.getResults());
521         },
522         [&](OpBuilder &b, Location loc) {
523           Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
524           b.create<scf::YieldOp>(loc, result);
525         });
526     result = !result ? same.getResult(0)
527                      : rewriter.create<AndOp>(loc, result, same.getResult(0));
528   }
529   rewriter.replaceOp(op, result);
530   return success();
531 }
532 
533 namespace {
534 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
535 public:
536   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
537 
538   LogicalResult
539   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
540                   ConversionPatternRewriter &rewriter) const override;
541 };
542 } // namespace
543 
544 LogicalResult ShapeOfOpConversion::matchAndRewrite(
545     ShapeOfOp op, ArrayRef<Value> operands,
546     ConversionPatternRewriter &rewriter) const {
547 
548   // For now, only error-free types are supported by this lowering.
549   if (op.getType().isa<ShapeType>())
550     return failure();
551 
552   // For ranked tensor arguments, lower to `tensor.from_elements`.
553   auto loc = op.getLoc();
554   ShapeOfOp::Adaptor transformed(operands);
555   Value tensor = transformed.arg();
556   Type tensorTy = tensor.getType();
557   if (tensorTy.isa<RankedTensorType>()) {
558 
559     // Build values for individual extents.
560     SmallVector<Value, 8> extentValues;
561     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
562     int64_t rank = rankedTensorTy.getRank();
563     for (int64_t i = 0; i < rank; i++) {
564       if (rankedTensorTy.isDynamicDim(i)) {
565         Value extent = rewriter.create<memref::DimOp>(loc, tensor, i);
566         extentValues.push_back(extent);
567       } else {
568         Value extent =
569             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
570         extentValues.push_back(extent);
571       }
572     }
573 
574     // Materialize extent tensor.
575     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
576         loc, rewriter.getIndexType(), extentValues);
577     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
578                                                 staticExtentTensor);
579     return success();
580   }
581 
582   // Lower to `tensor.generate` otherwise.
583   auto *ctx = rewriter.getContext();
584   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
585   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
586       op, getExtentTensorType(ctx), ValueRange{rank},
587       [&](OpBuilder &b, Location loc, ValueRange args) {
588         Value dim = args.front();
589         Value extent = b.create<memref::DimOp>(loc, tensor, dim);
590         b.create<tensor::YieldOp>(loc, extent);
591       });
592 
593   return success();
594 }
595 
596 namespace {
597 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
598 public:
599   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
600 
601   LogicalResult
602   matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
603                   ConversionPatternRewriter &rewriter) const override;
604 };
605 } // namespace
606 
607 LogicalResult SplitAtOpConversion::matchAndRewrite(
608     SplitAtOp op, ArrayRef<Value> operands,
609     ConversionPatternRewriter &rewriter) const {
610   // Error conditions are not implemented, only lower if all operands and
611   // results are extent tensors.
612   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
613                    [](Value v) { return v.getType().isa<ShapeType>(); }))
614     return failure();
615 
616   SplitAtOp::Adaptor transformed(op);
617   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
618   Value zero = b.create<ConstantIndexOp>(0);
619   Value rank = b.create<memref::DimOp>(transformed.operand(), zero);
620 
621   // index < 0 ? index + rank : index
622   Value originalIndex = transformed.index();
623   Value add = b.create<AddIOp>(originalIndex, rank);
624   Value indexIsNegative =
625       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
626   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
627 
628   Value one = b.create<ConstantIndexOp>(1);
629   Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
630   Value tailSize = b.create<SubIOp>(rank, index);
631   Value tail =
632       b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
633   rewriter.replaceOp(op, {head, tail});
634   return success();
635 }
636 
637 namespace {
638 class ToExtentTensorOpConversion
639     : public OpConversionPattern<ToExtentTensorOp> {
640 public:
641   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
642 
643   LogicalResult
644   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
645                   ConversionPatternRewriter &rewriter) const override {
646     ToExtentTensorOpAdaptor adaptor(operands);
647 
648     if (!adaptor.input().getType().isa<RankedTensorType>())
649       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
650 
651     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
652                                                 adaptor.input());
653     return success();
654   }
655 };
656 } // namespace
657 
658 namespace {
659 /// Import the Shape Ops to Std Patterns.
660 #include "ShapeToStandard.cpp.inc"
661 } // namespace
662 
663 namespace {
664 /// Conversion pass.
665 class ConvertShapeToStandardPass
666     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
667 
668   void runOnOperation() override;
669 };
670 } // namespace
671 
672 void ConvertShapeToStandardPass::runOnOperation() {
673   // Setup target legality.
674   MLIRContext &ctx = getContext();
675   ConversionTarget target(ctx);
676   target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
677                          tensor::TensorDialect>();
678   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
679 
680   // Setup conversion patterns.
681   RewritePatternSet patterns(&ctx);
682   populateShapeToStandardConversionPatterns(patterns);
683 
684   // Apply conversion.
685   auto module = getOperation();
686   if (failed(applyPartialConversion(module, target, std::move(patterns))))
687     signalPassFailure();
688 }
689 
690 void mlir::populateShapeToStandardConversionPatterns(
691     RewritePatternSet &patterns) {
692   // clang-format off
693   populateWithGenerated(patterns);
694   patterns.add<
695       AnyOpConversion,
696       BinaryOpConversion<AddOp, AddIOp>,
697       BinaryOpConversion<MulOp, MulIOp>,
698       BroadcastOpConverter,
699       ConstShapeOpConverter,
700       ConstSizeOpConversion,
701       IsBroadcastableOpConverter,
702       GetExtentOpConverter,
703       RankOpConverter,
704       ReduceOpConverter,
705       ShapeEqOpConverter,
706       ShapeOfOpConversion,
707       SplitAtOpConversion,
708       ToExtentTensorOpConversion>(patterns.getContext());
709   // clang-format on
710 }
711 
712 std::unique_ptr<OperationPass<ModuleOp>>
713 mlir::createConvertShapeToStandardPass() {
714   return std::make_unique<ConvertShapeToStandardPass>();
715 }
716