1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Shape/IR/Shape.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21
22 using namespace mlir;
23 using namespace mlir::shape;
24 using namespace mlir::scf;
25
26 /// Conversion patterns.
27 namespace {
28 class AnyOpConversion : public OpConversionPattern<AnyOp> {
29 public:
30 using OpConversionPattern<AnyOp>::OpConversionPattern;
31
32 LogicalResult
33 matchAndRewrite(AnyOp op, OpAdaptor adaptor,
34 ConversionPatternRewriter &rewriter) const override;
35 };
36 } // namespace
37
38 LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const39 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
40 ConversionPatternRewriter &rewriter) const {
41 // Replace `any` with its first operand.
42 // Any operand would be a valid substitution.
43 rewriter.replaceOp(op, {adaptor.getInputs().front()});
44 return success();
45 }
46
47 namespace {
48 template <typename SrcOpTy, typename DstOpTy>
49 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
50 public:
51 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
52
53 LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const54 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
55 ConversionPatternRewriter &rewriter) const override {
56 // For now, only error-free types are supported by this lowering.
57 if (op.getType().template isa<SizeType>())
58 return failure();
59
60 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
61 adaptor.getRhs());
62 return success();
63 }
64 };
65 } // namespace
66
67 namespace {
68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
70
71 LogicalResult
72 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
73 ConversionPatternRewriter &rewriter) const override;
74 };
75
76 // Get the resulting extent in a given dimension. This is computed with any
77 // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)78 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
79 ValueRange rankDiffs, Value outputDimension) {
80 Value one = lb.create<arith::ConstantIndexOp>(1);
81 Value broadcastedDim = one;
82 for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
83 Value shape = std::get<0>(tup);
84 Value rankDiff = std::get<1>(tup);
85 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
86 outputDimension, rankDiff);
87 Type indexTy = lb.getIndexType();
88 broadcastedDim =
89 lb.create<IfOp>(
90 TypeRange{indexTy}, outOfBounds,
91 [&](OpBuilder &b, Location loc) {
92 b.create<scf::YieldOp>(loc, broadcastedDim);
93 },
94 [&](OpBuilder &b, Location loc) {
95 // The broadcasting logic is:
96 // - if one extent (here we arbitrarily choose the
97 // extent from the greater-rank operand) is equal to 1,
98 // then take the extent from the other operand
99 // - otherwise, take the extent as-is.
100 // Note that this logic remains correct in the presence
101 // of dimensions of zero extent.
102 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
103 loc, indexTy, outputDimension, rankDiff);
104 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
105 loc, shape, ValueRange{lesserRankOperandDimension});
106
107 Value dimIsOne =
108 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
109 lesserRankOperandExtent, one);
110 Value dim = b.create<arith::SelectOp>(
111 loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
112 b.create<scf::YieldOp>(loc, dim);
113 })
114 .getResult(0);
115 }
116 return broadcastedDim;
117 }
118 } // namespace
119
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120 LogicalResult BroadcastOpConverter::matchAndRewrite(
121 BroadcastOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const {
123 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
124 // on shapes.
125 if (op.getType().isa<ShapeType>())
126 return failure();
127
128 auto loc = op.getLoc();
129 ImplicitLocOpBuilder lb(loc, rewriter);
130
131 Value zero = lb.create<arith::ConstantIndexOp>(0);
132 Type indexTy = lb.getIndexType();
133
134 // Save all the ranks for bounds checking. Because this is a tensor
135 // representing the shape extents, the rank is the extent of the only
136 // dimension in the tensor.
137 SmallVector<Value> ranks, rankDiffs;
138 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
139 return lb.create<tensor::DimOp>(v, zero);
140 }));
141
142 // Find the maximum rank
143 Value maxRank = ranks.front();
144 for (Value v : llvm::drop_begin(ranks, 1)) {
145 Value rankIsGreater =
146 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
147 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
148 }
149
150 // Calculate the difference of ranks and the maximum rank for later offsets.
151 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
152 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
153 }));
154
155 Value replacement = lb.create<tensor::GenerateOp>(
156 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
157 [&](OpBuilder &b, Location loc, ValueRange args) {
158 Value broadcastedDim =
159 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
160 rankDiffs, args[0]);
161
162 b.create<tensor::YieldOp>(loc, broadcastedDim);
163 });
164 if (replacement.getType() != op.getType())
165 replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
166 rewriter.replaceOp(op, replacement);
167 return success();
168 }
169
170 namespace {
171 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
172 public:
173 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override;
178 };
179 } // namespace
180
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const181 LogicalResult ConstShapeOpConverter::matchAndRewrite(
182 ConstShapeOp op, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const {
184
185 // For now, this lowering supports only extent tensors, not `shape.shape`
186 // types.
187 if (op.getType().isa<ShapeType>())
188 return failure();
189
190 auto loc = op.getLoc();
191 SmallVector<Value, 4> extentOperands;
192 for (auto extent : op.getShape()) {
193 extentOperands.push_back(
194 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
195 }
196 Type resultTy =
197 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
198 Value tensor =
199 rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
200 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
201 return success();
202 }
203
204 namespace {
205 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
206 public:
207 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
208
209 LogicalResult
210 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
211 ConversionPatternRewriter &rewriter) const override;
212 };
213 } // namespace
214
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const215 LogicalResult ConstSizeOpConversion::matchAndRewrite(
216 ConstSizeOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter) const {
218 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
219 op, op.getValue().getSExtValue());
220 return success();
221 }
222
223 namespace {
224 struct IsBroadcastableOpConverter
225 : public OpConversionPattern<IsBroadcastableOp> {
226 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227
228 LogicalResult
229 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override;
231 };
232 } // namespace
233
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
235 IsBroadcastableOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const {
237 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238 // on shapes.
239 if (!llvm::all_of(op.getShapes(),
240 [](Value v) { return !v.getType().isa<ShapeType>(); }))
241 return failure();
242
243 auto loc = op.getLoc();
244 ImplicitLocOpBuilder lb(loc, rewriter);
245 Value zero = lb.create<arith::ConstantIndexOp>(0);
246 Value one = lb.create<arith::ConstantIndexOp>(1);
247 Type indexTy = lb.getIndexType();
248
249 // Save all the ranks for bounds checking. Because this is a tensor
250 // representing the shape extents, the rank is the extent of the only
251 // dimension in the tensor.
252 SmallVector<Value> ranks, rankDiffs;
253 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
254 return lb.create<tensor::DimOp>(v, zero);
255 }));
256
257 // Find the maximum rank
258 Value maxRank = ranks.front();
259 for (Value v : llvm::drop_begin(ranks, 1)) {
260 Value rankIsGreater =
261 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
262 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
263 }
264
265 // Calculate the difference of ranks and the maximum rank for later offsets.
266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
268 }));
269
270 Type i1Ty = rewriter.getI1Type();
271 Value trueVal =
272 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
273
274 auto reduceResult = lb.create<ForOp>(
275 loc, zero, maxRank, one, ValueRange{trueVal},
276 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
277 // Find a non-1 dim, if it exists. Note that the first part of this
278 // could reuse the Broadcast lowering entirely, but we redo the work
279 // here to make optimizations easier between the two loops.
280 Value broadcastedDim = getBroadcastedDim(
281 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
282
283 Value broadcastable = iterArgs[0];
284 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285 Value shape, rankDiff;
286 std::tie(shape, rankDiff) = tup;
287 Value outOfBounds = b.create<arith::CmpIOp>(
288 loc, arith::CmpIPredicate::ult, iv, rankDiff);
289 broadcastable =
290 b.create<IfOp>(
291 loc, TypeRange{i1Ty}, outOfBounds,
292 [&](OpBuilder &b, Location loc) {
293 // Non existent dimensions are always broadcastable
294 b.create<scf::YieldOp>(loc, broadcastable);
295 },
296 [&](OpBuilder &b, Location loc) {
297 // Every value needs to be either 1, or the same non-1
298 // value to be broadcastable in this dim.
299 Value operandDimension =
300 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
301 Value dimensionExtent = b.create<tensor::ExtractOp>(
302 loc, shape, ValueRange{operandDimension});
303
304 Value equalOne = b.create<arith::CmpIOp>(
305 loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306 Value equalBroadcasted = b.create<arith::CmpIOp>(
307 loc, arith::CmpIPredicate::eq, dimensionExtent,
308 broadcastedDim);
309 Value result = b.create<arith::AndIOp>(
310 loc, broadcastable,
311 b.create<arith::OrIOp>(loc, equalOne,
312 equalBroadcasted));
313 b.create<scf::YieldOp>(loc, result);
314 })
315 .getResult(0);
316 }
317
318 b.create<scf::YieldOp>(loc, broadcastable);
319 });
320
321 rewriter.replaceOp(op, reduceResult.getResults().front());
322 return success();
323 }
324
325 namespace {
326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
327 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
328
329 LogicalResult
330 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter) const override;
332 };
333 } // namespace
334
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const335 LogicalResult GetExtentOpConverter::matchAndRewrite(
336 GetExtentOp op, OpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter) const {
338 // For now, only error-free types are supported by this lowering.
339 if (op.getType().isa<SizeType>())
340 return failure();
341
342 // Derive shape extent directly from shape origin if possible. This
343 // circumvents the necessity to materialize the shape in memory.
344 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
345 if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
346 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
347 adaptor.getDim());
348 return success();
349 }
350 }
351
352 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
353 adaptor.getShape(),
354 ValueRange{adaptor.getDim()});
355 return success();
356 }
357
358 namespace {
359 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
360 public:
361 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
362
363 LogicalResult
364 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
365 ConversionPatternRewriter &rewriter) const override;
366 };
367 } // namespace
368
369 LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const370 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
371 ConversionPatternRewriter &rewriter) const {
372 // For now, this lowering supports only error-free types.
373 if (op.getType().isa<SizeType>())
374 return failure();
375
376 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
377 return success();
378 }
379
380 namespace {
381 /// Converts `shape.reduce` to `scf.for`.
382 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
383 public:
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const final;
389 };
390 } // namespace
391
392 LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const {
395 // For now, this lowering is only defined on `tensor<?xindex>` operands.
396 if (op.getShape().getType().isa<ShapeType>())
397 return failure();
398
399 auto loc = op.getLoc();
400
401 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
402 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
403 Type indexTy = rewriter.getIndexType();
404 Value rank =
405 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
406
407 auto loop = rewriter.create<scf::ForOp>(
408 loc, zero, rank, one, op.getInitVals(),
409 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
410 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
411
412 SmallVector<Value, 2> mappedValues{iv, extent};
413 mappedValues.append(args.begin(), args.end());
414
415 BlockAndValueMapping mapping;
416 Block *reduceBody = op.getBody();
417 mapping.map(reduceBody->getArguments(), mappedValues);
418 for (auto &nested : reduceBody->without_terminator())
419 b.clone(nested, mapping);
420
421 SmallVector<Value, 2> mappedResults;
422 for (auto result : reduceBody->getTerminator()->getOperands())
423 mappedResults.push_back(mapping.lookup(result));
424 b.create<scf::YieldOp>(loc, mappedResults);
425 });
426
427 rewriter.replaceOp(op, loop.getResults());
428 return success();
429 }
430
431 namespace {
432 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
433 /// only defined on `tensor<?xindex>` operands. The test for equality first
434 /// compares their size and, if equal, checks every extent for equality.
435 ///
436 /// Example:
437 ///
438 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
439 ///
440 /// becomes
441 ///
442 /// %c0 = arith.constant 0 : index
443 /// %0 = dim %arg0, %c0 : tensor<?xindex>
444 /// %1 = dim %arg1, %c0 : tensor<?xindex>
445 /// %2 = arith.cmpi "eq", %0, %1 : index
446 /// %result = scf.if %2 -> (i1) {
447 /// %c1 = arith.constant 1 : index
448 /// %true = arith.constant true
449 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
450 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
451 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
452 /// %7 = arith.cmpi "eq", %5, %6 : index
453 /// %8 = arith.andi %arg3, %7 : i1
454 /// scf.yield %8 : i1
455 /// }
456 /// scf.yield %4 : i1
457 /// } else {
458 /// %false = arith.constant false
459 /// scf.yield %false : i1
460 /// }
461 ///
462 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
463 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
464
465 LogicalResult
466 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter) const override;
468 };
469 } // namespace
470
471 LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const472 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter) const {
474 if (!llvm::all_of(op.getShapes(),
475 [](Value v) { return !v.getType().isa<ShapeType>(); }))
476 return failure();
477
478 Type i1Ty = rewriter.getI1Type();
479 if (op.getShapes().size() <= 1) {
480 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
481 rewriter.getBoolAttr(true));
482 return success();
483 }
484
485 auto loc = op.getLoc();
486 Type indexTy = rewriter.getIndexType();
487 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
488 Value firstShape = adaptor.getShapes().front();
489 Value firstRank =
490 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
491 Value result = nullptr;
492 // Generate a linear sequence of compares, all with firstShape as lhs.
493 for (Value shape : adaptor.getShapes().drop_front(1)) {
494 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
495 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
496 firstRank, rank);
497 auto same = rewriter.create<IfOp>(
498 loc, i1Ty, eqRank,
499 [&](OpBuilder &b, Location loc) {
500 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
501 Value init =
502 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
503 auto loop = b.create<scf::ForOp>(
504 loc, zero, firstRank, one, ValueRange{init},
505 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
506 Value conj = args[0];
507 Value lhsExtent =
508 b.create<tensor::ExtractOp>(loc, firstShape, iv);
509 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
510 Value eqExtent = b.create<arith::CmpIOp>(
511 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
512 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
513 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
514 });
515 b.create<scf::YieldOp>(loc, loop.getResults());
516 },
517 [&](OpBuilder &b, Location loc) {
518 Value result =
519 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
520 b.create<scf::YieldOp>(loc, result);
521 });
522 result = !result ? same.getResult(0)
523 : rewriter.create<arith::AndIOp>(loc, result,
524 same.getResult(0));
525 }
526 rewriter.replaceOp(op, result);
527 return success();
528 }
529
530 namespace {
531 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532 public:
533 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534
535 LogicalResult
536 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
537 ConversionPatternRewriter &rewriter) const override;
538 };
539 } // namespace
540
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const541 LogicalResult ShapeOfOpConversion::matchAndRewrite(
542 ShapeOfOp op, OpAdaptor adaptor,
543 ConversionPatternRewriter &rewriter) const {
544
545 // For now, only error-free types are supported by this lowering.
546 if (op.getType().isa<ShapeType>())
547 return failure();
548
549 // For ranked tensor arguments, lower to `tensor.from_elements`.
550 auto loc = op.getLoc();
551 Value tensor = adaptor.getArg();
552 Type tensorTy = tensor.getType();
553 if (tensorTy.isa<RankedTensorType>()) {
554
555 // Build values for individual extents.
556 SmallVector<Value, 8> extentValues;
557 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
558 int64_t rank = rankedTensorTy.getRank();
559 for (int64_t i = 0; i < rank; i++) {
560 if (rankedTensorTy.isDynamicDim(i)) {
561 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
562 extentValues.push_back(extent);
563 } else {
564 Value extent = rewriter.create<arith::ConstantIndexOp>(
565 loc, rankedTensorTy.getDimSize(i));
566 extentValues.push_back(extent);
567 }
568 }
569
570 // Materialize extent tensor.
571 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
572 loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
573 extentValues);
574 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
575 staticExtentTensor);
576 return success();
577 }
578
579 // Lower to `tensor.generate` otherwise.
580 auto *ctx = rewriter.getContext();
581 Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
582 rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
583 op, getExtentTensorType(ctx), ValueRange{rank},
584 [&](OpBuilder &b, Location loc, ValueRange args) {
585 Value dim = args.front();
586 Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
587 b.create<tensor::YieldOp>(loc, extent);
588 });
589
590 return success();
591 }
592
593 namespace {
594 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
595 public:
596 using OpConversionPattern<SplitAtOp>::OpConversionPattern;
597
598 LogicalResult
599 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
600 ConversionPatternRewriter &rewriter) const override;
601 };
602 } // namespace
603
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const604 LogicalResult SplitAtOpConversion::matchAndRewrite(
605 SplitAtOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter) const {
607 // Error conditions are not implemented, only lower if all operands and
608 // results are extent tensors.
609 if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
610 [](Value v) { return v.getType().isa<ShapeType>(); }))
611 return failure();
612
613 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
614 Value zero = b.create<arith::ConstantIndexOp>(0);
615 Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
616
617 // index < 0 ? index + rank : index
618 Value originalIndex = adaptor.getIndex();
619 Value add = b.create<arith::AddIOp>(originalIndex, rank);
620 Value indexIsNegative =
621 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
622 Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
623
624 Value one = b.create<arith::ConstantIndexOp>(1);
625 Value head =
626 b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
627 Value tailSize = b.create<arith::SubIOp>(rank, index);
628 Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
629 tailSize, one);
630 rewriter.replaceOp(op, {head, tail});
631 return success();
632 }
633
634 namespace {
635 class ToExtentTensorOpConversion
636 : public OpConversionPattern<ToExtentTensorOp> {
637 public:
638 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
639
640 LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const641 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 if (!adaptor.getInput().getType().isa<RankedTensorType>())
644 return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
645
646 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
647 adaptor.getInput());
648 return success();
649 }
650 };
651 } // namespace
652
653 namespace {
654 /// Import the Shape Ops to Std Patterns.
655 #include "ShapeToStandard.cpp.inc"
656 } // namespace
657
658 namespace {
659 /// Conversion pass.
660 class ConvertShapeToStandardPass
661 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
662
663 void runOnOperation() override;
664 };
665 } // namespace
666
runOnOperation()667 void ConvertShapeToStandardPass::runOnOperation() {
668 // Setup target legality.
669 MLIRContext &ctx = getContext();
670 ConversionTarget target(ctx);
671 target.addLegalDialect<arith::ArithmeticDialect, SCFDialect,
672 tensor::TensorDialect>();
673 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
674
675 // Setup conversion patterns.
676 RewritePatternSet patterns(&ctx);
677 populateShapeToStandardConversionPatterns(patterns);
678
679 // Apply conversion.
680 auto module = getOperation();
681 if (failed(applyPartialConversion(module, target, std::move(patterns))))
682 signalPassFailure();
683 }
684
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)685 void mlir::populateShapeToStandardConversionPatterns(
686 RewritePatternSet &patterns) {
687 // clang-format off
688 populateWithGenerated(patterns);
689 patterns.add<
690 AnyOpConversion,
691 BinaryOpConversion<AddOp, arith::AddIOp>,
692 BinaryOpConversion<MulOp, arith::MulIOp>,
693 BroadcastOpConverter,
694 ConstShapeOpConverter,
695 ConstSizeOpConversion,
696 IsBroadcastableOpConverter,
697 GetExtentOpConverter,
698 RankOpConverter,
699 ReduceOpConverter,
700 ShapeEqOpConverter,
701 ShapeOfOpConversion,
702 SplitAtOpConversion,
703 ToExtentTensorOpConversion>(patterns.getContext());
704 // clang-format on
705 }
706
707 std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()708 mlir::createConvertShapeToStandardPass() {
709 return std::make_unique<ConvertShapeToStandardPass>();
710 }
711