1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 using namespace mlir::scf;
24 
25 /// Conversion patterns.
26 namespace {
27 class AnyOpConversion : public OpConversionPattern<AnyOp> {
28 public:
29   using OpConversionPattern<AnyOp>::OpConversionPattern;
30 
31   LogicalResult
32   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override;
34 };
35 } // namespace
36 
37 LogicalResult
38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
39                                  ConversionPatternRewriter &rewriter) const {
40   AnyOp::Adaptor transformed(operands);
41 
42   // Replace `any` with its first operand.
43   // Any operand would be a valid substitution.
44   rewriter.replaceOp(op, {transformed.inputs().front()});
45   return success();
46 }
47 
48 namespace {
49 template <typename SrcOpTy, typename DstOpTy>
50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
51 public:
52   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
53 
54   LogicalResult
55   matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
56                   ConversionPatternRewriter &rewriter) const override {
57     typename SrcOpTy::Adaptor transformed(operands);
58 
59     // For now, only error-free types are supported by this lowering.
60     if (op.getType().template isa<SizeType>())
61       return failure();
62 
63     rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
64                                          transformed.rhs());
65     return success();
66   }
67 };
68 } // namespace
69 
70 namespace {
71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
72   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
73 
74   LogicalResult
75   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
76                   ConversionPatternRewriter &rewriter) const override;
77 };
78 
79 // Get the resulting extent in a given dimension. This is computed with any
80 // number of extent tensors and shifted offsets into them.
81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
82                         ValueRange rankDiffs, Value outputDimension) {
83   Value one = lb.create<ConstantIndexOp>(1);
84   Value broadcastedDim = one;
85   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
86     Value shape = std::get<0>(tup);
87     Value rankDiff = std::get<1>(tup);
88     Value outOfBounds =
89         lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
90     Type indexTy = lb.getIndexType();
91     broadcastedDim =
92         lb.create<IfOp>(
93               TypeRange{indexTy}, outOfBounds,
94               [&](OpBuilder &b, Location loc) {
95                 b.create<scf::YieldOp>(loc, broadcastedDim);
96               },
97               [&](OpBuilder &b, Location loc) {
98                 // The broadcasting logic is:
99                 // - if one extent (here we arbitrarily choose the
100                 // extent from the greater-rank operand) is equal to 1,
101                 // then take the extent from the other operand
102                 // - otherwise, take the extent as-is.
103                 // Note that this logic remains correct in the presence
104                 // of dimensions of zero extent.
105                 Value lesserRankOperandDimension =
106                     b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
107                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
108                     loc, shape, ValueRange{lesserRankOperandDimension});
109 
110                 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
111                                                   lesserRankOperandExtent, one);
112                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
113                                                lesserRankOperandExtent);
114                 b.create<scf::YieldOp>(loc, dim);
115               })
116             .getResult(0);
117   }
118   return broadcastedDim;
119 }
120 } // namespace
121 
122 LogicalResult BroadcastOpConverter::matchAndRewrite(
123     BroadcastOp op, ArrayRef<Value> operands,
124     ConversionPatternRewriter &rewriter) const {
125   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
126   // on shapes.
127   if (op.getType().isa<ShapeType>())
128     return failure();
129 
130   auto loc = op.getLoc();
131   ImplicitLocOpBuilder lb(loc, rewriter);
132   BroadcastOp::Adaptor transformed(operands);
133 
134   Value zero = lb.create<ConstantIndexOp>(0);
135   Type indexTy = lb.getIndexType();
136 
137   // Save all the ranks for bounds checking. Because this is a tensor
138   // representing the shape extents, the rank is the extent of the only
139   // dimension in the tensor.
140   SmallVector<Value> ranks, rankDiffs;
141   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
142                        return lb.create<DimOp>(v, zero);
143                      }));
144 
145   // Find the maximum rank
146   Value maxRank = ranks.front();
147   for (Value v : llvm::drop_begin(ranks, 1)) {
148     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
149     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
150   }
151 
152   // Calculate the difference of ranks and the maximum rank for later offsets.
153   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154                        return lb.create<SubIOp>(indexTy, maxRank, v);
155                      }));
156 
157   rewriter.replaceOp(
158       op, lb.create<tensor::GenerateOp>(
159                 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160                 [&](OpBuilder &b, Location loc, ValueRange args) {
161                   Value broadcastedDim = getBroadcastedDim(
162                       ImplicitLocOpBuilder(loc, b), transformed.shapes(),
163                       rankDiffs, args[0]);
164 
165                   b.create<tensor::YieldOp>(loc, broadcastedDim);
166                 })
167               ->getResults());
168   return success();
169 }
170 
171 namespace {
172 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
173 public:
174   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
175 
176   LogicalResult
177   matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
178                   ConversionPatternRewriter &rewriter) const override;
179 };
180 } // namespace
181 
182 LogicalResult ConstShapeOpConverter::matchAndRewrite(
183     ConstShapeOp op, ArrayRef<Value> operands,
184     ConversionPatternRewriter &rewriter) const {
185 
186   // For now, this lowering supports only extent tensors, not `shape.shape`
187   // types.
188   if (op.getType().isa<ShapeType>())
189     return failure();
190 
191   auto loc = op.getLoc();
192   SmallVector<Value, 4> extentOperands;
193   for (auto extent : op.shape()) {
194     extentOperands.push_back(
195         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
196   }
197   Type indexTy = rewriter.getIndexType();
198   Value tensor =
199       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
200   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
201   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
202   return success();
203 }
204 
205 namespace {
206 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
207 public:
208   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
212                   ConversionPatternRewriter &rewriter) const override;
213 };
214 } // namespace
215 
216 LogicalResult ConstSizeOpConversion::matchAndRewrite(
217     ConstSizeOp op, ArrayRef<Value> operands,
218     ConversionPatternRewriter &rewriter) const {
219   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
220   return success();
221 }
222 
223 namespace {
224 struct IsBroadcastableOpConverter
225     : public OpConversionPattern<IsBroadcastableOp> {
226   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227 
228   LogicalResult
229   matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
230                   ConversionPatternRewriter &rewriter) const override;
231 };
232 } // namespace
233 
234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235     IsBroadcastableOp op, ArrayRef<Value> operands,
236     ConversionPatternRewriter &rewriter) const {
237   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238   // on shapes.
239   IsBroadcastableOp::Adaptor transformed(operands);
240   if (!llvm::all_of(op.shapes(),
241                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
242     return failure();
243 
244   auto loc = op.getLoc();
245   ImplicitLocOpBuilder lb(loc, rewriter);
246   Value zero = lb.create<ConstantIndexOp>(0);
247   Value one = lb.create<ConstantIndexOp>(1);
248   Type indexTy = lb.getIndexType();
249 
250   // Save all the ranks for bounds checking. Because this is a tensor
251   // representing the shape extents, the rank is the extent of the only
252   // dimension in the tensor.
253   SmallVector<Value> ranks, rankDiffs;
254   llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
255                        return lb.create<DimOp>(v, zero);
256                      }));
257 
258   // Find the maximum rank
259   Value maxRank = ranks.front();
260   for (Value v : llvm::drop_begin(ranks, 1)) {
261     Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
262     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
263   }
264 
265   // Calculate the difference of ranks and the maximum rank for later offsets.
266   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267                        return lb.create<SubIOp>(indexTy, maxRank, v);
268                      }));
269 
270   Type i1Ty = rewriter.getI1Type();
271   Value trueVal =
272       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273 
274   auto reduceResult = lb.create<ForOp>(
275       loc, zero, maxRank, one, ValueRange{trueVal},
276       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
277         // Find a non-1 dim, if it exists. Note that the first part of this
278         // could reuse the Broadcast lowering entirely, but we redo the work
279         // here to make optimizations easier between the two loops.
280         Value broadcastedDim = getBroadcastedDim(
281             ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
282 
283         Value broadcastable = iterArgs[0];
284         for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
285           Value shape, rankDiff;
286           std::tie(shape, rankDiff) = tup;
287           Value outOfBounds =
288               b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
289           broadcastable =
290               b.create<IfOp>(
291                    loc, TypeRange{i1Ty}, outOfBounds,
292                    [&](OpBuilder &b, Location loc) {
293                      // Non existent dimensions are always broadcastable
294                      b.create<scf::YieldOp>(loc, broadcastable);
295                    },
296                    [&](OpBuilder &b, Location loc) {
297                      // Every value needs to be either 1, or the same non-1
298                      // value to be broadcastable in this dim.
299                      Value operandDimension =
300                          b.create<SubIOp>(loc, indexTy, iv, rankDiff);
301                      Value dimensionExtent = b.create<tensor::ExtractOp>(
302                          loc, shape, ValueRange{operandDimension});
303 
304                      Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
305                                                        dimensionExtent, one);
306                      Value equalBroadcasted =
307                          b.create<CmpIOp>(loc, CmpIPredicate::eq,
308                                           dimensionExtent, broadcastedDim);
309                      Value result = b.create<AndOp>(
310                          loc, broadcastable,
311                          b.create<OrOp>(loc, equalOne, equalBroadcasted));
312                      b.create<scf::YieldOp>(loc, result);
313                    })
314                   .getResult(0);
315         }
316 
317         b.create<scf::YieldOp>(loc, broadcastable);
318       });
319 
320   rewriter.replaceOp(op, reduceResult.results().front());
321   return success();
322 }
323 
324 namespace {
325 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
326   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
327 
328   LogicalResult
329   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const override;
331 };
332 } // namespace
333 
334 LogicalResult GetExtentOpConverter::matchAndRewrite(
335     GetExtentOp op, ArrayRef<Value> operands,
336     ConversionPatternRewriter &rewriter) const {
337   GetExtentOp::Adaptor transformed(operands);
338 
339   // For now, only error-free types are supported by this lowering.
340   if (op.getType().isa<SizeType>())
341     return failure();
342 
343   // Derive shape extent directly from shape origin if possible. This
344   // circumvents the necessity to materialize the shape in memory.
345   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
346     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
347       rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
348                                          transformed.dim());
349       return success();
350     }
351   }
352 
353   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
354                                                  transformed.shape(),
355                                                  ValueRange{transformed.dim()});
356   return success();
357 }
358 
359 namespace {
360 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
361 public:
362   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
363 
364   LogicalResult
365   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
366                   ConversionPatternRewriter &rewriter) const override;
367 };
368 } // namespace
369 
370 LogicalResult
371 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
372                                  ConversionPatternRewriter &rewriter) const {
373   // For now, this lowering supports only error-free types.
374   if (op.getType().isa<SizeType>())
375     return failure();
376 
377   shape::RankOp::Adaptor transformed(operands);
378   rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
379   return success();
380 }
381 
382 namespace {
383 /// Converts `shape.reduce` to `scf.for`.
384 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
385 public:
386   using OpConversionPattern::OpConversionPattern;
387 
388   LogicalResult
389   matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
390                   ConversionPatternRewriter &rewriter) const final;
391 };
392 } // namespace
393 
394 LogicalResult
395 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
396                                    ConversionPatternRewriter &rewriter) const {
397   // For now, this lowering is only defined on `tensor<?xindex>` operands.
398   if (op.shape().getType().isa<ShapeType>())
399     return failure();
400 
401   auto loc = op.getLoc();
402   shape::ReduceOp::Adaptor transformed(operands);
403 
404   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
405   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
406   Type indexTy = rewriter.getIndexType();
407   Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
408 
409   auto loop = rewriter.create<scf::ForOp>(
410       loc, zero, rank, one, op.initVals(),
411       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
412         Value extent =
413             b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
414 
415         SmallVector<Value, 2> mappedValues{iv, extent};
416         mappedValues.append(args.begin(), args.end());
417 
418         BlockAndValueMapping mapping;
419         Block *reduceBody = op.getBody();
420         mapping.map(reduceBody->getArguments(), mappedValues);
421         for (auto &nested : reduceBody->without_terminator())
422           b.clone(nested, mapping);
423 
424         SmallVector<Value, 2> mappedResults;
425         for (auto result : reduceBody->getTerminator()->getOperands())
426           mappedResults.push_back(mapping.lookup(result));
427         b.create<scf::YieldOp>(loc, mappedResults);
428       });
429 
430   rewriter.replaceOp(op, loop.getResults());
431   return success();
432 }
433 
434 namespace {
435 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
436 /// only defined on `tensor<?xindex>` operands. The test for equality first
437 /// compares their size and, if equal, checks every extent for equality.
438 ///
439 /// Example:
440 ///
441 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
442 ///
443 /// becomes
444 ///
445 /// %c0 = constant 0 : index
446 /// %0 = dim %arg0, %c0 : tensor<?xindex>
447 /// %1 = dim %arg1, %c0 : tensor<?xindex>
448 /// %2 = cmpi "eq", %0, %1 : index
449 /// %result = scf.if %2 -> (i1) {
450 ///   %c1 = constant 1 : index
451 ///   %true = constant true
452 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
453 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
454 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
455 ///     %7 = cmpi "eq", %5, %6 : index
456 ///     %8 = and %arg3, %7 : i1
457 ///     scf.yield %8 : i1
458 ///   }
459 ///   scf.yield %4 : i1
460 /// } else {
461 ///   %false = constant false
462 ///   scf.yield %false : i1
463 /// }
464 ///
465 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
466   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
467 
468   LogicalResult
469   matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
470                   ConversionPatternRewriter &rewriter) const override;
471 };
472 } // namespace
473 
474 LogicalResult
475 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
476                                     ConversionPatternRewriter &rewriter) const {
477   if (!llvm::all_of(op.shapes(),
478                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
479     return failure();
480 
481   Type i1Ty = rewriter.getI1Type();
482   if (op.shapes().size() <= 1) {
483     rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
484                                             rewriter.getBoolAttr(true));
485     return success();
486   }
487 
488   ShapeEqOp::Adaptor transformed(operands);
489   auto loc = op.getLoc();
490   Type indexTy = rewriter.getIndexType();
491   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
492   Value firstShape = transformed.shapes().front();
493   Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
494   Value result = nullptr;
495   // Generate a linear sequence of compares, all with firstShape as lhs.
496   for (Value shape : transformed.shapes().drop_front(1)) {
497     Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
498     Value eqRank =
499         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
500     auto same = rewriter.create<IfOp>(
501         loc, i1Ty, eqRank,
502         [&](OpBuilder &b, Location loc) {
503           Value one = b.create<ConstantIndexOp>(loc, 1);
504           Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
505           auto loop = b.create<scf::ForOp>(
506               loc, zero, firstRank, one, ValueRange{init},
507               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
508                 Value conj = args[0];
509                 Value lhsExtent =
510                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
511                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
512                 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
513                                                   lhsExtent, rhsExtent);
514                 Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
515                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
516               });
517           b.create<scf::YieldOp>(loc, loop.getResults());
518         },
519         [&](OpBuilder &b, Location loc) {
520           Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
521           b.create<scf::YieldOp>(loc, result);
522         });
523     result = !result ? same.getResult(0)
524                      : rewriter.create<AndOp>(loc, result, same.getResult(0));
525   }
526   rewriter.replaceOp(op, result);
527   return success();
528 }
529 
530 namespace {
531 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532 public:
533   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534 
535   LogicalResult
536   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
537                   ConversionPatternRewriter &rewriter) const override;
538 };
539 } // namespace
540 
541 LogicalResult ShapeOfOpConversion::matchAndRewrite(
542     ShapeOfOp op, ArrayRef<Value> operands,
543     ConversionPatternRewriter &rewriter) const {
544 
545   // For now, only error-free types are supported by this lowering.
546   if (op.getType().isa<ShapeType>())
547     return failure();
548 
549   // For ranked tensor arguments, lower to `tensor.from_elements`.
550   auto loc = op.getLoc();
551   ShapeOfOp::Adaptor transformed(operands);
552   Value tensor = transformed.arg();
553   Type tensorTy = tensor.getType();
554   if (tensorTy.isa<RankedTensorType>()) {
555 
556     // Build values for individual extents.
557     SmallVector<Value, 8> extentValues;
558     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
559     int64_t rank = rankedTensorTy.getRank();
560     for (int64_t i = 0; i < rank; i++) {
561       if (rankedTensorTy.isDynamicDim(i)) {
562         Value extent = rewriter.create<DimOp>(loc, tensor, i);
563         extentValues.push_back(extent);
564       } else {
565         Value extent =
566             rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
567         extentValues.push_back(extent);
568       }
569     }
570 
571     // Materialize extent tensor.
572     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
573         loc, rewriter.getIndexType(), extentValues);
574     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
575                                                 staticExtentTensor);
576     return success();
577   }
578 
579   // Lower to `tensor.generate` otherwise.
580   auto *ctx = rewriter.getContext();
581   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
582   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
583       op, getExtentTensorType(ctx), ValueRange{rank},
584       [&](OpBuilder &b, Location loc, ValueRange args) {
585         Value dim = args.front();
586         Value extent = b.create<DimOp>(loc, tensor, dim);
587         b.create<tensor::YieldOp>(loc, extent);
588       });
589 
590   return success();
591 }
592 
593 namespace {
594 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
595 public:
596   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
597 
598   LogicalResult
599   matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
600                   ConversionPatternRewriter &rewriter) const override;
601 };
602 } // namespace
603 
604 LogicalResult SplitAtOpConversion::matchAndRewrite(
605     SplitAtOp op, ArrayRef<Value> operands,
606     ConversionPatternRewriter &rewriter) const {
607   // Error conditions are not implemented, only lower if all operands and
608   // results are extent tensors.
609   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
610                    [](Value v) { return v.getType().isa<ShapeType>(); }))
611     return failure();
612 
613   SplitAtOp::Adaptor transformed(op);
614   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
615   Value zero = b.create<ConstantIndexOp>(0);
616   Value rank = b.create<DimOp>(transformed.operand(), zero);
617 
618   // index < 0 ? index + rank : index
619   Value originalIndex = transformed.index();
620   Value add = b.create<AddIOp>(originalIndex, rank);
621   Value indexIsNegative =
622       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
623   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
624 
625   Value one = b.create<ConstantIndexOp>(1);
626   Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
627   Value tailSize = b.create<SubIOp>(rank, index);
628   Value tail =
629       b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
630   rewriter.replaceOp(op, {head, tail});
631   return success();
632 }
633 
634 namespace {
635 class ToExtentTensorOpConversion
636     : public OpConversionPattern<ToExtentTensorOp> {
637 public:
638   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
639 
640   LogicalResult
641   matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
642                   ConversionPatternRewriter &rewriter) const override {
643     ToExtentTensorOpAdaptor adaptor(operands);
644 
645     if (!adaptor.input().getType().isa<RankedTensorType>())
646       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
647 
648     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
649                                                 adaptor.input());
650     return success();
651   }
652 };
653 } // namespace
654 
655 namespace {
656 /// Import the Shape Ops to Std Patterns.
657 #include "ShapeToStandard.cpp.inc"
658 } // namespace
659 
660 namespace {
661 /// Conversion pass.
662 class ConvertShapeToStandardPass
663     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
664 
665   void runOnOperation() override;
666 };
667 } // namespace
668 
669 void ConvertShapeToStandardPass::runOnOperation() {
670   // Setup target legality.
671   MLIRContext &ctx = getContext();
672   ConversionTarget target(ctx);
673   target
674       .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
675   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
676 
677   // Setup conversion patterns.
678   OwningRewritePatternList patterns;
679   populateShapeToStandardConversionPatterns(patterns, &ctx);
680 
681   // Apply conversion.
682   auto module = getOperation();
683   if (failed(applyPartialConversion(module, target, std::move(patterns))))
684     signalPassFailure();
685 }
686 
687 void mlir::populateShapeToStandardConversionPatterns(
688     OwningRewritePatternList &patterns, MLIRContext *ctx) {
689   // clang-format off
690   populateWithGenerated(ctx, patterns);
691   patterns.insert<
692       AnyOpConversion,
693       BinaryOpConversion<AddOp, AddIOp>,
694       BinaryOpConversion<MulOp, MulIOp>,
695       BroadcastOpConverter,
696       ConstShapeOpConverter,
697       ConstSizeOpConversion,
698       IsBroadcastableOpConverter,
699       GetExtentOpConverter,
700       RankOpConverter,
701       ReduceOpConverter,
702       ShapeEqOpConverter,
703       ShapeOfOpConversion,
704       SplitAtOpConversion,
705       ToExtentTensorOpConversion>(ctx);
706   // clang-format on
707 }
708 
709 std::unique_ptr<OperationPass<ModuleOp>>
710 mlir::createConvertShapeToStandardPass() {
711   return std::make_unique<ConvertShapeToStandardPass>();
712 }
713