1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 using namespace mlir::scf;
24 
25 /// Conversion patterns.
26 namespace {
27 class AnyOpConversion : public OpConversionPattern<AnyOp> {
28 public:
29   using OpConversionPattern<AnyOp>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
33                   ConversionPatternRewriter &rewriter) const override;
34 };
35 } // namespace
36 
37 LogicalResult
38 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
39                                  ConversionPatternRewriter &rewriter) const {
40   // Replace `any` with its first operand.
41   // Any operand would be a valid substitution.
42   rewriter.replaceOp(op, {adaptor.getInputs().front()});
43   return success();
44 }
45 
46 namespace {
47 template <typename SrcOpTy, typename DstOpTy>
48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
49 public:
50   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
51 
52   LogicalResult
53   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
54                   ConversionPatternRewriter &rewriter) const override {
55     // For now, only error-free types are supported by this lowering.
56     if (op.getType().template isa<SizeType>())
57       return failure();
58 
59     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
60                                          adaptor.getRhs());
61     return success();
62   }
63 };
64 } // namespace
65 
66 namespace {
67 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
68   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override;
73 };
74 
75 // Get the resulting extent in a given dimension. This is computed with any
76 // number of extent tensors and shifted offsets into them.
77 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
78                         ValueRange rankDiffs, Value outputDimension) {
79   Value one = lb.create<arith::ConstantIndexOp>(1);
80   Value broadcastedDim = one;
81   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
82     Value shape = std::get<0>(tup);
83     Value rankDiff = std::get<1>(tup);
84     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
85                                                  outputDimension, rankDiff);
86     Type indexTy = lb.getIndexType();
87     broadcastedDim =
88         lb.create<IfOp>(
89               TypeRange{indexTy}, outOfBounds,
90               [&](OpBuilder &b, Location loc) {
91                 b.create<scf::YieldOp>(loc, broadcastedDim);
92               },
93               [&](OpBuilder &b, Location loc) {
94                 // The broadcasting logic is:
95                 // - if one extent (here we arbitrarily choose the
96                 // extent from the greater-rank operand) is equal to 1,
97                 // then take the extent from the other operand
98                 // - otherwise, take the extent as-is.
99                 // Note that this logic remains correct in the presence
100                 // of dimensions of zero extent.
101                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
102                     loc, indexTy, outputDimension, rankDiff);
103                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
104                     loc, shape, ValueRange{lesserRankOperandDimension});
105 
106                 Value dimIsOne =
107                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
108                                             lesserRankOperandExtent, one);
109                 Value dim = b.create<arith::SelectOp>(
110                     loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
111                 b.create<scf::YieldOp>(loc, dim);
112               })
113             .getResult(0);
114   }
115   return broadcastedDim;
116 }
117 } // namespace
118 
119 LogicalResult BroadcastOpConverter::matchAndRewrite(
120     BroadcastOp op, OpAdaptor adaptor,
121     ConversionPatternRewriter &rewriter) const {
122   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
123   // on shapes.
124   if (op.getType().isa<ShapeType>())
125     return failure();
126 
127   auto loc = op.getLoc();
128   ImplicitLocOpBuilder lb(loc, rewriter);
129 
130   Value zero = lb.create<arith::ConstantIndexOp>(0);
131   Type indexTy = lb.getIndexType();
132 
133   // Save all the ranks for bounds checking. Because this is a tensor
134   // representing the shape extents, the rank is the extent of the only
135   // dimension in the tensor.
136   SmallVector<Value> ranks, rankDiffs;
137   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
138                        return lb.create<tensor::DimOp>(v, zero);
139                      }));
140 
141   // Find the maximum rank
142   Value maxRank = ranks.front();
143   for (Value v : llvm::drop_begin(ranks, 1)) {
144     Value rankIsGreater =
145         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
146     maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
147   }
148 
149   // Calculate the difference of ranks and the maximum rank for later offsets.
150   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
151                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
152                      }));
153 
154   Value replacement = lb.create<tensor::GenerateOp>(
155       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
156       [&](OpBuilder &b, Location loc, ValueRange args) {
157         Value broadcastedDim =
158             getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
159                               rankDiffs, args[0]);
160 
161         b.create<tensor::YieldOp>(loc, broadcastedDim);
162       });
163   if (replacement.getType() != op.getType())
164     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
165   rewriter.replaceOp(op, replacement);
166   return success();
167 }
168 
169 namespace {
170 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
171 public:
172   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
173 
174   LogicalResult
175   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
176                   ConversionPatternRewriter &rewriter) const override;
177 };
178 } // namespace
179 
180 LogicalResult ConstShapeOpConverter::matchAndRewrite(
181     ConstShapeOp op, OpAdaptor adaptor,
182     ConversionPatternRewriter &rewriter) const {
183 
184   // For now, this lowering supports only extent tensors, not `shape.shape`
185   // types.
186   if (op.getType().isa<ShapeType>())
187     return failure();
188 
189   auto loc = op.getLoc();
190   SmallVector<Value, 4> extentOperands;
191   for (auto extent : op.getShape()) {
192     extentOperands.push_back(
193         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
194   }
195   Type resultTy =
196       RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
197   Value tensor =
198       rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
199   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
200   return success();
201 }
202 
203 namespace {
204 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
205 public:
206   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
207 
208   LogicalResult
209   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
210                   ConversionPatternRewriter &rewriter) const override;
211 };
212 } // namespace
213 
214 LogicalResult ConstSizeOpConversion::matchAndRewrite(
215     ConstSizeOp op, OpAdaptor adaptor,
216     ConversionPatternRewriter &rewriter) const {
217   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
218       op, op.getValue().getSExtValue());
219   return success();
220 }
221 
222 namespace {
223 struct IsBroadcastableOpConverter
224     : public OpConversionPattern<IsBroadcastableOp> {
225   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
226 
227   LogicalResult
228   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
229                   ConversionPatternRewriter &rewriter) const override;
230 };
231 } // namespace
232 
233 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
234     IsBroadcastableOp op, OpAdaptor adaptor,
235     ConversionPatternRewriter &rewriter) const {
236   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
237   // on shapes.
238   if (!llvm::all_of(op.getShapes(),
239                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
240     return failure();
241 
242   auto loc = op.getLoc();
243   ImplicitLocOpBuilder lb(loc, rewriter);
244   Value zero = lb.create<arith::ConstantIndexOp>(0);
245   Value one = lb.create<arith::ConstantIndexOp>(1);
246   Type indexTy = lb.getIndexType();
247 
248   // Save all the ranks for bounds checking. Because this is a tensor
249   // representing the shape extents, the rank is the extent of the only
250   // dimension in the tensor.
251   SmallVector<Value> ranks, rankDiffs;
252   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
253                        return lb.create<tensor::DimOp>(v, zero);
254                      }));
255 
256   // Find the maximum rank
257   Value maxRank = ranks.front();
258   for (Value v : llvm::drop_begin(ranks, 1)) {
259     Value rankIsGreater =
260         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
261     maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
262   }
263 
264   // Calculate the difference of ranks and the maximum rank for later offsets.
265   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
266                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
267                      }));
268 
269   Type i1Ty = rewriter.getI1Type();
270   Value trueVal =
271       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
272 
273   auto reduceResult = lb.create<ForOp>(
274       loc, zero, maxRank, one, ValueRange{trueVal},
275       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
276         // Find a non-1 dim, if it exists. Note that the first part of this
277         // could reuse the Broadcast lowering entirely, but we redo the work
278         // here to make optimizations easier between the two loops.
279         Value broadcastedDim = getBroadcastedDim(
280             ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
281 
282         Value broadcastable = iterArgs[0];
283         for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
284           Value shape, rankDiff;
285           std::tie(shape, rankDiff) = tup;
286           Value outOfBounds = b.create<arith::CmpIOp>(
287               loc, arith::CmpIPredicate::ult, iv, rankDiff);
288           broadcastable =
289               b.create<IfOp>(
290                    loc, TypeRange{i1Ty}, outOfBounds,
291                    [&](OpBuilder &b, Location loc) {
292                      // Non existent dimensions are always broadcastable
293                      b.create<scf::YieldOp>(loc, broadcastable);
294                    },
295                    [&](OpBuilder &b, Location loc) {
296                      // Every value needs to be either 1, or the same non-1
297                      // value to be broadcastable in this dim.
298                      Value operandDimension =
299                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
300                      Value dimensionExtent = b.create<tensor::ExtractOp>(
301                          loc, shape, ValueRange{operandDimension});
302 
303                      Value equalOne = b.create<arith::CmpIOp>(
304                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
305                      Value equalBroadcasted = b.create<arith::CmpIOp>(
306                          loc, arith::CmpIPredicate::eq, dimensionExtent,
307                          broadcastedDim);
308                      Value result = b.create<arith::AndIOp>(
309                          loc, broadcastable,
310                          b.create<arith::OrIOp>(loc, equalOne,
311                                                 equalBroadcasted));
312                      b.create<scf::YieldOp>(loc, result);
313                    })
314                   .getResult(0);
315         }
316 
317         b.create<scf::YieldOp>(loc, broadcastable);
318       });
319 
320   rewriter.replaceOp(op, reduceResult.getResults().front());
321   return success();
322 }
323 
324 namespace {
325 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
326   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
327 
328   LogicalResult
329   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
330                   ConversionPatternRewriter &rewriter) const override;
331 };
332 } // namespace
333 
334 LogicalResult GetExtentOpConverter::matchAndRewrite(
335     GetExtentOp op, OpAdaptor adaptor,
336     ConversionPatternRewriter &rewriter) const {
337   // For now, only error-free types are supported by this lowering.
338   if (op.getType().isa<SizeType>())
339     return failure();
340 
341   // Derive shape extent directly from shape origin if possible. This
342   // circumvents the necessity to materialize the shape in memory.
343   if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
344     if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
345       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
346                                                  adaptor.getDim());
347       return success();
348     }
349   }
350 
351   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
352                                                  adaptor.getShape(),
353                                                  ValueRange{adaptor.getDim()});
354   return success();
355 }
356 
357 namespace {
358 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
359 public:
360   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
361 
362   LogicalResult
363   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
364                   ConversionPatternRewriter &rewriter) const override;
365 };
366 } // namespace
367 
368 LogicalResult
369 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
370                                  ConversionPatternRewriter &rewriter) const {
371   // For now, this lowering supports only error-free types.
372   if (op.getType().isa<SizeType>())
373     return failure();
374 
375   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
376   return success();
377 }
378 
379 namespace {
380 /// Converts `shape.reduce` to `scf.for`.
381 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
382 public:
383   using OpConversionPattern::OpConversionPattern;
384 
385   LogicalResult
386   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
387                   ConversionPatternRewriter &rewriter) const final;
388 };
389 } // namespace
390 
391 LogicalResult
392 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
393                                    ConversionPatternRewriter &rewriter) const {
394   // For now, this lowering is only defined on `tensor<?xindex>` operands.
395   if (op.getShape().getType().isa<ShapeType>())
396     return failure();
397 
398   auto loc = op.getLoc();
399 
400   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
401   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
402   Type indexTy = rewriter.getIndexType();
403   Value rank =
404       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
405 
406   auto loop = rewriter.create<scf::ForOp>(
407       loc, zero, rank, one, op.getInitVals(),
408       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
409         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
410 
411         SmallVector<Value, 2> mappedValues{iv, extent};
412         mappedValues.append(args.begin(), args.end());
413 
414         BlockAndValueMapping mapping;
415         Block *reduceBody = op.getBody();
416         mapping.map(reduceBody->getArguments(), mappedValues);
417         for (auto &nested : reduceBody->without_terminator())
418           b.clone(nested, mapping);
419 
420         SmallVector<Value, 2> mappedResults;
421         for (auto result : reduceBody->getTerminator()->getOperands())
422           mappedResults.push_back(mapping.lookup(result));
423         b.create<scf::YieldOp>(loc, mappedResults);
424       });
425 
426   rewriter.replaceOp(op, loop.getResults());
427   return success();
428 }
429 
430 namespace {
431 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
432 /// only defined on `tensor<?xindex>` operands. The test for equality first
433 /// compares their size and, if equal, checks every extent for equality.
434 ///
435 /// Example:
436 ///
437 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
438 ///
439 /// becomes
440 ///
441 /// %c0 = arith.constant 0 : index
442 /// %0 = dim %arg0, %c0 : tensor<?xindex>
443 /// %1 = dim %arg1, %c0 : tensor<?xindex>
444 /// %2 = arith.cmpi "eq", %0, %1 : index
445 /// %result = scf.if %2 -> (i1) {
446 ///   %c1 = arith.constant 1 : index
447 ///   %true = arith.constant true
448 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
449 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
450 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
451 ///     %7 = arith.cmpi "eq", %5, %6 : index
452 ///     %8 = arith.andi %arg3, %7 : i1
453 ///     scf.yield %8 : i1
454 ///   }
455 ///   scf.yield %4 : i1
456 /// } else {
457 ///   %false = arith.constant false
458 ///   scf.yield %false : i1
459 /// }
460 ///
461 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
462   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
463 
464   LogicalResult
465   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
466                   ConversionPatternRewriter &rewriter) const override;
467 };
468 } // namespace
469 
470 LogicalResult
471 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
472                                     ConversionPatternRewriter &rewriter) const {
473   if (!llvm::all_of(op.getShapes(),
474                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
475     return failure();
476 
477   Type i1Ty = rewriter.getI1Type();
478   if (op.getShapes().size() <= 1) {
479     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
480                                                    rewriter.getBoolAttr(true));
481     return success();
482   }
483 
484   auto loc = op.getLoc();
485   Type indexTy = rewriter.getIndexType();
486   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
487   Value firstShape = adaptor.getShapes().front();
488   Value firstRank =
489       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
490   Value result = nullptr;
491   // Generate a linear sequence of compares, all with firstShape as lhs.
492   for (Value shape : adaptor.getShapes().drop_front(1)) {
493     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
494     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
495                                                   firstRank, rank);
496     auto same = rewriter.create<IfOp>(
497         loc, i1Ty, eqRank,
498         [&](OpBuilder &b, Location loc) {
499           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
500           Value init =
501               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
502           auto loop = b.create<scf::ForOp>(
503               loc, zero, firstRank, one, ValueRange{init},
504               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
505                 Value conj = args[0];
506                 Value lhsExtent =
507                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
508                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
509                 Value eqExtent = b.create<arith::CmpIOp>(
510                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
511                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
512                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
513               });
514           b.create<scf::YieldOp>(loc, loop.getResults());
515         },
516         [&](OpBuilder &b, Location loc) {
517           Value result =
518               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
519           b.create<scf::YieldOp>(loc, result);
520         });
521     result = !result ? same.getResult(0)
522                      : rewriter.create<arith::AndIOp>(loc, result,
523                                                       same.getResult(0));
524   }
525   rewriter.replaceOp(op, result);
526   return success();
527 }
528 
529 namespace {
530 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
531 public:
532   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
533 
534   LogicalResult
535   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
536                   ConversionPatternRewriter &rewriter) const override;
537 };
538 } // namespace
539 
540 LogicalResult ShapeOfOpConversion::matchAndRewrite(
541     ShapeOfOp op, OpAdaptor adaptor,
542     ConversionPatternRewriter &rewriter) const {
543 
544   // For now, only error-free types are supported by this lowering.
545   if (op.getType().isa<ShapeType>())
546     return failure();
547 
548   // For ranked tensor arguments, lower to `tensor.from_elements`.
549   auto loc = op.getLoc();
550   Value tensor = adaptor.getArg();
551   Type tensorTy = tensor.getType();
552   if (tensorTy.isa<RankedTensorType>()) {
553 
554     // Build values for individual extents.
555     SmallVector<Value, 8> extentValues;
556     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
557     int64_t rank = rankedTensorTy.getRank();
558     for (int64_t i = 0; i < rank; i++) {
559       if (rankedTensorTy.isDynamicDim(i)) {
560         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
561         extentValues.push_back(extent);
562       } else {
563         Value extent = rewriter.create<arith::ConstantIndexOp>(
564             loc, rankedTensorTy.getDimSize(i));
565         extentValues.push_back(extent);
566       }
567     }
568 
569     // Materialize extent tensor.
570     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
571         loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
572         extentValues);
573     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
574                                                 staticExtentTensor);
575     return success();
576   }
577 
578   // Lower to `tensor.generate` otherwise.
579   auto *ctx = rewriter.getContext();
580   Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
581   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
582       op, getExtentTensorType(ctx), ValueRange{rank},
583       [&](OpBuilder &b, Location loc, ValueRange args) {
584         Value dim = args.front();
585         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
586         b.create<tensor::YieldOp>(loc, extent);
587       });
588 
589   return success();
590 }
591 
592 namespace {
593 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
594 public:
595   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
596 
597   LogicalResult
598   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
599                   ConversionPatternRewriter &rewriter) const override;
600 };
601 } // namespace
602 
603 LogicalResult SplitAtOpConversion::matchAndRewrite(
604     SplitAtOp op, OpAdaptor adaptor,
605     ConversionPatternRewriter &rewriter) const {
606   // Error conditions are not implemented, only lower if all operands and
607   // results are extent tensors.
608   if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
609                    [](Value v) { return v.getType().isa<ShapeType>(); }))
610     return failure();
611 
612   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
613   Value zero = b.create<arith::ConstantIndexOp>(0);
614   Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
615 
616   // index < 0 ? index + rank : index
617   Value originalIndex = adaptor.getIndex();
618   Value add = b.create<arith::AddIOp>(originalIndex, rank);
619   Value indexIsNegative =
620       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
621   Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
622 
623   Value one = b.create<arith::ConstantIndexOp>(1);
624   Value head =
625       b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
626   Value tailSize = b.create<arith::SubIOp>(rank, index);
627   Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
628                                                 tailSize, one);
629   rewriter.replaceOp(op, {head, tail});
630   return success();
631 }
632 
633 namespace {
634 class ToExtentTensorOpConversion
635     : public OpConversionPattern<ToExtentTensorOp> {
636 public:
637   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
638 
639   LogicalResult
640   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
641                   ConversionPatternRewriter &rewriter) const override {
642     if (!adaptor.getInput().getType().isa<RankedTensorType>())
643       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
644 
645     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
646                                                 adaptor.getInput());
647     return success();
648   }
649 };
650 } // namespace
651 
652 namespace {
653 /// Import the Shape Ops to Std Patterns.
654 #include "ShapeToStandard.cpp.inc"
655 } // namespace
656 
657 namespace {
658 /// Conversion pass.
659 class ConvertShapeToStandardPass
660     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
661 
662   void runOnOperation() override;
663 };
664 } // namespace
665 
666 void ConvertShapeToStandardPass::runOnOperation() {
667   // Setup target legality.
668   MLIRContext &ctx = getContext();
669   ConversionTarget target(ctx);
670   target.addLegalDialect<arith::ArithmeticDialect, SCFDialect,
671                          tensor::TensorDialect>();
672   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
673 
674   // Setup conversion patterns.
675   RewritePatternSet patterns(&ctx);
676   populateShapeToStandardConversionPatterns(patterns);
677 
678   // Apply conversion.
679   auto module = getOperation();
680   if (failed(applyPartialConversion(module, target, std::move(patterns))))
681     signalPassFailure();
682 }
683 
684 void mlir::populateShapeToStandardConversionPatterns(
685     RewritePatternSet &patterns) {
686   // clang-format off
687   populateWithGenerated(patterns);
688   patterns.add<
689       AnyOpConversion,
690       BinaryOpConversion<AddOp, arith::AddIOp>,
691       BinaryOpConversion<MulOp, arith::MulIOp>,
692       BroadcastOpConverter,
693       ConstShapeOpConverter,
694       ConstSizeOpConversion,
695       IsBroadcastableOpConverter,
696       GetExtentOpConverter,
697       RankOpConverter,
698       ReduceOpConverter,
699       ShapeEqOpConverter,
700       ShapeOfOpConversion,
701       SplitAtOpConversion,
702       ToExtentTensorOpConversion>(patterns.getContext());
703   // clang-format on
704 }
705 
706 std::unique_ptr<OperationPass<ModuleOp>>
707 mlir::createConvertShapeToStandardPass() {
708   return std::make_unique<ConvertShapeToStandardPass>();
709 }
710