1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 using namespace mlir::scf; 24 25 /// Conversion patterns. 26 namespace { 27 class AnyOpConversion : public OpConversionPattern<AnyOp> { 28 public: 29 using OpConversionPattern<AnyOp>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(AnyOp op, OpAdaptor adaptor, 33 ConversionPatternRewriter &rewriter) const override; 34 }; 35 } // namespace 36 37 LogicalResult 38 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, 39 ConversionPatternRewriter &rewriter) const { 40 // Replace `any` with its first operand. 41 // Any operand would be a valid substitution. 42 rewriter.replaceOp(op, {adaptor.getInputs().front()}); 43 return success(); 44 } 45 46 namespace { 47 template <typename SrcOpTy, typename DstOpTy> 48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 49 public: 50 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 51 52 LogicalResult 53 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, 54 ConversionPatternRewriter &rewriter) const override { 55 // For now, only error-free types are supported by this lowering. 56 if (op.getType().template isa<SizeType>()) 57 return failure(); 58 59 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(), 60 adaptor.getRhs()); 61 return success(); 62 } 63 }; 64 } // namespace 65 66 namespace { 67 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 68 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override; 73 }; 74 75 // Get the resulting extent in a given dimension. This is computed with any 76 // number of extent tensors and shifted offsets into them. 77 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 78 ValueRange rankDiffs, Value outputDimension) { 79 Value one = lb.create<arith::ConstantIndexOp>(1); 80 Value broadcastedDim = one; 81 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 82 Value shape = std::get<0>(tup); 83 Value rankDiff = std::get<1>(tup); 84 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, 85 outputDimension, rankDiff); 86 Type indexTy = lb.getIndexType(); 87 broadcastedDim = 88 lb.create<IfOp>( 89 TypeRange{indexTy}, outOfBounds, 90 [&](OpBuilder &b, Location loc) { 91 b.create<scf::YieldOp>(loc, broadcastedDim); 92 }, 93 [&](OpBuilder &b, Location loc) { 94 // The broadcasting logic is: 95 // - if one extent (here we arbitrarily choose the 96 // extent from the greater-rank operand) is equal to 1, 97 // then take the extent from the other operand 98 // - otherwise, take the extent as-is. 99 // Note that this logic remains correct in the presence 100 // of dimensions of zero extent. 101 Value lesserRankOperandDimension = b.create<arith::SubIOp>( 102 loc, indexTy, outputDimension, rankDiff); 103 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 104 loc, shape, ValueRange{lesserRankOperandDimension}); 105 106 Value dimIsOne = 107 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 108 lesserRankOperandExtent, one); 109 Value dim = b.create<arith::SelectOp>( 110 loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); 111 b.create<scf::YieldOp>(loc, dim); 112 }) 113 .getResult(0); 114 } 115 return broadcastedDim; 116 } 117 } // namespace 118 119 LogicalResult BroadcastOpConverter::matchAndRewrite( 120 BroadcastOp op, OpAdaptor adaptor, 121 ConversionPatternRewriter &rewriter) const { 122 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 123 // on shapes. 124 if (op.getType().isa<ShapeType>()) 125 return failure(); 126 127 auto loc = op.getLoc(); 128 ImplicitLocOpBuilder lb(loc, rewriter); 129 130 Value zero = lb.create<arith::ConstantIndexOp>(0); 131 Type indexTy = lb.getIndexType(); 132 133 // Save all the ranks for bounds checking. Because this is a tensor 134 // representing the shape extents, the rank is the extent of the only 135 // dimension in the tensor. 136 SmallVector<Value> ranks, rankDiffs; 137 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { 138 return lb.create<tensor::DimOp>(v, zero); 139 })); 140 141 // Find the maximum rank 142 Value maxRank = ranks.front(); 143 for (Value v : llvm::drop_begin(ranks, 1)) { 144 Value rankIsGreater = 145 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 146 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank); 147 } 148 149 // Calculate the difference of ranks and the maximum rank for later offsets. 150 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 151 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 152 })); 153 154 Value replacement = lb.create<tensor::GenerateOp>( 155 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 156 [&](OpBuilder &b, Location loc, ValueRange args) { 157 Value broadcastedDim = 158 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), 159 rankDiffs, args[0]); 160 161 b.create<tensor::YieldOp>(loc, broadcastedDim); 162 }); 163 if (replacement.getType() != op.getType()) 164 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 165 rewriter.replaceOp(op, replacement); 166 return success(); 167 } 168 169 namespace { 170 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 171 public: 172 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 173 174 LogicalResult 175 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, 176 ConversionPatternRewriter &rewriter) const override; 177 }; 178 } // namespace 179 180 LogicalResult ConstShapeOpConverter::matchAndRewrite( 181 ConstShapeOp op, OpAdaptor adaptor, 182 ConversionPatternRewriter &rewriter) const { 183 184 // For now, this lowering supports only extent tensors, not `shape.shape` 185 // types. 186 if (op.getType().isa<ShapeType>()) 187 return failure(); 188 189 auto loc = op.getLoc(); 190 SmallVector<Value, 4> extentOperands; 191 for (auto extent : op.getShape()) { 192 extentOperands.push_back( 193 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); 194 } 195 Type resultTy = 196 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); 197 Value tensor = 198 rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands); 199 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 200 return success(); 201 } 202 203 namespace { 204 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 205 public: 206 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 207 208 LogicalResult 209 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, 210 ConversionPatternRewriter &rewriter) const override; 211 }; 212 } // namespace 213 214 LogicalResult ConstSizeOpConversion::matchAndRewrite( 215 ConstSizeOp op, OpAdaptor adaptor, 216 ConversionPatternRewriter &rewriter) const { 217 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( 218 op, op.getValue().getSExtValue()); 219 return success(); 220 } 221 222 namespace { 223 struct IsBroadcastableOpConverter 224 : public OpConversionPattern<IsBroadcastableOp> { 225 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 226 227 LogicalResult 228 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, 229 ConversionPatternRewriter &rewriter) const override; 230 }; 231 } // namespace 232 233 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 234 IsBroadcastableOp op, OpAdaptor adaptor, 235 ConversionPatternRewriter &rewriter) const { 236 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 237 // on shapes. 238 if (!llvm::all_of(op.getShapes(), 239 [](Value v) { return !v.getType().isa<ShapeType>(); })) 240 return failure(); 241 242 auto loc = op.getLoc(); 243 ImplicitLocOpBuilder lb(loc, rewriter); 244 Value zero = lb.create<arith::ConstantIndexOp>(0); 245 Value one = lb.create<arith::ConstantIndexOp>(1); 246 Type indexTy = lb.getIndexType(); 247 248 // Save all the ranks for bounds checking. Because this is a tensor 249 // representing the shape extents, the rank is the extent of the only 250 // dimension in the tensor. 251 SmallVector<Value> ranks, rankDiffs; 252 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { 253 return lb.create<tensor::DimOp>(v, zero); 254 })); 255 256 // Find the maximum rank 257 Value maxRank = ranks.front(); 258 for (Value v : llvm::drop_begin(ranks, 1)) { 259 Value rankIsGreater = 260 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 261 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank); 262 } 263 264 // Calculate the difference of ranks and the maximum rank for later offsets. 265 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 266 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 267 })); 268 269 Type i1Ty = rewriter.getI1Type(); 270 Value trueVal = 271 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 272 273 auto reduceResult = lb.create<ForOp>( 274 loc, zero, maxRank, one, ValueRange{trueVal}, 275 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 276 // Find a non-1 dim, if it exists. Note that the first part of this 277 // could reuse the Broadcast lowering entirely, but we redo the work 278 // here to make optimizations easier between the two loops. 279 Value broadcastedDim = getBroadcastedDim( 280 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv); 281 282 Value broadcastable = iterArgs[0]; 283 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { 284 Value shape, rankDiff; 285 std::tie(shape, rankDiff) = tup; 286 Value outOfBounds = b.create<arith::CmpIOp>( 287 loc, arith::CmpIPredicate::ult, iv, rankDiff); 288 broadcastable = 289 b.create<IfOp>( 290 loc, TypeRange{i1Ty}, outOfBounds, 291 [&](OpBuilder &b, Location loc) { 292 // Non existent dimensions are always broadcastable 293 b.create<scf::YieldOp>(loc, broadcastable); 294 }, 295 [&](OpBuilder &b, Location loc) { 296 // Every value needs to be either 1, or the same non-1 297 // value to be broadcastable in this dim. 298 Value operandDimension = 299 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); 300 Value dimensionExtent = b.create<tensor::ExtractOp>( 301 loc, shape, ValueRange{operandDimension}); 302 303 Value equalOne = b.create<arith::CmpIOp>( 304 loc, arith::CmpIPredicate::eq, dimensionExtent, one); 305 Value equalBroadcasted = b.create<arith::CmpIOp>( 306 loc, arith::CmpIPredicate::eq, dimensionExtent, 307 broadcastedDim); 308 Value result = b.create<arith::AndIOp>( 309 loc, broadcastable, 310 b.create<arith::OrIOp>(loc, equalOne, 311 equalBroadcasted)); 312 b.create<scf::YieldOp>(loc, result); 313 }) 314 .getResult(0); 315 } 316 317 b.create<scf::YieldOp>(loc, broadcastable); 318 }); 319 320 rewriter.replaceOp(op, reduceResult.getResults().front()); 321 return success(); 322 } 323 324 namespace { 325 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 326 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, 330 ConversionPatternRewriter &rewriter) const override; 331 }; 332 } // namespace 333 334 LogicalResult GetExtentOpConverter::matchAndRewrite( 335 GetExtentOp op, OpAdaptor adaptor, 336 ConversionPatternRewriter &rewriter) const { 337 // For now, only error-free types are supported by this lowering. 338 if (op.getType().isa<SizeType>()) 339 return failure(); 340 341 // Derive shape extent directly from shape origin if possible. This 342 // circumvents the necessity to materialize the shape in memory. 343 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) { 344 if (shapeOfOp.getArg().getType().isa<ShapedType>()) { 345 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(), 346 adaptor.getDim()); 347 return success(); 348 } 349 } 350 351 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 352 adaptor.getShape(), 353 ValueRange{adaptor.getDim()}); 354 return success(); 355 } 356 357 namespace { 358 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 359 public: 360 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 361 362 LogicalResult 363 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 364 ConversionPatternRewriter &rewriter) const override; 365 }; 366 } // namespace 367 368 LogicalResult 369 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 370 ConversionPatternRewriter &rewriter) const { 371 // For now, this lowering supports only error-free types. 372 if (op.getType().isa<SizeType>()) 373 return failure(); 374 375 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0); 376 return success(); 377 } 378 379 namespace { 380 /// Converts `shape.reduce` to `scf.for`. 381 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 382 public: 383 using OpConversionPattern::OpConversionPattern; 384 385 LogicalResult 386 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 387 ConversionPatternRewriter &rewriter) const final; 388 }; 389 } // namespace 390 391 LogicalResult 392 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 393 ConversionPatternRewriter &rewriter) const { 394 // For now, this lowering is only defined on `tensor<?xindex>` operands. 395 if (op.getShape().getType().isa<ShapeType>()) 396 return failure(); 397 398 auto loc = op.getLoc(); 399 400 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 401 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 402 Type indexTy = rewriter.getIndexType(); 403 Value rank = 404 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero); 405 406 auto loop = rewriter.create<scf::ForOp>( 407 loc, zero, rank, one, op.getInitVals(), 408 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 409 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv); 410 411 SmallVector<Value, 2> mappedValues{iv, extent}; 412 mappedValues.append(args.begin(), args.end()); 413 414 BlockAndValueMapping mapping; 415 Block *reduceBody = op.getBody(); 416 mapping.map(reduceBody->getArguments(), mappedValues); 417 for (auto &nested : reduceBody->without_terminator()) 418 b.clone(nested, mapping); 419 420 SmallVector<Value, 2> mappedResults; 421 for (auto result : reduceBody->getTerminator()->getOperands()) 422 mappedResults.push_back(mapping.lookup(result)); 423 b.create<scf::YieldOp>(loc, mappedResults); 424 }); 425 426 rewriter.replaceOp(op, loop.getResults()); 427 return success(); 428 } 429 430 namespace { 431 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 432 /// only defined on `tensor<?xindex>` operands. The test for equality first 433 /// compares their size and, if equal, checks every extent for equality. 434 /// 435 /// Example: 436 /// 437 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 438 /// 439 /// becomes 440 /// 441 /// %c0 = arith.constant 0 : index 442 /// %0 = dim %arg0, %c0 : tensor<?xindex> 443 /// %1 = dim %arg1, %c0 : tensor<?xindex> 444 /// %2 = arith.cmpi "eq", %0, %1 : index 445 /// %result = scf.if %2 -> (i1) { 446 /// %c1 = arith.constant 1 : index 447 /// %true = arith.constant true 448 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 449 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 450 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 451 /// %7 = arith.cmpi "eq", %5, %6 : index 452 /// %8 = arith.andi %arg3, %7 : i1 453 /// scf.yield %8 : i1 454 /// } 455 /// scf.yield %4 : i1 456 /// } else { 457 /// %false = arith.constant false 458 /// scf.yield %false : i1 459 /// } 460 /// 461 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 462 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 463 464 LogicalResult 465 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 466 ConversionPatternRewriter &rewriter) const override; 467 }; 468 } // namespace 469 470 LogicalResult 471 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 472 ConversionPatternRewriter &rewriter) const { 473 if (!llvm::all_of(op.getShapes(), 474 [](Value v) { return !v.getType().isa<ShapeType>(); })) 475 return failure(); 476 477 Type i1Ty = rewriter.getI1Type(); 478 if (op.getShapes().size() <= 1) { 479 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty, 480 rewriter.getBoolAttr(true)); 481 return success(); 482 } 483 484 auto loc = op.getLoc(); 485 Type indexTy = rewriter.getIndexType(); 486 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 487 Value firstShape = adaptor.getShapes().front(); 488 Value firstRank = 489 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); 490 Value result = nullptr; 491 // Generate a linear sequence of compares, all with firstShape as lhs. 492 for (Value shape : adaptor.getShapes().drop_front(1)) { 493 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); 494 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 495 firstRank, rank); 496 auto same = rewriter.create<IfOp>( 497 loc, i1Ty, eqRank, 498 [&](OpBuilder &b, Location loc) { 499 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 500 Value init = 501 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 502 auto loop = b.create<scf::ForOp>( 503 loc, zero, firstRank, one, ValueRange{init}, 504 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 505 Value conj = args[0]; 506 Value lhsExtent = 507 b.create<tensor::ExtractOp>(loc, firstShape, iv); 508 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 509 Value eqExtent = b.create<arith::CmpIOp>( 510 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); 511 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); 512 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 513 }); 514 b.create<scf::YieldOp>(loc, loop.getResults()); 515 }, 516 [&](OpBuilder &b, Location loc) { 517 Value result = 518 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 519 b.create<scf::YieldOp>(loc, result); 520 }); 521 result = !result ? same.getResult(0) 522 : rewriter.create<arith::AndIOp>(loc, result, 523 same.getResult(0)); 524 } 525 rewriter.replaceOp(op, result); 526 return success(); 527 } 528 529 namespace { 530 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 531 public: 532 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 533 534 LogicalResult 535 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, 536 ConversionPatternRewriter &rewriter) const override; 537 }; 538 } // namespace 539 540 LogicalResult ShapeOfOpConversion::matchAndRewrite( 541 ShapeOfOp op, OpAdaptor adaptor, 542 ConversionPatternRewriter &rewriter) const { 543 544 // For now, only error-free types are supported by this lowering. 545 if (op.getType().isa<ShapeType>()) 546 return failure(); 547 548 // For ranked tensor arguments, lower to `tensor.from_elements`. 549 auto loc = op.getLoc(); 550 Value tensor = adaptor.getArg(); 551 Type tensorTy = tensor.getType(); 552 if (tensorTy.isa<RankedTensorType>()) { 553 554 // Build values for individual extents. 555 SmallVector<Value, 8> extentValues; 556 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 557 int64_t rank = rankedTensorTy.getRank(); 558 for (int64_t i = 0; i < rank; i++) { 559 if (rankedTensorTy.isDynamicDim(i)) { 560 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); 561 extentValues.push_back(extent); 562 } else { 563 Value extent = rewriter.create<arith::ConstantIndexOp>( 564 loc, rankedTensorTy.getDimSize(i)); 565 extentValues.push_back(extent); 566 } 567 } 568 569 // Materialize extent tensor. 570 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 571 loc, RankedTensorType::get({rank}, rewriter.getIndexType()), 572 extentValues); 573 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 574 staticExtentTensor); 575 return success(); 576 } 577 578 // Lower to `tensor.generate` otherwise. 579 auto *ctx = rewriter.getContext(); 580 Value rank = rewriter.create<tensor::RankOp>(loc, tensor); 581 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 582 op, getExtentTensorType(ctx), ValueRange{rank}, 583 [&](OpBuilder &b, Location loc, ValueRange args) { 584 Value dim = args.front(); 585 Value extent = b.create<tensor::DimOp>(loc, tensor, dim); 586 b.create<tensor::YieldOp>(loc, extent); 587 }); 588 589 return success(); 590 } 591 592 namespace { 593 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 594 public: 595 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 596 597 LogicalResult 598 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, 599 ConversionPatternRewriter &rewriter) const override; 600 }; 601 } // namespace 602 603 LogicalResult SplitAtOpConversion::matchAndRewrite( 604 SplitAtOp op, OpAdaptor adaptor, 605 ConversionPatternRewriter &rewriter) const { 606 // Error conditions are not implemented, only lower if all operands and 607 // results are extent tensors. 608 if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, 609 [](Value v) { return v.getType().isa<ShapeType>(); })) 610 return failure(); 611 612 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 613 Value zero = b.create<arith::ConstantIndexOp>(0); 614 Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero); 615 616 // index < 0 ? index + rank : index 617 Value originalIndex = adaptor.getIndex(); 618 Value add = b.create<arith::AddIOp>(originalIndex, rank); 619 Value indexIsNegative = 620 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); 621 Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex); 622 623 Value one = b.create<arith::ConstantIndexOp>(1); 624 Value head = 625 b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one); 626 Value tailSize = b.create<arith::SubIOp>(rank, index); 627 Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index, 628 tailSize, one); 629 rewriter.replaceOp(op, {head, tail}); 630 return success(); 631 } 632 633 namespace { 634 class ToExtentTensorOpConversion 635 : public OpConversionPattern<ToExtentTensorOp> { 636 public: 637 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 638 639 LogicalResult 640 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, 641 ConversionPatternRewriter &rewriter) const override { 642 if (!adaptor.getInput().getType().isa<RankedTensorType>()) 643 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 644 645 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 646 adaptor.getInput()); 647 return success(); 648 } 649 }; 650 } // namespace 651 652 namespace { 653 /// Import the Shape Ops to Std Patterns. 654 #include "ShapeToStandard.cpp.inc" 655 } // namespace 656 657 namespace { 658 /// Conversion pass. 659 class ConvertShapeToStandardPass 660 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 661 662 void runOnOperation() override; 663 }; 664 } // namespace 665 666 void ConvertShapeToStandardPass::runOnOperation() { 667 // Setup target legality. 668 MLIRContext &ctx = getContext(); 669 ConversionTarget target(ctx); 670 target.addLegalDialect<arith::ArithmeticDialect, SCFDialect, 671 tensor::TensorDialect>(); 672 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 673 674 // Setup conversion patterns. 675 RewritePatternSet patterns(&ctx); 676 populateShapeToStandardConversionPatterns(patterns); 677 678 // Apply conversion. 679 auto module = getOperation(); 680 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 681 signalPassFailure(); 682 } 683 684 void mlir::populateShapeToStandardConversionPatterns( 685 RewritePatternSet &patterns) { 686 // clang-format off 687 populateWithGenerated(patterns); 688 patterns.add< 689 AnyOpConversion, 690 BinaryOpConversion<AddOp, arith::AddIOp>, 691 BinaryOpConversion<MulOp, arith::MulIOp>, 692 BroadcastOpConverter, 693 ConstShapeOpConverter, 694 ConstSizeOpConversion, 695 IsBroadcastableOpConverter, 696 GetExtentOpConverter, 697 RankOpConverter, 698 ReduceOpConverter, 699 ShapeEqOpConverter, 700 ShapeOfOpConversion, 701 SplitAtOpConversion, 702 ToExtentTensorOpConversion>(patterns.getContext()); 703 // clang-format on 704 } 705 706 std::unique_ptr<OperationPass<ModuleOp>> 707 mlir::createConvertShapeToStandardPass() { 708 return std::make_unique<ConvertShapeToStandardPass>(); 709 } 710