1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/STLExtras.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 using namespace mlir::scf; 25 26 /// Conversion patterns. 27 namespace { 28 class AnyOpConversion : public OpConversionPattern<AnyOp> { 29 public: 30 using OpConversionPattern<AnyOp>::OpConversionPattern; 31 32 LogicalResult 33 matchAndRewrite(AnyOp op, OpAdaptor adaptor, 34 ConversionPatternRewriter &rewriter) const override; 35 }; 36 } // namespace 37 38 LogicalResult 39 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, 40 ConversionPatternRewriter &rewriter) const { 41 // Replace `any` with its first operand. 42 // Any operand would be a valid substitution. 43 rewriter.replaceOp(op, {adaptor.getInputs().front()}); 44 return success(); 45 } 46 47 namespace { 48 template <typename SrcOpTy, typename DstOpTy> 49 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 50 public: 51 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 52 53 LogicalResult 54 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(), 61 adaptor.getRhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 76 // Get the resulting extent in a given dimension. This is computed with any 77 // number of extent tensors and shifted offsets into them. 78 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 79 ValueRange rankDiffs, Value outputDimension) { 80 Value one = lb.create<arith::ConstantIndexOp>(1); 81 Value broadcastedDim = one; 82 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 83 Value shape = std::get<0>(tup); 84 Value rankDiff = std::get<1>(tup); 85 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, 86 outputDimension, rankDiff); 87 Type indexTy = lb.getIndexType(); 88 broadcastedDim = 89 lb.create<IfOp>( 90 TypeRange{indexTy}, outOfBounds, 91 [&](OpBuilder &b, Location loc) { 92 b.create<scf::YieldOp>(loc, broadcastedDim); 93 }, 94 [&](OpBuilder &b, Location loc) { 95 // The broadcasting logic is: 96 // - if one extent (here we arbitrarily choose the 97 // extent from the greater-rank operand) is equal to 1, 98 // then take the extent from the other operand 99 // - otherwise, take the extent as-is. 100 // Note that this logic remains correct in the presence 101 // of dimensions of zero extent. 102 Value lesserRankOperandDimension = b.create<arith::SubIOp>( 103 loc, indexTy, outputDimension, rankDiff); 104 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 105 loc, shape, ValueRange{lesserRankOperandDimension}); 106 107 Value dimIsOne = 108 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 109 lesserRankOperandExtent, one); 110 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 111 lesserRankOperandExtent); 112 b.create<scf::YieldOp>(loc, dim); 113 }) 114 .getResult(0); 115 } 116 return broadcastedDim; 117 } 118 } // namespace 119 120 LogicalResult BroadcastOpConverter::matchAndRewrite( 121 BroadcastOp op, OpAdaptor adaptor, 122 ConversionPatternRewriter &rewriter) const { 123 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 124 // on shapes. 125 if (op.getType().isa<ShapeType>()) 126 return failure(); 127 128 auto loc = op.getLoc(); 129 ImplicitLocOpBuilder lb(loc, rewriter); 130 131 Value zero = lb.create<arith::ConstantIndexOp>(0); 132 Type indexTy = lb.getIndexType(); 133 134 // Save all the ranks for bounds checking. Because this is a tensor 135 // representing the shape extents, the rank is the extent of the only 136 // dimension in the tensor. 137 SmallVector<Value> ranks, rankDiffs; 138 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { 139 return lb.create<tensor::DimOp>(v, zero); 140 })); 141 142 // Find the maximum rank 143 Value maxRank = ranks.front(); 144 for (Value v : llvm::drop_begin(ranks, 1)) { 145 Value rankIsGreater = 146 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 147 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 148 } 149 150 // Calculate the difference of ranks and the maximum rank for later offsets. 151 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 152 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 153 })); 154 155 Value replacement = lb.create<tensor::GenerateOp>( 156 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 157 [&](OpBuilder &b, Location loc, ValueRange args) { 158 Value broadcastedDim = 159 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), 160 rankDiffs, args[0]); 161 162 b.create<tensor::YieldOp>(loc, broadcastedDim); 163 }); 164 if (replacement.getType() != op.getType()) 165 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 166 rewriter.replaceOp(op, replacement); 167 return success(); 168 } 169 170 namespace { 171 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 172 public: 173 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 174 175 LogicalResult 176 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, 177 ConversionPatternRewriter &rewriter) const override; 178 }; 179 } // namespace 180 181 LogicalResult ConstShapeOpConverter::matchAndRewrite( 182 ConstShapeOp op, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const { 184 185 // For now, this lowering supports only extent tensors, not `shape.shape` 186 // types. 187 if (op.getType().isa<ShapeType>()) 188 return failure(); 189 190 auto loc = op.getLoc(); 191 SmallVector<Value, 4> extentOperands; 192 for (auto extent : op.getShape()) { 193 extentOperands.push_back( 194 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); 195 } 196 Type indexTy = rewriter.getIndexType(); 197 Value tensor = 198 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 199 Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy); 200 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 201 return success(); 202 } 203 204 namespace { 205 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 206 public: 207 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 208 209 LogicalResult 210 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, 211 ConversionPatternRewriter &rewriter) const override; 212 }; 213 } // namespace 214 215 LogicalResult ConstSizeOpConversion::matchAndRewrite( 216 ConstSizeOp op, OpAdaptor adaptor, 217 ConversionPatternRewriter &rewriter) const { 218 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( 219 op, op.getValue().getSExtValue()); 220 return success(); 221 } 222 223 namespace { 224 struct IsBroadcastableOpConverter 225 : public OpConversionPattern<IsBroadcastableOp> { 226 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 227 228 LogicalResult 229 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, 230 ConversionPatternRewriter &rewriter) const override; 231 }; 232 } // namespace 233 234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 235 IsBroadcastableOp op, OpAdaptor adaptor, 236 ConversionPatternRewriter &rewriter) const { 237 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 238 // on shapes. 239 if (!llvm::all_of(op.getShapes(), 240 [](Value v) { return !v.getType().isa<ShapeType>(); })) 241 return failure(); 242 243 auto loc = op.getLoc(); 244 ImplicitLocOpBuilder lb(loc, rewriter); 245 Value zero = lb.create<arith::ConstantIndexOp>(0); 246 Value one = lb.create<arith::ConstantIndexOp>(1); 247 Type indexTy = lb.getIndexType(); 248 249 // Save all the ranks for bounds checking. Because this is a tensor 250 // representing the shape extents, the rank is the extent of the only 251 // dimension in the tensor. 252 SmallVector<Value> ranks, rankDiffs; 253 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { 254 return lb.create<tensor::DimOp>(v, zero); 255 })); 256 257 // Find the maximum rank 258 Value maxRank = ranks.front(); 259 for (Value v : llvm::drop_begin(ranks, 1)) { 260 Value rankIsGreater = 261 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 262 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 263 } 264 265 // Calculate the difference of ranks and the maximum rank for later offsets. 266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 267 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 268 })); 269 270 Type i1Ty = rewriter.getI1Type(); 271 Value trueVal = 272 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 273 274 auto reduceResult = lb.create<ForOp>( 275 loc, zero, maxRank, one, ValueRange{trueVal}, 276 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 277 // Find a non-1 dim, if it exists. Note that the first part of this 278 // could reuse the Broadcast lowering entirely, but we redo the work 279 // here to make optimizations easier between the two loops. 280 Value broadcastedDim = getBroadcastedDim( 281 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv); 282 283 Value broadcastable = iterArgs[0]; 284 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { 285 Value shape, rankDiff; 286 std::tie(shape, rankDiff) = tup; 287 Value outOfBounds = b.create<arith::CmpIOp>( 288 loc, arith::CmpIPredicate::ult, iv, rankDiff); 289 broadcastable = 290 b.create<IfOp>( 291 loc, TypeRange{i1Ty}, outOfBounds, 292 [&](OpBuilder &b, Location loc) { 293 // Non existent dimensions are always broadcastable 294 b.create<scf::YieldOp>(loc, broadcastable); 295 }, 296 [&](OpBuilder &b, Location loc) { 297 // Every value needs to be either 1, or the same non-1 298 // value to be broadcastable in this dim. 299 Value operandDimension = 300 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); 301 Value dimensionExtent = b.create<tensor::ExtractOp>( 302 loc, shape, ValueRange{operandDimension}); 303 304 Value equalOne = b.create<arith::CmpIOp>( 305 loc, arith::CmpIPredicate::eq, dimensionExtent, one); 306 Value equalBroadcasted = b.create<arith::CmpIOp>( 307 loc, arith::CmpIPredicate::eq, dimensionExtent, 308 broadcastedDim); 309 Value result = b.create<arith::AndIOp>( 310 loc, broadcastable, 311 b.create<arith::OrIOp>(loc, equalOne, 312 equalBroadcasted)); 313 b.create<scf::YieldOp>(loc, result); 314 }) 315 .getResult(0); 316 } 317 318 b.create<scf::YieldOp>(loc, broadcastable); 319 }); 320 321 rewriter.replaceOp(op, reduceResult.results().front()); 322 return success(); 323 } 324 325 namespace { 326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 327 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override; 332 }; 333 } // namespace 334 335 LogicalResult GetExtentOpConverter::matchAndRewrite( 336 GetExtentOp op, OpAdaptor adaptor, 337 ConversionPatternRewriter &rewriter) const { 338 // For now, only error-free types are supported by this lowering. 339 if (op.getType().isa<SizeType>()) 340 return failure(); 341 342 // Derive shape extent directly from shape origin if possible. This 343 // circumvents the necessity to materialize the shape in memory. 344 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) { 345 if (shapeOfOp.getArg().getType().isa<ShapedType>()) { 346 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(), 347 adaptor.getDim()); 348 return success(); 349 } 350 } 351 352 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 353 adaptor.getShape(), 354 ValueRange{adaptor.getDim()}); 355 return success(); 356 } 357 358 namespace { 359 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 360 public: 361 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 362 363 LogicalResult 364 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 365 ConversionPatternRewriter &rewriter) const override; 366 }; 367 } // namespace 368 369 LogicalResult 370 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 371 ConversionPatternRewriter &rewriter) const { 372 // For now, this lowering supports only error-free types. 373 if (op.getType().isa<SizeType>()) 374 return failure(); 375 376 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0); 377 return success(); 378 } 379 380 namespace { 381 /// Converts `shape.reduce` to `scf.for`. 382 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 383 public: 384 using OpConversionPattern::OpConversionPattern; 385 386 LogicalResult 387 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 388 ConversionPatternRewriter &rewriter) const final; 389 }; 390 } // namespace 391 392 LogicalResult 393 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const { 395 // For now, this lowering is only defined on `tensor<?xindex>` operands. 396 if (op.getShape().getType().isa<ShapeType>()) 397 return failure(); 398 399 auto loc = op.getLoc(); 400 401 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 402 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 403 Type indexTy = rewriter.getIndexType(); 404 Value rank = 405 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero); 406 407 auto loop = rewriter.create<scf::ForOp>( 408 loc, zero, rank, one, op.getInitVals(), 409 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 410 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv); 411 412 SmallVector<Value, 2> mappedValues{iv, extent}; 413 mappedValues.append(args.begin(), args.end()); 414 415 BlockAndValueMapping mapping; 416 Block *reduceBody = op.getBody(); 417 mapping.map(reduceBody->getArguments(), mappedValues); 418 for (auto &nested : reduceBody->without_terminator()) 419 b.clone(nested, mapping); 420 421 SmallVector<Value, 2> mappedResults; 422 for (auto result : reduceBody->getTerminator()->getOperands()) 423 mappedResults.push_back(mapping.lookup(result)); 424 b.create<scf::YieldOp>(loc, mappedResults); 425 }); 426 427 rewriter.replaceOp(op, loop.getResults()); 428 return success(); 429 } 430 431 namespace { 432 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 433 /// only defined on `tensor<?xindex>` operands. The test for equality first 434 /// compares their size and, if equal, checks every extent for equality. 435 /// 436 /// Example: 437 /// 438 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 439 /// 440 /// becomes 441 /// 442 /// %c0 = arith.constant 0 : index 443 /// %0 = dim %arg0, %c0 : tensor<?xindex> 444 /// %1 = dim %arg1, %c0 : tensor<?xindex> 445 /// %2 = arith.cmpi "eq", %0, %1 : index 446 /// %result = scf.if %2 -> (i1) { 447 /// %c1 = arith.constant 1 : index 448 /// %true = arith.constant true 449 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 450 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 451 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 452 /// %7 = arith.cmpi "eq", %5, %6 : index 453 /// %8 = arith.andi %arg3, %7 : i1 454 /// scf.yield %8 : i1 455 /// } 456 /// scf.yield %4 : i1 457 /// } else { 458 /// %false = arith.constant false 459 /// scf.yield %false : i1 460 /// } 461 /// 462 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 463 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 464 465 LogicalResult 466 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 467 ConversionPatternRewriter &rewriter) const override; 468 }; 469 } // namespace 470 471 LogicalResult 472 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 473 ConversionPatternRewriter &rewriter) const { 474 if (!llvm::all_of(op.getShapes(), 475 [](Value v) { return !v.getType().isa<ShapeType>(); })) 476 return failure(); 477 478 Type i1Ty = rewriter.getI1Type(); 479 if (op.getShapes().size() <= 1) { 480 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty, 481 rewriter.getBoolAttr(true)); 482 return success(); 483 } 484 485 auto loc = op.getLoc(); 486 Type indexTy = rewriter.getIndexType(); 487 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 488 Value firstShape = adaptor.getShapes().front(); 489 Value firstRank = 490 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); 491 Value result = nullptr; 492 // Generate a linear sequence of compares, all with firstShape as lhs. 493 for (Value shape : adaptor.getShapes().drop_front(1)) { 494 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); 495 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 496 firstRank, rank); 497 auto same = rewriter.create<IfOp>( 498 loc, i1Ty, eqRank, 499 [&](OpBuilder &b, Location loc) { 500 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 501 Value init = 502 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 503 auto loop = b.create<scf::ForOp>( 504 loc, zero, firstRank, one, ValueRange{init}, 505 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 506 Value conj = args[0]; 507 Value lhsExtent = 508 b.create<tensor::ExtractOp>(loc, firstShape, iv); 509 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 510 Value eqExtent = b.create<arith::CmpIOp>( 511 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); 512 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); 513 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 514 }); 515 b.create<scf::YieldOp>(loc, loop.getResults()); 516 }, 517 [&](OpBuilder &b, Location loc) { 518 Value result = 519 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 520 b.create<scf::YieldOp>(loc, result); 521 }); 522 result = !result ? same.getResult(0) 523 : rewriter.create<arith::AndIOp>(loc, result, 524 same.getResult(0)); 525 } 526 rewriter.replaceOp(op, result); 527 return success(); 528 } 529 530 namespace { 531 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 532 public: 533 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 534 535 LogicalResult 536 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, 537 ConversionPatternRewriter &rewriter) const override; 538 }; 539 } // namespace 540 541 LogicalResult ShapeOfOpConversion::matchAndRewrite( 542 ShapeOfOp op, OpAdaptor adaptor, 543 ConversionPatternRewriter &rewriter) const { 544 545 // For now, only error-free types are supported by this lowering. 546 if (op.getType().isa<ShapeType>()) 547 return failure(); 548 549 // For ranked tensor arguments, lower to `tensor.from_elements`. 550 auto loc = op.getLoc(); 551 Value tensor = adaptor.getArg(); 552 Type tensorTy = tensor.getType(); 553 if (tensorTy.isa<RankedTensorType>()) { 554 555 // Build values for individual extents. 556 SmallVector<Value, 8> extentValues; 557 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 558 int64_t rank = rankedTensorTy.getRank(); 559 for (int64_t i = 0; i < rank; i++) { 560 if (rankedTensorTy.isDynamicDim(i)) { 561 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); 562 extentValues.push_back(extent); 563 } else { 564 Value extent = rewriter.create<arith::ConstantIndexOp>( 565 loc, rankedTensorTy.getDimSize(i)); 566 extentValues.push_back(extent); 567 } 568 } 569 570 // Materialize extent tensor. 571 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 572 loc, rewriter.getIndexType(), extentValues); 573 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 574 staticExtentTensor); 575 return success(); 576 } 577 578 // Lower to `tensor.generate` otherwise. 579 auto *ctx = rewriter.getContext(); 580 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 581 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 582 op, getExtentTensorType(ctx), ValueRange{rank}, 583 [&](OpBuilder &b, Location loc, ValueRange args) { 584 Value dim = args.front(); 585 Value extent = b.create<tensor::DimOp>(loc, tensor, dim); 586 b.create<tensor::YieldOp>(loc, extent); 587 }); 588 589 return success(); 590 } 591 592 namespace { 593 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 594 public: 595 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 596 597 LogicalResult 598 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, 599 ConversionPatternRewriter &rewriter) const override; 600 }; 601 } // namespace 602 603 LogicalResult SplitAtOpConversion::matchAndRewrite( 604 SplitAtOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const { 606 // Error conditions are not implemented, only lower if all operands and 607 // results are extent tensors. 608 if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, 609 [](Value v) { return v.getType().isa<ShapeType>(); })) 610 return failure(); 611 612 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 613 Value zero = b.create<arith::ConstantIndexOp>(0); 614 Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero); 615 616 // index < 0 ? index + rank : index 617 Value originalIndex = adaptor.getIndex(); 618 Value add = b.create<arith::AddIOp>(originalIndex, rank); 619 Value indexIsNegative = 620 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); 621 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 622 623 Value one = b.create<arith::ConstantIndexOp>(1); 624 Value head = 625 b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one); 626 Value tailSize = b.create<arith::SubIOp>(rank, index); 627 Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index, 628 tailSize, one); 629 rewriter.replaceOp(op, {head, tail}); 630 return success(); 631 } 632 633 namespace { 634 class ToExtentTensorOpConversion 635 : public OpConversionPattern<ToExtentTensorOp> { 636 public: 637 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 638 639 LogicalResult 640 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, 641 ConversionPatternRewriter &rewriter) const override { 642 if (!adaptor.getInput().getType().isa<RankedTensorType>()) 643 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 644 645 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 646 adaptor.getInput()); 647 return success(); 648 } 649 }; 650 } // namespace 651 652 namespace { 653 /// Import the Shape Ops to Std Patterns. 654 #include "ShapeToStandard.cpp.inc" 655 } // namespace 656 657 namespace { 658 /// Conversion pass. 659 class ConvertShapeToStandardPass 660 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 661 662 void runOnOperation() override; 663 }; 664 } // namespace 665 666 void ConvertShapeToStandardPass::runOnOperation() { 667 // Setup target legality. 668 MLIRContext &ctx = getContext(); 669 ConversionTarget target(ctx); 670 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect, 671 SCFDialect, tensor::TensorDialect>(); 672 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 673 674 // Setup conversion patterns. 675 RewritePatternSet patterns(&ctx); 676 populateShapeToStandardConversionPatterns(patterns); 677 678 // Apply conversion. 679 auto module = getOperation(); 680 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 681 signalPassFailure(); 682 } 683 684 void mlir::populateShapeToStandardConversionPatterns( 685 RewritePatternSet &patterns) { 686 // clang-format off 687 populateWithGenerated(patterns); 688 patterns.add< 689 AnyOpConversion, 690 BinaryOpConversion<AddOp, arith::AddIOp>, 691 BinaryOpConversion<MulOp, arith::MulIOp>, 692 BroadcastOpConverter, 693 ConstShapeOpConverter, 694 ConstSizeOpConversion, 695 IsBroadcastableOpConverter, 696 GetExtentOpConverter, 697 RankOpConverter, 698 ReduceOpConverter, 699 ShapeEqOpConverter, 700 ShapeOfOpConversion, 701 SplitAtOpConversion, 702 ToExtentTensorOpConversion>(patterns.getContext()); 703 // clang-format on 704 } 705 706 std::unique_ptr<OperationPass<ModuleOp>> 707 mlir::createConvertShapeToStandardPass() { 708 return std::make_unique<ConvertShapeToStandardPass>(); 709 } 710