1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 95 Value lhsRankULE = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Value lesserRank = 99 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 100 Value greaterRank = 101 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 102 auto erasedRankType = 103 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 104 Value rankErasedLhs = 105 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 106 Value rankErasedRhs = 107 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 108 Value lesserRankOperand = 109 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 110 Value greaterRankOperand = 111 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 112 113 Value rankDiff = 114 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 115 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 116 op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, 117 [&](OpBuilder &b, Location loc, ValueRange args) { 118 Value outputDimension = args[0]; 119 Value isUnchallengedDimension = b.create<CmpIOp>( 120 loc, CmpIPredicate::ult, outputDimension, rankDiff); 121 Value greaterRankOperandExtent = b.create<ExtractElementOp>( 122 loc, greaterRankOperand, outputDimension); 123 // The initial dimensions of the greater-rank operand are unchallenged, 124 // so we can take them as-is. Otherwise, we need to do a comparison. 125 // We need an actual branch here (instead of a select) because the 126 // lesser-rank operand might be rank 0, so any extract_element would be 127 // invalid. 128 auto ifOp = b.create<IfOp>( 129 loc, TypeRange{indexTy}, isUnchallengedDimension, 130 [&](OpBuilder &b, Location loc) { 131 b.create<scf::YieldOp>(loc, greaterRankOperandExtent); 132 }, 133 [&](OpBuilder &b, Location loc) { 134 // The broadcasting logic is: 135 // - if one extent (here we arbitrarily choose the extent from 136 // the greater-rank operand) is equal to 1, then take the extent 137 // from the other operand 138 // - otherwise, take the extent as-is. 139 // Note that this logic remains correct in the presence of 140 // dimensions of zero extent. 141 Value lesserRankOperandDimension = 142 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 143 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 144 loc, lesserRankOperand, 145 ValueRange{lesserRankOperandDimension}); 146 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 147 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 148 Value broadcastedExtent = b.create<SelectOp>( 149 loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, 150 greaterRankOperandExtent); 151 b.create<scf::YieldOp>(loc, broadcastedExtent); 152 }); 153 b.create<mlir::YieldOp>(loc, ifOp.getResult(0)); 154 }); 155 return success(); 156 } 157 158 namespace { 159 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 160 public: 161 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 162 163 LogicalResult 164 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 165 ConversionPatternRewriter &rewriter) const override; 166 }; 167 } // namespace 168 169 LogicalResult ConstShapeOpConverter::matchAndRewrite( 170 ConstShapeOp op, ArrayRef<Value> operands, 171 ConversionPatternRewriter &rewriter) const { 172 173 // For now, this lowering supports only extent tensors, not `shape.shape` 174 // types. 175 if (op.getType().isa<ShapeType>()) 176 return failure(); 177 178 auto loc = op.getLoc(); 179 SmallVector<Value, 4> extentOperands; 180 for (auto extent : op.shape()) { 181 extentOperands.push_back( 182 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 183 } 184 Type indexTy = rewriter.getIndexType(); 185 Value tensor = 186 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 187 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 188 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 189 return success(); 190 } 191 192 namespace { 193 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 194 public: 195 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 196 197 LogicalResult 198 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 199 ConversionPatternRewriter &rewriter) const override; 200 }; 201 } // namespace 202 203 LogicalResult ConstSizeOpConversion::matchAndRewrite( 204 ConstSizeOp op, ArrayRef<Value> operands, 205 ConversionPatternRewriter &rewriter) const { 206 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 207 return success(); 208 } 209 210 namespace { 211 struct IsBroadcastableOpConverter 212 : public OpConversionPattern<IsBroadcastableOp> { 213 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 214 215 LogicalResult 216 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 217 ConversionPatternRewriter &rewriter) const override; 218 }; 219 } // namespace 220 221 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 222 IsBroadcastableOp op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const { 224 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 225 // on shapes. 226 IsBroadcastableOp::Adaptor transformed(operands); 227 if (transformed.lhs().getType().isa<ShapeType>() || 228 transformed.rhs().getType().isa<ShapeType>()) 229 return failure(); 230 231 auto loc = op.getLoc(); 232 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 233 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 234 235 // Find smaller and greater rank and extent tensor. 236 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); 237 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); 238 Value lhsRankULE = 239 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 240 Type indexTy = rewriter.getIndexType(); 241 Value lesserRank = 242 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 243 Value greaterRank = 244 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 245 auto erasedRankType = 246 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 247 Value rankErasedLhs = 248 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 249 Value rankErasedRhs = 250 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 251 Value lesserRankOperand = 252 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 253 Value greaterRankOperand = 254 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 255 Value rankDiff = 256 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 257 Type i1Ty = rewriter.getI1Type(); 258 Value init = 259 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 260 261 // Determine if all overlapping extents are broadcastable. 262 auto reduceResult = rewriter.create<ForOp>( 263 loc, rankDiff, greaterRank, one, ValueRange{init}, 264 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 265 Value greaterRankOperandExtent = 266 b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv}); 267 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 268 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 269 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 270 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 271 loc, lesserRankOperand, ValueRange{ivShifted}); 272 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( 273 loc, CmpIPredicate::eq, lesserRankOperandExtent, one); 274 Value extentsAreEqual = 275 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, 276 lesserRankOperandExtent); 277 Value broadcastableExtents = b.create<AndOp>( 278 loc, iterArgs[0], 279 b.create<OrOp>(loc, 280 b.create<OrOp>(loc, greaterRankOperandExtentIsOne, 281 lesserRankOperandExtentIsOne), 282 extentsAreEqual)); 283 b.create<scf::YieldOp>(loc, broadcastableExtents); 284 }); 285 286 rewriter.replaceOp(op, reduceResult.results().front()); 287 return success(); 288 } 289 290 namespace { 291 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 292 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 293 294 LogicalResult 295 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 296 ConversionPatternRewriter &rewriter) const override; 297 }; 298 } // namespace 299 300 LogicalResult GetExtentOpConverter::matchAndRewrite( 301 GetExtentOp op, ArrayRef<Value> operands, 302 ConversionPatternRewriter &rewriter) const { 303 GetExtentOp::Adaptor transformed(operands); 304 305 // For now, only error-free types are supported by this lowering. 306 if (op.getType().isa<SizeType>()) 307 return failure(); 308 309 // Derive shape extent directly from shape origin if possible. This 310 // circumvents the necessity to materialize the shape in memory. 311 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 312 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 313 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 314 transformed.dim()); 315 return success(); 316 } 317 } 318 319 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 320 transformed.shape(), 321 ValueRange{transformed.dim()}); 322 return success(); 323 } 324 325 namespace { 326 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 327 public: 328 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 329 330 LogicalResult 331 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 332 ConversionPatternRewriter &rewriter) const override; 333 }; 334 } // namespace 335 336 LogicalResult 337 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 338 ConversionPatternRewriter &rewriter) const { 339 // For now, this lowering supports only error-free types. 340 if (op.getType().isa<SizeType>()) 341 return failure(); 342 343 shape::RankOp::Adaptor transformed(operands); 344 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 345 return success(); 346 } 347 348 namespace { 349 /// Converts `shape.reduce` to `scf.for`. 350 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const final; 357 }; 358 } // namespace 359 360 LogicalResult 361 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 362 ConversionPatternRewriter &rewriter) const { 363 // For now, this lowering is only defined on `tensor<?xindex>` operands. 364 if (op.shape().getType().isa<ShapeType>()) 365 return failure(); 366 367 auto loc = op.getLoc(); 368 shape::ReduceOp::Adaptor transformed(operands); 369 370 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 371 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 372 Type indexTy = rewriter.getIndexType(); 373 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 374 375 auto loop = rewriter.create<scf::ForOp>( 376 loc, zero, rank, one, op.initVals(), 377 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 378 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 379 380 SmallVector<Value, 2> mappedValues{iv, extent}; 381 mappedValues.append(args.begin(), args.end()); 382 383 BlockAndValueMapping mapping; 384 Block *reduceBody = op.getBody(); 385 mapping.map(reduceBody->getArguments(), mappedValues); 386 for (auto &nested : reduceBody->without_terminator()) 387 b.clone(nested, mapping); 388 389 SmallVector<Value, 2> mappedResults; 390 for (auto result : reduceBody->getTerminator()->getOperands()) 391 mappedResults.push_back(mapping.lookup(result)); 392 b.create<scf::YieldOp>(loc, mappedResults); 393 }); 394 395 rewriter.replaceOp(op, loop.getResults()); 396 return success(); 397 } 398 399 namespace { 400 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 401 /// only defined on `tensor<?xindex>` operands. The test for equality first 402 /// compares their size and, if equal, checks every extent for equality. 403 /// 404 /// Example: 405 /// 406 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 407 /// 408 /// becomes 409 /// 410 /// %c0 = constant 0 : index 411 /// %0 = dim %arg0, %c0 : tensor<?xindex> 412 /// %1 = dim %arg1, %c0 : tensor<?xindex> 413 /// %2 = cmpi "eq", %0, %1 : index 414 /// %result = scf.if %2 -> (i1) { 415 /// %c1 = constant 1 : index 416 /// %true = constant true 417 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 418 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 419 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 420 /// %7 = cmpi "eq", %5, %6 : index 421 /// %8 = and %arg3, %7 : i1 422 /// scf.yield %8 : i1 423 /// } 424 /// scf.yield %4 : i1 425 /// } else { 426 /// %false = constant false 427 /// scf.yield %false : i1 428 /// } 429 /// 430 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 431 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 432 433 LogicalResult 434 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 435 ConversionPatternRewriter &rewriter) const override; 436 }; 437 } // namespace 438 439 LogicalResult 440 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 441 ConversionPatternRewriter &rewriter) const { 442 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 443 // on shapes. 444 if (op.lhs().getType().isa<ShapeType>() || 445 op.rhs().getType().isa<ShapeType>()) { 446 return failure(); 447 } 448 449 ShapeEqOp::Adaptor transformed(operands); 450 auto loc = op.getLoc(); 451 Type indexTy = rewriter.getIndexType(); 452 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 453 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 454 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 455 Value eqRank = 456 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 457 Type i1Ty = rewriter.getI1Type(); 458 rewriter.replaceOpWithNewOp<IfOp>( 459 op, i1Ty, eqRank, 460 [&](OpBuilder &b, Location loc) { 461 Value one = b.create<ConstantIndexOp>(loc, 1); 462 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 463 auto loop = b.create<scf::ForOp>( 464 loc, zero, lhsRank, one, ValueRange{init}, 465 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 466 Value conj = args[0]; 467 Value lhsExtent = 468 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 469 Value rhsExtent = 470 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 471 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 472 lhsExtent, rhsExtent); 473 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 474 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 475 }); 476 b.create<scf::YieldOp>(loc, loop.getResults()); 477 }, 478 [&](OpBuilder &b, Location loc) { 479 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 480 b.create<scf::YieldOp>(loc, result); 481 }); 482 return success(); 483 } 484 485 namespace { 486 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 487 public: 488 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 489 490 LogicalResult 491 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 492 ConversionPatternRewriter &rewriter) const override; 493 }; 494 } // namespace 495 496 LogicalResult ShapeOfOpConversion::matchAndRewrite( 497 ShapeOfOp op, ArrayRef<Value> operands, 498 ConversionPatternRewriter &rewriter) const { 499 500 // For now, only error-free types are supported by this lowering. 501 if (op.getType().isa<ShapeType>()) 502 return failure(); 503 504 // For ranked tensor arguments, lower to `tensor_from_elements`. 505 auto loc = op.getLoc(); 506 ShapeOfOp::Adaptor transformed(operands); 507 Value tensor = transformed.arg(); 508 Type tensorTy = tensor.getType(); 509 if (tensorTy.isa<RankedTensorType>()) { 510 511 // Build values for individual extents. 512 SmallVector<Value, 8> extentValues; 513 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 514 int64_t rank = rankedTensorTy.getRank(); 515 for (int64_t i = 0; i < rank; i++) { 516 if (rankedTensorTy.isDynamicDim(i)) { 517 Value extent = rewriter.create<DimOp>(loc, tensor, i); 518 extentValues.push_back(extent); 519 } else { 520 Value extent = 521 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 522 extentValues.push_back(extent); 523 } 524 } 525 526 // Materialize extent tensor. 527 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 528 loc, rewriter.getIndexType(), extentValues); 529 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 530 op.getType()); 531 return success(); 532 } 533 534 // Lower to `dynamic_tensor_from_elements` otherwise. 535 auto *ctx = rewriter.getContext(); 536 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 537 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 538 op, getExtentTensorType(ctx), ValueRange{rank}, 539 [&](OpBuilder &b, Location loc, ValueRange args) { 540 Value dim = args.front(); 541 Value extent = b.create<DimOp>(loc, tensor, dim); 542 b.create<mlir::YieldOp>(loc, extent); 543 }); 544 545 return success(); 546 } 547 548 namespace { 549 class ToExtentTensorOpConversion 550 : public OpConversionPattern<ToExtentTensorOp> { 551 public: 552 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 553 554 LogicalResult 555 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const override { 557 ToExtentTensorOpAdaptor adaptor(operands); 558 559 if (!adaptor.input().getType().isa<RankedTensorType>()) 560 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 561 562 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 563 op.getType()); 564 return success(); 565 } 566 }; 567 } // namespace 568 569 namespace { 570 /// Import the Shape Ops to Std Patterns. 571 #include "ShapeToStandard.cpp.inc" 572 } // namespace 573 574 namespace { 575 /// Conversion pass. 576 class ConvertShapeToStandardPass 577 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 578 579 void runOnOperation() override; 580 }; 581 } // namespace 582 583 void ConvertShapeToStandardPass::runOnOperation() { 584 // Setup target legality. 585 MLIRContext &ctx = getContext(); 586 ConversionTarget target(ctx); 587 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 588 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>(); 589 590 // Setup conversion patterns. 591 OwningRewritePatternList patterns; 592 populateShapeToStandardConversionPatterns(patterns, &ctx); 593 594 // Apply conversion. 595 auto module = getOperation(); 596 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 597 signalPassFailure(); 598 } 599 600 void mlir::populateShapeToStandardConversionPatterns( 601 OwningRewritePatternList &patterns, MLIRContext *ctx) { 602 // clang-format off 603 populateWithGenerated(ctx, patterns); 604 patterns.insert< 605 AnyOpConversion, 606 BinaryOpConversion<AddOp, AddIOp>, 607 BinaryOpConversion<MulOp, MulIOp>, 608 BroadcastOpConverter, 609 ConstShapeOpConverter, 610 ConstSizeOpConversion, 611 IsBroadcastableOpConverter, 612 GetExtentOpConverter, 613 RankOpConverter, 614 ReduceOpConverter, 615 ShapeEqOpConverter, 616 ShapeOfOpConversion, 617 ToExtentTensorOpConversion>(ctx); 618 // clang-format on 619 } 620 621 std::unique_ptr<OperationPass<ModuleOp>> 622 mlir::createConvertShapeToStandardPass() { 623 return std::make_unique<ConvertShapeToStandardPass>(); 624 } 625