1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 using namespace mlir::scf;
24 
25 /// Conversion patterns.
26 namespace {
27 class AnyOpConversion : public OpConversionPattern<AnyOp> {
28 public:
29   using OpConversionPattern<AnyOp>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
33                   ConversionPatternRewriter &rewriter) const override;
34 };
35 } // namespace
36 
37 LogicalResult
38 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
39                                  ConversionPatternRewriter &rewriter) const {
40   // Replace `any` with its first operand.
41   // Any operand would be a valid substitution.
42   rewriter.replaceOp(op, {adaptor.inputs().front()});
43   return success();
44 }
45 
46 namespace {
47 template <typename SrcOpTy, typename DstOpTy>
48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
49 public:
50   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
51 
52   LogicalResult
53   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
54                   ConversionPatternRewriter &rewriter) const override {
55     // For now, only error-free types are supported by this lowering.
56     if (op.getType().template isa<SizeType>())
57       return failure();
58 
59     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
60     return success();
61   }
62 };
63 } // namespace
64 
65 namespace {
66 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
67   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
68 
69   LogicalResult
70   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
71                   ConversionPatternRewriter &rewriter) const override;
72 };
73 
74 // Get the resulting extent in a given dimension. This is computed with any
75 // number of extent tensors and shifted offsets into them.
76 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
77                         ValueRange rankDiffs, Value outputDimension) {
78   Value one = lb.create<ConstantIndexOp>(1);
79   Value broadcastedDim = one;
80   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
81     Value shape = std::get<0>(tup);
82     Value rankDiff = std::get<1>(tup);
83     Value outOfBounds =
84         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
85     Type indexTy = lb.getIndexType();
86     broadcastedDim =
87         lb.create<IfOp>(
88               TypeRange{indexTy}, outOfBounds,
89               [&](OpBuilder &b, Location loc) {
90                 b.create<scf::YieldOp>(loc, broadcastedDim);
91               },
92               [&](OpBuilder &b, Location loc) {
93                 // The broadcasting logic is:
94                 // - if one extent (here we arbitrarily choose the
95                 // extent from the greater-rank operand) is equal to 1,
96                 // then take the extent from the other operand
97                 // - otherwise, take the extent as-is.
98                 // Note that this logic remains correct in the presence
99                 // of dimensions of zero extent.
100                 Value lesserRankOperandDimension =
101                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
102                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
103                     loc, shape, ValueRange{lesserRankOperandDimension});
104 
105                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
106                                                   lesserRankOperandExtent, one);
107                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
108                                                lesserRankOperandExtent);
109                 b.create<scf::YieldOp>(loc, dim);
110               })
111             .getResult(0);
112   }
113   return broadcastedDim;
114 }
115 } // namespace
116 
117 LogicalResult BroadcastOpConverter::matchAndRewrite(
118     BroadcastOp op, OpAdaptor adaptor,
119     ConversionPatternRewriter &rewriter) const {
120   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
121   // on shapes.
122   if (op.getType().isa<ShapeType>())
123     return failure();
124 
125   auto loc = op.getLoc();
126   ImplicitLocOpBuilder lb(loc, rewriter);
127 
128   Value zero = lb.create<ConstantIndexOp>(0);
129   Type indexTy = lb.getIndexType();
130 
131   // Save all the ranks for bounds checking. Because this is a tensor
132   // representing the shape extents, the rank is the extent of the only
133   // dimension in the tensor.
134   SmallVector<Value> ranks, rankDiffs;
135   llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
136                        return lb.create<tensor::DimOp>(v, zero);
137                      }));
138 
139   // Find the maximum rank
140   Value maxRank = ranks.front();
141   for (Value v : llvm::drop_begin(ranks, 1)) {
142     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
143     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
144   }
145 
146   // Calculate the difference of ranks and the maximum rank for later offsets.
147   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
148                        return lb.create<SubIOp>(indexTy, maxRank, v);
149                      }));
150 
151   Value replacement = lb.create<tensor::GenerateOp>(
152       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
153       [&](OpBuilder &b, Location loc, ValueRange args) {
154         Value broadcastedDim = getBroadcastedDim(
155             ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]);
156 
157         b.create<tensor::YieldOp>(loc, broadcastedDim);
158       });
159   if (replacement.getType() != op.getType())
160     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
161   rewriter.replaceOp(op, replacement);
162   return success();
163 }
164 
165 namespace {
166 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
167 public:
168   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
169 
170   LogicalResult
171   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
172                   ConversionPatternRewriter &rewriter) const override;
173 };
174 } // namespace
175 
176 LogicalResult ConstShapeOpConverter::matchAndRewrite(
177     ConstShapeOp op, OpAdaptor adaptor,
178     ConversionPatternRewriter &rewriter) const {
179 
180   // For now, this lowering supports only extent tensors, not `shape.shape`
181   // types.
182   if (op.getType().isa<ShapeType>())
183     return failure();
184 
185   auto loc = op.getLoc();
186   SmallVector<Value, 4> extentOperands;
187   for (auto extent : op.shape()) {
188     extentOperands.push_back(
189         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
190   }
191   Type indexTy = rewriter.getIndexType();
192   Value tensor =
193       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
194   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
195   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
196   return success();
197 }
198 
199 namespace {
200 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
201 public:
202   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
203 
204   LogicalResult
205   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
206                   ConversionPatternRewriter &rewriter) const override;
207 };
208 } // namespace
209 
210 LogicalResult ConstSizeOpConversion::matchAndRewrite(
211     ConstSizeOp op, OpAdaptor adaptor,
212     ConversionPatternRewriter &rewriter) const {
213   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
214   return success();
215 }
216 
217 namespace {
218 struct IsBroadcastableOpConverter
219     : public OpConversionPattern<IsBroadcastableOp> {
220   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
221 
222   LogicalResult
223   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
224                   ConversionPatternRewriter &rewriter) const override;
225 };
226 } // namespace
227 
228 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
229     IsBroadcastableOp op, OpAdaptor adaptor,
230     ConversionPatternRewriter &rewriter) const {
231   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
232   // on shapes.
233   if (!llvm::all_of(op.shapes(),
234                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
235     return failure();
236 
237   auto loc = op.getLoc();
238   ImplicitLocOpBuilder lb(loc, rewriter);
239   Value zero = lb.create<ConstantIndexOp>(0);
240   Value one = lb.create<ConstantIndexOp>(1);
241   Type indexTy = lb.getIndexType();
242 
243   // Save all the ranks for bounds checking. Because this is a tensor
244   // representing the shape extents, the rank is the extent of the only
245   // dimension in the tensor.
246   SmallVector<Value> ranks, rankDiffs;
247   llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
248                        return lb.create<tensor::DimOp>(v, zero);
249                      }));
250 
251   // Find the maximum rank
252   Value maxRank = ranks.front();
253   for (Value v : llvm::drop_begin(ranks, 1)) {
254     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
255     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
256   }
257 
258   // Calculate the difference of ranks and the maximum rank for later offsets.
259   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
260                        return lb.create<SubIOp>(indexTy, maxRank, v);
261                      }));
262 
263   Type i1Ty = rewriter.getI1Type();
264   Value trueVal =
265       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
266 
267   auto reduceResult = lb.create<ForOp>(
268       loc, zero, maxRank, one, ValueRange{trueVal},
269       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
270         // Find a non-1 dim, if it exists. Note that the first part of this
271         // could reuse the Broadcast lowering entirely, but we redo the work
272         // here to make optimizations easier between the two loops.
273         Value broadcastedDim = getBroadcastedDim(
274             ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv);
275 
276         Value broadcastable = iterArgs[0];
277         for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) {
278           Value shape, rankDiff;
279           std::tie(shape, rankDiff) = tup;
280           Value outOfBounds =
281               b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
282           broadcastable =
283               b.create<IfOp>(
284                    loc, TypeRange{i1Ty}, outOfBounds,
285                    [&](OpBuilder &b, Location loc) {
286                      // Non existent dimensions are always broadcastable
287                      b.create<scf::YieldOp>(loc, broadcastable);
288                    },
289                    [&](OpBuilder &b, Location loc) {
290                      // Every value needs to be either 1, or the same non-1
291                      // value to be broadcastable in this dim.
292                      Value operandDimension =
293                          b.create<SubIOp>(loc, indexTy, iv, rankDiff);
294                      Value dimensionExtent = b.create<tensor::ExtractOp>(
295                          loc, shape, ValueRange{operandDimension});
296 
297                      Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
298                                                        dimensionExtent, one);
299                      Value equalBroadcasted =
300                          b.create<CmpIOp>(loc, CmpIPredicate::eq,
301                                           dimensionExtent, broadcastedDim);
302                      Value result = b.create<AndOp>(
303                          loc, broadcastable,
304                          b.create<OrOp>(loc, equalOne, equalBroadcasted));
305                      b.create<scf::YieldOp>(loc, result);
306                    })
307                   .getResult(0);
308         }
309 
310         b.create<scf::YieldOp>(loc, broadcastable);
311       });
312 
313   rewriter.replaceOp(op, reduceResult.results().front());
314   return success();
315 }
316 
317 namespace {
318 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
319   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
320 
321   LogicalResult
322   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
323                   ConversionPatternRewriter &rewriter) const override;
324 };
325 } // namespace
326 
327 LogicalResult GetExtentOpConverter::matchAndRewrite(
328     GetExtentOp op, OpAdaptor adaptor,
329     ConversionPatternRewriter &rewriter) const {
330   // For now, only error-free types are supported by this lowering.
331   if (op.getType().isa<SizeType>())
332     return failure();
333 
334   // Derive shape extent directly from shape origin if possible. This
335   // circumvents the necessity to materialize the shape in memory.
336   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
337     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
338       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(),
339                                                  adaptor.dim());
340       return success();
341     }
342   }
343 
344   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
345       op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()});
346   return success();
347 }
348 
349 namespace {
350 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
351 public:
352   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
353 
354   LogicalResult
355   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
356                   ConversionPatternRewriter &rewriter) const override;
357 };
358 } // namespace
359 
360 LogicalResult
361 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
362                                  ConversionPatternRewriter &rewriter) const {
363   // For now, this lowering supports only error-free types.
364   if (op.getType().isa<SizeType>())
365     return failure();
366 
367   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.shape(), 0);
368   return success();
369 }
370 
371 namespace {
372 /// Converts `shape.reduce` to `scf.for`.
373 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
374 public:
375   using OpConversionPattern::OpConversionPattern;
376 
377   LogicalResult
378   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
379                   ConversionPatternRewriter &rewriter) const final;
380 };
381 } // namespace
382 
383 LogicalResult
384 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
385                                    ConversionPatternRewriter &rewriter) const {
386   // For now, this lowering is only defined on `tensor<?xindex>` operands.
387   if (op.shape().getType().isa<ShapeType>())
388     return failure();
389 
390   auto loc = op.getLoc();
391 
392   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
393   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
394   Type indexTy = rewriter.getIndexType();
395   Value rank =
396       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.shape(), zero);
397 
398   auto loop = rewriter.create<scf::ForOp>(
399       loc, zero, rank, one, op.initVals(),
400       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
401         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.shape(), iv);
402 
403         SmallVector<Value, 2> mappedValues{iv, extent};
404         mappedValues.append(args.begin(), args.end());
405 
406         BlockAndValueMapping mapping;
407         Block *reduceBody = op.getBody();
408         mapping.map(reduceBody->getArguments(), mappedValues);
409         for (auto &nested : reduceBody->without_terminator())
410           b.clone(nested, mapping);
411 
412         SmallVector<Value, 2> mappedResults;
413         for (auto result : reduceBody->getTerminator()->getOperands())
414           mappedResults.push_back(mapping.lookup(result));
415         b.create<scf::YieldOp>(loc, mappedResults);
416       });
417 
418   rewriter.replaceOp(op, loop.getResults());
419   return success();
420 }
421 
422 namespace {
423 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
424 /// only defined on `tensor<?xindex>` operands. The test for equality first
425 /// compares their size and, if equal, checks every extent for equality.
426 ///
427 /// Example:
428 ///
429 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
430 ///
431 /// becomes
432 ///
433 /// %c0 = constant 0 : index
434 /// %0 = dim %arg0, %c0 : tensor<?xindex>
435 /// %1 = dim %arg1, %c0 : tensor<?xindex>
436 /// %2 = cmpi "eq", %0, %1 : index
437 /// %result = scf.if %2 -> (i1) {
438 ///   %c1 = constant 1 : index
439 ///   %true = constant true
440 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
441 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
442 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
443 ///     %7 = cmpi "eq", %5, %6 : index
444 ///     %8 = and %arg3, %7 : i1
445 ///     scf.yield %8 : i1
446 ///   }
447 ///   scf.yield %4 : i1
448 /// } else {
449 ///   %false = constant false
450 ///   scf.yield %false : i1
451 /// }
452 ///
453 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
454   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
455 
456   LogicalResult
457   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
458                   ConversionPatternRewriter &rewriter) const override;
459 };
460 } // namespace
461 
462 LogicalResult
463 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
464                                     ConversionPatternRewriter &rewriter) const {
465   if (!llvm::all_of(op.shapes(),
466                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
467     return failure();
468 
469   Type i1Ty = rewriter.getI1Type();
470   if (op.shapes().size() <= 1) {
471     rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
472                                             rewriter.getBoolAttr(true));
473     return success();
474   }
475 
476   auto loc = op.getLoc();
477   Type indexTy = rewriter.getIndexType();
478   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
479   Value firstShape = adaptor.shapes().front();
480   Value firstRank =
481       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
482   Value result = nullptr;
483   // Generate a linear sequence of compares, all with firstShape as lhs.
484   for (Value shape : adaptor.shapes().drop_front(1)) {
485     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
486     Value eqRank =
487         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
488     auto same = rewriter.create<IfOp>(
489         loc, i1Ty, eqRank,
490         [&](OpBuilder &b, Location loc) {
491           Value one = b.create<ConstantIndexOp>(loc, 1);
492           Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
493           auto loop = b.create<scf::ForOp>(
494               loc, zero, firstRank, one, ValueRange{init},
495               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
496                 Value conj = args[0];
497                 Value lhsExtent =
498                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
499                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
500                 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
501                                                   lhsExtent, rhsExtent);
502                 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
503                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
504               });
505           b.create<scf::YieldOp>(loc, loop.getResults());
506         },
507         [&](OpBuilder &b, Location loc) {
508           Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
509           b.create<scf::YieldOp>(loc, result);
510         });
511     result = !result ? same.getResult(0)
512                      : rewriter.create<AndOp>(loc, result, same.getResult(0));
513   }
514   rewriter.replaceOp(op, result);
515   return success();
516 }
517 
518 namespace {
519 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
520 public:
521   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
522 
523   LogicalResult
524   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
525                   ConversionPatternRewriter &rewriter) const override;
526 };
527 } // namespace
528 
529 LogicalResult ShapeOfOpConversion::matchAndRewrite(
530     ShapeOfOp op, OpAdaptor adaptor,
531     ConversionPatternRewriter &rewriter) const {
532 
533   // For now, only error-free types are supported by this lowering.
534   if (op.getType().isa<ShapeType>())
535     return failure();
536 
537   // For ranked tensor arguments, lower to `tensor.from_elements`.
538   auto loc = op.getLoc();
539   Value tensor = adaptor.arg();
540   Type tensorTy = tensor.getType();
541   if (tensorTy.isa<RankedTensorType>()) {
542 
543     // Build values for individual extents.
544     SmallVector<Value, 8> extentValues;
545     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
546     int64_t rank = rankedTensorTy.getRank();
547     for (int64_t i = 0; i < rank; i++) {
548       if (rankedTensorTy.isDynamicDim(i)) {
549         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
550         extentValues.push_back(extent);
551       } else {
552         Value extent =
553             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
554         extentValues.push_back(extent);
555       }
556     }
557 
558     // Materialize extent tensor.
559     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
560         loc, rewriter.getIndexType(), extentValues);
561     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
562                                                 staticExtentTensor);
563     return success();
564   }
565 
566   // Lower to `tensor.generate` otherwise.
567   auto *ctx = rewriter.getContext();
568   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
569   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
570       op, getExtentTensorType(ctx), ValueRange{rank},
571       [&](OpBuilder &b, Location loc, ValueRange args) {
572         Value dim = args.front();
573         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
574         b.create<tensor::YieldOp>(loc, extent);
575       });
576 
577   return success();
578 }
579 
580 namespace {
581 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
582 public:
583   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
584 
585   LogicalResult
586   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
587                   ConversionPatternRewriter &rewriter) const override;
588 };
589 } // namespace
590 
591 LogicalResult SplitAtOpConversion::matchAndRewrite(
592     SplitAtOp op, OpAdaptor adaptor,
593     ConversionPatternRewriter &rewriter) const {
594   // Error conditions are not implemented, only lower if all operands and
595   // results are extent tensors.
596   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
597                    [](Value v) { return v.getType().isa<ShapeType>(); }))
598     return failure();
599 
600   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
601   Value zero = b.create<ConstantIndexOp>(0);
602   Value rank = b.create<tensor::DimOp>(adaptor.operand(), zero);
603 
604   // index < 0 ? index + rank : index
605   Value originalIndex = adaptor.index();
606   Value add = b.create<AddIOp>(originalIndex, rank);
607   Value indexIsNegative =
608       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
609   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
610 
611   Value one = b.create<ConstantIndexOp>(1);
612   Value head =
613       b.create<tensor::ExtractSliceOp>(adaptor.operand(), zero, index, one);
614   Value tailSize = b.create<SubIOp>(rank, index);
615   Value tail =
616       b.create<tensor::ExtractSliceOp>(adaptor.operand(), index, tailSize, one);
617   rewriter.replaceOp(op, {head, tail});
618   return success();
619 }
620 
621 namespace {
622 class ToExtentTensorOpConversion
623     : public OpConversionPattern<ToExtentTensorOp> {
624 public:
625   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
626 
627   LogicalResult
628   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
629                   ConversionPatternRewriter &rewriter) const override {
630     if (!adaptor.input().getType().isa<RankedTensorType>())
631       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
632 
633     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
634                                                 adaptor.input());
635     return success();
636   }
637 };
638 } // namespace
639 
640 namespace {
641 /// Import the Shape Ops to Std Patterns.
642 #include "ShapeToStandard.cpp.inc"
643 } // namespace
644 
645 namespace {
646 /// Conversion pass.
647 class ConvertShapeToStandardPass
648     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
649 
650   void runOnOperation() override;
651 };
652 } // namespace
653 
654 void ConvertShapeToStandardPass::runOnOperation() {
655   // Setup target legality.
656   MLIRContext &ctx = getContext();
657   ConversionTarget target(ctx);
658   target
659       .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
660   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
661 
662   // Setup conversion patterns.
663   RewritePatternSet patterns(&ctx);
664   populateShapeToStandardConversionPatterns(patterns);
665 
666   // Apply conversion.
667   auto module = getOperation();
668   if (failed(applyPartialConversion(module, target, std::move(patterns))))
669     signalPassFailure();
670 }
671 
672 void mlir::populateShapeToStandardConversionPatterns(
673     RewritePatternSet &patterns) {
674   // clang-format off
675   populateWithGenerated(patterns);
676   patterns.add<
677       AnyOpConversion,
678       BinaryOpConversion<AddOp, AddIOp>,
679       BinaryOpConversion<MulOp, MulIOp>,
680       BroadcastOpConverter,
681       ConstShapeOpConverter,
682       ConstSizeOpConversion,
683       IsBroadcastableOpConverter,
684       GetExtentOpConverter,
685       RankOpConverter,
686       ReduceOpConverter,
687       ShapeEqOpConverter,
688       ShapeOfOpConversion,
689       SplitAtOpConversion,
690       ToExtentTensorOpConversion>(patterns.getContext());
691   // clang-format on
692 }
693 
694 std::unique_ptr<OperationPass<ModuleOp>>
695 mlir::createConvertShapeToStandardPass() {
696   return std::make_unique<ConvertShapeToStandardPass>();
697 }
698