//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"

using namespace mlir;
using namespace mlir::shape;
using namespace mlir::scf;

/// Conversion patterns.
namespace {
class AnyOpConversion : public OpConversionPattern<AnyOp> {
public:
  using OpConversionPattern<AnyOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(AnyOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult
AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
  // Replace `any` with its first operand.
  // Any operand would be a valid substitution.
  rewriter.replaceOp(op, {adaptor.getInputs().front()});
  return success();
}

namespace {
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
public:
  using OpConversionPattern<SrcOpTy>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // For now, only error-free types are supported by this lowering.
    if (op.getType().template isa<SizeType>())
      return failure();

    rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
                                         adaptor.getRhs());
    return success();
  }
};
} // namespace

namespace {
struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
  using OpConversionPattern<BroadcastOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

// Get the resulting extent in a given dimension. This is computed with any
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
                        ValueRange rankDiffs, Value outputDimension) {
  Value one = lb.create<arith::ConstantIndexOp>(1);
  Value broadcastedDim = one;
  for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
    Value shape = std::get<0>(tup);
    Value rankDiff = std::get<1>(tup);
    Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
                                                 outputDimension, rankDiff);
    Type indexTy = lb.getIndexType();
    broadcastedDim =
        lb.create<IfOp>(
              TypeRange{indexTy}, outOfBounds,
              [&](OpBuilder &b, Location loc) {
                b.create<scf::YieldOp>(loc, broadcastedDim);
              },
              [&](OpBuilder &b, Location loc) {
                // The broadcasting logic is:
                // - if one extent (here we arbitrarily choose the
                // extent from the greater-rank operand) is equal to 1,
                // then take the extent from the other operand
                // - otherwise, take the extent as-is.
                // Note that this logic remains correct in the presence
                // of dimensions of zero extent.
                Value lesserRankOperandDimension = b.create<arith::SubIOp>(
                    loc, indexTy, outputDimension, rankDiff);
                Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
                    loc, shape, ValueRange{lesserRankOperandDimension});

                Value dimIsOne =
                    b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                            lesserRankOperandExtent, one);
                Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
                                               lesserRankOperandExtent);
                b.create<scf::YieldOp>(loc, dim);
              })
            .getResult(0);
  }
  return broadcastedDim;
}
} // namespace

LogicalResult BroadcastOpConverter::matchAndRewrite(
    BroadcastOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
  // on shapes.
  if (op.getType().isa<ShapeType>())
    return failure();

  auto loc = op.getLoc();
  ImplicitLocOpBuilder lb(loc, rewriter);

  Value zero = lb.create<arith::ConstantIndexOp>(0);
  Type indexTy = lb.getIndexType();

  // Save all the ranks for bounds checking. Because this is a tensor
  // representing the shape extents, the rank is the extent of the only
  // dimension in the tensor.
  SmallVector<Value> ranks, rankDiffs;
  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
                       return lb.create<tensor::DimOp>(v, zero);
                     }));

  // Find the maximum rank
  Value maxRank = ranks.front();
  for (Value v : llvm::drop_begin(ranks, 1)) {
    Value rankIsGreater =
        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
    maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
  }

  // Calculate the difference of ranks and the maximum rank for later offsets.
  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
                       return lb.create<arith::SubIOp>(indexTy, maxRank, v);
                     }));

  Value replacement = lb.create<tensor::GenerateOp>(
      getExtentTensorType(lb.getContext()), ValueRange{maxRank},
      [&](OpBuilder &b, Location loc, ValueRange args) {
        Value broadcastedDim =
            getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
                              rankDiffs, args[0]);

        b.create<tensor::YieldOp>(loc, broadcastedDim);
      });
  if (replacement.getType() != op.getType())
    replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
  rewriter.replaceOp(op, replacement);
  return success();
}

namespace {
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
public:
  using OpConversionPattern<ConstShapeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult ConstShapeOpConverter::matchAndRewrite(
    ConstShapeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {

  // For now, this lowering supports only extent tensors, not `shape.shape`
  // types.
  if (op.getType().isa<ShapeType>())
    return failure();

  auto loc = op.getLoc();
  SmallVector<Value, 4> extentOperands;
  for (auto extent : op.getShape()) {
    extentOperands.push_back(
        rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
  }
  Type resultTy =
      RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
  Value tensor =
      rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
  return success();
}

namespace {
class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
public:
  using OpConversionPattern<ConstSizeOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult ConstSizeOpConversion::matchAndRewrite(
    ConstSizeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
      op, op.getValue().getSExtValue());
  return success();
}

namespace {
struct IsBroadcastableOpConverter
    : public OpConversionPattern<IsBroadcastableOp> {
  using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
    IsBroadcastableOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
  // on shapes.
  if (!llvm::all_of(op.getShapes(),
                    [](Value v) { return !v.getType().isa<ShapeType>(); }))
    return failure();

  auto loc = op.getLoc();
  ImplicitLocOpBuilder lb(loc, rewriter);
  Value zero = lb.create<arith::ConstantIndexOp>(0);
  Value one = lb.create<arith::ConstantIndexOp>(1);
  Type indexTy = lb.getIndexType();

  // Save all the ranks for bounds checking. Because this is a tensor
  // representing the shape extents, the rank is the extent of the only
  // dimension in the tensor.
  SmallVector<Value> ranks, rankDiffs;
  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
                       return lb.create<tensor::DimOp>(v, zero);
                     }));

  // Find the maximum rank
  Value maxRank = ranks.front();
  for (Value v : llvm::drop_begin(ranks, 1)) {
    Value rankIsGreater =
        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
    maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
  }

  // Calculate the difference of ranks and the maximum rank for later offsets.
  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
                       return lb.create<arith::SubIOp>(indexTy, maxRank, v);
                     }));

  Type i1Ty = rewriter.getI1Type();
  Value trueVal =
      rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));

  auto reduceResult = lb.create<ForOp>(
      loc, zero, maxRank, one, ValueRange{trueVal},
      [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
        // Find a non-1 dim, if it exists. Note that the first part of this
        // could reuse the Broadcast lowering entirely, but we redo the work
        // here to make optimizations easier between the two loops.
        Value broadcastedDim = getBroadcastedDim(
            ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);

        Value broadcastable = iterArgs[0];
        for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
          Value shape, rankDiff;
          std::tie(shape, rankDiff) = tup;
          Value outOfBounds = b.create<arith::CmpIOp>(
              loc, arith::CmpIPredicate::ult, iv, rankDiff);
          broadcastable =
              b.create<IfOp>(
                   loc, TypeRange{i1Ty}, outOfBounds,
                   [&](OpBuilder &b, Location loc) {
                     // Non existent dimensions are always broadcastable
                     b.create<scf::YieldOp>(loc, broadcastable);
                   },
                   [&](OpBuilder &b, Location loc) {
                     // Every value needs to be either 1, or the same non-1
                     // value to be broadcastable in this dim.
                     Value operandDimension =
                         b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
                     Value dimensionExtent = b.create<tensor::ExtractOp>(
                         loc, shape, ValueRange{operandDimension});

                     Value equalOne = b.create<arith::CmpIOp>(
                         loc, arith::CmpIPredicate::eq, dimensionExtent, one);
                     Value equalBroadcasted = b.create<arith::CmpIOp>(
                         loc, arith::CmpIPredicate::eq, dimensionExtent,
                         broadcastedDim);
                     Value result = b.create<arith::AndIOp>(
                         loc, broadcastable,
                         b.create<arith::OrIOp>(loc, equalOne,
                                                equalBroadcasted));
                     b.create<scf::YieldOp>(loc, result);
                   })
                  .getResult(0);
        }

        b.create<scf::YieldOp>(loc, broadcastable);
      });

  rewriter.replaceOp(op, reduceResult.getResults().front());
  return success();
}

namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
  using OpConversionPattern<GetExtentOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult GetExtentOpConverter::matchAndRewrite(
    GetExtentOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // For now, only error-free types are supported by this lowering.
  if (op.getType().isa<SizeType>())
    return failure();

  // Derive shape extent directly from shape origin if possible. This
  // circumvents the necessity to materialize the shape in memory.
  if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
    if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
      rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
                                                 adaptor.getDim());
      return success();
    }
  }

  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
                                                 adaptor.getShape(),
                                                 ValueRange{adaptor.getDim()});
  return success();
}

namespace {
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
  using OpConversionPattern<shape::RankOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult
RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
  // For now, this lowering supports only error-free types.
  if (op.getType().isa<SizeType>())
    return failure();

  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
  return success();
}

namespace {
/// Converts `shape.reduce` to `scf.for`.
struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
  // For now, this lowering is only defined on `tensor<?xindex>` operands.
  if (op.getShape().getType().isa<ShapeType>())
    return failure();

  auto loc = op.getLoc();

  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
  Type indexTy = rewriter.getIndexType();
  Value rank =
      rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);

  auto loop = rewriter.create<scf::ForOp>(
      loc, zero, rank, one, op.getInitVals(),
      [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
        Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);

        SmallVector<Value, 2> mappedValues{iv, extent};
        mappedValues.append(args.begin(), args.end());

        BlockAndValueMapping mapping;
        Block *reduceBody = op.getBody();
        mapping.map(reduceBody->getArguments(), mappedValues);
        for (auto &nested : reduceBody->without_terminator())
          b.clone(nested, mapping);

        SmallVector<Value, 2> mappedResults;
        for (auto result : reduceBody->getTerminator()->getOperands())
          mappedResults.push_back(mapping.lookup(result));
        b.create<scf::YieldOp>(loc, mappedResults);
      });

  rewriter.replaceOp(op, loop.getResults());
  return success();
}

namespace {
/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
/// only defined on `tensor<?xindex>` operands. The test for equality first
/// compares their size and, if equal, checks every extent for equality.
///
/// Example:
///
/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
///
/// becomes
///
/// %c0 = arith.constant 0 : index
/// %0 = dim %arg0, %c0 : tensor<?xindex>
/// %1 = dim %arg1, %c0 : tensor<?xindex>
/// %2 = arith.cmpi "eq", %0, %1 : index
/// %result = scf.if %2 -> (i1) {
///   %c1 = arith.constant 1 : index
///   %true = arith.constant true
///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
///     %7 = arith.cmpi "eq", %5, %6 : index
///     %8 = arith.andi %arg3, %7 : i1
///     scf.yield %8 : i1
///   }
///   scf.yield %4 : i1
/// } else {
///   %false = arith.constant false
///   scf.yield %false : i1
/// }
///
struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
  using OpConversionPattern<ShapeEqOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
  if (!llvm::all_of(op.getShapes(),
                    [](Value v) { return !v.getType().isa<ShapeType>(); }))
    return failure();

  Type i1Ty = rewriter.getI1Type();
  if (op.getShapes().size() <= 1) {
    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
                                                   rewriter.getBoolAttr(true));
    return success();
  }

  auto loc = op.getLoc();
  Type indexTy = rewriter.getIndexType();
  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
  Value firstShape = adaptor.getShapes().front();
  Value firstRank =
      rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
  Value result = nullptr;
  // Generate a linear sequence of compares, all with firstShape as lhs.
  for (Value shape : adaptor.getShapes().drop_front(1)) {
    Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
    Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                                  firstRank, rank);
    auto same = rewriter.create<IfOp>(
        loc, i1Ty, eqRank,
        [&](OpBuilder &b, Location loc) {
          Value one = b.create<arith::ConstantIndexOp>(loc, 1);
          Value init =
              b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
          auto loop = b.create<scf::ForOp>(
              loc, zero, firstRank, one, ValueRange{init},
              [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
                Value conj = args[0];
                Value lhsExtent =
                    b.create<tensor::ExtractOp>(loc, firstShape, iv);
                Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
                Value eqExtent = b.create<arith::CmpIOp>(
                    loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
                Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
                b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
              });
          b.create<scf::YieldOp>(loc, loop.getResults());
        },
        [&](OpBuilder &b, Location loc) {
          Value result =
              b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
          b.create<scf::YieldOp>(loc, result);
        });
    result = !result ? same.getResult(0)
                     : rewriter.create<arith::AndIOp>(loc, result,
                                                      same.getResult(0));
  }
  rewriter.replaceOp(op, result);
  return success();
}

namespace {
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
public:
  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult ShapeOfOpConversion::matchAndRewrite(
    ShapeOfOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {

  // For now, only error-free types are supported by this lowering.
  if (op.getType().isa<ShapeType>())
    return failure();

  // For ranked tensor arguments, lower to `tensor.from_elements`.
  auto loc = op.getLoc();
  Value tensor = adaptor.getArg();
  Type tensorTy = tensor.getType();
  if (tensorTy.isa<RankedTensorType>()) {

    // Build values for individual extents.
    SmallVector<Value, 8> extentValues;
    RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
    int64_t rank = rankedTensorTy.getRank();
    for (int64_t i = 0; i < rank; i++) {
      if (rankedTensorTy.isDynamicDim(i)) {
        Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
        extentValues.push_back(extent);
      } else {
        Value extent = rewriter.create<arith::ConstantIndexOp>(
            loc, rankedTensorTy.getDimSize(i));
        extentValues.push_back(extent);
      }
    }

    // Materialize extent tensor.
    Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
        loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
        extentValues);
    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
                                                staticExtentTensor);
    return success();
  }

  // Lower to `tensor.generate` otherwise.
  auto *ctx = rewriter.getContext();
  Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
  rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
      op, getExtentTensorType(ctx), ValueRange{rank},
      [&](OpBuilder &b, Location loc, ValueRange args) {
        Value dim = args.front();
        Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
        b.create<tensor::YieldOp>(loc, extent);
      });

  return success();
}

namespace {
class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
public:
  using OpConversionPattern<SplitAtOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult SplitAtOpConversion::matchAndRewrite(
    SplitAtOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  // Error conditions are not implemented, only lower if all operands and
  // results are extent tensors.
  if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
                   [](Value v) { return v.getType().isa<ShapeType>(); }))
    return failure();

  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
  Value zero = b.create<arith::ConstantIndexOp>(0);
  Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);

  // index < 0 ? index + rank : index
  Value originalIndex = adaptor.getIndex();
  Value add = b.create<arith::AddIOp>(originalIndex, rank);
  Value indexIsNegative =
      b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
  Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);

  Value one = b.create<arith::ConstantIndexOp>(1);
  Value head =
      b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
  Value tailSize = b.create<arith::SubIOp>(rank, index);
  Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
                                                tailSize, one);
  rewriter.replaceOp(op, {head, tail});
  return success();
}

namespace {
class ToExtentTensorOpConversion
    : public OpConversionPattern<ToExtentTensorOp> {
public:
  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (!adaptor.getInput().getType().isa<RankedTensorType>())
      return rewriter.notifyMatchFailure(op, "input needs to be a tensor");

    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
                                                adaptor.getInput());
    return success();
  }
};
} // namespace

namespace {
/// Import the Shape Ops to Std Patterns.
#include "ShapeToStandard.cpp.inc"
} // namespace

namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
    : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {

  void runOnOperation() override;
};
} // namespace

void ConvertShapeToStandardPass::runOnOperation() {
  // Setup target legality.
  MLIRContext &ctx = getContext();
  ConversionTarget target(ctx);
  target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect,
                         SCFDialect, tensor::TensorDialect>();
  target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();

  // Setup conversion patterns.
  RewritePatternSet patterns(&ctx);
  populateShapeToStandardConversionPatterns(patterns);

  // Apply conversion.
  auto module = getOperation();
  if (failed(applyPartialConversion(module, target, std::move(patterns))))
    signalPassFailure();
}

void mlir::populateShapeToStandardConversionPatterns(
    RewritePatternSet &patterns) {
  // clang-format off
  populateWithGenerated(patterns);
  patterns.add<
      AnyOpConversion,
      BinaryOpConversion<AddOp, arith::AddIOp>,
      BinaryOpConversion<MulOp, arith::MulIOp>,
      BroadcastOpConverter,
      ConstShapeOpConverter,
      ConstSizeOpConversion,
      IsBroadcastableOpConverter,
      GetExtentOpConverter,
      RankOpConverter,
      ReduceOpConverter,
      ShapeEqOpConverter,
      ShapeOfOpConversion,
      SplitAtOpConversion,
      ToExtentTensorOpConversion>(patterns.getContext());
  // clang-format on
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertShapeToStandardPass() {
  return std::make_unique<ConvertShapeToStandardPass>();
}
