1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 using namespace mlir::scf;
25 
26 /// Conversion patterns.
27 namespace {
28 class AnyOpConversion : public OpConversionPattern<AnyOp> {
29 public:
30   using OpConversionPattern<AnyOp>::OpConversionPattern;
31 
32   LogicalResult
33   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
34                   ConversionPatternRewriter &rewriter) const override;
35 };
36 } // namespace
37 
38 LogicalResult
39 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
40                                  ConversionPatternRewriter &rewriter) const {
41   AnyOp::Adaptor transformed(operands);
42 
43   // Replace `any` with its first operand.
44   // Any operand would be a valid substitution.
45   rewriter.replaceOp(op, {transformed.inputs().front()});
46   return success();
47 }
48 
49 namespace {
50 template <typename SrcOpTy, typename DstOpTy>
51 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
52 public:
53   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
57                   ConversionPatternRewriter &rewriter) const override {
58     typename SrcOpTy::Adaptor transformed(operands);
59 
60     // For now, only error-free types are supported by this lowering.
61     if (op.getType().template isa<SizeType>())
62       return failure();
63 
64     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
65                                          transformed.rhs());
66     return success();
67   }
68 };
69 } // namespace
70 
71 namespace {
72 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
73   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
74 
75   LogicalResult
76   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
77                   ConversionPatternRewriter &rewriter) const override;
78 };
79 
80 // Get the resulting extent in a given dimension. This is computed with any
81 // number of extent tensors and shifted offsets into them.
82 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
83                         ValueRange rankDiffs, Value outputDimension) {
84   Value one = lb.create<ConstantIndexOp>(1);
85   Value broadcastedDim = one;
86   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
87     Value shape = std::get<0>(tup);
88     Value rankDiff = std::get<1>(tup);
89     Value outOfBounds =
90         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
91     Type indexTy = lb.getIndexType();
92     broadcastedDim =
93         lb.create<IfOp>(
94               TypeRange{indexTy}, outOfBounds,
95               [&](OpBuilder &b, Location loc) {
96                 b.create<scf::YieldOp>(loc, broadcastedDim);
97               },
98               [&](OpBuilder &b, Location loc) {
99                 // The broadcasting logic is:
100                 // - if one extent (here we arbitrarily choose the
101                 // extent from the greater-rank operand) is equal to 1,
102                 // then take the extent from the other operand
103                 // - otherwise, take the extent as-is.
104                 // Note that this logic remains correct in the presence
105                 // of dimensions of zero extent.
106                 Value lesserRankOperandDimension =
107                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
108                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
109                     loc, shape, ValueRange{lesserRankOperandDimension});
110 
111                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
112                                                   lesserRankOperandExtent, one);
113                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
114                                                lesserRankOperandExtent);
115                 b.create<scf::YieldOp>(loc, dim);
116               })
117             .getResult(0);
118   }
119   return broadcastedDim;
120 }
121 } // namespace
122 
123 LogicalResult BroadcastOpConverter::matchAndRewrite(
124     BroadcastOp op, ArrayRef<Value> operands,
125     ConversionPatternRewriter &rewriter) const {
126   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
127   // on shapes.
128   if (op.getType().isa<ShapeType>())
129     return failure();
130 
131   auto loc = op.getLoc();
132   ImplicitLocOpBuilder lb(loc, rewriter);
133   BroadcastOp::Adaptor transformed(operands);
134 
135   Value zero = lb.create<ConstantIndexOp>(0);
136   Type indexTy = lb.getIndexType();
137 
138   // Save all the ranks for bounds checking. Because this is a tensor
139   // representing the shape extents, the rank is the extent of the only
140   // dimension in the tensor.
141   SmallVector<Value> ranks, rankDiffs;
142   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
143                        return lb.create<memref::DimOp>(v, zero);
144                      }));
145 
146   // Find the maximum rank
147   Value maxRank = ranks.front();
148   for (Value v : llvm::drop_begin(ranks, 1)) {
149     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
150     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
151   }
152 
153   // Calculate the difference of ranks and the maximum rank for later offsets.
154   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155                        return lb.create<SubIOp>(indexTy, maxRank, v);
156                      }));
157 
158   Value replacement = lb.create<tensor::GenerateOp>(
159       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160       [&](OpBuilder &b, Location loc, ValueRange args) {
161         Value broadcastedDim =
162             getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
163                               transformed.shapes(), rankDiffs, args[0]);
164 
165         b.create<tensor::YieldOp>(loc, broadcastedDim);
166       });
167   if (replacement.getType() != op.getType())
168     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169   rewriter.replaceOp(op, replacement);
170   return success();
171 }
172 
173 namespace {
174 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175 public:
176   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
177 
178   LogicalResult
179   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
180                   ConversionPatternRewriter &rewriter) const override;
181 };
182 } // namespace
183 
184 LogicalResult ConstShapeOpConverter::matchAndRewrite(
185     ConstShapeOp op, ArrayRef<Value> operands,
186     ConversionPatternRewriter &rewriter) const {
187 
188   // For now, this lowering supports only extent tensors, not `shape.shape`
189   // types.
190   if (op.getType().isa<ShapeType>())
191     return failure();
192 
193   auto loc = op.getLoc();
194   SmallVector<Value, 4> extentOperands;
195   for (auto extent : op.shape()) {
196     extentOperands.push_back(
197         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
198   }
199   Type indexTy = rewriter.getIndexType();
200   Value tensor =
201       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
202   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
203   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204   return success();
205 }
206 
207 namespace {
208 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
209 public:
210   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
211 
212   LogicalResult
213   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
214                   ConversionPatternRewriter &rewriter) const override;
215 };
216 } // namespace
217 
218 LogicalResult ConstSizeOpConversion::matchAndRewrite(
219     ConstSizeOp op, ArrayRef<Value> operands,
220     ConversionPatternRewriter &rewriter) const {
221   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
222   return success();
223 }
224 
225 namespace {
226 struct IsBroadcastableOpConverter
227     : public OpConversionPattern<IsBroadcastableOp> {
228   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
229 
230   LogicalResult
231   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
232                   ConversionPatternRewriter &rewriter) const override;
233 };
234 } // namespace
235 
236 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
237     IsBroadcastableOp op, ArrayRef<Value> operands,
238     ConversionPatternRewriter &rewriter) const {
239   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
240   // on shapes.
241   IsBroadcastableOp::Adaptor transformed(operands);
242   if (!llvm::all_of(op.shapes(),
243                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
244     return failure();
245 
246   auto loc = op.getLoc();
247   ImplicitLocOpBuilder lb(loc, rewriter);
248   Value zero = lb.create<ConstantIndexOp>(0);
249   Value one = lb.create<ConstantIndexOp>(1);
250   Type indexTy = lb.getIndexType();
251 
252   // Save all the ranks for bounds checking. Because this is a tensor
253   // representing the shape extents, the rank is the extent of the only
254   // dimension in the tensor.
255   SmallVector<Value> ranks, rankDiffs;
256   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
257                        return lb.create<memref::DimOp>(v, zero);
258                      }));
259 
260   // Find the maximum rank
261   Value maxRank = ranks.front();
262   for (Value v : llvm::drop_begin(ranks, 1)) {
263     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
264     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
265   }
266 
267   // Calculate the difference of ranks and the maximum rank for later offsets.
268   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
269                        return lb.create<SubIOp>(indexTy, maxRank, v);
270                      }));
271 
272   Type i1Ty = rewriter.getI1Type();
273   Value trueVal =
274       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
275 
276   auto reduceResult = lb.create<ForOp>(
277       loc, zero, maxRank, one, ValueRange{trueVal},
278       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
279         // Find a non-1 dim, if it exists. Note that the first part of this
280         // could reuse the Broadcast lowering entirely, but we redo the work
281         // here to make optimizations easier between the two loops.
282         Value broadcastedDim = getBroadcastedDim(
283             ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
284 
285         Value broadcastable = iterArgs[0];
286         for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
287           Value shape, rankDiff;
288           std::tie(shape, rankDiff) = tup;
289           Value outOfBounds =
290               b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
291           broadcastable =
292               b.create<IfOp>(
293                    loc, TypeRange{i1Ty}, outOfBounds,
294                    [&](OpBuilder &b, Location loc) {
295                      // Non existent dimensions are always broadcastable
296                      b.create<scf::YieldOp>(loc, broadcastable);
297                    },
298                    [&](OpBuilder &b, Location loc) {
299                      // Every value needs to be either 1, or the same non-1
300                      // value to be broadcastable in this dim.
301                      Value operandDimension =
302                          b.create<SubIOp>(loc, indexTy, iv, rankDiff);
303                      Value dimensionExtent = b.create<tensor::ExtractOp>(
304                          loc, shape, ValueRange{operandDimension});
305 
306                      Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
307                                                        dimensionExtent, one);
308                      Value equalBroadcasted =
309                          b.create<CmpIOp>(loc, CmpIPredicate::eq,
310                                           dimensionExtent, broadcastedDim);
311                      Value result = b.create<AndOp>(
312                          loc, broadcastable,
313                          b.create<OrOp>(loc, equalOne, equalBroadcasted));
314                      b.create<scf::YieldOp>(loc, result);
315                    })
316                   .getResult(0);
317         }
318 
319         b.create<scf::YieldOp>(loc, broadcastable);
320       });
321 
322   rewriter.replaceOp(op, reduceResult.results().front());
323   return success();
324 }
325 
326 namespace {
327 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
328   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
329 
330   LogicalResult
331   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const override;
333 };
334 } // namespace
335 
336 LogicalResult GetExtentOpConverter::matchAndRewrite(
337     GetExtentOp op, ArrayRef<Value> operands,
338     ConversionPatternRewriter &rewriter) const {
339   GetExtentOp::Adaptor transformed(operands);
340 
341   // For now, only error-free types are supported by this lowering.
342   if (op.getType().isa<SizeType>())
343     return failure();
344 
345   // Derive shape extent directly from shape origin if possible. This
346   // circumvents the necessity to materialize the shape in memory.
347   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
348     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
349       rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(),
350                                                  transformed.dim());
351       return success();
352     }
353   }
354 
355   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
356                                                  transformed.shape(),
357                                                  ValueRange{transformed.dim()});
358   return success();
359 }
360 
361 namespace {
362 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
363 public:
364   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
365 
366   LogicalResult
367   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
368                   ConversionPatternRewriter &rewriter) const override;
369 };
370 } // namespace
371 
372 LogicalResult
373 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
374                                  ConversionPatternRewriter &rewriter) const {
375   // For now, this lowering supports only error-free types.
376   if (op.getType().isa<SizeType>())
377     return failure();
378 
379   shape::RankOp::Adaptor transformed(operands);
380   rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0);
381   return success();
382 }
383 
384 namespace {
385 /// Converts `shape.reduce` to `scf.for`.
386 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
387 public:
388   using OpConversionPattern::OpConversionPattern;
389 
390   LogicalResult
391   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
392                   ConversionPatternRewriter &rewriter) const final;
393 };
394 } // namespace
395 
396 LogicalResult
397 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
398                                    ConversionPatternRewriter &rewriter) const {
399   // For now, this lowering is only defined on `tensor<?xindex>` operands.
400   if (op.shape().getType().isa<ShapeType>())
401     return failure();
402 
403   auto loc = op.getLoc();
404   shape::ReduceOp::Adaptor transformed(operands);
405 
406   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
407   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
408   Type indexTy = rewriter.getIndexType();
409   Value rank =
410       rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero);
411 
412   auto loop = rewriter.create<scf::ForOp>(
413       loc, zero, rank, one, op.initVals(),
414       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
415         Value extent =
416             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
417 
418         SmallVector<Value, 2> mappedValues{iv, extent};
419         mappedValues.append(args.begin(), args.end());
420 
421         BlockAndValueMapping mapping;
422         Block *reduceBody = op.getBody();
423         mapping.map(reduceBody->getArguments(), mappedValues);
424         for (auto &nested : reduceBody->without_terminator())
425           b.clone(nested, mapping);
426 
427         SmallVector<Value, 2> mappedResults;
428         for (auto result : reduceBody->getTerminator()->getOperands())
429           mappedResults.push_back(mapping.lookup(result));
430         b.create<scf::YieldOp>(loc, mappedResults);
431       });
432 
433   rewriter.replaceOp(op, loop.getResults());
434   return success();
435 }
436 
437 namespace {
438 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
439 /// only defined on `tensor<?xindex>` operands. The test for equality first
440 /// compares their size and, if equal, checks every extent for equality.
441 ///
442 /// Example:
443 ///
444 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
445 ///
446 /// becomes
447 ///
448 /// %c0 = constant 0 : index
449 /// %0 = dim %arg0, %c0 : tensor<?xindex>
450 /// %1 = dim %arg1, %c0 : tensor<?xindex>
451 /// %2 = cmpi "eq", %0, %1 : index
452 /// %result = scf.if %2 -> (i1) {
453 ///   %c1 = constant 1 : index
454 ///   %true = constant true
455 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
456 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
457 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
458 ///     %7 = cmpi "eq", %5, %6 : index
459 ///     %8 = and %arg3, %7 : i1
460 ///     scf.yield %8 : i1
461 ///   }
462 ///   scf.yield %4 : i1
463 /// } else {
464 ///   %false = constant false
465 ///   scf.yield %false : i1
466 /// }
467 ///
468 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
469   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
470 
471   LogicalResult
472   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
473                   ConversionPatternRewriter &rewriter) const override;
474 };
475 } // namespace
476 
477 LogicalResult
478 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
479                                     ConversionPatternRewriter &rewriter) const {
480   if (!llvm::all_of(op.shapes(),
481                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
482     return failure();
483 
484   Type i1Ty = rewriter.getI1Type();
485   if (op.shapes().size() <= 1) {
486     rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
487                                             rewriter.getBoolAttr(true));
488     return success();
489   }
490 
491   ShapeEqOp::Adaptor transformed(operands);
492   auto loc = op.getLoc();
493   Type indexTy = rewriter.getIndexType();
494   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
495   Value firstShape = transformed.shapes().front();
496   Value firstRank =
497       rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero);
498   Value result = nullptr;
499   // Generate a linear sequence of compares, all with firstShape as lhs.
500   for (Value shape : transformed.shapes().drop_front(1)) {
501     Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero);
502     Value eqRank =
503         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
504     auto same = rewriter.create<IfOp>(
505         loc, i1Ty, eqRank,
506         [&](OpBuilder &b, Location loc) {
507           Value one = b.create<ConstantIndexOp>(loc, 1);
508           Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
509           auto loop = b.create<scf::ForOp>(
510               loc, zero, firstRank, one, ValueRange{init},
511               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
512                 Value conj = args[0];
513                 Value lhsExtent =
514                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
515                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
516                 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
517                                                   lhsExtent, rhsExtent);
518                 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
519                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
520               });
521           b.create<scf::YieldOp>(loc, loop.getResults());
522         },
523         [&](OpBuilder &b, Location loc) {
524           Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
525           b.create<scf::YieldOp>(loc, result);
526         });
527     result = !result ? same.getResult(0)
528                      : rewriter.create<AndOp>(loc, result, same.getResult(0));
529   }
530   rewriter.replaceOp(op, result);
531   return success();
532 }
533 
534 namespace {
535 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
536 public:
537   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
538 
539   LogicalResult
540   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
541                   ConversionPatternRewriter &rewriter) const override;
542 };
543 } // namespace
544 
545 LogicalResult ShapeOfOpConversion::matchAndRewrite(
546     ShapeOfOp op, ArrayRef<Value> operands,
547     ConversionPatternRewriter &rewriter) const {
548 
549   // For now, only error-free types are supported by this lowering.
550   if (op.getType().isa<ShapeType>())
551     return failure();
552 
553   // For ranked tensor arguments, lower to `tensor.from_elements`.
554   auto loc = op.getLoc();
555   ShapeOfOp::Adaptor transformed(operands);
556   Value tensor = transformed.arg();
557   Type tensorTy = tensor.getType();
558   if (tensorTy.isa<RankedTensorType>()) {
559 
560     // Build values for individual extents.
561     SmallVector<Value, 8> extentValues;
562     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
563     int64_t rank = rankedTensorTy.getRank();
564     for (int64_t i = 0; i < rank; i++) {
565       if (rankedTensorTy.isDynamicDim(i)) {
566         Value extent = rewriter.create<memref::DimOp>(loc, tensor, i);
567         extentValues.push_back(extent);
568       } else {
569         Value extent =
570             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
571         extentValues.push_back(extent);
572       }
573     }
574 
575     // Materialize extent tensor.
576     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
577         loc, rewriter.getIndexType(), extentValues);
578     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
579                                                 staticExtentTensor);
580     return success();
581   }
582 
583   // Lower to `tensor.generate` otherwise.
584   auto *ctx = rewriter.getContext();
585   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
586   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
587       op, getExtentTensorType(ctx), ValueRange{rank},
588       [&](OpBuilder &b, Location loc, ValueRange args) {
589         Value dim = args.front();
590         Value extent = b.create<memref::DimOp>(loc, tensor, dim);
591         b.create<tensor::YieldOp>(loc, extent);
592       });
593 
594   return success();
595 }
596 
597 namespace {
598 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
599 public:
600   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
601 
602   LogicalResult
603   matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
604                   ConversionPatternRewriter &rewriter) const override;
605 };
606 } // namespace
607 
608 LogicalResult SplitAtOpConversion::matchAndRewrite(
609     SplitAtOp op, ArrayRef<Value> operands,
610     ConversionPatternRewriter &rewriter) const {
611   // Error conditions are not implemented, only lower if all operands and
612   // results are extent tensors.
613   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
614                    [](Value v) { return v.getType().isa<ShapeType>(); }))
615     return failure();
616 
617   SplitAtOp::Adaptor transformed(op);
618   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
619   Value zero = b.create<ConstantIndexOp>(0);
620   Value rank = b.create<memref::DimOp>(transformed.operand(), zero);
621 
622   // index < 0 ? index + rank : index
623   Value originalIndex = transformed.index();
624   Value add = b.create<AddIOp>(originalIndex, rank);
625   Value indexIsNegative =
626       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
627   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
628 
629   Value one = b.create<ConstantIndexOp>(1);
630   Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
631   Value tailSize = b.create<SubIOp>(rank, index);
632   Value tail =
633       b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
634   rewriter.replaceOp(op, {head, tail});
635   return success();
636 }
637 
638 namespace {
639 class ToExtentTensorOpConversion
640     : public OpConversionPattern<ToExtentTensorOp> {
641 public:
642   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
643 
644   LogicalResult
645   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     ToExtentTensorOpAdaptor adaptor(operands);
648 
649     if (!adaptor.input().getType().isa<RankedTensorType>())
650       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
651 
652     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
653                                                 adaptor.input());
654     return success();
655   }
656 };
657 } // namespace
658 
659 namespace {
660 /// Import the Shape Ops to Std Patterns.
661 #include "ShapeToStandard.cpp.inc"
662 } // namespace
663 
664 namespace {
665 /// Conversion pass.
666 class ConvertShapeToStandardPass
667     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
668 
669   void runOnOperation() override;
670 };
671 } // namespace
672 
673 void ConvertShapeToStandardPass::runOnOperation() {
674   // Setup target legality.
675   MLIRContext &ctx = getContext();
676   ConversionTarget target(ctx);
677   target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
678                          tensor::TensorDialect>();
679   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
680 
681   // Setup conversion patterns.
682   RewritePatternSet patterns(&ctx);
683   populateShapeToStandardConversionPatterns(patterns);
684 
685   // Apply conversion.
686   auto module = getOperation();
687   if (failed(applyPartialConversion(module, target, std::move(patterns))))
688     signalPassFailure();
689 }
690 
691 void mlir::populateShapeToStandardConversionPatterns(
692     RewritePatternSet &patterns) {
693   // clang-format off
694   populateWithGenerated(patterns);
695   patterns.add<
696       AnyOpConversion,
697       BinaryOpConversion<AddOp, AddIOp>,
698       BinaryOpConversion<MulOp, MulIOp>,
699       BroadcastOpConverter,
700       ConstShapeOpConverter,
701       ConstSizeOpConversion,
702       IsBroadcastableOpConverter,
703       GetExtentOpConverter,
704       RankOpConverter,
705       ReduceOpConverter,
706       ShapeEqOpConverter,
707       ShapeOfOpConversion,
708       SplitAtOpConversion,
709       ToExtentTensorOpConversion>(patterns.getContext());
710   // clang-format on
711 }
712 
713 std::unique_ptr<OperationPass<ModuleOp>>
714 mlir::createConvertShapeToStandardPass() {
715   return std::make_unique<ConvertShapeToStandardPass>();
716 }
717