171c10803SAlexander Belyaev //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
23713314bSFrederik Gossen //
33713314bSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43713314bSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information.
53713314bSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63713314bSFrederik Gossen //
73713314bSFrederik Gossen //===----------------------------------------------------------------------===//
83713314bSFrederik Gossen
93713314bSFrederik Gossen #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
103713314bSFrederik Gossen
113713314bSFrederik Gossen #include "../PassDetail.h"
12a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1336550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
14*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
153713314bSFrederik Gossen #include "mlir/Dialect/Shape/IR/Shape.h"
16444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
17a70f2eb3SFrederik Gossen #include "mlir/IR/BlockAndValueMapping.h"
18f30f347dSTres Popp #include "mlir/IR/ImplicitLocOpBuilder.h"
193713314bSFrederik Gossen #include "mlir/Transforms/DialectConversion.h"
20f30f347dSTres Popp #include "llvm/ADT/STLExtras.h"
213713314bSFrederik Gossen
2224edbdf9SFrederik Gossen using namespace mlir;
2380be54c0SAlexander Belyaev using namespace mlir::shape;
24a70f2eb3SFrederik Gossen using namespace mlir::scf;
2524edbdf9SFrederik Gossen
263713314bSFrederik Gossen /// Conversion patterns.
274baf18dbSFrederik Gossen namespace {
289df6afbbSFrederik Gossen class AnyOpConversion : public OpConversionPattern<AnyOp> {
299df6afbbSFrederik Gossen public:
309df6afbbSFrederik Gossen using OpConversionPattern<AnyOp>::OpConversionPattern;
319df6afbbSFrederik Gossen
329df6afbbSFrederik Gossen LogicalResult
33b54c724bSRiver Riddle matchAndRewrite(AnyOp op, OpAdaptor adaptor,
344baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
354baf18dbSFrederik Gossen };
364baf18dbSFrederik Gossen } // namespace
374baf18dbSFrederik Gossen
384baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const39b54c724bSRiver Riddle AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
404baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
419df6afbbSFrederik Gossen // Replace `any` with its first operand.
429df6afbbSFrederik Gossen // Any operand would be a valid substitution.
43cfb72fd3SJacques Pienaar rewriter.replaceOp(op, {adaptor.getInputs().front()});
449df6afbbSFrederik Gossen return success();
459df6afbbSFrederik Gossen }
469df6afbbSFrederik Gossen
474baf18dbSFrederik Gossen namespace {
4880be54c0SAlexander Belyaev template <typename SrcOpTy, typename DstOpTy>
4980be54c0SAlexander Belyaev class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
5080be54c0SAlexander Belyaev public:
5180be54c0SAlexander Belyaev using OpConversionPattern<SrcOpTy>::OpConversionPattern;
5280be54c0SAlexander Belyaev
5380be54c0SAlexander Belyaev LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const54b54c724bSRiver Riddle matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
5580be54c0SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
566673c6cdSFrederik Gossen // For now, only error-free types are supported by this lowering.
576673c6cdSFrederik Gossen if (op.getType().template isa<SizeType>())
586673c6cdSFrederik Gossen return failure();
596673c6cdSFrederik Gossen
60cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
61cfb72fd3SJacques Pienaar adaptor.getRhs());
6280be54c0SAlexander Belyaev return success();
6380be54c0SAlexander Belyaev }
6480be54c0SAlexander Belyaev };
654baf18dbSFrederik Gossen } // namespace
6680be54c0SAlexander Belyaev
674baf18dbSFrederik Gossen namespace {
68a70f2eb3SFrederik Gossen struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69a70f2eb3SFrederik Gossen using OpConversionPattern<BroadcastOp>::OpConversionPattern;
705d9f33aaSStephan Herhut
715d9f33aaSStephan Herhut LogicalResult
72b54c724bSRiver Riddle matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
734baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
744baf18dbSFrederik Gossen };
75f30f347dSTres Popp
76f30f347dSTres Popp // Get the resulting extent in a given dimension. This is computed with any
77f30f347dSTres Popp // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)78f30f347dSTres Popp Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
79f30f347dSTres Popp ValueRange rankDiffs, Value outputDimension) {
80a54f4eaeSMogball Value one = lb.create<arith::ConstantIndexOp>(1);
81f30f347dSTres Popp Value broadcastedDim = one;
82f30f347dSTres Popp for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
83f30f347dSTres Popp Value shape = std::get<0>(tup);
84f30f347dSTres Popp Value rankDiff = std::get<1>(tup);
85a54f4eaeSMogball Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
86a54f4eaeSMogball outputDimension, rankDiff);
87f30f347dSTres Popp Type indexTy = lb.getIndexType();
88f30f347dSTres Popp broadcastedDim =
89f30f347dSTres Popp lb.create<IfOp>(
90f30f347dSTres Popp TypeRange{indexTy}, outOfBounds,
91f30f347dSTres Popp [&](OpBuilder &b, Location loc) {
92f30f347dSTres Popp b.create<scf::YieldOp>(loc, broadcastedDim);
93f30f347dSTres Popp },
94f30f347dSTres Popp [&](OpBuilder &b, Location loc) {
95f30f347dSTres Popp // The broadcasting logic is:
96f30f347dSTres Popp // - if one extent (here we arbitrarily choose the
97f30f347dSTres Popp // extent from the greater-rank operand) is equal to 1,
98f30f347dSTres Popp // then take the extent from the other operand
99f30f347dSTres Popp // - otherwise, take the extent as-is.
100f30f347dSTres Popp // Note that this logic remains correct in the presence
101f30f347dSTres Popp // of dimensions of zero extent.
102a54f4eaeSMogball Value lesserRankOperandDimension = b.create<arith::SubIOp>(
103a54f4eaeSMogball loc, indexTy, outputDimension, rankDiff);
104f30f347dSTres Popp Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
105f30f347dSTres Popp loc, shape, ValueRange{lesserRankOperandDimension});
106f30f347dSTres Popp
107a54f4eaeSMogball Value dimIsOne =
108a54f4eaeSMogball b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
109f30f347dSTres Popp lesserRankOperandExtent, one);
110dec8af70SRiver Riddle Value dim = b.create<arith::SelectOp>(
111dec8af70SRiver Riddle loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
112f30f347dSTres Popp b.create<scf::YieldOp>(loc, dim);
113f30f347dSTres Popp })
114f30f347dSTres Popp .getResult(0);
115f30f347dSTres Popp }
116f30f347dSTres Popp return broadcastedDim;
117f30f347dSTres Popp }
1184baf18dbSFrederik Gossen } // namespace
1194baf18dbSFrederik Gossen
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120a70f2eb3SFrederik Gossen LogicalResult BroadcastOpConverter::matchAndRewrite(
121b54c724bSRiver Riddle BroadcastOp op, OpAdaptor adaptor,
1224baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
123a70f2eb3SFrederik Gossen // For now, this lowering is only defined on `tensor<?xindex>` operands, not
124a70f2eb3SFrederik Gossen // on shapes.
1256673c6cdSFrederik Gossen if (op.getType().isa<ShapeType>())
1266673c6cdSFrederik Gossen return failure();
127ac3e5c4dSFrederik Gossen
1286673c6cdSFrederik Gossen auto loc = op.getLoc();
129f30f347dSTres Popp ImplicitLocOpBuilder lb(loc, rewriter);
130ac3e5c4dSFrederik Gossen
131a54f4eaeSMogball Value zero = lb.create<arith::ConstantIndexOp>(0);
132f30f347dSTres Popp Type indexTy = lb.getIndexType();
133a70f2eb3SFrederik Gossen
134f30f347dSTres Popp // Save all the ranks for bounds checking. Because this is a tensor
135f30f347dSTres Popp // representing the shape extents, the rank is the extent of the only
136f30f347dSTres Popp // dimension in the tensor.
137f30f347dSTres Popp SmallVector<Value> ranks, rankDiffs;
138cfb72fd3SJacques Pienaar llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
139c0a6318dSMatthias Springer return lb.create<tensor::DimOp>(v, zero);
140f30f347dSTres Popp }));
141f30f347dSTres Popp
142f30f347dSTres Popp // Find the maximum rank
143f30f347dSTres Popp Value maxRank = ranks.front();
144f30f347dSTres Popp for (Value v : llvm::drop_begin(ranks, 1)) {
145a54f4eaeSMogball Value rankIsGreater =
146a54f4eaeSMogball lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
147dec8af70SRiver Riddle maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
148f30f347dSTres Popp }
149f30f347dSTres Popp
150f30f347dSTres Popp // Calculate the difference of ranks and the maximum rank for later offsets.
151f30f347dSTres Popp llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
152a54f4eaeSMogball return lb.create<arith::SubIOp>(indexTy, maxRank, v);
153f30f347dSTres Popp }));
154f30f347dSTres Popp
155eb56fa97SFrederik Gossen Value replacement = lb.create<tensor::GenerateOp>(
156f30f347dSTres Popp getExtentTensorType(lb.getContext()), ValueRange{maxRank},
15757211fd2SSean Silva [&](OpBuilder &b, Location loc, ValueRange args) {
158cfb72fd3SJacques Pienaar Value broadcastedDim =
159cfb72fd3SJacques Pienaar getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
160cfb72fd3SJacques Pienaar rankDiffs, args[0]);
161f30f347dSTres Popp
162f30f347dSTres Popp b.create<tensor::YieldOp>(loc, broadcastedDim);
163eb56fa97SFrederik Gossen });
164eb56fa97SFrederik Gossen if (replacement.getType() != op.getType())
165eb56fa97SFrederik Gossen replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
166eb56fa97SFrederik Gossen rewriter.replaceOp(op, replacement);
167ac3e5c4dSFrederik Gossen return success();
168ac3e5c4dSFrederik Gossen }
169ac3e5c4dSFrederik Gossen
1704baf18dbSFrederik Gossen namespace {
171dfcc0989SFrederik Gossen class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
172dfcc0989SFrederik Gossen public:
173dfcc0989SFrederik Gossen using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
174dfcc0989SFrederik Gossen
175dfcc0989SFrederik Gossen LogicalResult
176b54c724bSRiver Riddle matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
177dfcc0989SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
178dfcc0989SFrederik Gossen };
179dfcc0989SFrederik Gossen } // namespace
180dfcc0989SFrederik Gossen
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const181dfcc0989SFrederik Gossen LogicalResult ConstShapeOpConverter::matchAndRewrite(
182b54c724bSRiver Riddle ConstShapeOp op, OpAdaptor adaptor,
183dfcc0989SFrederik Gossen ConversionPatternRewriter &rewriter) const {
184dfcc0989SFrederik Gossen
185dfcc0989SFrederik Gossen // For now, this lowering supports only extent tensors, not `shape.shape`
186dfcc0989SFrederik Gossen // types.
187dfcc0989SFrederik Gossen if (op.getType().isa<ShapeType>())
188dfcc0989SFrederik Gossen return failure();
189dfcc0989SFrederik Gossen
190dfcc0989SFrederik Gossen auto loc = op.getLoc();
191dfcc0989SFrederik Gossen SmallVector<Value, 4> extentOperands;
192cfb72fd3SJacques Pienaar for (auto extent : op.getShape()) {
193dfcc0989SFrederik Gossen extentOperands.push_back(
194a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
195dfcc0989SFrederik Gossen }
196f77e9f87SAlexander Belyaev Type resultTy =
197f77e9f87SAlexander Belyaev RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
19884a6da67SSean Silva Value tensor =
199f77e9f87SAlexander Belyaev rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
200129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
201dfcc0989SFrederik Gossen return success();
202dfcc0989SFrederik Gossen }
203dfcc0989SFrederik Gossen
204dfcc0989SFrederik Gossen namespace {
205a70f2eb3SFrederik Gossen class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
2065d9f33aaSStephan Herhut public:
207a70f2eb3SFrederik Gossen using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
2085d9f33aaSStephan Herhut
2095d9f33aaSStephan Herhut LogicalResult
210b54c724bSRiver Riddle matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
211a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
212973800dcSDavid Truby };
213973800dcSDavid Truby } // namespace
21415acdd75SFrederik Gossen
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const215a70f2eb3SFrederik Gossen LogicalResult ConstSizeOpConversion::matchAndRewrite(
216b54c724bSRiver Riddle ConstSizeOp op, OpAdaptor adaptor,
217a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
218a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
219cfb72fd3SJacques Pienaar op, op.getValue().getSExtValue());
220a70f2eb3SFrederik Gossen return success();
221a70f2eb3SFrederik Gossen }
222a70f2eb3SFrederik Gossen
2235d9f33aaSStephan Herhut namespace {
224511484f2STres Popp struct IsBroadcastableOpConverter
225511484f2STres Popp : public OpConversionPattern<IsBroadcastableOp> {
226511484f2STres Popp using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227511484f2STres Popp
228511484f2STres Popp LogicalResult
229b54c724bSRiver Riddle matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
230511484f2STres Popp ConversionPatternRewriter &rewriter) const override;
231511484f2STres Popp };
232511484f2STres Popp } // namespace
233511484f2STres Popp
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const234511484f2STres Popp LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235b54c724bSRiver Riddle IsBroadcastableOp op, OpAdaptor adaptor,
236511484f2STres Popp ConversionPatternRewriter &rewriter) const {
237511484f2STres Popp // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238511484f2STres Popp // on shapes.
239cfb72fd3SJacques Pienaar if (!llvm::all_of(op.getShapes(),
2403842d4b6STres Popp [](Value v) { return !v.getType().isa<ShapeType>(); }))
241511484f2STres Popp return failure();
242511484f2STres Popp
243511484f2STres Popp auto loc = op.getLoc();
2443842d4b6STres Popp ImplicitLocOpBuilder lb(loc, rewriter);
245a54f4eaeSMogball Value zero = lb.create<arith::ConstantIndexOp>(0);
246a54f4eaeSMogball Value one = lb.create<arith::ConstantIndexOp>(1);
2473842d4b6STres Popp Type indexTy = lb.getIndexType();
248511484f2STres Popp
2493842d4b6STres Popp // Save all the ranks for bounds checking. Because this is a tensor
2503842d4b6STres Popp // representing the shape extents, the rank is the extent of the only
2513842d4b6STres Popp // dimension in the tensor.
2523842d4b6STres Popp SmallVector<Value> ranks, rankDiffs;
253cfb72fd3SJacques Pienaar llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
254c0a6318dSMatthias Springer return lb.create<tensor::DimOp>(v, zero);
2553842d4b6STres Popp }));
2563842d4b6STres Popp
2573842d4b6STres Popp // Find the maximum rank
2583842d4b6STres Popp Value maxRank = ranks.front();
2593842d4b6STres Popp for (Value v : llvm::drop_begin(ranks, 1)) {
260a54f4eaeSMogball Value rankIsGreater =
261a54f4eaeSMogball lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
262dec8af70SRiver Riddle maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
2633842d4b6STres Popp }
2643842d4b6STres Popp
2653842d4b6STres Popp // Calculate the difference of ranks and the maximum rank for later offsets.
2663842d4b6STres Popp llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267a54f4eaeSMogball return lb.create<arith::SubIOp>(indexTy, maxRank, v);
2683842d4b6STres Popp }));
2693842d4b6STres Popp
270511484f2STres Popp Type i1Ty = rewriter.getI1Type();
2713842d4b6STres Popp Value trueVal =
272a54f4eaeSMogball rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273511484f2STres Popp
2743842d4b6STres Popp auto reduceResult = lb.create<ForOp>(
2753842d4b6STres Popp loc, zero, maxRank, one, ValueRange{trueVal},
276511484f2STres Popp [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
2773842d4b6STres Popp // Find a non-1 dim, if it exists. Note that the first part of this
2783842d4b6STres Popp // could reuse the Broadcast lowering entirely, but we redo the work
2793842d4b6STres Popp // here to make optimizations easier between the two loops.
2803842d4b6STres Popp Value broadcastedDim = getBroadcastedDim(
281cfb72fd3SJacques Pienaar ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
2823842d4b6STres Popp
2833842d4b6STres Popp Value broadcastable = iterArgs[0];
284cfb72fd3SJacques Pienaar for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
2853842d4b6STres Popp Value shape, rankDiff;
2863842d4b6STres Popp std::tie(shape, rankDiff) = tup;
287a54f4eaeSMogball Value outOfBounds = b.create<arith::CmpIOp>(
288a54f4eaeSMogball loc, arith::CmpIPredicate::ult, iv, rankDiff);
2893842d4b6STres Popp broadcastable =
2903842d4b6STres Popp b.create<IfOp>(
2913842d4b6STres Popp loc, TypeRange{i1Ty}, outOfBounds,
2923842d4b6STres Popp [&](OpBuilder &b, Location loc) {
2933842d4b6STres Popp // Non existent dimensions are always broadcastable
2943842d4b6STres Popp b.create<scf::YieldOp>(loc, broadcastable);
2953842d4b6STres Popp },
2963842d4b6STres Popp [&](OpBuilder &b, Location loc) {
2973842d4b6STres Popp // Every value needs to be either 1, or the same non-1
2983842d4b6STres Popp // value to be broadcastable in this dim.
2993842d4b6STres Popp Value operandDimension =
300a54f4eaeSMogball b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
3013842d4b6STres Popp Value dimensionExtent = b.create<tensor::ExtractOp>(
3023842d4b6STres Popp loc, shape, ValueRange{operandDimension});
3033842d4b6STres Popp
304a54f4eaeSMogball Value equalOne = b.create<arith::CmpIOp>(
305a54f4eaeSMogball loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306a54f4eaeSMogball Value equalBroadcasted = b.create<arith::CmpIOp>(
307a54f4eaeSMogball loc, arith::CmpIPredicate::eq, dimensionExtent,
308a54f4eaeSMogball broadcastedDim);
309a54f4eaeSMogball Value result = b.create<arith::AndIOp>(
3103842d4b6STres Popp loc, broadcastable,
311a54f4eaeSMogball b.create<arith::OrIOp>(loc, equalOne,
312a54f4eaeSMogball equalBroadcasted));
3133842d4b6STres Popp b.create<scf::YieldOp>(loc, result);
3143842d4b6STres Popp })
3153842d4b6STres Popp .getResult(0);
3163842d4b6STres Popp }
3173842d4b6STres Popp
3183842d4b6STres Popp b.create<scf::YieldOp>(loc, broadcastable);
319511484f2STres Popp });
320511484f2STres Popp
321c0342a2dSJacques Pienaar rewriter.replaceOp(op, reduceResult.getResults().front());
322511484f2STres Popp return success();
323511484f2STres Popp }
324511484f2STres Popp
325511484f2STres Popp namespace {
3268577a090SFrederik Gossen class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
3278577a090SFrederik Gossen using OpConversionPattern<GetExtentOp>::OpConversionPattern;
3288577a090SFrederik Gossen
3298577a090SFrederik Gossen LogicalResult
330b54c724bSRiver Riddle matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
3314baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
3324baf18dbSFrederik Gossen };
3334baf18dbSFrederik Gossen } // namespace
3344baf18dbSFrederik Gossen
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3354baf18dbSFrederik Gossen LogicalResult GetExtentOpConverter::matchAndRewrite(
336b54c724bSRiver Riddle GetExtentOp op, OpAdaptor adaptor,
3374baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
3386673c6cdSFrederik Gossen // For now, only error-free types are supported by this lowering.
3396673c6cdSFrederik Gossen if (op.getType().isa<SizeType>())
3406673c6cdSFrederik Gossen return failure();
3416673c6cdSFrederik Gossen
3426673c6cdSFrederik Gossen // Derive shape extent directly from shape origin if possible. This
3436673c6cdSFrederik Gossen // circumvents the necessity to materialize the shape in memory.
344cfb72fd3SJacques Pienaar if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
345cfb72fd3SJacques Pienaar if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
346cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
347cfb72fd3SJacques Pienaar adaptor.getDim());
3488577a090SFrederik Gossen return success();
3498577a090SFrederik Gossen }
3506673c6cdSFrederik Gossen }
3518577a090SFrederik Gossen
352cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
353cfb72fd3SJacques Pienaar adaptor.getShape(),
354cfb72fd3SJacques Pienaar ValueRange{adaptor.getDim()});
3558577a090SFrederik Gossen return success();
3568577a090SFrederik Gossen }
3578577a090SFrederik Gossen
3584baf18dbSFrederik Gossen namespace {
35924debf5aSFrederik Gossen class RankOpConverter : public OpConversionPattern<shape::RankOp> {
36024debf5aSFrederik Gossen public:
36124debf5aSFrederik Gossen using OpConversionPattern<shape::RankOp>::OpConversionPattern;
36224debf5aSFrederik Gossen
36324debf5aSFrederik Gossen LogicalResult
364b54c724bSRiver Riddle matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3654baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
3664baf18dbSFrederik Gossen };
3674baf18dbSFrederik Gossen } // namespace
3684baf18dbSFrederik Gossen
3694baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const370b54c724bSRiver Riddle RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3714baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
372a97940d4SFrederik Gossen // For now, this lowering supports only error-free types.
373a97940d4SFrederik Gossen if (op.getType().isa<SizeType>())
374a97940d4SFrederik Gossen return failure();
375a97940d4SFrederik Gossen
376cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
37724debf5aSFrederik Gossen return success();
37824debf5aSFrederik Gossen }
37924debf5aSFrederik Gossen
3804baf18dbSFrederik Gossen namespace {
381a70f2eb3SFrederik Gossen /// Converts `shape.reduce` to `scf.for`.
382a70f2eb3SFrederik Gossen struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
383a70f2eb3SFrederik Gossen public:
384a70f2eb3SFrederik Gossen using OpConversionPattern::OpConversionPattern;
385a70f2eb3SFrederik Gossen
386a70f2eb3SFrederik Gossen LogicalResult
387b54c724bSRiver Riddle matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
388a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const final;
389a70f2eb3SFrederik Gossen };
390a70f2eb3SFrederik Gossen } // namespace
391a70f2eb3SFrederik Gossen
392a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393b54c724bSRiver Riddle ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
394a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
395a70f2eb3SFrederik Gossen // For now, this lowering is only defined on `tensor<?xindex>` operands.
396cfb72fd3SJacques Pienaar if (op.getShape().getType().isa<ShapeType>())
397a70f2eb3SFrederik Gossen return failure();
398a70f2eb3SFrederik Gossen
399a70f2eb3SFrederik Gossen auto loc = op.getLoc();
400a70f2eb3SFrederik Gossen
401a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
402a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
403a70f2eb3SFrederik Gossen Type indexTy = rewriter.getIndexType();
404e2310704SJulian Gross Value rank =
405cfb72fd3SJacques Pienaar rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
406a70f2eb3SFrederik Gossen
407a70f2eb3SFrederik Gossen auto loop = rewriter.create<scf::ForOp>(
408cfb72fd3SJacques Pienaar loc, zero, rank, one, op.getInitVals(),
409a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
410cfb72fd3SJacques Pienaar Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
411a70f2eb3SFrederik Gossen
412a70f2eb3SFrederik Gossen SmallVector<Value, 2> mappedValues{iv, extent};
413a70f2eb3SFrederik Gossen mappedValues.append(args.begin(), args.end());
414a70f2eb3SFrederik Gossen
415a70f2eb3SFrederik Gossen BlockAndValueMapping mapping;
416a70f2eb3SFrederik Gossen Block *reduceBody = op.getBody();
417a70f2eb3SFrederik Gossen mapping.map(reduceBody->getArguments(), mappedValues);
418a70f2eb3SFrederik Gossen for (auto &nested : reduceBody->without_terminator())
419a70f2eb3SFrederik Gossen b.clone(nested, mapping);
420a70f2eb3SFrederik Gossen
421a70f2eb3SFrederik Gossen SmallVector<Value, 2> mappedResults;
422a70f2eb3SFrederik Gossen for (auto result : reduceBody->getTerminator()->getOperands())
423a70f2eb3SFrederik Gossen mappedResults.push_back(mapping.lookup(result));
424a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, mappedResults);
425a70f2eb3SFrederik Gossen });
426a70f2eb3SFrederik Gossen
427a70f2eb3SFrederik Gossen rewriter.replaceOp(op, loop.getResults());
428a70f2eb3SFrederik Gossen return success();
429a70f2eb3SFrederik Gossen }
430a70f2eb3SFrederik Gossen
431a70f2eb3SFrederik Gossen namespace {
432a70f2eb3SFrederik Gossen /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
433a70f2eb3SFrederik Gossen /// only defined on `tensor<?xindex>` operands. The test for equality first
434a70f2eb3SFrederik Gossen /// compares their size and, if equal, checks every extent for equality.
435a70f2eb3SFrederik Gossen ///
436a70f2eb3SFrederik Gossen /// Example:
437a70f2eb3SFrederik Gossen ///
438a70f2eb3SFrederik Gossen /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
439a70f2eb3SFrederik Gossen ///
440a70f2eb3SFrederik Gossen /// becomes
441a70f2eb3SFrederik Gossen ///
442cb3aa49eSMogball /// %c0 = arith.constant 0 : index
443a70f2eb3SFrederik Gossen /// %0 = dim %arg0, %c0 : tensor<?xindex>
444a70f2eb3SFrederik Gossen /// %1 = dim %arg1, %c0 : tensor<?xindex>
445a54f4eaeSMogball /// %2 = arith.cmpi "eq", %0, %1 : index
446a70f2eb3SFrederik Gossen /// %result = scf.if %2 -> (i1) {
447a54f4eaeSMogball /// %c1 = arith.constant 1 : index
448a54f4eaeSMogball /// %true = arith.constant true
449a70f2eb3SFrederik Gossen /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
450444822d7SSean Silva /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
451444822d7SSean Silva /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
452a54f4eaeSMogball /// %7 = arith.cmpi "eq", %5, %6 : index
453a54f4eaeSMogball /// %8 = arith.andi %arg3, %7 : i1
454a70f2eb3SFrederik Gossen /// scf.yield %8 : i1
455a70f2eb3SFrederik Gossen /// }
456a70f2eb3SFrederik Gossen /// scf.yield %4 : i1
457a70f2eb3SFrederik Gossen /// } else {
458a54f4eaeSMogball /// %false = arith.constant false
459a70f2eb3SFrederik Gossen /// scf.yield %false : i1
460a70f2eb3SFrederik Gossen /// }
461a70f2eb3SFrederik Gossen ///
462a70f2eb3SFrederik Gossen struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
463a70f2eb3SFrederik Gossen using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
464a70f2eb3SFrederik Gossen
465a70f2eb3SFrederik Gossen LogicalResult
466b54c724bSRiver Riddle matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
467a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
468a70f2eb3SFrederik Gossen };
469a70f2eb3SFrederik Gossen } // namespace
470a70f2eb3SFrederik Gossen
471a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const472b54c724bSRiver Riddle ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
473a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
474cfb72fd3SJacques Pienaar if (!llvm::all_of(op.getShapes(),
47524acadefSBenjamin Kramer [](Value v) { return !v.getType().isa<ShapeType>(); }))
476a70f2eb3SFrederik Gossen return failure();
47724acadefSBenjamin Kramer
47824acadefSBenjamin Kramer Type i1Ty = rewriter.getI1Type();
479cfb72fd3SJacques Pienaar if (op.getShapes().size() <= 1) {
480a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
48124acadefSBenjamin Kramer rewriter.getBoolAttr(true));
48224acadefSBenjamin Kramer return success();
483a70f2eb3SFrederik Gossen }
484a70f2eb3SFrederik Gossen
485a70f2eb3SFrederik Gossen auto loc = op.getLoc();
486a70f2eb3SFrederik Gossen Type indexTy = rewriter.getIndexType();
487a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
488cfb72fd3SJacques Pienaar Value firstShape = adaptor.getShapes().front();
489e2310704SJulian Gross Value firstRank =
490c0a6318dSMatthias Springer rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
49124acadefSBenjamin Kramer Value result = nullptr;
49224acadefSBenjamin Kramer // Generate a linear sequence of compares, all with firstShape as lhs.
493cfb72fd3SJacques Pienaar for (Value shape : adaptor.getShapes().drop_front(1)) {
494c0a6318dSMatthias Springer Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
495a54f4eaeSMogball Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
496a54f4eaeSMogball firstRank, rank);
49724acadefSBenjamin Kramer auto same = rewriter.create<IfOp>(
49824acadefSBenjamin Kramer loc, i1Ty, eqRank,
499a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc) {
500a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(loc, 1);
501a54f4eaeSMogball Value init =
502a54f4eaeSMogball b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
503a70f2eb3SFrederik Gossen auto loop = b.create<scf::ForOp>(
50424acadefSBenjamin Kramer loc, zero, firstRank, one, ValueRange{init},
505a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
506a70f2eb3SFrederik Gossen Value conj = args[0];
507a70f2eb3SFrederik Gossen Value lhsExtent =
50824acadefSBenjamin Kramer b.create<tensor::ExtractOp>(loc, firstShape, iv);
50924acadefSBenjamin Kramer Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
510a54f4eaeSMogball Value eqExtent = b.create<arith::CmpIOp>(
511a54f4eaeSMogball loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
512a54f4eaeSMogball Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
513a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
514a70f2eb3SFrederik Gossen });
515a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, loop.getResults());
516a70f2eb3SFrederik Gossen },
517a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc) {
518a54f4eaeSMogball Value result =
519a54f4eaeSMogball b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
520a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, result);
521a70f2eb3SFrederik Gossen });
52224acadefSBenjamin Kramer result = !result ? same.getResult(0)
523a54f4eaeSMogball : rewriter.create<arith::AndIOp>(loc, result,
524a54f4eaeSMogball same.getResult(0));
52524acadefSBenjamin Kramer }
52624acadefSBenjamin Kramer rewriter.replaceOp(op, result);
527a70f2eb3SFrederik Gossen return success();
528a70f2eb3SFrederik Gossen }
529a70f2eb3SFrederik Gossen
530a70f2eb3SFrederik Gossen namespace {
531a70f2eb3SFrederik Gossen class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532a70f2eb3SFrederik Gossen public:
533a70f2eb3SFrederik Gossen using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534a70f2eb3SFrederik Gossen
535a70f2eb3SFrederik Gossen LogicalResult
536b54c724bSRiver Riddle matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
537a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
538a70f2eb3SFrederik Gossen };
539a70f2eb3SFrederik Gossen } // namespace
540a70f2eb3SFrederik Gossen
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const541a70f2eb3SFrederik Gossen LogicalResult ShapeOfOpConversion::matchAndRewrite(
542b54c724bSRiver Riddle ShapeOfOp op, OpAdaptor adaptor,
543a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
544a70f2eb3SFrederik Gossen
545a70f2eb3SFrederik Gossen // For now, only error-free types are supported by this lowering.
546a70f2eb3SFrederik Gossen if (op.getType().isa<ShapeType>())
547a70f2eb3SFrederik Gossen return failure();
548a70f2eb3SFrederik Gossen
549be7352c0SSean Silva // For ranked tensor arguments, lower to `tensor.from_elements`.
5505106a8b8SFrederik Gossen auto loc = op.getLoc();
551cfb72fd3SJacques Pienaar Value tensor = adaptor.getArg();
552a70f2eb3SFrederik Gossen Type tensorTy = tensor.getType();
553a70f2eb3SFrederik Gossen if (tensorTy.isa<RankedTensorType>()) {
554a70f2eb3SFrederik Gossen
555a70f2eb3SFrederik Gossen // Build values for individual extents.
556a70f2eb3SFrederik Gossen SmallVector<Value, 8> extentValues;
557a70f2eb3SFrederik Gossen RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
558a70f2eb3SFrederik Gossen int64_t rank = rankedTensorTy.getRank();
559a70f2eb3SFrederik Gossen for (int64_t i = 0; i < rank; i++) {
560a70f2eb3SFrederik Gossen if (rankedTensorTy.isDynamicDim(i)) {
561c0a6318dSMatthias Springer Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
562a70f2eb3SFrederik Gossen extentValues.push_back(extent);
563a70f2eb3SFrederik Gossen } else {
564a54f4eaeSMogball Value extent = rewriter.create<arith::ConstantIndexOp>(
565a54f4eaeSMogball loc, rankedTensorTy.getDimSize(i));
566a70f2eb3SFrederik Gossen extentValues.push_back(extent);
567a70f2eb3SFrederik Gossen }
568a70f2eb3SFrederik Gossen }
569a70f2eb3SFrederik Gossen
570a70f2eb3SFrederik Gossen // Materialize extent tensor.
571be7352c0SSean Silva Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
572f77e9f87SAlexander Belyaev loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
573f77e9f87SAlexander Belyaev extentValues);
574129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
575129d6e55SSean Silva staticExtentTensor);
576a70f2eb3SFrederik Gossen return success();
577a70f2eb3SFrederik Gossen }
578a70f2eb3SFrederik Gossen
579be7352c0SSean Silva // Lower to `tensor.generate` otherwise.
5805106a8b8SFrederik Gossen auto *ctx = rewriter.getContext();
58115f8f3e2SAlexander Belyaev Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
582be7352c0SSean Silva rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
5835106a8b8SFrederik Gossen op, getExtentTensorType(ctx), ValueRange{rank},
5845106a8b8SFrederik Gossen [&](OpBuilder &b, Location loc, ValueRange args) {
5855106a8b8SFrederik Gossen Value dim = args.front();
586c0a6318dSMatthias Springer Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
587be7352c0SSean Silva b.create<tensor::YieldOp>(loc, extent);
588a70f2eb3SFrederik Gossen });
589a70f2eb3SFrederik Gossen
590a70f2eb3SFrederik Gossen return success();
591a70f2eb3SFrederik Gossen }
592a70f2eb3SFrederik Gossen
593a70f2eb3SFrederik Gossen namespace {
59442c195f0SBenjamin Kramer class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
59542c195f0SBenjamin Kramer public:
59642c195f0SBenjamin Kramer using OpConversionPattern<SplitAtOp>::OpConversionPattern;
59742c195f0SBenjamin Kramer
59842c195f0SBenjamin Kramer LogicalResult
599b54c724bSRiver Riddle matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
60042c195f0SBenjamin Kramer ConversionPatternRewriter &rewriter) const override;
60142c195f0SBenjamin Kramer };
60242c195f0SBenjamin Kramer } // namespace
60342c195f0SBenjamin Kramer
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const60442c195f0SBenjamin Kramer LogicalResult SplitAtOpConversion::matchAndRewrite(
605b54c724bSRiver Riddle SplitAtOp op, OpAdaptor adaptor,
60642c195f0SBenjamin Kramer ConversionPatternRewriter &rewriter) const {
60742c195f0SBenjamin Kramer // Error conditions are not implemented, only lower if all operands and
60842c195f0SBenjamin Kramer // results are extent tensors.
609cfb72fd3SJacques Pienaar if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
61042c195f0SBenjamin Kramer [](Value v) { return v.getType().isa<ShapeType>(); }))
61142c195f0SBenjamin Kramer return failure();
61242c195f0SBenjamin Kramer
61342c195f0SBenjamin Kramer ImplicitLocOpBuilder b(op.getLoc(), rewriter);
614a54f4eaeSMogball Value zero = b.create<arith::ConstantIndexOp>(0);
615cfb72fd3SJacques Pienaar Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
61642c195f0SBenjamin Kramer
61742c195f0SBenjamin Kramer // index < 0 ? index + rank : index
618cfb72fd3SJacques Pienaar Value originalIndex = adaptor.getIndex();
619a54f4eaeSMogball Value add = b.create<arith::AddIOp>(originalIndex, rank);
62042c195f0SBenjamin Kramer Value indexIsNegative =
621a54f4eaeSMogball b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
622dec8af70SRiver Riddle Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
62342c195f0SBenjamin Kramer
624a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(1);
625060208b4SMatthias Springer Value head =
626cfb72fd3SJacques Pienaar b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
627a54f4eaeSMogball Value tailSize = b.create<arith::SubIOp>(rank, index);
628cfb72fd3SJacques Pienaar Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
629cfb72fd3SJacques Pienaar tailSize, one);
63042c195f0SBenjamin Kramer rewriter.replaceOp(op, {head, tail});
63142c195f0SBenjamin Kramer return success();
63242c195f0SBenjamin Kramer }
63342c195f0SBenjamin Kramer
63442c195f0SBenjamin Kramer namespace {
635a70f2eb3SFrederik Gossen class ToExtentTensorOpConversion
636a70f2eb3SFrederik Gossen : public OpConversionPattern<ToExtentTensorOp> {
637a70f2eb3SFrederik Gossen public:
638a70f2eb3SFrederik Gossen using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
639a70f2eb3SFrederik Gossen
640a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const641b54c724bSRiver Riddle matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
642a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override {
643cfb72fd3SJacques Pienaar if (!adaptor.getInput().getType().isa<RankedTensorType>())
644a70f2eb3SFrederik Gossen return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
645a70f2eb3SFrederik Gossen
646129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
647cfb72fd3SJacques Pienaar adaptor.getInput());
648a70f2eb3SFrederik Gossen return success();
649a70f2eb3SFrederik Gossen }
650a70f2eb3SFrederik Gossen };
651a70f2eb3SFrederik Gossen } // namespace
652a70f2eb3SFrederik Gossen
653a70f2eb3SFrederik Gossen namespace {
654d05d4219STres Popp /// Import the Shape Ops to Std Patterns.
655d05d4219STres Popp #include "ShapeToStandard.cpp.inc"
656d05d4219STres Popp } // namespace
657d05d4219STres Popp
658d05d4219STres Popp namespace {
6593713314bSFrederik Gossen /// Conversion pass.
6603713314bSFrederik Gossen class ConvertShapeToStandardPass
6613713314bSFrederik Gossen : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
662eaf49130SFrederik Gossen
6634baf18dbSFrederik Gossen void runOnOperation() override;
6644baf18dbSFrederik Gossen };
6654baf18dbSFrederik Gossen } // namespace
6664baf18dbSFrederik Gossen
runOnOperation()6674baf18dbSFrederik Gossen void ConvertShapeToStandardPass::runOnOperation() {
6683713314bSFrederik Gossen // Setup target legality.
669b6b9d3eaSFrederik Gossen MLIRContext &ctx = getContext();
6703713314bSFrederik Gossen ConversionTarget target(ctx);
6711f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect, SCFDialect,
6721f971e23SRiver Riddle tensor::TensorDialect>();
67358ceae95SRiver Riddle target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
6743713314bSFrederik Gossen
6753713314bSFrederik Gossen // Setup conversion patterns.
676dc4e913bSChris Lattner RewritePatternSet patterns(&ctx);
6773a506b31SChris Lattner populateShapeToStandardConversionPatterns(patterns);
6783713314bSFrederik Gossen
6793713314bSFrederik Gossen // Apply conversion.
6803713314bSFrederik Gossen auto module = getOperation();
6813fffffa8SRiver Riddle if (failed(applyPartialConversion(module, target, std::move(patterns))))
6823713314bSFrederik Gossen signalPassFailure();
6833713314bSFrederik Gossen }
6843713314bSFrederik Gossen
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)68524edbdf9SFrederik Gossen void mlir::populateShapeToStandardConversionPatterns(
686dc4e913bSChris Lattner RewritePatternSet &patterns) {
6873713314bSFrederik Gossen // clang-format off
6881d909c9aSChris Lattner populateWithGenerated(patterns);
689dc4e913bSChris Lattner patterns.add<
6909df6afbbSFrederik Gossen AnyOpConversion,
691a54f4eaeSMogball BinaryOpConversion<AddOp, arith::AddIOp>,
692a54f4eaeSMogball BinaryOpConversion<MulOp, arith::MulIOp>,
693a70f2eb3SFrederik Gossen BroadcastOpConverter,
694a70f2eb3SFrederik Gossen ConstShapeOpConverter,
6955d9f33aaSStephan Herhut ConstSizeOpConversion,
696511484f2STres Popp IsBroadcastableOpConverter,
6978577a090SFrederik Gossen GetExtentOpConverter,
69824debf5aSFrederik Gossen RankOpConverter,
699a70f2eb3SFrederik Gossen ReduceOpConverter,
700a70f2eb3SFrederik Gossen ShapeEqOpConverter,
7015d9f33aaSStephan Herhut ShapeOfOpConversion,
70242c195f0SBenjamin Kramer SplitAtOpConversion,
7033a506b31SChris Lattner ToExtentTensorOpConversion>(patterns.getContext());
7043713314bSFrederik Gossen // clang-format on
7053713314bSFrederik Gossen }
7063713314bSFrederik Gossen
70724edbdf9SFrederik Gossen std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()70824edbdf9SFrederik Gossen mlir::createConvertShapeToStandardPass() {
7093713314bSFrederik Gossen return std::make_unique<ConvertShapeToStandardPass>();
7103713314bSFrederik Gossen }
711