1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 using namespace mlir::scf;
25 
26 /// Conversion patterns.
27 namespace {
28 class AnyOpConversion : public OpConversionPattern<AnyOp> {
29 public:
30   using OpConversionPattern<AnyOp>::OpConversionPattern;
31 
32   LogicalResult
33   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
34                   ConversionPatternRewriter &rewriter) const override;
35 };
36 } // namespace
37 
38 LogicalResult
39 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
40                                  ConversionPatternRewriter &rewriter) const {
41   // Replace `any` with its first operand.
42   // Any operand would be a valid substitution.
43   rewriter.replaceOp(op, {adaptor.getInputs().front()});
44   return success();
45 }
46 
47 namespace {
48 template <typename SrcOpTy, typename DstOpTy>
49 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
50 public:
51   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
52 
53   LogicalResult
54   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
55                   ConversionPatternRewriter &rewriter) const override {
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
61                                          adaptor.getRhs());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70 
71   LogicalResult
72   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
73                   ConversionPatternRewriter &rewriter) const override;
74 };
75 
76 // Get the resulting extent in a given dimension. This is computed with any
77 // number of extent tensors and shifted offsets into them.
78 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
79                         ValueRange rankDiffs, Value outputDimension) {
80   Value one = lb.create<arith::ConstantIndexOp>(1);
81   Value broadcastedDim = one;
82   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
83     Value shape = std::get<0>(tup);
84     Value rankDiff = std::get<1>(tup);
85     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
86                                                  outputDimension, rankDiff);
87     Type indexTy = lb.getIndexType();
88     broadcastedDim =
89         lb.create<IfOp>(
90               TypeRange{indexTy}, outOfBounds,
91               [&](OpBuilder &b, Location loc) {
92                 b.create<scf::YieldOp>(loc, broadcastedDim);
93               },
94               [&](OpBuilder &b, Location loc) {
95                 // The broadcasting logic is:
96                 // - if one extent (here we arbitrarily choose the
97                 // extent from the greater-rank operand) is equal to 1,
98                 // then take the extent from the other operand
99                 // - otherwise, take the extent as-is.
100                 // Note that this logic remains correct in the presence
101                 // of dimensions of zero extent.
102                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
103                     loc, indexTy, outputDimension, rankDiff);
104                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
105                     loc, shape, ValueRange{lesserRankOperandDimension});
106 
107                 Value dimIsOne =
108                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
109                                             lesserRankOperandExtent, one);
110                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
111                                                lesserRankOperandExtent);
112                 b.create<scf::YieldOp>(loc, dim);
113               })
114             .getResult(0);
115   }
116   return broadcastedDim;
117 }
118 } // namespace
119 
120 LogicalResult BroadcastOpConverter::matchAndRewrite(
121     BroadcastOp op, OpAdaptor adaptor,
122     ConversionPatternRewriter &rewriter) const {
123   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
124   // on shapes.
125   if (op.getType().isa<ShapeType>())
126     return failure();
127 
128   auto loc = op.getLoc();
129   ImplicitLocOpBuilder lb(loc, rewriter);
130 
131   Value zero = lb.create<arith::ConstantIndexOp>(0);
132   Type indexTy = lb.getIndexType();
133 
134   // Save all the ranks for bounds checking. Because this is a tensor
135   // representing the shape extents, the rank is the extent of the only
136   // dimension in the tensor.
137   SmallVector<Value> ranks, rankDiffs;
138   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
139                        return lb.create<tensor::DimOp>(v, zero);
140                      }));
141 
142   // Find the maximum rank
143   Value maxRank = ranks.front();
144   for (Value v : llvm::drop_begin(ranks, 1)) {
145     Value rankIsGreater =
146         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
147     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
148   }
149 
150   // Calculate the difference of ranks and the maximum rank for later offsets.
151   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
152                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
153                      }));
154 
155   Value replacement = lb.create<tensor::GenerateOp>(
156       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
157       [&](OpBuilder &b, Location loc, ValueRange args) {
158         Value broadcastedDim =
159             getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
160                               rankDiffs, args[0]);
161 
162         b.create<tensor::YieldOp>(loc, broadcastedDim);
163       });
164   if (replacement.getType() != op.getType())
165     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
166   rewriter.replaceOp(op, replacement);
167   return success();
168 }
169 
170 namespace {
171 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
172 public:
173   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
174 
175   LogicalResult
176   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
177                   ConversionPatternRewriter &rewriter) const override;
178 };
179 } // namespace
180 
181 LogicalResult ConstShapeOpConverter::matchAndRewrite(
182     ConstShapeOp op, OpAdaptor adaptor,
183     ConversionPatternRewriter &rewriter) const {
184 
185   // For now, this lowering supports only extent tensors, not `shape.shape`
186   // types.
187   if (op.getType().isa<ShapeType>())
188     return failure();
189 
190   auto loc = op.getLoc();
191   SmallVector<Value, 4> extentOperands;
192   for (auto extent : op.getShape()) {
193     extentOperands.push_back(
194         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
195   }
196   Type indexTy = rewriter.getIndexType();
197   Value tensor =
198       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
199   Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy);
200   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
201   return success();
202 }
203 
204 namespace {
205 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
206 public:
207   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
208 
209   LogicalResult
210   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
211                   ConversionPatternRewriter &rewriter) const override;
212 };
213 } // namespace
214 
215 LogicalResult ConstSizeOpConversion::matchAndRewrite(
216     ConstSizeOp op, OpAdaptor adaptor,
217     ConversionPatternRewriter &rewriter) const {
218   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
219       op, op.getValue().getSExtValue());
220   return success();
221 }
222 
223 namespace {
224 struct IsBroadcastableOpConverter
225     : public OpConversionPattern<IsBroadcastableOp> {
226   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227 
228   LogicalResult
229   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
230                   ConversionPatternRewriter &rewriter) const override;
231 };
232 } // namespace
233 
234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235     IsBroadcastableOp op, OpAdaptor adaptor,
236     ConversionPatternRewriter &rewriter) const {
237   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238   // on shapes.
239   if (!llvm::all_of(op.getShapes(),
240                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
241     return failure();
242 
243   auto loc = op.getLoc();
244   ImplicitLocOpBuilder lb(loc, rewriter);
245   Value zero = lb.create<arith::ConstantIndexOp>(0);
246   Value one = lb.create<arith::ConstantIndexOp>(1);
247   Type indexTy = lb.getIndexType();
248 
249   // Save all the ranks for bounds checking. Because this is a tensor
250   // representing the shape extents, the rank is the extent of the only
251   // dimension in the tensor.
252   SmallVector<Value> ranks, rankDiffs;
253   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
254                        return lb.create<tensor::DimOp>(v, zero);
255                      }));
256 
257   // Find the maximum rank
258   Value maxRank = ranks.front();
259   for (Value v : llvm::drop_begin(ranks, 1)) {
260     Value rankIsGreater =
261         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
262     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
263   }
264 
265   // Calculate the difference of ranks and the maximum rank for later offsets.
266   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
268                      }));
269 
270   Type i1Ty = rewriter.getI1Type();
271   Value trueVal =
272       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273 
274   auto reduceResult = lb.create<ForOp>(
275       loc, zero, maxRank, one, ValueRange{trueVal},
276       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
277         // Find a non-1 dim, if it exists. Note that the first part of this
278         // could reuse the Broadcast lowering entirely, but we redo the work
279         // here to make optimizations easier between the two loops.
280         Value broadcastedDim = getBroadcastedDim(
281             ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
282 
283         Value broadcastable = iterArgs[0];
284         for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285           Value shape, rankDiff;
286           std::tie(shape, rankDiff) = tup;
287           Value outOfBounds = b.create<arith::CmpIOp>(
288               loc, arith::CmpIPredicate::ult, iv, rankDiff);
289           broadcastable =
290               b.create<IfOp>(
291                    loc, TypeRange{i1Ty}, outOfBounds,
292                    [&](OpBuilder &b, Location loc) {
293                      // Non existent dimensions are always broadcastable
294                      b.create<scf::YieldOp>(loc, broadcastable);
295                    },
296                    [&](OpBuilder &b, Location loc) {
297                      // Every value needs to be either 1, or the same non-1
298                      // value to be broadcastable in this dim.
299                      Value operandDimension =
300                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
301                      Value dimensionExtent = b.create<tensor::ExtractOp>(
302                          loc, shape, ValueRange{operandDimension});
303 
304                      Value equalOne = b.create<arith::CmpIOp>(
305                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306                      Value equalBroadcasted = b.create<arith::CmpIOp>(
307                          loc, arith::CmpIPredicate::eq, dimensionExtent,
308                          broadcastedDim);
309                      Value result = b.create<arith::AndIOp>(
310                          loc, broadcastable,
311                          b.create<arith::OrIOp>(loc, equalOne,
312                                                 equalBroadcasted));
313                      b.create<scf::YieldOp>(loc, result);
314                    })
315                   .getResult(0);
316         }
317 
318         b.create<scf::YieldOp>(loc, broadcastable);
319       });
320 
321   rewriter.replaceOp(op, reduceResult.results().front());
322   return success();
323 }
324 
325 namespace {
326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
327   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override;
332 };
333 } // namespace
334 
335 LogicalResult GetExtentOpConverter::matchAndRewrite(
336     GetExtentOp op, OpAdaptor adaptor,
337     ConversionPatternRewriter &rewriter) const {
338   // For now, only error-free types are supported by this lowering.
339   if (op.getType().isa<SizeType>())
340     return failure();
341 
342   // Derive shape extent directly from shape origin if possible. This
343   // circumvents the necessity to materialize the shape in memory.
344   if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
345     if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
346       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
347                                                  adaptor.getDim());
348       return success();
349     }
350   }
351 
352   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
353                                                  adaptor.getShape(),
354                                                  ValueRange{adaptor.getDim()});
355   return success();
356 }
357 
358 namespace {
359 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
360 public:
361   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
362 
363   LogicalResult
364   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
365                   ConversionPatternRewriter &rewriter) const override;
366 };
367 } // namespace
368 
369 LogicalResult
370 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
371                                  ConversionPatternRewriter &rewriter) const {
372   // For now, this lowering supports only error-free types.
373   if (op.getType().isa<SizeType>())
374     return failure();
375 
376   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
377   return success();
378 }
379 
380 namespace {
381 /// Converts `shape.reduce` to `scf.for`.
382 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
383 public:
384   using OpConversionPattern::OpConversionPattern;
385 
386   LogicalResult
387   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
388                   ConversionPatternRewriter &rewriter) const final;
389 };
390 } // namespace
391 
392 LogicalResult
393 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
394                                    ConversionPatternRewriter &rewriter) const {
395   // For now, this lowering is only defined on `tensor<?xindex>` operands.
396   if (op.getShape().getType().isa<ShapeType>())
397     return failure();
398 
399   auto loc = op.getLoc();
400 
401   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
402   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
403   Type indexTy = rewriter.getIndexType();
404   Value rank =
405       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
406 
407   auto loop = rewriter.create<scf::ForOp>(
408       loc, zero, rank, one, op.getInitVals(),
409       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
410         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
411 
412         SmallVector<Value, 2> mappedValues{iv, extent};
413         mappedValues.append(args.begin(), args.end());
414 
415         BlockAndValueMapping mapping;
416         Block *reduceBody = op.getBody();
417         mapping.map(reduceBody->getArguments(), mappedValues);
418         for (auto &nested : reduceBody->without_terminator())
419           b.clone(nested, mapping);
420 
421         SmallVector<Value, 2> mappedResults;
422         for (auto result : reduceBody->getTerminator()->getOperands())
423           mappedResults.push_back(mapping.lookup(result));
424         b.create<scf::YieldOp>(loc, mappedResults);
425       });
426 
427   rewriter.replaceOp(op, loop.getResults());
428   return success();
429 }
430 
431 namespace {
432 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
433 /// only defined on `tensor<?xindex>` operands. The test for equality first
434 /// compares their size and, if equal, checks every extent for equality.
435 ///
436 /// Example:
437 ///
438 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
439 ///
440 /// becomes
441 ///
442 /// %c0 = arith.constant 0 : index
443 /// %0 = dim %arg0, %c0 : tensor<?xindex>
444 /// %1 = dim %arg1, %c0 : tensor<?xindex>
445 /// %2 = arith.cmpi "eq", %0, %1 : index
446 /// %result = scf.if %2 -> (i1) {
447 ///   %c1 = arith.constant 1 : index
448 ///   %true = arith.constant true
449 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
450 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
451 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
452 ///     %7 = arith.cmpi "eq", %5, %6 : index
453 ///     %8 = arith.andi %arg3, %7 : i1
454 ///     scf.yield %8 : i1
455 ///   }
456 ///   scf.yield %4 : i1
457 /// } else {
458 ///   %false = arith.constant false
459 ///   scf.yield %false : i1
460 /// }
461 ///
462 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
463   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
464 
465   LogicalResult
466   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
467                   ConversionPatternRewriter &rewriter) const override;
468 };
469 } // namespace
470 
471 LogicalResult
472 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
473                                     ConversionPatternRewriter &rewriter) const {
474   if (!llvm::all_of(op.getShapes(),
475                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
476     return failure();
477 
478   Type i1Ty = rewriter.getI1Type();
479   if (op.getShapes().size() <= 1) {
480     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
481                                                    rewriter.getBoolAttr(true));
482     return success();
483   }
484 
485   auto loc = op.getLoc();
486   Type indexTy = rewriter.getIndexType();
487   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
488   Value firstShape = adaptor.getShapes().front();
489   Value firstRank =
490       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
491   Value result = nullptr;
492   // Generate a linear sequence of compares, all with firstShape as lhs.
493   for (Value shape : adaptor.getShapes().drop_front(1)) {
494     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
495     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
496                                                   firstRank, rank);
497     auto same = rewriter.create<IfOp>(
498         loc, i1Ty, eqRank,
499         [&](OpBuilder &b, Location loc) {
500           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
501           Value init =
502               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
503           auto loop = b.create<scf::ForOp>(
504               loc, zero, firstRank, one, ValueRange{init},
505               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
506                 Value conj = args[0];
507                 Value lhsExtent =
508                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
509                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
510                 Value eqExtent = b.create<arith::CmpIOp>(
511                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
512                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
513                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
514               });
515           b.create<scf::YieldOp>(loc, loop.getResults());
516         },
517         [&](OpBuilder &b, Location loc) {
518           Value result =
519               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
520           b.create<scf::YieldOp>(loc, result);
521         });
522     result = !result ? same.getResult(0)
523                      : rewriter.create<arith::AndIOp>(loc, result,
524                                                       same.getResult(0));
525   }
526   rewriter.replaceOp(op, result);
527   return success();
528 }
529 
530 namespace {
531 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532 public:
533   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534 
535   LogicalResult
536   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
537                   ConversionPatternRewriter &rewriter) const override;
538 };
539 } // namespace
540 
541 LogicalResult ShapeOfOpConversion::matchAndRewrite(
542     ShapeOfOp op, OpAdaptor adaptor,
543     ConversionPatternRewriter &rewriter) const {
544 
545   // For now, only error-free types are supported by this lowering.
546   if (op.getType().isa<ShapeType>())
547     return failure();
548 
549   // For ranked tensor arguments, lower to `tensor.from_elements`.
550   auto loc = op.getLoc();
551   Value tensor = adaptor.getArg();
552   Type tensorTy = tensor.getType();
553   if (tensorTy.isa<RankedTensorType>()) {
554 
555     // Build values for individual extents.
556     SmallVector<Value, 8> extentValues;
557     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
558     int64_t rank = rankedTensorTy.getRank();
559     for (int64_t i = 0; i < rank; i++) {
560       if (rankedTensorTy.isDynamicDim(i)) {
561         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
562         extentValues.push_back(extent);
563       } else {
564         Value extent = rewriter.create<arith::ConstantIndexOp>(
565             loc, rankedTensorTy.getDimSize(i));
566         extentValues.push_back(extent);
567       }
568     }
569 
570     // Materialize extent tensor.
571     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
572         loc, rewriter.getIndexType(), extentValues);
573     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
574                                                 staticExtentTensor);
575     return success();
576   }
577 
578   // Lower to `tensor.generate` otherwise.
579   auto *ctx = rewriter.getContext();
580   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
581   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
582       op, getExtentTensorType(ctx), ValueRange{rank},
583       [&](OpBuilder &b, Location loc, ValueRange args) {
584         Value dim = args.front();
585         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
586         b.create<tensor::YieldOp>(loc, extent);
587       });
588 
589   return success();
590 }
591 
592 namespace {
593 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
594 public:
595   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
596 
597   LogicalResult
598   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
599                   ConversionPatternRewriter &rewriter) const override;
600 };
601 } // namespace
602 
603 LogicalResult SplitAtOpConversion::matchAndRewrite(
604     SplitAtOp op, OpAdaptor adaptor,
605     ConversionPatternRewriter &rewriter) const {
606   // Error conditions are not implemented, only lower if all operands and
607   // results are extent tensors.
608   if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
609                    [](Value v) { return v.getType().isa<ShapeType>(); }))
610     return failure();
611 
612   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
613   Value zero = b.create<arith::ConstantIndexOp>(0);
614   Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
615 
616   // index < 0 ? index + rank : index
617   Value originalIndex = adaptor.getIndex();
618   Value add = b.create<arith::AddIOp>(originalIndex, rank);
619   Value indexIsNegative =
620       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
621   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
622 
623   Value one = b.create<arith::ConstantIndexOp>(1);
624   Value head =
625       b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
626   Value tailSize = b.create<arith::SubIOp>(rank, index);
627   Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
628                                                 tailSize, one);
629   rewriter.replaceOp(op, {head, tail});
630   return success();
631 }
632 
633 namespace {
634 class ToExtentTensorOpConversion
635     : public OpConversionPattern<ToExtentTensorOp> {
636 public:
637   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
638 
639   LogicalResult
640   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
641                   ConversionPatternRewriter &rewriter) const override {
642     if (!adaptor.getInput().getType().isa<RankedTensorType>())
643       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
644 
645     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
646                                                 adaptor.getInput());
647     return success();
648   }
649 };
650 } // namespace
651 
652 namespace {
653 /// Import the Shape Ops to Std Patterns.
654 #include "ShapeToStandard.cpp.inc"
655 } // namespace
656 
657 namespace {
658 /// Conversion pass.
659 class ConvertShapeToStandardPass
660     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
661 
662   void runOnOperation() override;
663 };
664 } // namespace
665 
666 void ConvertShapeToStandardPass::runOnOperation() {
667   // Setup target legality.
668   MLIRContext &ctx = getContext();
669   ConversionTarget target(ctx);
670   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
671                          SCFDialect, tensor::TensorDialect>();
672   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
673 
674   // Setup conversion patterns.
675   RewritePatternSet patterns(&ctx);
676   populateShapeToStandardConversionPatterns(patterns);
677 
678   // Apply conversion.
679   auto module = getOperation();
680   if (failed(applyPartialConversion(module, target, std::move(patterns))))
681     signalPassFailure();
682 }
683 
684 void mlir::populateShapeToStandardConversionPatterns(
685     RewritePatternSet &patterns) {
686   // clang-format off
687   populateWithGenerated(patterns);
688   patterns.add<
689       AnyOpConversion,
690       BinaryOpConversion<AddOp, arith::AddIOp>,
691       BinaryOpConversion<MulOp, arith::MulIOp>,
692       BroadcastOpConverter,
693       ConstShapeOpConverter,
694       ConstSizeOpConversion,
695       IsBroadcastableOpConverter,
696       GetExtentOpConverter,
697       RankOpConverter,
698       ReduceOpConverter,
699       ShapeEqOpConverter,
700       ShapeOfOpConversion,
701       SplitAtOpConversion,
702       ToExtentTensorOpConversion>(patterns.getContext());
703   // clang-format on
704 }
705 
706 std::unique_ptr<OperationPass<ModuleOp>>
707 mlir::createConvertShapeToStandardPass() {
708   return std::make_unique<ConvertShapeToStandardPass>();
709 }
710