171c10803SAlexander Belyaev //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
23713314bSFrederik Gossen //
33713314bSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43713314bSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information.
53713314bSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63713314bSFrederik Gossen //
73713314bSFrederik Gossen //===----------------------------------------------------------------------===//
83713314bSFrederik Gossen 
93713314bSFrederik Gossen #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
103713314bSFrederik Gossen 
113713314bSFrederik Gossen #include "../PassDetail.h"
12a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1336550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
14*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
153713314bSFrederik Gossen #include "mlir/Dialect/Shape/IR/Shape.h"
16444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
17a70f2eb3SFrederik Gossen #include "mlir/IR/BlockAndValueMapping.h"
18f30f347dSTres Popp #include "mlir/IR/ImplicitLocOpBuilder.h"
193713314bSFrederik Gossen #include "mlir/Transforms/DialectConversion.h"
20f30f347dSTres Popp #include "llvm/ADT/STLExtras.h"
213713314bSFrederik Gossen 
2224edbdf9SFrederik Gossen using namespace mlir;
2380be54c0SAlexander Belyaev using namespace mlir::shape;
24a70f2eb3SFrederik Gossen using namespace mlir::scf;
2524edbdf9SFrederik Gossen 
263713314bSFrederik Gossen /// Conversion patterns.
274baf18dbSFrederik Gossen namespace {
289df6afbbSFrederik Gossen class AnyOpConversion : public OpConversionPattern<AnyOp> {
299df6afbbSFrederik Gossen public:
309df6afbbSFrederik Gossen   using OpConversionPattern<AnyOp>::OpConversionPattern;
319df6afbbSFrederik Gossen 
329df6afbbSFrederik Gossen   LogicalResult
33b54c724bSRiver Riddle   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
344baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
354baf18dbSFrederik Gossen };
364baf18dbSFrederik Gossen } // namespace
374baf18dbSFrederik Gossen 
384baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const39b54c724bSRiver Riddle AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
404baf18dbSFrederik Gossen                                  ConversionPatternRewriter &rewriter) const {
419df6afbbSFrederik Gossen   // Replace `any` with its first operand.
429df6afbbSFrederik Gossen   // Any operand would be a valid substitution.
43cfb72fd3SJacques Pienaar   rewriter.replaceOp(op, {adaptor.getInputs().front()});
449df6afbbSFrederik Gossen   return success();
459df6afbbSFrederik Gossen }
469df6afbbSFrederik Gossen 
474baf18dbSFrederik Gossen namespace {
4880be54c0SAlexander Belyaev template <typename SrcOpTy, typename DstOpTy>
4980be54c0SAlexander Belyaev class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
5080be54c0SAlexander Belyaev public:
5180be54c0SAlexander Belyaev   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
5280be54c0SAlexander Belyaev 
5380be54c0SAlexander Belyaev   LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const54b54c724bSRiver Riddle   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
5580be54c0SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
566673c6cdSFrederik Gossen     // For now, only error-free types are supported by this lowering.
576673c6cdSFrederik Gossen     if (op.getType().template isa<SizeType>())
586673c6cdSFrederik Gossen       return failure();
596673c6cdSFrederik Gossen 
60cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
61cfb72fd3SJacques Pienaar                                          adaptor.getRhs());
6280be54c0SAlexander Belyaev     return success();
6380be54c0SAlexander Belyaev   }
6480be54c0SAlexander Belyaev };
654baf18dbSFrederik Gossen } // namespace
6680be54c0SAlexander Belyaev 
674baf18dbSFrederik Gossen namespace {
68a70f2eb3SFrederik Gossen struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69a70f2eb3SFrederik Gossen   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
705d9f33aaSStephan Herhut 
715d9f33aaSStephan Herhut   LogicalResult
72b54c724bSRiver Riddle   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
734baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
744baf18dbSFrederik Gossen };
75f30f347dSTres Popp 
76f30f347dSTres Popp // Get the resulting extent in a given dimension. This is computed with any
77f30f347dSTres Popp // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)78f30f347dSTres Popp Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
79f30f347dSTres Popp                         ValueRange rankDiffs, Value outputDimension) {
80a54f4eaeSMogball   Value one = lb.create<arith::ConstantIndexOp>(1);
81f30f347dSTres Popp   Value broadcastedDim = one;
82f30f347dSTres Popp   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
83f30f347dSTres Popp     Value shape = std::get<0>(tup);
84f30f347dSTres Popp     Value rankDiff = std::get<1>(tup);
85a54f4eaeSMogball     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
86a54f4eaeSMogball                                                  outputDimension, rankDiff);
87f30f347dSTres Popp     Type indexTy = lb.getIndexType();
88f30f347dSTres Popp     broadcastedDim =
89f30f347dSTres Popp         lb.create<IfOp>(
90f30f347dSTres Popp               TypeRange{indexTy}, outOfBounds,
91f30f347dSTres Popp               [&](OpBuilder &b, Location loc) {
92f30f347dSTres Popp                 b.create<scf::YieldOp>(loc, broadcastedDim);
93f30f347dSTres Popp               },
94f30f347dSTres Popp               [&](OpBuilder &b, Location loc) {
95f30f347dSTres Popp                 // The broadcasting logic is:
96f30f347dSTres Popp                 // - if one extent (here we arbitrarily choose the
97f30f347dSTres Popp                 // extent from the greater-rank operand) is equal to 1,
98f30f347dSTres Popp                 // then take the extent from the other operand
99f30f347dSTres Popp                 // - otherwise, take the extent as-is.
100f30f347dSTres Popp                 // Note that this logic remains correct in the presence
101f30f347dSTres Popp                 // of dimensions of zero extent.
102a54f4eaeSMogball                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
103a54f4eaeSMogball                     loc, indexTy, outputDimension, rankDiff);
104f30f347dSTres Popp                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
105f30f347dSTres Popp                     loc, shape, ValueRange{lesserRankOperandDimension});
106f30f347dSTres Popp 
107a54f4eaeSMogball                 Value dimIsOne =
108a54f4eaeSMogball                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
109f30f347dSTres Popp                                             lesserRankOperandExtent, one);
110dec8af70SRiver Riddle                 Value dim = b.create<arith::SelectOp>(
111dec8af70SRiver Riddle                     loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
112f30f347dSTres Popp                 b.create<scf::YieldOp>(loc, dim);
113f30f347dSTres Popp               })
114f30f347dSTres Popp             .getResult(0);
115f30f347dSTres Popp   }
116f30f347dSTres Popp   return broadcastedDim;
117f30f347dSTres Popp }
1184baf18dbSFrederik Gossen } // namespace
1194baf18dbSFrederik Gossen 
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120a70f2eb3SFrederik Gossen LogicalResult BroadcastOpConverter::matchAndRewrite(
121b54c724bSRiver Riddle     BroadcastOp op, OpAdaptor adaptor,
1224baf18dbSFrederik Gossen     ConversionPatternRewriter &rewriter) const {
123a70f2eb3SFrederik Gossen   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
124a70f2eb3SFrederik Gossen   // on shapes.
1256673c6cdSFrederik Gossen   if (op.getType().isa<ShapeType>())
1266673c6cdSFrederik Gossen     return failure();
127ac3e5c4dSFrederik Gossen 
1286673c6cdSFrederik Gossen   auto loc = op.getLoc();
129f30f347dSTres Popp   ImplicitLocOpBuilder lb(loc, rewriter);
130ac3e5c4dSFrederik Gossen 
131a54f4eaeSMogball   Value zero = lb.create<arith::ConstantIndexOp>(0);
132f30f347dSTres Popp   Type indexTy = lb.getIndexType();
133a70f2eb3SFrederik Gossen 
134f30f347dSTres Popp   // Save all the ranks for bounds checking. Because this is a tensor
135f30f347dSTres Popp   // representing the shape extents, the rank is the extent of the only
136f30f347dSTres Popp   // dimension in the tensor.
137f30f347dSTres Popp   SmallVector<Value> ranks, rankDiffs;
138cfb72fd3SJacques Pienaar   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
139c0a6318dSMatthias Springer                        return lb.create<tensor::DimOp>(v, zero);
140f30f347dSTres Popp                      }));
141f30f347dSTres Popp 
142f30f347dSTres Popp   // Find the maximum rank
143f30f347dSTres Popp   Value maxRank = ranks.front();
144f30f347dSTres Popp   for (Value v : llvm::drop_begin(ranks, 1)) {
145a54f4eaeSMogball     Value rankIsGreater =
146a54f4eaeSMogball         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
147dec8af70SRiver Riddle     maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
148f30f347dSTres Popp   }
149f30f347dSTres Popp 
150f30f347dSTres Popp   // Calculate the difference of ranks and the maximum rank for later offsets.
151f30f347dSTres Popp   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
152a54f4eaeSMogball                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
153f30f347dSTres Popp                      }));
154f30f347dSTres Popp 
155eb56fa97SFrederik Gossen   Value replacement = lb.create<tensor::GenerateOp>(
156f30f347dSTres Popp       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
15757211fd2SSean Silva       [&](OpBuilder &b, Location loc, ValueRange args) {
158cfb72fd3SJacques Pienaar         Value broadcastedDim =
159cfb72fd3SJacques Pienaar             getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
160cfb72fd3SJacques Pienaar                               rankDiffs, args[0]);
161f30f347dSTres Popp 
162f30f347dSTres Popp         b.create<tensor::YieldOp>(loc, broadcastedDim);
163eb56fa97SFrederik Gossen       });
164eb56fa97SFrederik Gossen   if (replacement.getType() != op.getType())
165eb56fa97SFrederik Gossen     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
166eb56fa97SFrederik Gossen   rewriter.replaceOp(op, replacement);
167ac3e5c4dSFrederik Gossen   return success();
168ac3e5c4dSFrederik Gossen }
169ac3e5c4dSFrederik Gossen 
1704baf18dbSFrederik Gossen namespace {
171dfcc0989SFrederik Gossen class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
172dfcc0989SFrederik Gossen public:
173dfcc0989SFrederik Gossen   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
174dfcc0989SFrederik Gossen 
175dfcc0989SFrederik Gossen   LogicalResult
176b54c724bSRiver Riddle   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
177dfcc0989SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
178dfcc0989SFrederik Gossen };
179dfcc0989SFrederik Gossen } // namespace
180dfcc0989SFrederik Gossen 
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const181dfcc0989SFrederik Gossen LogicalResult ConstShapeOpConverter::matchAndRewrite(
182b54c724bSRiver Riddle     ConstShapeOp op, OpAdaptor adaptor,
183dfcc0989SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
184dfcc0989SFrederik Gossen 
185dfcc0989SFrederik Gossen   // For now, this lowering supports only extent tensors, not `shape.shape`
186dfcc0989SFrederik Gossen   // types.
187dfcc0989SFrederik Gossen   if (op.getType().isa<ShapeType>())
188dfcc0989SFrederik Gossen     return failure();
189dfcc0989SFrederik Gossen 
190dfcc0989SFrederik Gossen   auto loc = op.getLoc();
191dfcc0989SFrederik Gossen   SmallVector<Value, 4> extentOperands;
192cfb72fd3SJacques Pienaar   for (auto extent : op.getShape()) {
193dfcc0989SFrederik Gossen     extentOperands.push_back(
194a54f4eaeSMogball         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
195dfcc0989SFrederik Gossen   }
196f77e9f87SAlexander Belyaev   Type resultTy =
197f77e9f87SAlexander Belyaev       RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
19884a6da67SSean Silva   Value tensor =
199f77e9f87SAlexander Belyaev       rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
200129d6e55SSean Silva   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
201dfcc0989SFrederik Gossen   return success();
202dfcc0989SFrederik Gossen }
203dfcc0989SFrederik Gossen 
204dfcc0989SFrederik Gossen namespace {
205a70f2eb3SFrederik Gossen class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
2065d9f33aaSStephan Herhut public:
207a70f2eb3SFrederik Gossen   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
2085d9f33aaSStephan Herhut 
2095d9f33aaSStephan Herhut   LogicalResult
210b54c724bSRiver Riddle   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
211a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
212973800dcSDavid Truby };
213973800dcSDavid Truby } // namespace
21415acdd75SFrederik Gossen 
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const215a70f2eb3SFrederik Gossen LogicalResult ConstSizeOpConversion::matchAndRewrite(
216b54c724bSRiver Riddle     ConstSizeOp op, OpAdaptor adaptor,
217a70f2eb3SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
218a54f4eaeSMogball   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
219cfb72fd3SJacques Pienaar       op, op.getValue().getSExtValue());
220a70f2eb3SFrederik Gossen   return success();
221a70f2eb3SFrederik Gossen }
222a70f2eb3SFrederik Gossen 
2235d9f33aaSStephan Herhut namespace {
224511484f2STres Popp struct IsBroadcastableOpConverter
225511484f2STres Popp     : public OpConversionPattern<IsBroadcastableOp> {
226511484f2STres Popp   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227511484f2STres Popp 
228511484f2STres Popp   LogicalResult
229b54c724bSRiver Riddle   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
230511484f2STres Popp                   ConversionPatternRewriter &rewriter) const override;
231511484f2STres Popp };
232511484f2STres Popp } // namespace
233511484f2STres Popp 
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const234511484f2STres Popp LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235b54c724bSRiver Riddle     IsBroadcastableOp op, OpAdaptor adaptor,
236511484f2STres Popp     ConversionPatternRewriter &rewriter) const {
237511484f2STres Popp   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238511484f2STres Popp   // on shapes.
239cfb72fd3SJacques Pienaar   if (!llvm::all_of(op.getShapes(),
2403842d4b6STres Popp                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
241511484f2STres Popp     return failure();
242511484f2STres Popp 
243511484f2STres Popp   auto loc = op.getLoc();
2443842d4b6STres Popp   ImplicitLocOpBuilder lb(loc, rewriter);
245a54f4eaeSMogball   Value zero = lb.create<arith::ConstantIndexOp>(0);
246a54f4eaeSMogball   Value one = lb.create<arith::ConstantIndexOp>(1);
2473842d4b6STres Popp   Type indexTy = lb.getIndexType();
248511484f2STres Popp 
2493842d4b6STres Popp   // Save all the ranks for bounds checking. Because this is a tensor
2503842d4b6STres Popp   // representing the shape extents, the rank is the extent of the only
2513842d4b6STres Popp   // dimension in the tensor.
2523842d4b6STres Popp   SmallVector<Value> ranks, rankDiffs;
253cfb72fd3SJacques Pienaar   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
254c0a6318dSMatthias Springer                        return lb.create<tensor::DimOp>(v, zero);
2553842d4b6STres Popp                      }));
2563842d4b6STres Popp 
2573842d4b6STres Popp   // Find the maximum rank
2583842d4b6STres Popp   Value maxRank = ranks.front();
2593842d4b6STres Popp   for (Value v : llvm::drop_begin(ranks, 1)) {
260a54f4eaeSMogball     Value rankIsGreater =
261a54f4eaeSMogball         lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
262dec8af70SRiver Riddle     maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
2633842d4b6STres Popp   }
2643842d4b6STres Popp 
2653842d4b6STres Popp   // Calculate the difference of ranks and the maximum rank for later offsets.
2663842d4b6STres Popp   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267a54f4eaeSMogball                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
2683842d4b6STres Popp                      }));
2693842d4b6STres Popp 
270511484f2STres Popp   Type i1Ty = rewriter.getI1Type();
2713842d4b6STres Popp   Value trueVal =
272a54f4eaeSMogball       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273511484f2STres Popp 
2743842d4b6STres Popp   auto reduceResult = lb.create<ForOp>(
2753842d4b6STres Popp       loc, zero, maxRank, one, ValueRange{trueVal},
276511484f2STres Popp       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
2773842d4b6STres Popp         // Find a non-1 dim, if it exists. Note that the first part of this
2783842d4b6STres Popp         // could reuse the Broadcast lowering entirely, but we redo the work
2793842d4b6STres Popp         // here to make optimizations easier between the two loops.
2803842d4b6STres Popp         Value broadcastedDim = getBroadcastedDim(
281cfb72fd3SJacques Pienaar             ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
2823842d4b6STres Popp 
2833842d4b6STres Popp         Value broadcastable = iterArgs[0];
284cfb72fd3SJacques Pienaar         for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
2853842d4b6STres Popp           Value shape, rankDiff;
2863842d4b6STres Popp           std::tie(shape, rankDiff) = tup;
287a54f4eaeSMogball           Value outOfBounds = b.create<arith::CmpIOp>(
288a54f4eaeSMogball               loc, arith::CmpIPredicate::ult, iv, rankDiff);
2893842d4b6STres Popp           broadcastable =
2903842d4b6STres Popp               b.create<IfOp>(
2913842d4b6STres Popp                    loc, TypeRange{i1Ty}, outOfBounds,
2923842d4b6STres Popp                    [&](OpBuilder &b, Location loc) {
2933842d4b6STres Popp                      // Non existent dimensions are always broadcastable
2943842d4b6STres Popp                      b.create<scf::YieldOp>(loc, broadcastable);
2953842d4b6STres Popp                    },
2963842d4b6STres Popp                    [&](OpBuilder &b, Location loc) {
2973842d4b6STres Popp                      // Every value needs to be either 1, or the same non-1
2983842d4b6STres Popp                      // value to be broadcastable in this dim.
2993842d4b6STres Popp                      Value operandDimension =
300a54f4eaeSMogball                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
3013842d4b6STres Popp                      Value dimensionExtent = b.create<tensor::ExtractOp>(
3023842d4b6STres Popp                          loc, shape, ValueRange{operandDimension});
3033842d4b6STres Popp 
304a54f4eaeSMogball                      Value equalOne = b.create<arith::CmpIOp>(
305a54f4eaeSMogball                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306a54f4eaeSMogball                      Value equalBroadcasted = b.create<arith::CmpIOp>(
307a54f4eaeSMogball                          loc, arith::CmpIPredicate::eq, dimensionExtent,
308a54f4eaeSMogball                          broadcastedDim);
309a54f4eaeSMogball                      Value result = b.create<arith::AndIOp>(
3103842d4b6STres Popp                          loc, broadcastable,
311a54f4eaeSMogball                          b.create<arith::OrIOp>(loc, equalOne,
312a54f4eaeSMogball                                                 equalBroadcasted));
3133842d4b6STres Popp                      b.create<scf::YieldOp>(loc, result);
3143842d4b6STres Popp                    })
3153842d4b6STres Popp                   .getResult(0);
3163842d4b6STres Popp         }
3173842d4b6STres Popp 
3183842d4b6STres Popp         b.create<scf::YieldOp>(loc, broadcastable);
319511484f2STres Popp       });
320511484f2STres Popp 
321c0342a2dSJacques Pienaar   rewriter.replaceOp(op, reduceResult.getResults().front());
322511484f2STres Popp   return success();
323511484f2STres Popp }
324511484f2STres Popp 
325511484f2STres Popp namespace {
3268577a090SFrederik Gossen class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
3278577a090SFrederik Gossen   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
3288577a090SFrederik Gossen 
3298577a090SFrederik Gossen   LogicalResult
330b54c724bSRiver Riddle   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
3314baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
3324baf18dbSFrederik Gossen };
3334baf18dbSFrederik Gossen } // namespace
3344baf18dbSFrederik Gossen 
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3354baf18dbSFrederik Gossen LogicalResult GetExtentOpConverter::matchAndRewrite(
336b54c724bSRiver Riddle     GetExtentOp op, OpAdaptor adaptor,
3374baf18dbSFrederik Gossen     ConversionPatternRewriter &rewriter) const {
3386673c6cdSFrederik Gossen   // For now, only error-free types are supported by this lowering.
3396673c6cdSFrederik Gossen   if (op.getType().isa<SizeType>())
3406673c6cdSFrederik Gossen     return failure();
3416673c6cdSFrederik Gossen 
3426673c6cdSFrederik Gossen   // Derive shape extent directly from shape origin if possible. This
3436673c6cdSFrederik Gossen   // circumvents the necessity to materialize the shape in memory.
344cfb72fd3SJacques Pienaar   if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
345cfb72fd3SJacques Pienaar     if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
346cfb72fd3SJacques Pienaar       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
347cfb72fd3SJacques Pienaar                                                  adaptor.getDim());
3488577a090SFrederik Gossen       return success();
3498577a090SFrederik Gossen     }
3506673c6cdSFrederik Gossen   }
3518577a090SFrederik Gossen 
352cfb72fd3SJacques Pienaar   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
353cfb72fd3SJacques Pienaar                                                  adaptor.getShape(),
354cfb72fd3SJacques Pienaar                                                  ValueRange{adaptor.getDim()});
3558577a090SFrederik Gossen   return success();
3568577a090SFrederik Gossen }
3578577a090SFrederik Gossen 
3584baf18dbSFrederik Gossen namespace {
35924debf5aSFrederik Gossen class RankOpConverter : public OpConversionPattern<shape::RankOp> {
36024debf5aSFrederik Gossen public:
36124debf5aSFrederik Gossen   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
36224debf5aSFrederik Gossen 
36324debf5aSFrederik Gossen   LogicalResult
364b54c724bSRiver Riddle   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3654baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
3664baf18dbSFrederik Gossen };
3674baf18dbSFrederik Gossen } // namespace
3684baf18dbSFrederik Gossen 
3694baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const370b54c724bSRiver Riddle RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3714baf18dbSFrederik Gossen                                  ConversionPatternRewriter &rewriter) const {
372a97940d4SFrederik Gossen   // For now, this lowering supports only error-free types.
373a97940d4SFrederik Gossen   if (op.getType().isa<SizeType>())
374a97940d4SFrederik Gossen     return failure();
375a97940d4SFrederik Gossen 
376cfb72fd3SJacques Pienaar   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
37724debf5aSFrederik Gossen   return success();
37824debf5aSFrederik Gossen }
37924debf5aSFrederik Gossen 
3804baf18dbSFrederik Gossen namespace {
381a70f2eb3SFrederik Gossen /// Converts `shape.reduce` to `scf.for`.
382a70f2eb3SFrederik Gossen struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
383a70f2eb3SFrederik Gossen public:
384a70f2eb3SFrederik Gossen   using OpConversionPattern::OpConversionPattern;
385a70f2eb3SFrederik Gossen 
386a70f2eb3SFrederik Gossen   LogicalResult
387b54c724bSRiver Riddle   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
388a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const final;
389a70f2eb3SFrederik Gossen };
390a70f2eb3SFrederik Gossen } // namespace
391a70f2eb3SFrederik Gossen 
392a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393b54c724bSRiver Riddle ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
394a70f2eb3SFrederik Gossen                                    ConversionPatternRewriter &rewriter) const {
395a70f2eb3SFrederik Gossen   // For now, this lowering is only defined on `tensor<?xindex>` operands.
396cfb72fd3SJacques Pienaar   if (op.getShape().getType().isa<ShapeType>())
397a70f2eb3SFrederik Gossen     return failure();
398a70f2eb3SFrederik Gossen 
399a70f2eb3SFrederik Gossen   auto loc = op.getLoc();
400a70f2eb3SFrederik Gossen 
401a54f4eaeSMogball   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
402a54f4eaeSMogball   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
403a70f2eb3SFrederik Gossen   Type indexTy = rewriter.getIndexType();
404e2310704SJulian Gross   Value rank =
405cfb72fd3SJacques Pienaar       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
406a70f2eb3SFrederik Gossen 
407a70f2eb3SFrederik Gossen   auto loop = rewriter.create<scf::ForOp>(
408cfb72fd3SJacques Pienaar       loc, zero, rank, one, op.getInitVals(),
409a70f2eb3SFrederik Gossen       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
410cfb72fd3SJacques Pienaar         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
411a70f2eb3SFrederik Gossen 
412a70f2eb3SFrederik Gossen         SmallVector<Value, 2> mappedValues{iv, extent};
413a70f2eb3SFrederik Gossen         mappedValues.append(args.begin(), args.end());
414a70f2eb3SFrederik Gossen 
415a70f2eb3SFrederik Gossen         BlockAndValueMapping mapping;
416a70f2eb3SFrederik Gossen         Block *reduceBody = op.getBody();
417a70f2eb3SFrederik Gossen         mapping.map(reduceBody->getArguments(), mappedValues);
418a70f2eb3SFrederik Gossen         for (auto &nested : reduceBody->without_terminator())
419a70f2eb3SFrederik Gossen           b.clone(nested, mapping);
420a70f2eb3SFrederik Gossen 
421a70f2eb3SFrederik Gossen         SmallVector<Value, 2> mappedResults;
422a70f2eb3SFrederik Gossen         for (auto result : reduceBody->getTerminator()->getOperands())
423a70f2eb3SFrederik Gossen           mappedResults.push_back(mapping.lookup(result));
424a70f2eb3SFrederik Gossen         b.create<scf::YieldOp>(loc, mappedResults);
425a70f2eb3SFrederik Gossen       });
426a70f2eb3SFrederik Gossen 
427a70f2eb3SFrederik Gossen   rewriter.replaceOp(op, loop.getResults());
428a70f2eb3SFrederik Gossen   return success();
429a70f2eb3SFrederik Gossen }
430a70f2eb3SFrederik Gossen 
431a70f2eb3SFrederik Gossen namespace {
432a70f2eb3SFrederik Gossen /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
433a70f2eb3SFrederik Gossen /// only defined on `tensor<?xindex>` operands. The test for equality first
434a70f2eb3SFrederik Gossen /// compares their size and, if equal, checks every extent for equality.
435a70f2eb3SFrederik Gossen ///
436a70f2eb3SFrederik Gossen /// Example:
437a70f2eb3SFrederik Gossen ///
438a70f2eb3SFrederik Gossen /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
439a70f2eb3SFrederik Gossen ///
440a70f2eb3SFrederik Gossen /// becomes
441a70f2eb3SFrederik Gossen ///
442cb3aa49eSMogball /// %c0 = arith.constant 0 : index
443a70f2eb3SFrederik Gossen /// %0 = dim %arg0, %c0 : tensor<?xindex>
444a70f2eb3SFrederik Gossen /// %1 = dim %arg1, %c0 : tensor<?xindex>
445a54f4eaeSMogball /// %2 = arith.cmpi "eq", %0, %1 : index
446a70f2eb3SFrederik Gossen /// %result = scf.if %2 -> (i1) {
447a54f4eaeSMogball ///   %c1 = arith.constant 1 : index
448a54f4eaeSMogball ///   %true = arith.constant true
449a70f2eb3SFrederik Gossen ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
450444822d7SSean Silva ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
451444822d7SSean Silva ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
452a54f4eaeSMogball ///     %7 = arith.cmpi "eq", %5, %6 : index
453a54f4eaeSMogball ///     %8 = arith.andi %arg3, %7 : i1
454a70f2eb3SFrederik Gossen ///     scf.yield %8 : i1
455a70f2eb3SFrederik Gossen ///   }
456a70f2eb3SFrederik Gossen ///   scf.yield %4 : i1
457a70f2eb3SFrederik Gossen /// } else {
458a54f4eaeSMogball ///   %false = arith.constant false
459a70f2eb3SFrederik Gossen ///   scf.yield %false : i1
460a70f2eb3SFrederik Gossen /// }
461a70f2eb3SFrederik Gossen ///
462a70f2eb3SFrederik Gossen struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
463a70f2eb3SFrederik Gossen   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
464a70f2eb3SFrederik Gossen 
465a70f2eb3SFrederik Gossen   LogicalResult
466b54c724bSRiver Riddle   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
467a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
468a70f2eb3SFrederik Gossen };
469a70f2eb3SFrederik Gossen } // namespace
470a70f2eb3SFrederik Gossen 
471a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const472b54c724bSRiver Riddle ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
473a70f2eb3SFrederik Gossen                                     ConversionPatternRewriter &rewriter) const {
474cfb72fd3SJacques Pienaar   if (!llvm::all_of(op.getShapes(),
47524acadefSBenjamin Kramer                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
476a70f2eb3SFrederik Gossen     return failure();
47724acadefSBenjamin Kramer 
47824acadefSBenjamin Kramer   Type i1Ty = rewriter.getI1Type();
479cfb72fd3SJacques Pienaar   if (op.getShapes().size() <= 1) {
480a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
48124acadefSBenjamin Kramer                                                    rewriter.getBoolAttr(true));
48224acadefSBenjamin Kramer     return success();
483a70f2eb3SFrederik Gossen   }
484a70f2eb3SFrederik Gossen 
485a70f2eb3SFrederik Gossen   auto loc = op.getLoc();
486a70f2eb3SFrederik Gossen   Type indexTy = rewriter.getIndexType();
487a54f4eaeSMogball   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
488cfb72fd3SJacques Pienaar   Value firstShape = adaptor.getShapes().front();
489e2310704SJulian Gross   Value firstRank =
490c0a6318dSMatthias Springer       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
49124acadefSBenjamin Kramer   Value result = nullptr;
49224acadefSBenjamin Kramer   // Generate a linear sequence of compares, all with firstShape as lhs.
493cfb72fd3SJacques Pienaar   for (Value shape : adaptor.getShapes().drop_front(1)) {
494c0a6318dSMatthias Springer     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
495a54f4eaeSMogball     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
496a54f4eaeSMogball                                                   firstRank, rank);
49724acadefSBenjamin Kramer     auto same = rewriter.create<IfOp>(
49824acadefSBenjamin Kramer         loc, i1Ty, eqRank,
499a70f2eb3SFrederik Gossen         [&](OpBuilder &b, Location loc) {
500a54f4eaeSMogball           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
501a54f4eaeSMogball           Value init =
502a54f4eaeSMogball               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
503a70f2eb3SFrederik Gossen           auto loop = b.create<scf::ForOp>(
50424acadefSBenjamin Kramer               loc, zero, firstRank, one, ValueRange{init},
505a70f2eb3SFrederik Gossen               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
506a70f2eb3SFrederik Gossen                 Value conj = args[0];
507a70f2eb3SFrederik Gossen                 Value lhsExtent =
50824acadefSBenjamin Kramer                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
50924acadefSBenjamin Kramer                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
510a54f4eaeSMogball                 Value eqExtent = b.create<arith::CmpIOp>(
511a54f4eaeSMogball                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
512a54f4eaeSMogball                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
513a70f2eb3SFrederik Gossen                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
514a70f2eb3SFrederik Gossen               });
515a70f2eb3SFrederik Gossen           b.create<scf::YieldOp>(loc, loop.getResults());
516a70f2eb3SFrederik Gossen         },
517a70f2eb3SFrederik Gossen         [&](OpBuilder &b, Location loc) {
518a54f4eaeSMogball           Value result =
519a54f4eaeSMogball               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
520a70f2eb3SFrederik Gossen           b.create<scf::YieldOp>(loc, result);
521a70f2eb3SFrederik Gossen         });
52224acadefSBenjamin Kramer     result = !result ? same.getResult(0)
523a54f4eaeSMogball                      : rewriter.create<arith::AndIOp>(loc, result,
524a54f4eaeSMogball                                                       same.getResult(0));
52524acadefSBenjamin Kramer   }
52624acadefSBenjamin Kramer   rewriter.replaceOp(op, result);
527a70f2eb3SFrederik Gossen   return success();
528a70f2eb3SFrederik Gossen }
529a70f2eb3SFrederik Gossen 
530a70f2eb3SFrederik Gossen namespace {
531a70f2eb3SFrederik Gossen class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532a70f2eb3SFrederik Gossen public:
533a70f2eb3SFrederik Gossen   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534a70f2eb3SFrederik Gossen 
535a70f2eb3SFrederik Gossen   LogicalResult
536b54c724bSRiver Riddle   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
537a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
538a70f2eb3SFrederik Gossen };
539a70f2eb3SFrederik Gossen } // namespace
540a70f2eb3SFrederik Gossen 
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const541a70f2eb3SFrederik Gossen LogicalResult ShapeOfOpConversion::matchAndRewrite(
542b54c724bSRiver Riddle     ShapeOfOp op, OpAdaptor adaptor,
543a70f2eb3SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
544a70f2eb3SFrederik Gossen 
545a70f2eb3SFrederik Gossen   // For now, only error-free types are supported by this lowering.
546a70f2eb3SFrederik Gossen   if (op.getType().isa<ShapeType>())
547a70f2eb3SFrederik Gossen     return failure();
548a70f2eb3SFrederik Gossen 
549be7352c0SSean Silva   // For ranked tensor arguments, lower to `tensor.from_elements`.
5505106a8b8SFrederik Gossen   auto loc = op.getLoc();
551cfb72fd3SJacques Pienaar   Value tensor = adaptor.getArg();
552a70f2eb3SFrederik Gossen   Type tensorTy = tensor.getType();
553a70f2eb3SFrederik Gossen   if (tensorTy.isa<RankedTensorType>()) {
554a70f2eb3SFrederik Gossen 
555a70f2eb3SFrederik Gossen     // Build values for individual extents.
556a70f2eb3SFrederik Gossen     SmallVector<Value, 8> extentValues;
557a70f2eb3SFrederik Gossen     RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
558a70f2eb3SFrederik Gossen     int64_t rank = rankedTensorTy.getRank();
559a70f2eb3SFrederik Gossen     for (int64_t i = 0; i < rank; i++) {
560a70f2eb3SFrederik Gossen       if (rankedTensorTy.isDynamicDim(i)) {
561c0a6318dSMatthias Springer         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
562a70f2eb3SFrederik Gossen         extentValues.push_back(extent);
563a70f2eb3SFrederik Gossen       } else {
564a54f4eaeSMogball         Value extent = rewriter.create<arith::ConstantIndexOp>(
565a54f4eaeSMogball             loc, rankedTensorTy.getDimSize(i));
566a70f2eb3SFrederik Gossen         extentValues.push_back(extent);
567a70f2eb3SFrederik Gossen       }
568a70f2eb3SFrederik Gossen     }
569a70f2eb3SFrederik Gossen 
570a70f2eb3SFrederik Gossen     // Materialize extent tensor.
571be7352c0SSean Silva     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
572f77e9f87SAlexander Belyaev         loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
573f77e9f87SAlexander Belyaev         extentValues);
574129d6e55SSean Silva     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
575129d6e55SSean Silva                                                 staticExtentTensor);
576a70f2eb3SFrederik Gossen     return success();
577a70f2eb3SFrederik Gossen   }
578a70f2eb3SFrederik Gossen 
579be7352c0SSean Silva   // Lower to `tensor.generate` otherwise.
5805106a8b8SFrederik Gossen   auto *ctx = rewriter.getContext();
58115f8f3e2SAlexander Belyaev   Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
582be7352c0SSean Silva   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
5835106a8b8SFrederik Gossen       op, getExtentTensorType(ctx), ValueRange{rank},
5845106a8b8SFrederik Gossen       [&](OpBuilder &b, Location loc, ValueRange args) {
5855106a8b8SFrederik Gossen         Value dim = args.front();
586c0a6318dSMatthias Springer         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
587be7352c0SSean Silva         b.create<tensor::YieldOp>(loc, extent);
588a70f2eb3SFrederik Gossen       });
589a70f2eb3SFrederik Gossen 
590a70f2eb3SFrederik Gossen   return success();
591a70f2eb3SFrederik Gossen }
592a70f2eb3SFrederik Gossen 
593a70f2eb3SFrederik Gossen namespace {
59442c195f0SBenjamin Kramer class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
59542c195f0SBenjamin Kramer public:
59642c195f0SBenjamin Kramer   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
59742c195f0SBenjamin Kramer 
59842c195f0SBenjamin Kramer   LogicalResult
599b54c724bSRiver Riddle   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
60042c195f0SBenjamin Kramer                   ConversionPatternRewriter &rewriter) const override;
60142c195f0SBenjamin Kramer };
60242c195f0SBenjamin Kramer } // namespace
60342c195f0SBenjamin Kramer 
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const60442c195f0SBenjamin Kramer LogicalResult SplitAtOpConversion::matchAndRewrite(
605b54c724bSRiver Riddle     SplitAtOp op, OpAdaptor adaptor,
60642c195f0SBenjamin Kramer     ConversionPatternRewriter &rewriter) const {
60742c195f0SBenjamin Kramer   // Error conditions are not implemented, only lower if all operands and
60842c195f0SBenjamin Kramer   // results are extent tensors.
609cfb72fd3SJacques Pienaar   if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
61042c195f0SBenjamin Kramer                    [](Value v) { return v.getType().isa<ShapeType>(); }))
61142c195f0SBenjamin Kramer     return failure();
61242c195f0SBenjamin Kramer 
61342c195f0SBenjamin Kramer   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
614a54f4eaeSMogball   Value zero = b.create<arith::ConstantIndexOp>(0);
615cfb72fd3SJacques Pienaar   Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
61642c195f0SBenjamin Kramer 
61742c195f0SBenjamin Kramer   // index < 0 ? index + rank : index
618cfb72fd3SJacques Pienaar   Value originalIndex = adaptor.getIndex();
619a54f4eaeSMogball   Value add = b.create<arith::AddIOp>(originalIndex, rank);
62042c195f0SBenjamin Kramer   Value indexIsNegative =
621a54f4eaeSMogball       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
622dec8af70SRiver Riddle   Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
62342c195f0SBenjamin Kramer 
624a54f4eaeSMogball   Value one = b.create<arith::ConstantIndexOp>(1);
625060208b4SMatthias Springer   Value head =
626cfb72fd3SJacques Pienaar       b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
627a54f4eaeSMogball   Value tailSize = b.create<arith::SubIOp>(rank, index);
628cfb72fd3SJacques Pienaar   Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
629cfb72fd3SJacques Pienaar                                                 tailSize, one);
63042c195f0SBenjamin Kramer   rewriter.replaceOp(op, {head, tail});
63142c195f0SBenjamin Kramer   return success();
63242c195f0SBenjamin Kramer }
63342c195f0SBenjamin Kramer 
63442c195f0SBenjamin Kramer namespace {
635a70f2eb3SFrederik Gossen class ToExtentTensorOpConversion
636a70f2eb3SFrederik Gossen     : public OpConversionPattern<ToExtentTensorOp> {
637a70f2eb3SFrederik Gossen public:
638a70f2eb3SFrederik Gossen   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
639a70f2eb3SFrederik Gossen 
640a70f2eb3SFrederik Gossen   LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const641b54c724bSRiver Riddle   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
642a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override {
643cfb72fd3SJacques Pienaar     if (!adaptor.getInput().getType().isa<RankedTensorType>())
644a70f2eb3SFrederik Gossen       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
645a70f2eb3SFrederik Gossen 
646129d6e55SSean Silva     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
647cfb72fd3SJacques Pienaar                                                 adaptor.getInput());
648a70f2eb3SFrederik Gossen     return success();
649a70f2eb3SFrederik Gossen   }
650a70f2eb3SFrederik Gossen };
651a70f2eb3SFrederik Gossen } // namespace
652a70f2eb3SFrederik Gossen 
653a70f2eb3SFrederik Gossen namespace {
654d05d4219STres Popp /// Import the Shape Ops to Std Patterns.
655d05d4219STres Popp #include "ShapeToStandard.cpp.inc"
656d05d4219STres Popp } // namespace
657d05d4219STres Popp 
658d05d4219STres Popp namespace {
6593713314bSFrederik Gossen /// Conversion pass.
6603713314bSFrederik Gossen class ConvertShapeToStandardPass
6613713314bSFrederik Gossen     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
662eaf49130SFrederik Gossen 
6634baf18dbSFrederik Gossen   void runOnOperation() override;
6644baf18dbSFrederik Gossen };
6654baf18dbSFrederik Gossen } // namespace
6664baf18dbSFrederik Gossen 
runOnOperation()6674baf18dbSFrederik Gossen void ConvertShapeToStandardPass::runOnOperation() {
6683713314bSFrederik Gossen   // Setup target legality.
669b6b9d3eaSFrederik Gossen   MLIRContext &ctx = getContext();
6703713314bSFrederik Gossen   ConversionTarget target(ctx);
6711f971e23SRiver Riddle   target.addLegalDialect<arith::ArithmeticDialect, SCFDialect,
6721f971e23SRiver Riddle                          tensor::TensorDialect>();
67358ceae95SRiver Riddle   target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
6743713314bSFrederik Gossen 
6753713314bSFrederik Gossen   // Setup conversion patterns.
676dc4e913bSChris Lattner   RewritePatternSet patterns(&ctx);
6773a506b31SChris Lattner   populateShapeToStandardConversionPatterns(patterns);
6783713314bSFrederik Gossen 
6793713314bSFrederik Gossen   // Apply conversion.
6803713314bSFrederik Gossen   auto module = getOperation();
6813fffffa8SRiver Riddle   if (failed(applyPartialConversion(module, target, std::move(patterns))))
6823713314bSFrederik Gossen     signalPassFailure();
6833713314bSFrederik Gossen }
6843713314bSFrederik Gossen 
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)68524edbdf9SFrederik Gossen void mlir::populateShapeToStandardConversionPatterns(
686dc4e913bSChris Lattner     RewritePatternSet &patterns) {
6873713314bSFrederik Gossen   // clang-format off
6881d909c9aSChris Lattner   populateWithGenerated(patterns);
689dc4e913bSChris Lattner   patterns.add<
6909df6afbbSFrederik Gossen       AnyOpConversion,
691a54f4eaeSMogball       BinaryOpConversion<AddOp, arith::AddIOp>,
692a54f4eaeSMogball       BinaryOpConversion<MulOp, arith::MulIOp>,
693a70f2eb3SFrederik Gossen       BroadcastOpConverter,
694a70f2eb3SFrederik Gossen       ConstShapeOpConverter,
6955d9f33aaSStephan Herhut       ConstSizeOpConversion,
696511484f2STres Popp       IsBroadcastableOpConverter,
6978577a090SFrederik Gossen       GetExtentOpConverter,
69824debf5aSFrederik Gossen       RankOpConverter,
699a70f2eb3SFrederik Gossen       ReduceOpConverter,
700a70f2eb3SFrederik Gossen       ShapeEqOpConverter,
7015d9f33aaSStephan Herhut       ShapeOfOpConversion,
70242c195f0SBenjamin Kramer       SplitAtOpConversion,
7033a506b31SChris Lattner       ToExtentTensorOpConversion>(patterns.getContext());
7043713314bSFrederik Gossen   // clang-format on
7053713314bSFrederik Gossen }
7063713314bSFrederik Gossen 
70724edbdf9SFrederik Gossen std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()70824edbdf9SFrederik Gossen mlir::createConvertShapeToStandardPass() {
7093713314bSFrederik Gossen   return std::make_unique<ConvertShapeToStandardPass>();
7103713314bSFrederik Gossen }
711