1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 using namespace mlir::scf;
25 
26 /// Conversion patterns.
27 namespace {
28 class AnyOpConversion : public OpConversionPattern<AnyOp> {
29 public:
30   using OpConversionPattern<AnyOp>::OpConversionPattern;
31 
32   LogicalResult
33   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
34                   ConversionPatternRewriter &rewriter) const override;
35 };
36 } // namespace
37 
38 LogicalResult
39 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
40                                  ConversionPatternRewriter &rewriter) const {
41   // Replace `any` with its first operand.
42   // Any operand would be a valid substitution.
43   rewriter.replaceOp(op, {adaptor.inputs().front()});
44   return success();
45 }
46 
47 namespace {
48 template <typename SrcOpTy, typename DstOpTy>
49 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
50 public:
51   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
52 
53   LogicalResult
54   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
55                   ConversionPatternRewriter &rewriter) const override {
56     // For now, only error-free types are supported by this lowering.
57     if (op.getType().template isa<SizeType>())
58       return failure();
59 
60     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
61     return success();
62   }
63 };
64 } // namespace
65 
66 namespace {
67 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
68   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override;
73 };
74 
75 // Get the resulting extent in a given dimension. This is computed with any
76 // number of extent tensors and shifted offsets into them.
77 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
78                         ValueRange rankDiffs, Value outputDimension) {
79   Value one = lb.create<arith::ConstantIndexOp>(1);
80   Value broadcastedDim = one;
81   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
82     Value shape = std::get<0>(tup);
83     Value rankDiff = std::get<1>(tup);
84     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
85                                                  outputDimension, rankDiff);
86     Type indexTy = lb.getIndexType();
87     broadcastedDim =
88         lb.create<IfOp>(
89               TypeRange{indexTy}, outOfBounds,
90               [&](OpBuilder &b, Location loc) {
91                 b.create<scf::YieldOp>(loc, broadcastedDim);
92               },
93               [&](OpBuilder &b, Location loc) {
94                 // The broadcasting logic is:
95                 // - if one extent (here we arbitrarily choose the
96                 // extent from the greater-rank operand) is equal to 1,
97                 // then take the extent from the other operand
98                 // - otherwise, take the extent as-is.
99                 // Note that this logic remains correct in the presence
100                 // of dimensions of zero extent.
101                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
102                     loc, indexTy, outputDimension, rankDiff);
103                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
104                     loc, shape, ValueRange{lesserRankOperandDimension});
105 
106                 Value dimIsOne =
107                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
108                                             lesserRankOperandExtent, one);
109                 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
110                                                lesserRankOperandExtent);
111                 b.create<scf::YieldOp>(loc, dim);
112               })
113             .getResult(0);
114   }
115   return broadcastedDim;
116 }
117 } // namespace
118 
119 LogicalResult BroadcastOpConverter::matchAndRewrite(
120     BroadcastOp op, OpAdaptor adaptor,
121     ConversionPatternRewriter &rewriter) const {
122   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
123   // on shapes.
124   if (op.getType().isa<ShapeType>())
125     return failure();
126 
127   auto loc = op.getLoc();
128   ImplicitLocOpBuilder lb(loc, rewriter);
129 
130   Value zero = lb.create<arith::ConstantIndexOp>(0);
131   Type indexTy = lb.getIndexType();
132 
133   // Save all the ranks for bounds checking. Because this is a tensor
134   // representing the shape extents, the rank is the extent of the only
135   // dimension in the tensor.
136   SmallVector<Value> ranks, rankDiffs;
137   llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
138                        return lb.create<tensor::DimOp>(v, zero);
139                      }));
140 
141   // Find the maximum rank
142   Value maxRank = ranks.front();
143   for (Value v : llvm::drop_begin(ranks, 1)) {
144     Value rankIsGreater =
145         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
146     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
147   }
148 
149   // Calculate the difference of ranks and the maximum rank for later offsets.
150   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
151                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
152                      }));
153 
154   Value replacement = lb.create<tensor::GenerateOp>(
155       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
156       [&](OpBuilder &b, Location loc, ValueRange args) {
157         Value broadcastedDim = getBroadcastedDim(
158             ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]);
159 
160         b.create<tensor::YieldOp>(loc, broadcastedDim);
161       });
162   if (replacement.getType() != op.getType())
163     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
164   rewriter.replaceOp(op, replacement);
165   return success();
166 }
167 
168 namespace {
169 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
170 public:
171   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
172 
173   LogicalResult
174   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override;
176 };
177 } // namespace
178 
179 LogicalResult ConstShapeOpConverter::matchAndRewrite(
180     ConstShapeOp op, OpAdaptor adaptor,
181     ConversionPatternRewriter &rewriter) const {
182 
183   // For now, this lowering supports only extent tensors, not `shape.shape`
184   // types.
185   if (op.getType().isa<ShapeType>())
186     return failure();
187 
188   auto loc = op.getLoc();
189   SmallVector<Value, 4> extentOperands;
190   for (auto extent : op.shape()) {
191     extentOperands.push_back(
192         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
193   }
194   Type indexTy = rewriter.getIndexType();
195   Value tensor =
196       rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
197   Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy);
198   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
199   return success();
200 }
201 
202 namespace {
203 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
204 public:
205   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
206 
207   LogicalResult
208   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
209                   ConversionPatternRewriter &rewriter) const override;
210 };
211 } // namespace
212 
213 LogicalResult ConstSizeOpConversion::matchAndRewrite(
214     ConstSizeOp op, OpAdaptor adaptor,
215     ConversionPatternRewriter &rewriter) const {
216   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
217       op, op.value().getSExtValue());
218   return success();
219 }
220 
221 namespace {
222 struct IsBroadcastableOpConverter
223     : public OpConversionPattern<IsBroadcastableOp> {
224   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
225 
226   LogicalResult
227   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
228                   ConversionPatternRewriter &rewriter) const override;
229 };
230 } // namespace
231 
232 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
233     IsBroadcastableOp op, OpAdaptor adaptor,
234     ConversionPatternRewriter &rewriter) const {
235   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
236   // on shapes.
237   if (!llvm::all_of(op.shapes(),
238                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
239     return failure();
240 
241   auto loc = op.getLoc();
242   ImplicitLocOpBuilder lb(loc, rewriter);
243   Value zero = lb.create<arith::ConstantIndexOp>(0);
244   Value one = lb.create<arith::ConstantIndexOp>(1);
245   Type indexTy = lb.getIndexType();
246 
247   // Save all the ranks for bounds checking. Because this is a tensor
248   // representing the shape extents, the rank is the extent of the only
249   // dimension in the tensor.
250   SmallVector<Value> ranks, rankDiffs;
251   llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
252                        return lb.create<tensor::DimOp>(v, zero);
253                      }));
254 
255   // Find the maximum rank
256   Value maxRank = ranks.front();
257   for (Value v : llvm::drop_begin(ranks, 1)) {
258     Value rankIsGreater =
259         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
260     maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
261   }
262 
263   // Calculate the difference of ranks and the maximum rank for later offsets.
264   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
265                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
266                      }));
267 
268   Type i1Ty = rewriter.getI1Type();
269   Value trueVal =
270       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
271 
272   auto reduceResult = lb.create<ForOp>(
273       loc, zero, maxRank, one, ValueRange{trueVal},
274       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
275         // Find a non-1 dim, if it exists. Note that the first part of this
276         // could reuse the Broadcast lowering entirely, but we redo the work
277         // here to make optimizations easier between the two loops.
278         Value broadcastedDim = getBroadcastedDim(
279             ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv);
280 
281         Value broadcastable = iterArgs[0];
282         for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) {
283           Value shape, rankDiff;
284           std::tie(shape, rankDiff) = tup;
285           Value outOfBounds = b.create<arith::CmpIOp>(
286               loc, arith::CmpIPredicate::ult, iv, rankDiff);
287           broadcastable =
288               b.create<IfOp>(
289                    loc, TypeRange{i1Ty}, outOfBounds,
290                    [&](OpBuilder &b, Location loc) {
291                      // Non existent dimensions are always broadcastable
292                      b.create<scf::YieldOp>(loc, broadcastable);
293                    },
294                    [&](OpBuilder &b, Location loc) {
295                      // Every value needs to be either 1, or the same non-1
296                      // value to be broadcastable in this dim.
297                      Value operandDimension =
298                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
299                      Value dimensionExtent = b.create<tensor::ExtractOp>(
300                          loc, shape, ValueRange{operandDimension});
301 
302                      Value equalOne = b.create<arith::CmpIOp>(
303                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
304                      Value equalBroadcasted = b.create<arith::CmpIOp>(
305                          loc, arith::CmpIPredicate::eq, dimensionExtent,
306                          broadcastedDim);
307                      Value result = b.create<arith::AndIOp>(
308                          loc, broadcastable,
309                          b.create<arith::OrIOp>(loc, equalOne,
310                                                 equalBroadcasted));
311                      b.create<scf::YieldOp>(loc, result);
312                    })
313                   .getResult(0);
314         }
315 
316         b.create<scf::YieldOp>(loc, broadcastable);
317       });
318 
319   rewriter.replaceOp(op, reduceResult.results().front());
320   return success();
321 }
322 
323 namespace {
324 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
325   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
326 
327   LogicalResult
328   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
329                   ConversionPatternRewriter &rewriter) const override;
330 };
331 } // namespace
332 
333 LogicalResult GetExtentOpConverter::matchAndRewrite(
334     GetExtentOp op, OpAdaptor adaptor,
335     ConversionPatternRewriter &rewriter) const {
336   // For now, only error-free types are supported by this lowering.
337   if (op.getType().isa<SizeType>())
338     return failure();
339 
340   // Derive shape extent directly from shape origin if possible. This
341   // circumvents the necessity to materialize the shape in memory.
342   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
343     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
344       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(),
345                                                  adaptor.dim());
346       return success();
347     }
348   }
349 
350   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
351       op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()});
352   return success();
353 }
354 
355 namespace {
356 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
357 public:
358   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
359 
360   LogicalResult
361   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
362                   ConversionPatternRewriter &rewriter) const override;
363 };
364 } // namespace
365 
366 LogicalResult
367 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
368                                  ConversionPatternRewriter &rewriter) const {
369   // For now, this lowering supports only error-free types.
370   if (op.getType().isa<SizeType>())
371     return failure();
372 
373   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.shape(), 0);
374   return success();
375 }
376 
377 namespace {
378 /// Converts `shape.reduce` to `scf.for`.
379 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
380 public:
381   using OpConversionPattern::OpConversionPattern;
382 
383   LogicalResult
384   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
385                   ConversionPatternRewriter &rewriter) const final;
386 };
387 } // namespace
388 
389 LogicalResult
390 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
391                                    ConversionPatternRewriter &rewriter) const {
392   // For now, this lowering is only defined on `tensor<?xindex>` operands.
393   if (op.shape().getType().isa<ShapeType>())
394     return failure();
395 
396   auto loc = op.getLoc();
397 
398   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
399   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
400   Type indexTy = rewriter.getIndexType();
401   Value rank =
402       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.shape(), zero);
403 
404   auto loop = rewriter.create<scf::ForOp>(
405       loc, zero, rank, one, op.initVals(),
406       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
407         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.shape(), iv);
408 
409         SmallVector<Value, 2> mappedValues{iv, extent};
410         mappedValues.append(args.begin(), args.end());
411 
412         BlockAndValueMapping mapping;
413         Block *reduceBody = op.getBody();
414         mapping.map(reduceBody->getArguments(), mappedValues);
415         for (auto &nested : reduceBody->without_terminator())
416           b.clone(nested, mapping);
417 
418         SmallVector<Value, 2> mappedResults;
419         for (auto result : reduceBody->getTerminator()->getOperands())
420           mappedResults.push_back(mapping.lookup(result));
421         b.create<scf::YieldOp>(loc, mappedResults);
422       });
423 
424   rewriter.replaceOp(op, loop.getResults());
425   return success();
426 }
427 
428 namespace {
429 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
430 /// only defined on `tensor<?xindex>` operands. The test for equality first
431 /// compares their size and, if equal, checks every extent for equality.
432 ///
433 /// Example:
434 ///
435 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
436 ///
437 /// becomes
438 ///
439 /// %c0 = constant 0 : index
440 /// %0 = dim %arg0, %c0 : tensor<?xindex>
441 /// %1 = dim %arg1, %c0 : tensor<?xindex>
442 /// %2 = arith.cmpi "eq", %0, %1 : index
443 /// %result = scf.if %2 -> (i1) {
444 ///   %c1 = arith.constant 1 : index
445 ///   %true = arith.constant true
446 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
447 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
448 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
449 ///     %7 = arith.cmpi "eq", %5, %6 : index
450 ///     %8 = arith.andi %arg3, %7 : i1
451 ///     scf.yield %8 : i1
452 ///   }
453 ///   scf.yield %4 : i1
454 /// } else {
455 ///   %false = arith.constant false
456 ///   scf.yield %false : i1
457 /// }
458 ///
459 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
460   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
461 
462   LogicalResult
463   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override;
465 };
466 } // namespace
467 
468 LogicalResult
469 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
470                                     ConversionPatternRewriter &rewriter) const {
471   if (!llvm::all_of(op.shapes(),
472                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
473     return failure();
474 
475   Type i1Ty = rewriter.getI1Type();
476   if (op.shapes().size() <= 1) {
477     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
478                                                    rewriter.getBoolAttr(true));
479     return success();
480   }
481 
482   auto loc = op.getLoc();
483   Type indexTy = rewriter.getIndexType();
484   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
485   Value firstShape = adaptor.shapes().front();
486   Value firstRank =
487       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
488   Value result = nullptr;
489   // Generate a linear sequence of compares, all with firstShape as lhs.
490   for (Value shape : adaptor.shapes().drop_front(1)) {
491     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
492     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
493                                                   firstRank, rank);
494     auto same = rewriter.create<IfOp>(
495         loc, i1Ty, eqRank,
496         [&](OpBuilder &b, Location loc) {
497           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
498           Value init =
499               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
500           auto loop = b.create<scf::ForOp>(
501               loc, zero, firstRank, one, ValueRange{init},
502               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
503                 Value conj = args[0];
504                 Value lhsExtent =
505                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
506                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
507                 Value eqExtent = b.create<arith::CmpIOp>(
508                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
509                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
510                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
511               });
512           b.create<scf::YieldOp>(loc, loop.getResults());
513         },
514         [&](OpBuilder &b, Location loc) {
515           Value result =
516               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
517           b.create<scf::YieldOp>(loc, result);
518         });
519     result = !result ? same.getResult(0)
520                      : rewriter.create<arith::AndIOp>(loc, result,
521                                                       same.getResult(0));
522   }
523   rewriter.replaceOp(op, result);
524   return success();
525 }
526 
527 namespace {
528 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
529 public:
530   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
531 
532   LogicalResult
533   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
534                   ConversionPatternRewriter &rewriter) const override;
535 };
536 } // namespace
537 
538 LogicalResult ShapeOfOpConversion::matchAndRewrite(
539     ShapeOfOp op, OpAdaptor adaptor,
540     ConversionPatternRewriter &rewriter) const {
541 
542   // For now, only error-free types are supported by this lowering.
543   if (op.getType().isa<ShapeType>())
544     return failure();
545 
546   // For ranked tensor arguments, lower to `tensor.from_elements`.
547   auto loc = op.getLoc();
548   Value tensor = adaptor.arg();
549   Type tensorTy = tensor.getType();
550   if (tensorTy.isa<RankedTensorType>()) {
551 
552     // Build values for individual extents.
553     SmallVector<Value, 8> extentValues;
554     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
555     int64_t rank = rankedTensorTy.getRank();
556     for (int64_t i = 0; i < rank; i++) {
557       if (rankedTensorTy.isDynamicDim(i)) {
558         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
559         extentValues.push_back(extent);
560       } else {
561         Value extent = rewriter.create<arith::ConstantIndexOp>(
562             loc, rankedTensorTy.getDimSize(i));
563         extentValues.push_back(extent);
564       }
565     }
566 
567     // Materialize extent tensor.
568     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
569         loc, rewriter.getIndexType(), extentValues);
570     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
571                                                 staticExtentTensor);
572     return success();
573   }
574 
575   // Lower to `tensor.generate` otherwise.
576   auto *ctx = rewriter.getContext();
577   Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
578   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
579       op, getExtentTensorType(ctx), ValueRange{rank},
580       [&](OpBuilder &b, Location loc, ValueRange args) {
581         Value dim = args.front();
582         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
583         b.create<tensor::YieldOp>(loc, extent);
584       });
585 
586   return success();
587 }
588 
589 namespace {
590 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
591 public:
592   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
593 
594   LogicalResult
595   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
596                   ConversionPatternRewriter &rewriter) const override;
597 };
598 } // namespace
599 
600 LogicalResult SplitAtOpConversion::matchAndRewrite(
601     SplitAtOp op, OpAdaptor adaptor,
602     ConversionPatternRewriter &rewriter) const {
603   // Error conditions are not implemented, only lower if all operands and
604   // results are extent tensors.
605   if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
606                    [](Value v) { return v.getType().isa<ShapeType>(); }))
607     return failure();
608 
609   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
610   Value zero = b.create<arith::ConstantIndexOp>(0);
611   Value rank = b.create<tensor::DimOp>(adaptor.operand(), zero);
612 
613   // index < 0 ? index + rank : index
614   Value originalIndex = adaptor.index();
615   Value add = b.create<arith::AddIOp>(originalIndex, rank);
616   Value indexIsNegative =
617       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
618   Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
619 
620   Value one = b.create<arith::ConstantIndexOp>(1);
621   Value head =
622       b.create<tensor::ExtractSliceOp>(adaptor.operand(), zero, index, one);
623   Value tailSize = b.create<arith::SubIOp>(rank, index);
624   Value tail =
625       b.create<tensor::ExtractSliceOp>(adaptor.operand(), index, tailSize, one);
626   rewriter.replaceOp(op, {head, tail});
627   return success();
628 }
629 
630 namespace {
631 class ToExtentTensorOpConversion
632     : public OpConversionPattern<ToExtentTensorOp> {
633 public:
634   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
635 
636   LogicalResult
637   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
638                   ConversionPatternRewriter &rewriter) const override {
639     if (!adaptor.input().getType().isa<RankedTensorType>())
640       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
641 
642     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
643                                                 adaptor.input());
644     return success();
645   }
646 };
647 } // namespace
648 
649 namespace {
650 /// Import the Shape Ops to Std Patterns.
651 #include "ShapeToStandard.cpp.inc"
652 } // namespace
653 
654 namespace {
655 /// Conversion pass.
656 class ConvertShapeToStandardPass
657     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
658 
659   void runOnOperation() override;
660 };
661 } // namespace
662 
663 void ConvertShapeToStandardPass::runOnOperation() {
664   // Setup target legality.
665   MLIRContext &ctx = getContext();
666   ConversionTarget target(ctx);
667   target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
668                          SCFDialect, tensor::TensorDialect>();
669   target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
670 
671   // Setup conversion patterns.
672   RewritePatternSet patterns(&ctx);
673   populateShapeToStandardConversionPatterns(patterns);
674 
675   // Apply conversion.
676   auto module = getOperation();
677   if (failed(applyPartialConversion(module, target, std::move(patterns))))
678     signalPassFailure();
679 }
680 
681 void mlir::populateShapeToStandardConversionPatterns(
682     RewritePatternSet &patterns) {
683   // clang-format off
684   populateWithGenerated(patterns);
685   patterns.add<
686       AnyOpConversion,
687       BinaryOpConversion<AddOp, arith::AddIOp>,
688       BinaryOpConversion<MulOp, arith::MulIOp>,
689       BroadcastOpConverter,
690       ConstShapeOpConverter,
691       ConstSizeOpConversion,
692       IsBroadcastableOpConverter,
693       GetExtentOpConverter,
694       RankOpConverter,
695       ReduceOpConverter,
696       ShapeEqOpConverter,
697       ShapeOfOpConversion,
698       SplitAtOpConversion,
699       ToExtentTensorOpConversion>(patterns.getContext());
700   // clang-format on
701 }
702 
703 std::unique_ptr<OperationPass<ModuleOp>>
704 mlir::createConvertShapeToStandardPass() {
705   return std::make_unique<ConvertShapeToStandardPass>();
706 }
707