1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 19 using namespace mlir; 20 using namespace mlir::shape; 21 using namespace mlir::scf; 22 23 /// Conversion patterns. 24 namespace { 25 class AnyOpConversion : public OpConversionPattern<AnyOp> { 26 public: 27 using OpConversionPattern<AnyOp>::OpConversionPattern; 28 29 LogicalResult 30 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 31 ConversionPatternRewriter &rewriter) const override; 32 }; 33 } // namespace 34 35 LogicalResult 36 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 37 ConversionPatternRewriter &rewriter) const { 38 AnyOp::Adaptor transformed(operands); 39 40 // Replace `any` with its first operand. 41 // Any operand would be a valid substitution. 42 rewriter.replaceOp(op, {transformed.inputs().front()}); 43 return success(); 44 } 45 46 namespace { 47 template <typename SrcOpTy, typename DstOpTy> 48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 49 public: 50 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 51 52 LogicalResult 53 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 54 ConversionPatternRewriter &rewriter) const override { 55 typename SrcOpTy::Adaptor transformed(operands); 56 57 // For now, only error-free types are supported by this lowering. 58 if (op.getType().template isa<SizeType>()) 59 return failure(); 60 61 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 62 transformed.rhs()); 63 return success(); 64 } 65 }; 66 } // namespace 67 68 namespace { 69 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 70 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 71 72 LogicalResult 73 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 74 ConversionPatternRewriter &rewriter) const override; 75 }; 76 } // namespace 77 78 LogicalResult BroadcastOpConverter::matchAndRewrite( 79 BroadcastOp op, ArrayRef<Value> operands, 80 ConversionPatternRewriter &rewriter) const { 81 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 82 // on shapes. 83 if (op.getType().isa<ShapeType>()) 84 return failure(); 85 86 assert(!op.lhs().getType().isa<ShapeType>() && 87 !op.rhs().getType().isa<ShapeType>()); 88 auto loc = op.getLoc(); 89 BroadcastOp::Adaptor transformed(operands); 90 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 91 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 92 93 // Find smaller and greater rank and extent tensor. 94 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 95 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 96 Value lhsRankULE = 97 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 98 Type indexTy = rewriter.getIndexType(); 99 Value lesserRank = 100 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 101 Value greaterRank = 102 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 103 auto erasedRankType = 104 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 105 Value rankErasedLhs = 106 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 107 Value rankErasedRhs = 108 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 109 Value lesserRankOperand = 110 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 111 Value greaterRankOperand = 112 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 113 114 Value rankDiff = 115 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 116 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 117 op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, 118 [&](OpBuilder &b, Location loc, ValueRange args) { 119 Value outputDimension = args[0]; 120 Value isUnchallengedDimension = b.create<CmpIOp>( 121 loc, CmpIPredicate::ult, outputDimension, rankDiff); 122 Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( 123 loc, greaterRankOperand, outputDimension); 124 // The initial dimensions of the greater-rank operand are unchallenged, 125 // so we can take them as-is. Otherwise, we need to do a comparison. 126 // We need an actual branch here (instead of a select) because the 127 // lesser-rank operand might be rank 0, so any tensor.extract would be 128 // invalid. 129 auto ifOp = b.create<IfOp>( 130 loc, TypeRange{indexTy}, isUnchallengedDimension, 131 [&](OpBuilder &b, Location loc) { 132 b.create<scf::YieldOp>(loc, greaterRankOperandExtent); 133 }, 134 [&](OpBuilder &b, Location loc) { 135 // The broadcasting logic is: 136 // - if one extent (here we arbitrarily choose the extent from 137 // the greater-rank operand) is equal to 1, then take the extent 138 // from the other operand 139 // - otherwise, take the extent as-is. 140 // Note that this logic remains correct in the presence of 141 // dimensions of zero extent. 142 Value lesserRankOperandDimension = 143 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 144 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 145 loc, lesserRankOperand, 146 ValueRange{lesserRankOperandDimension}); 147 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 148 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 149 Value broadcastedExtent = b.create<SelectOp>( 150 loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, 151 greaterRankOperandExtent); 152 b.create<scf::YieldOp>(loc, broadcastedExtent); 153 }); 154 b.create<mlir::YieldOp>(loc, ifOp.getResult(0)); 155 }); 156 return success(); 157 } 158 159 namespace { 160 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 161 public: 162 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 163 164 LogicalResult 165 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const override; 167 }; 168 } // namespace 169 170 LogicalResult ConstShapeOpConverter::matchAndRewrite( 171 ConstShapeOp op, ArrayRef<Value> operands, 172 ConversionPatternRewriter &rewriter) const { 173 174 // For now, this lowering supports only extent tensors, not `shape.shape` 175 // types. 176 if (op.getType().isa<ShapeType>()) 177 return failure(); 178 179 auto loc = op.getLoc(); 180 SmallVector<Value, 4> extentOperands; 181 for (auto extent : op.shape()) { 182 extentOperands.push_back( 183 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 184 } 185 Type indexTy = rewriter.getIndexType(); 186 Value tensor = 187 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 188 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 189 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 190 return success(); 191 } 192 193 namespace { 194 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 195 public: 196 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 197 198 LogicalResult 199 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 200 ConversionPatternRewriter &rewriter) const override; 201 }; 202 } // namespace 203 204 LogicalResult ConstSizeOpConversion::matchAndRewrite( 205 ConstSizeOp op, ArrayRef<Value> operands, 206 ConversionPatternRewriter &rewriter) const { 207 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 208 return success(); 209 } 210 211 namespace { 212 struct IsBroadcastableOpConverter 213 : public OpConversionPattern<IsBroadcastableOp> { 214 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 215 216 LogicalResult 217 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 218 ConversionPatternRewriter &rewriter) const override; 219 }; 220 } // namespace 221 222 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 223 IsBroadcastableOp op, ArrayRef<Value> operands, 224 ConversionPatternRewriter &rewriter) const { 225 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 226 // on shapes. 227 IsBroadcastableOp::Adaptor transformed(operands); 228 if (transformed.lhs().getType().isa<ShapeType>() || 229 transformed.rhs().getType().isa<ShapeType>()) 230 return failure(); 231 232 auto loc = op.getLoc(); 233 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 234 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 235 236 // Find smaller and greater rank and extent tensor. 237 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); 238 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); 239 Value lhsRankULE = 240 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 241 Type indexTy = rewriter.getIndexType(); 242 Value lesserRank = 243 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 244 Value greaterRank = 245 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 246 auto erasedRankType = 247 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 248 Value rankErasedLhs = 249 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 250 Value rankErasedRhs = 251 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 252 Value lesserRankOperand = 253 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 254 Value greaterRankOperand = 255 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 256 Value rankDiff = 257 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 258 Type i1Ty = rewriter.getI1Type(); 259 Value init = 260 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 261 262 // Determine if all overlapping extents are broadcastable. 263 auto reduceResult = rewriter.create<ForOp>( 264 loc, rankDiff, greaterRank, one, ValueRange{init}, 265 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 266 Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( 267 loc, greaterRankOperand, ValueRange{iv}); 268 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 269 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 270 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 271 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 272 loc, lesserRankOperand, ValueRange{ivShifted}); 273 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( 274 loc, CmpIPredicate::eq, lesserRankOperandExtent, one); 275 Value extentsAreEqual = 276 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, 277 lesserRankOperandExtent); 278 Value broadcastableExtents = b.create<AndOp>( 279 loc, iterArgs[0], 280 b.create<OrOp>(loc, 281 b.create<OrOp>(loc, greaterRankOperandExtentIsOne, 282 lesserRankOperandExtentIsOne), 283 extentsAreEqual)); 284 b.create<scf::YieldOp>(loc, broadcastableExtents); 285 }); 286 287 rewriter.replaceOp(op, reduceResult.results().front()); 288 return success(); 289 } 290 291 namespace { 292 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 293 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 294 295 LogicalResult 296 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const override; 298 }; 299 } // namespace 300 301 LogicalResult GetExtentOpConverter::matchAndRewrite( 302 GetExtentOp op, ArrayRef<Value> operands, 303 ConversionPatternRewriter &rewriter) const { 304 GetExtentOp::Adaptor transformed(operands); 305 306 // For now, only error-free types are supported by this lowering. 307 if (op.getType().isa<SizeType>()) 308 return failure(); 309 310 // Derive shape extent directly from shape origin if possible. This 311 // circumvents the necessity to materialize the shape in memory. 312 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 313 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 314 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 315 transformed.dim()); 316 return success(); 317 } 318 } 319 320 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 321 transformed.shape(), 322 ValueRange{transformed.dim()}); 323 return success(); 324 } 325 326 namespace { 327 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 328 public: 329 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 330 331 LogicalResult 332 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 333 ConversionPatternRewriter &rewriter) const override; 334 }; 335 } // namespace 336 337 LogicalResult 338 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 339 ConversionPatternRewriter &rewriter) const { 340 // For now, this lowering supports only error-free types. 341 if (op.getType().isa<SizeType>()) 342 return failure(); 343 344 shape::RankOp::Adaptor transformed(operands); 345 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 346 return success(); 347 } 348 349 namespace { 350 /// Converts `shape.reduce` to `scf.for`. 351 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 352 public: 353 using OpConversionPattern::OpConversionPattern; 354 355 LogicalResult 356 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 357 ConversionPatternRewriter &rewriter) const final; 358 }; 359 } // namespace 360 361 LogicalResult 362 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 363 ConversionPatternRewriter &rewriter) const { 364 // For now, this lowering is only defined on `tensor<?xindex>` operands. 365 if (op.shape().getType().isa<ShapeType>()) 366 return failure(); 367 368 auto loc = op.getLoc(); 369 shape::ReduceOp::Adaptor transformed(operands); 370 371 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 372 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 373 Type indexTy = rewriter.getIndexType(); 374 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 375 376 auto loop = rewriter.create<scf::ForOp>( 377 loc, zero, rank, one, op.initVals(), 378 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 379 Value extent = 380 b.create<tensor::ExtractOp>(loc, transformed.shape(), iv); 381 382 SmallVector<Value, 2> mappedValues{iv, extent}; 383 mappedValues.append(args.begin(), args.end()); 384 385 BlockAndValueMapping mapping; 386 Block *reduceBody = op.getBody(); 387 mapping.map(reduceBody->getArguments(), mappedValues); 388 for (auto &nested : reduceBody->without_terminator()) 389 b.clone(nested, mapping); 390 391 SmallVector<Value, 2> mappedResults; 392 for (auto result : reduceBody->getTerminator()->getOperands()) 393 mappedResults.push_back(mapping.lookup(result)); 394 b.create<scf::YieldOp>(loc, mappedResults); 395 }); 396 397 rewriter.replaceOp(op, loop.getResults()); 398 return success(); 399 } 400 401 namespace { 402 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 403 /// only defined on `tensor<?xindex>` operands. The test for equality first 404 /// compares their size and, if equal, checks every extent for equality. 405 /// 406 /// Example: 407 /// 408 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 409 /// 410 /// becomes 411 /// 412 /// %c0 = constant 0 : index 413 /// %0 = dim %arg0, %c0 : tensor<?xindex> 414 /// %1 = dim %arg1, %c0 : tensor<?xindex> 415 /// %2 = cmpi "eq", %0, %1 : index 416 /// %result = scf.if %2 -> (i1) { 417 /// %c1 = constant 1 : index 418 /// %true = constant true 419 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 420 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 421 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 422 /// %7 = cmpi "eq", %5, %6 : index 423 /// %8 = and %arg3, %7 : i1 424 /// scf.yield %8 : i1 425 /// } 426 /// scf.yield %4 : i1 427 /// } else { 428 /// %false = constant false 429 /// scf.yield %false : i1 430 /// } 431 /// 432 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 433 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 434 435 LogicalResult 436 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 437 ConversionPatternRewriter &rewriter) const override; 438 }; 439 } // namespace 440 441 LogicalResult 442 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 443 ConversionPatternRewriter &rewriter) const { 444 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 445 // on shapes. 446 if (op.lhs().getType().isa<ShapeType>() || 447 op.rhs().getType().isa<ShapeType>()) { 448 return failure(); 449 } 450 451 ShapeEqOp::Adaptor transformed(operands); 452 auto loc = op.getLoc(); 453 Type indexTy = rewriter.getIndexType(); 454 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 455 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 456 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 457 Value eqRank = 458 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 459 Type i1Ty = rewriter.getI1Type(); 460 rewriter.replaceOpWithNewOp<IfOp>( 461 op, i1Ty, eqRank, 462 [&](OpBuilder &b, Location loc) { 463 Value one = b.create<ConstantIndexOp>(loc, 1); 464 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 465 auto loop = b.create<scf::ForOp>( 466 loc, zero, lhsRank, one, ValueRange{init}, 467 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 468 Value conj = args[0]; 469 Value lhsExtent = 470 b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv); 471 Value rhsExtent = 472 b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv); 473 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 474 lhsExtent, rhsExtent); 475 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 476 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 477 }); 478 b.create<scf::YieldOp>(loc, loop.getResults()); 479 }, 480 [&](OpBuilder &b, Location loc) { 481 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 482 b.create<scf::YieldOp>(loc, result); 483 }); 484 return success(); 485 } 486 487 namespace { 488 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 489 public: 490 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 491 492 LogicalResult 493 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 494 ConversionPatternRewriter &rewriter) const override; 495 }; 496 } // namespace 497 498 LogicalResult ShapeOfOpConversion::matchAndRewrite( 499 ShapeOfOp op, ArrayRef<Value> operands, 500 ConversionPatternRewriter &rewriter) const { 501 502 // For now, only error-free types are supported by this lowering. 503 if (op.getType().isa<ShapeType>()) 504 return failure(); 505 506 // For ranked tensor arguments, lower to `tensor_from_elements`. 507 auto loc = op.getLoc(); 508 ShapeOfOp::Adaptor transformed(operands); 509 Value tensor = transformed.arg(); 510 Type tensorTy = tensor.getType(); 511 if (tensorTy.isa<RankedTensorType>()) { 512 513 // Build values for individual extents. 514 SmallVector<Value, 8> extentValues; 515 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 516 int64_t rank = rankedTensorTy.getRank(); 517 for (int64_t i = 0; i < rank; i++) { 518 if (rankedTensorTy.isDynamicDim(i)) { 519 Value extent = rewriter.create<DimOp>(loc, tensor, i); 520 extentValues.push_back(extent); 521 } else { 522 Value extent = 523 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 524 extentValues.push_back(extent); 525 } 526 } 527 528 // Materialize extent tensor. 529 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 530 loc, rewriter.getIndexType(), extentValues); 531 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 532 op.getType()); 533 return success(); 534 } 535 536 // Lower to `dynamic_tensor_from_elements` otherwise. 537 auto *ctx = rewriter.getContext(); 538 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 539 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 540 op, getExtentTensorType(ctx), ValueRange{rank}, 541 [&](OpBuilder &b, Location loc, ValueRange args) { 542 Value dim = args.front(); 543 Value extent = b.create<DimOp>(loc, tensor, dim); 544 b.create<mlir::YieldOp>(loc, extent); 545 }); 546 547 return success(); 548 } 549 550 namespace { 551 class ToExtentTensorOpConversion 552 : public OpConversionPattern<ToExtentTensorOp> { 553 public: 554 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 555 556 LogicalResult 557 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 558 ConversionPatternRewriter &rewriter) const override { 559 ToExtentTensorOpAdaptor adaptor(operands); 560 561 if (!adaptor.input().getType().isa<RankedTensorType>()) 562 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 563 564 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 565 op.getType()); 566 return success(); 567 } 568 }; 569 } // namespace 570 571 namespace { 572 /// Import the Shape Ops to Std Patterns. 573 #include "ShapeToStandard.cpp.inc" 574 } // namespace 575 576 namespace { 577 /// Conversion pass. 578 class ConvertShapeToStandardPass 579 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 580 581 void runOnOperation() override; 582 }; 583 } // namespace 584 585 void ConvertShapeToStandardPass::runOnOperation() { 586 // Setup target legality. 587 MLIRContext &ctx = getContext(); 588 ConversionTarget target(ctx); 589 target 590 .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>(); 591 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>(); 592 593 // Setup conversion patterns. 594 OwningRewritePatternList patterns; 595 populateShapeToStandardConversionPatterns(patterns, &ctx); 596 597 // Apply conversion. 598 auto module = getOperation(); 599 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 600 signalPassFailure(); 601 } 602 603 void mlir::populateShapeToStandardConversionPatterns( 604 OwningRewritePatternList &patterns, MLIRContext *ctx) { 605 // clang-format off 606 populateWithGenerated(ctx, patterns); 607 patterns.insert< 608 AnyOpConversion, 609 BinaryOpConversion<AddOp, AddIOp>, 610 BinaryOpConversion<MulOp, MulIOp>, 611 BroadcastOpConverter, 612 ConstShapeOpConverter, 613 ConstSizeOpConversion, 614 IsBroadcastableOpConverter, 615 GetExtentOpConverter, 616 RankOpConverter, 617 ReduceOpConverter, 618 ShapeEqOpConverter, 619 ShapeOfOpConversion, 620 ToExtentTensorOpConversion>(ctx); 621 // clang-format on 622 } 623 624 std::unique_ptr<OperationPass<ModuleOp>> 625 mlir::createConvertShapeToStandardPass() { 626 return std::make_unique<ConvertShapeToStandardPass>(); 627 } 628