1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 using namespace mlir::scf; 24 25 /// Conversion patterns. 26 namespace { 27 class AnyOpConversion : public OpConversionPattern<AnyOp> { 28 public: 29 using OpConversionPattern<AnyOp>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override; 34 }; 35 } // namespace 36 37 LogicalResult 38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 39 ConversionPatternRewriter &rewriter) const { 40 AnyOp::Adaptor transformed(operands); 41 42 // Replace `any` with its first operand. 43 // Any operand would be a valid substitution. 44 rewriter.replaceOp(op, {transformed.inputs().front()}); 45 return success(); 46 } 47 48 namespace { 49 template <typename SrcOpTy, typename DstOpTy> 50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 51 public: 52 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 53 54 LogicalResult 55 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 56 ConversionPatternRewriter &rewriter) const override { 57 typename SrcOpTy::Adaptor transformed(operands); 58 59 // For now, only error-free types are supported by this lowering. 60 if (op.getType().template isa<SizeType>()) 61 return failure(); 62 63 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 64 transformed.rhs()); 65 return success(); 66 } 67 }; 68 } // namespace 69 70 namespace { 71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 72 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 73 74 LogicalResult 75 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 76 ConversionPatternRewriter &rewriter) const override; 77 }; 78 79 // Get the resulting extent in a given dimension. This is computed with any 80 // number of extent tensors and shifted offsets into them. 81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 82 ValueRange rankDiffs, Value outputDimension) { 83 Value one = lb.create<ConstantIndexOp>(1); 84 Value broadcastedDim = one; 85 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 86 Value shape = std::get<0>(tup); 87 Value rankDiff = std::get<1>(tup); 88 Value outOfBounds = 89 lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff); 90 Type indexTy = lb.getIndexType(); 91 broadcastedDim = 92 lb.create<IfOp>( 93 TypeRange{indexTy}, outOfBounds, 94 [&](OpBuilder &b, Location loc) { 95 b.create<scf::YieldOp>(loc, broadcastedDim); 96 }, 97 [&](OpBuilder &b, Location loc) { 98 // The broadcasting logic is: 99 // - if one extent (here we arbitrarily choose the 100 // extent from the greater-rank operand) is equal to 1, 101 // then take the extent from the other operand 102 // - otherwise, take the extent as-is. 103 // Note that this logic remains correct in the presence 104 // of dimensions of zero extent. 105 Value lesserRankOperandDimension = 106 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 107 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 108 loc, shape, ValueRange{lesserRankOperandDimension}); 109 110 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 111 lesserRankOperandExtent, one); 112 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 113 lesserRankOperandExtent); 114 b.create<scf::YieldOp>(loc, dim); 115 }) 116 .getResult(0); 117 } 118 return broadcastedDim; 119 } 120 } // namespace 121 122 LogicalResult BroadcastOpConverter::matchAndRewrite( 123 BroadcastOp op, ArrayRef<Value> operands, 124 ConversionPatternRewriter &rewriter) const { 125 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 126 // on shapes. 127 if (op.getType().isa<ShapeType>()) 128 return failure(); 129 130 auto loc = op.getLoc(); 131 ImplicitLocOpBuilder lb(loc, rewriter); 132 BroadcastOp::Adaptor transformed(operands); 133 134 Value zero = lb.create<ConstantIndexOp>(0); 135 Type indexTy = lb.getIndexType(); 136 137 // Save all the ranks for bounds checking. Because this is a tensor 138 // representing the shape extents, the rank is the extent of the only 139 // dimension in the tensor. 140 SmallVector<Value> ranks, rankDiffs; 141 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 142 return lb.create<DimOp>(v, zero); 143 })); 144 145 // Find the maximum rank 146 Value maxRank = ranks.front(); 147 for (Value v : llvm::drop_begin(ranks, 1)) { 148 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 149 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 150 } 151 152 // Calculate the difference of ranks and the maximum rank for later offsets. 153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 154 return lb.create<SubIOp>(indexTy, maxRank, v); 155 })); 156 157 rewriter.replaceOp( 158 op, lb.create<tensor::GenerateOp>( 159 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 160 [&](OpBuilder &b, Location loc, ValueRange args) { 161 Value broadcastedDim = getBroadcastedDim( 162 ImplicitLocOpBuilder(loc, b), transformed.shapes(), 163 rankDiffs, args[0]); 164 165 b.create<tensor::YieldOp>(loc, broadcastedDim); 166 }) 167 ->getResults()); 168 return success(); 169 } 170 171 namespace { 172 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 173 public: 174 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 175 176 LogicalResult 177 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 178 ConversionPatternRewriter &rewriter) const override; 179 }; 180 } // namespace 181 182 LogicalResult ConstShapeOpConverter::matchAndRewrite( 183 ConstShapeOp op, ArrayRef<Value> operands, 184 ConversionPatternRewriter &rewriter) const { 185 186 // For now, this lowering supports only extent tensors, not `shape.shape` 187 // types. 188 if (op.getType().isa<ShapeType>()) 189 return failure(); 190 191 auto loc = op.getLoc(); 192 SmallVector<Value, 4> extentOperands; 193 for (auto extent : op.shape()) { 194 extentOperands.push_back( 195 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 196 } 197 Type indexTy = rewriter.getIndexType(); 198 Value tensor = 199 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 200 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 201 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 202 return success(); 203 } 204 205 namespace { 206 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 207 public: 208 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 212 ConversionPatternRewriter &rewriter) const override; 213 }; 214 } // namespace 215 216 LogicalResult ConstSizeOpConversion::matchAndRewrite( 217 ConstSizeOp op, ArrayRef<Value> operands, 218 ConversionPatternRewriter &rewriter) const { 219 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 220 return success(); 221 } 222 223 namespace { 224 struct IsBroadcastableOpConverter 225 : public OpConversionPattern<IsBroadcastableOp> { 226 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 227 228 LogicalResult 229 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 230 ConversionPatternRewriter &rewriter) const override; 231 }; 232 } // namespace 233 234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 235 IsBroadcastableOp op, ArrayRef<Value> operands, 236 ConversionPatternRewriter &rewriter) const { 237 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 238 // on shapes. 239 IsBroadcastableOp::Adaptor transformed(operands); 240 if (!llvm::all_of(op.shapes(), 241 [](Value v) { return !v.getType().isa<ShapeType>(); })) 242 return failure(); 243 244 auto loc = op.getLoc(); 245 ImplicitLocOpBuilder lb(loc, rewriter); 246 Value zero = lb.create<ConstantIndexOp>(0); 247 Value one = lb.create<ConstantIndexOp>(1); 248 Type indexTy = lb.getIndexType(); 249 250 // Save all the ranks for bounds checking. Because this is a tensor 251 // representing the shape extents, the rank is the extent of the only 252 // dimension in the tensor. 253 SmallVector<Value> ranks, rankDiffs; 254 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 255 return lb.create<DimOp>(v, zero); 256 })); 257 258 // Find the maximum rank 259 Value maxRank = ranks.front(); 260 for (Value v : llvm::drop_begin(ranks, 1)) { 261 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 262 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 263 } 264 265 // Calculate the difference of ranks and the maximum rank for later offsets. 266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 267 return lb.create<SubIOp>(indexTy, maxRank, v); 268 })); 269 270 Type i1Ty = rewriter.getI1Type(); 271 Value trueVal = 272 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 273 274 auto reduceResult = lb.create<ForOp>( 275 loc, zero, maxRank, one, ValueRange{trueVal}, 276 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 277 // Find a non-1 dim, if it exists. Note that the first part of this 278 // could reuse the Broadcast lowering entirely, but we redo the work 279 // here to make optimizations easier between the two loops. 280 Value broadcastedDim = getBroadcastedDim( 281 ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); 282 283 Value broadcastable = iterArgs[0]; 284 for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { 285 Value shape, rankDiff; 286 std::tie(shape, rankDiff) = tup; 287 Value outOfBounds = 288 b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff); 289 broadcastable = 290 b.create<IfOp>( 291 loc, TypeRange{i1Ty}, outOfBounds, 292 [&](OpBuilder &b, Location loc) { 293 // Non existent dimensions are always broadcastable 294 b.create<scf::YieldOp>(loc, broadcastable); 295 }, 296 [&](OpBuilder &b, Location loc) { 297 // Every value needs to be either 1, or the same non-1 298 // value to be broadcastable in this dim. 299 Value operandDimension = 300 b.create<SubIOp>(loc, indexTy, iv, rankDiff); 301 Value dimensionExtent = b.create<tensor::ExtractOp>( 302 loc, shape, ValueRange{operandDimension}); 303 304 Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 305 dimensionExtent, one); 306 Value equalBroadcasted = 307 b.create<CmpIOp>(loc, CmpIPredicate::eq, 308 dimensionExtent, broadcastedDim); 309 Value result = b.create<AndOp>( 310 loc, broadcastable, 311 b.create<OrOp>(loc, equalOne, equalBroadcasted)); 312 b.create<scf::YieldOp>(loc, result); 313 }) 314 .getResult(0); 315 } 316 317 b.create<scf::YieldOp>(loc, broadcastable); 318 }); 319 320 rewriter.replaceOp(op, reduceResult.results().front()); 321 return success(); 322 } 323 324 namespace { 325 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 326 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 330 ConversionPatternRewriter &rewriter) const override; 331 }; 332 } // namespace 333 334 LogicalResult GetExtentOpConverter::matchAndRewrite( 335 GetExtentOp op, ArrayRef<Value> operands, 336 ConversionPatternRewriter &rewriter) const { 337 GetExtentOp::Adaptor transformed(operands); 338 339 // For now, only error-free types are supported by this lowering. 340 if (op.getType().isa<SizeType>()) 341 return failure(); 342 343 // Derive shape extent directly from shape origin if possible. This 344 // circumvents the necessity to materialize the shape in memory. 345 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 346 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 347 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 348 transformed.dim()); 349 return success(); 350 } 351 } 352 353 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 354 transformed.shape(), 355 ValueRange{transformed.dim()}); 356 return success(); 357 } 358 359 namespace { 360 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 361 public: 362 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 363 364 LogicalResult 365 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 366 ConversionPatternRewriter &rewriter) const override; 367 }; 368 } // namespace 369 370 LogicalResult 371 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 372 ConversionPatternRewriter &rewriter) const { 373 // For now, this lowering supports only error-free types. 374 if (op.getType().isa<SizeType>()) 375 return failure(); 376 377 shape::RankOp::Adaptor transformed(operands); 378 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 379 return success(); 380 } 381 382 namespace { 383 /// Converts `shape.reduce` to `scf.for`. 384 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 385 public: 386 using OpConversionPattern::OpConversionPattern; 387 388 LogicalResult 389 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 390 ConversionPatternRewriter &rewriter) const final; 391 }; 392 } // namespace 393 394 LogicalResult 395 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 396 ConversionPatternRewriter &rewriter) const { 397 // For now, this lowering is only defined on `tensor<?xindex>` operands. 398 if (op.shape().getType().isa<ShapeType>()) 399 return failure(); 400 401 auto loc = op.getLoc(); 402 shape::ReduceOp::Adaptor transformed(operands); 403 404 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 405 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 406 Type indexTy = rewriter.getIndexType(); 407 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 408 409 auto loop = rewriter.create<scf::ForOp>( 410 loc, zero, rank, one, op.initVals(), 411 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 412 Value extent = 413 b.create<tensor::ExtractOp>(loc, transformed.shape(), iv); 414 415 SmallVector<Value, 2> mappedValues{iv, extent}; 416 mappedValues.append(args.begin(), args.end()); 417 418 BlockAndValueMapping mapping; 419 Block *reduceBody = op.getBody(); 420 mapping.map(reduceBody->getArguments(), mappedValues); 421 for (auto &nested : reduceBody->without_terminator()) 422 b.clone(nested, mapping); 423 424 SmallVector<Value, 2> mappedResults; 425 for (auto result : reduceBody->getTerminator()->getOperands()) 426 mappedResults.push_back(mapping.lookup(result)); 427 b.create<scf::YieldOp>(loc, mappedResults); 428 }); 429 430 rewriter.replaceOp(op, loop.getResults()); 431 return success(); 432 } 433 434 namespace { 435 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 436 /// only defined on `tensor<?xindex>` operands. The test for equality first 437 /// compares their size and, if equal, checks every extent for equality. 438 /// 439 /// Example: 440 /// 441 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 442 /// 443 /// becomes 444 /// 445 /// %c0 = constant 0 : index 446 /// %0 = dim %arg0, %c0 : tensor<?xindex> 447 /// %1 = dim %arg1, %c0 : tensor<?xindex> 448 /// %2 = cmpi "eq", %0, %1 : index 449 /// %result = scf.if %2 -> (i1) { 450 /// %c1 = constant 1 : index 451 /// %true = constant true 452 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 453 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 454 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 455 /// %7 = cmpi "eq", %5, %6 : index 456 /// %8 = and %arg3, %7 : i1 457 /// scf.yield %8 : i1 458 /// } 459 /// scf.yield %4 : i1 460 /// } else { 461 /// %false = constant false 462 /// scf.yield %false : i1 463 /// } 464 /// 465 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 466 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 467 468 LogicalResult 469 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 470 ConversionPatternRewriter &rewriter) const override; 471 }; 472 } // namespace 473 474 LogicalResult 475 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const { 477 if (!llvm::all_of(op.shapes(), 478 [](Value v) { return !v.getType().isa<ShapeType>(); })) 479 return failure(); 480 481 Type i1Ty = rewriter.getI1Type(); 482 if (op.shapes().size() <= 1) { 483 rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty, 484 rewriter.getBoolAttr(true)); 485 return success(); 486 } 487 488 ShapeEqOp::Adaptor transformed(operands); 489 auto loc = op.getLoc(); 490 Type indexTy = rewriter.getIndexType(); 491 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 492 Value firstShape = transformed.shapes().front(); 493 Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero); 494 Value result = nullptr; 495 // Generate a linear sequence of compares, all with firstShape as lhs. 496 for (Value shape : transformed.shapes().drop_front(1)) { 497 Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero); 498 Value eqRank = 499 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank); 500 auto same = rewriter.create<IfOp>( 501 loc, i1Ty, eqRank, 502 [&](OpBuilder &b, Location loc) { 503 Value one = b.create<ConstantIndexOp>(loc, 1); 504 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 505 auto loop = b.create<scf::ForOp>( 506 loc, zero, firstRank, one, ValueRange{init}, 507 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 508 Value conj = args[0]; 509 Value lhsExtent = 510 b.create<tensor::ExtractOp>(loc, firstShape, iv); 511 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 512 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 513 lhsExtent, rhsExtent); 514 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 515 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 516 }); 517 b.create<scf::YieldOp>(loc, loop.getResults()); 518 }, 519 [&](OpBuilder &b, Location loc) { 520 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 521 b.create<scf::YieldOp>(loc, result); 522 }); 523 result = !result ? same.getResult(0) 524 : rewriter.create<AndOp>(loc, result, same.getResult(0)); 525 } 526 rewriter.replaceOp(op, result); 527 return success(); 528 } 529 530 namespace { 531 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 532 public: 533 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 534 535 LogicalResult 536 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 537 ConversionPatternRewriter &rewriter) const override; 538 }; 539 } // namespace 540 541 LogicalResult ShapeOfOpConversion::matchAndRewrite( 542 ShapeOfOp op, ArrayRef<Value> operands, 543 ConversionPatternRewriter &rewriter) const { 544 545 // For now, only error-free types are supported by this lowering. 546 if (op.getType().isa<ShapeType>()) 547 return failure(); 548 549 // For ranked tensor arguments, lower to `tensor.from_elements`. 550 auto loc = op.getLoc(); 551 ShapeOfOp::Adaptor transformed(operands); 552 Value tensor = transformed.arg(); 553 Type tensorTy = tensor.getType(); 554 if (tensorTy.isa<RankedTensorType>()) { 555 556 // Build values for individual extents. 557 SmallVector<Value, 8> extentValues; 558 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 559 int64_t rank = rankedTensorTy.getRank(); 560 for (int64_t i = 0; i < rank; i++) { 561 if (rankedTensorTy.isDynamicDim(i)) { 562 Value extent = rewriter.create<DimOp>(loc, tensor, i); 563 extentValues.push_back(extent); 564 } else { 565 Value extent = 566 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 567 extentValues.push_back(extent); 568 } 569 } 570 571 // Materialize extent tensor. 572 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 573 loc, rewriter.getIndexType(), extentValues); 574 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 575 staticExtentTensor); 576 return success(); 577 } 578 579 // Lower to `tensor.generate` otherwise. 580 auto *ctx = rewriter.getContext(); 581 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 582 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 583 op, getExtentTensorType(ctx), ValueRange{rank}, 584 [&](OpBuilder &b, Location loc, ValueRange args) { 585 Value dim = args.front(); 586 Value extent = b.create<DimOp>(loc, tensor, dim); 587 b.create<tensor::YieldOp>(loc, extent); 588 }); 589 590 return success(); 591 } 592 593 namespace { 594 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 595 public: 596 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 597 598 LogicalResult 599 matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands, 600 ConversionPatternRewriter &rewriter) const override; 601 }; 602 } // namespace 603 604 LogicalResult SplitAtOpConversion::matchAndRewrite( 605 SplitAtOp op, ArrayRef<Value> operands, 606 ConversionPatternRewriter &rewriter) const { 607 // Error conditions are not implemented, only lower if all operands and 608 // results are extent tensors. 609 if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, 610 [](Value v) { return v.getType().isa<ShapeType>(); })) 611 return failure(); 612 613 SplitAtOp::Adaptor transformed(op); 614 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 615 Value zero = b.create<ConstantIndexOp>(0); 616 Value rank = b.create<DimOp>(transformed.operand(), zero); 617 618 // index < 0 ? index + rank : index 619 Value originalIndex = transformed.index(); 620 Value add = b.create<AddIOp>(originalIndex, rank); 621 Value indexIsNegative = 622 b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero); 623 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 624 625 Value one = b.create<ConstantIndexOp>(1); 626 Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one); 627 Value tailSize = b.create<SubIOp>(rank, index); 628 Value tail = 629 b.create<SubTensorOp>(transformed.operand(), index, tailSize, one); 630 rewriter.replaceOp(op, {head, tail}); 631 return success(); 632 } 633 634 namespace { 635 class ToExtentTensorOpConversion 636 : public OpConversionPattern<ToExtentTensorOp> { 637 public: 638 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 639 640 LogicalResult 641 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 642 ConversionPatternRewriter &rewriter) const override { 643 ToExtentTensorOpAdaptor adaptor(operands); 644 645 if (!adaptor.input().getType().isa<RankedTensorType>()) 646 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 647 648 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 649 adaptor.input()); 650 return success(); 651 } 652 }; 653 } // namespace 654 655 namespace { 656 /// Import the Shape Ops to Std Patterns. 657 #include "ShapeToStandard.cpp.inc" 658 } // namespace 659 660 namespace { 661 /// Conversion pass. 662 class ConvertShapeToStandardPass 663 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 664 665 void runOnOperation() override; 666 }; 667 } // namespace 668 669 void ConvertShapeToStandardPass::runOnOperation() { 670 // Setup target legality. 671 MLIRContext &ctx = getContext(); 672 ConversionTarget target(ctx); 673 target 674 .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>(); 675 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>(); 676 677 // Setup conversion patterns. 678 OwningRewritePatternList patterns; 679 populateShapeToStandardConversionPatterns(patterns, &ctx); 680 681 // Apply conversion. 682 auto module = getOperation(); 683 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 684 signalPassFailure(); 685 } 686 687 void mlir::populateShapeToStandardConversionPatterns( 688 OwningRewritePatternList &patterns, MLIRContext *ctx) { 689 // clang-format off 690 populateWithGenerated(ctx, patterns); 691 patterns.insert< 692 AnyOpConversion, 693 BinaryOpConversion<AddOp, AddIOp>, 694 BinaryOpConversion<MulOp, MulIOp>, 695 BroadcastOpConverter, 696 ConstShapeOpConverter, 697 ConstSizeOpConversion, 698 IsBroadcastableOpConverter, 699 GetExtentOpConverter, 700 RankOpConverter, 701 ReduceOpConverter, 702 ShapeEqOpConverter, 703 ShapeOfOpConversion, 704 SplitAtOpConversion, 705 ToExtentTensorOpConversion>(ctx); 706 // clang-format on 707 } 708 709 std::unique_ptr<OperationPass<ModuleOp>> 710 mlir::createConvertShapeToStandardPass() { 711 return std::make_unique<ConvertShapeToStandardPass>(); 712 } 713