1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 using namespace mlir::scf; 24 25 /// Conversion patterns. 26 namespace { 27 class AnyOpConversion : public OpConversionPattern<AnyOp> { 28 public: 29 using OpConversionPattern<AnyOp>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override; 34 }; 35 } // namespace 36 37 LogicalResult 38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 39 ConversionPatternRewriter &rewriter) const { 40 AnyOp::Adaptor transformed(operands); 41 42 // Replace `any` with its first operand. 43 // Any operand would be a valid substitution. 44 rewriter.replaceOp(op, {transformed.inputs().front()}); 45 return success(); 46 } 47 48 namespace { 49 template <typename SrcOpTy, typename DstOpTy> 50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 51 public: 52 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 53 54 LogicalResult 55 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 56 ConversionPatternRewriter &rewriter) const override { 57 typename SrcOpTy::Adaptor transformed(operands); 58 59 // For now, only error-free types are supported by this lowering. 60 if (op.getType().template isa<SizeType>()) 61 return failure(); 62 63 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 64 transformed.rhs()); 65 return success(); 66 } 67 }; 68 } // namespace 69 70 namespace { 71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 72 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 73 74 LogicalResult 75 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 76 ConversionPatternRewriter &rewriter) const override; 77 }; 78 79 // Get the resulting extent in a given dimension. This is computed with any 80 // number of extent tensors and shifted offsets into them. 81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 82 ValueRange rankDiffs, Value outputDimension) { 83 Value one = lb.create<ConstantIndexOp>(1); 84 Value broadcastedDim = one; 85 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 86 Value shape = std::get<0>(tup); 87 Value rankDiff = std::get<1>(tup); 88 Value outOfBounds = 89 lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff); 90 Type indexTy = lb.getIndexType(); 91 broadcastedDim = 92 lb.create<IfOp>( 93 TypeRange{indexTy}, outOfBounds, 94 [&](OpBuilder &b, Location loc) { 95 b.create<scf::YieldOp>(loc, broadcastedDim); 96 }, 97 [&](OpBuilder &b, Location loc) { 98 // The broadcasting logic is: 99 // - if one extent (here we arbitrarily choose the 100 // extent from the greater-rank operand) is equal to 1, 101 // then take the extent from the other operand 102 // - otherwise, take the extent as-is. 103 // Note that this logic remains correct in the presence 104 // of dimensions of zero extent. 105 Value lesserRankOperandDimension = 106 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 107 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 108 loc, shape, ValueRange{lesserRankOperandDimension}); 109 110 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 111 lesserRankOperandExtent, one); 112 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 113 lesserRankOperandExtent); 114 b.create<scf::YieldOp>(loc, dim); 115 }) 116 .getResult(0); 117 } 118 return broadcastedDim; 119 } 120 } // namespace 121 122 LogicalResult BroadcastOpConverter::matchAndRewrite( 123 BroadcastOp op, ArrayRef<Value> operands, 124 ConversionPatternRewriter &rewriter) const { 125 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 126 // on shapes. 127 if (op.getType().isa<ShapeType>()) 128 return failure(); 129 130 auto loc = op.getLoc(); 131 ImplicitLocOpBuilder lb(loc, rewriter); 132 BroadcastOp::Adaptor transformed(operands); 133 134 Value zero = lb.create<ConstantIndexOp>(0); 135 Type indexTy = lb.getIndexType(); 136 137 // Save all the ranks for bounds checking. Because this is a tensor 138 // representing the shape extents, the rank is the extent of the only 139 // dimension in the tensor. 140 SmallVector<Value> ranks, rankDiffs; 141 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 142 return lb.create<DimOp>(v, zero); 143 })); 144 145 // Find the maximum rank 146 Value maxRank = ranks.front(); 147 for (Value v : llvm::drop_begin(ranks, 1)) { 148 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 149 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 150 } 151 152 // Calculate the difference of ranks and the maximum rank for later offsets. 153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 154 return lb.create<SubIOp>(indexTy, maxRank, v); 155 })); 156 157 rewriter.replaceOp( 158 op, lb.create<tensor::GenerateOp>( 159 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 160 [&](OpBuilder &b, Location loc, ValueRange args) { 161 Value broadcastedDim = getBroadcastedDim( 162 ImplicitLocOpBuilder(loc, b), transformed.shapes(), 163 rankDiffs, args[0]); 164 165 b.create<tensor::YieldOp>(loc, broadcastedDim); 166 }) 167 ->getResults()); 168 return success(); 169 } 170 171 namespace { 172 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 173 public: 174 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 175 176 LogicalResult 177 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 178 ConversionPatternRewriter &rewriter) const override; 179 }; 180 } // namespace 181 182 LogicalResult ConstShapeOpConverter::matchAndRewrite( 183 ConstShapeOp op, ArrayRef<Value> operands, 184 ConversionPatternRewriter &rewriter) const { 185 186 // For now, this lowering supports only extent tensors, not `shape.shape` 187 // types. 188 if (op.getType().isa<ShapeType>()) 189 return failure(); 190 191 auto loc = op.getLoc(); 192 SmallVector<Value, 4> extentOperands; 193 for (auto extent : op.shape()) { 194 extentOperands.push_back( 195 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 196 } 197 Type indexTy = rewriter.getIndexType(); 198 Value tensor = 199 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 200 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 201 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 202 return success(); 203 } 204 205 namespace { 206 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 207 public: 208 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 212 ConversionPatternRewriter &rewriter) const override; 213 }; 214 } // namespace 215 216 LogicalResult ConstSizeOpConversion::matchAndRewrite( 217 ConstSizeOp op, ArrayRef<Value> operands, 218 ConversionPatternRewriter &rewriter) const { 219 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 220 return success(); 221 } 222 223 namespace { 224 struct IsBroadcastableOpConverter 225 : public OpConversionPattern<IsBroadcastableOp> { 226 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 227 228 LogicalResult 229 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 230 ConversionPatternRewriter &rewriter) const override; 231 }; 232 } // namespace 233 234 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 235 IsBroadcastableOp op, ArrayRef<Value> operands, 236 ConversionPatternRewriter &rewriter) const { 237 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 238 // on shapes. 239 IsBroadcastableOp::Adaptor transformed(operands); 240 if (transformed.lhs().getType().isa<ShapeType>() || 241 transformed.rhs().getType().isa<ShapeType>()) 242 return failure(); 243 244 auto loc = op.getLoc(); 245 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 246 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 247 248 // Find smaller and greater rank and extent tensor. 249 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); 250 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); 251 Value lhsRankULE = 252 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 253 Type indexTy = rewriter.getIndexType(); 254 Value lesserRank = 255 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 256 Value greaterRank = 257 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 258 auto erasedRankType = 259 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 260 Value rankErasedLhs = 261 rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs()); 262 Value rankErasedRhs = 263 rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs()); 264 Value lesserRankOperand = 265 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 266 Value greaterRankOperand = 267 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 268 Value rankDiff = 269 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 270 Type i1Ty = rewriter.getI1Type(); 271 Value init = 272 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 273 274 // Determine if all overlapping extents are broadcastable. 275 auto reduceResult = rewriter.create<ForOp>( 276 loc, rankDiff, greaterRank, one, ValueRange{init}, 277 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 278 Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( 279 loc, greaterRankOperand, ValueRange{iv}); 280 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 281 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 282 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 283 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 284 loc, lesserRankOperand, ValueRange{ivShifted}); 285 Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( 286 loc, CmpIPredicate::eq, lesserRankOperandExtent, one); 287 Value extentsAreEqual = 288 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, 289 lesserRankOperandExtent); 290 Value broadcastableExtents = b.create<AndOp>( 291 loc, iterArgs[0], 292 b.create<OrOp>(loc, 293 b.create<OrOp>(loc, greaterRankOperandExtentIsOne, 294 lesserRankOperandExtentIsOne), 295 extentsAreEqual)); 296 b.create<scf::YieldOp>(loc, broadcastableExtents); 297 }); 298 299 rewriter.replaceOp(op, reduceResult.results().front()); 300 return success(); 301 } 302 303 namespace { 304 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 305 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 306 307 LogicalResult 308 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 309 ConversionPatternRewriter &rewriter) const override; 310 }; 311 } // namespace 312 313 LogicalResult GetExtentOpConverter::matchAndRewrite( 314 GetExtentOp op, ArrayRef<Value> operands, 315 ConversionPatternRewriter &rewriter) const { 316 GetExtentOp::Adaptor transformed(operands); 317 318 // For now, only error-free types are supported by this lowering. 319 if (op.getType().isa<SizeType>()) 320 return failure(); 321 322 // Derive shape extent directly from shape origin if possible. This 323 // circumvents the necessity to materialize the shape in memory. 324 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 325 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 326 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 327 transformed.dim()); 328 return success(); 329 } 330 } 331 332 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 333 transformed.shape(), 334 ValueRange{transformed.dim()}); 335 return success(); 336 } 337 338 namespace { 339 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 340 public: 341 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 342 343 LogicalResult 344 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override; 346 }; 347 } // namespace 348 349 LogicalResult 350 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 351 ConversionPatternRewriter &rewriter) const { 352 // For now, this lowering supports only error-free types. 353 if (op.getType().isa<SizeType>()) 354 return failure(); 355 356 shape::RankOp::Adaptor transformed(operands); 357 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 358 return success(); 359 } 360 361 namespace { 362 /// Converts `shape.reduce` to `scf.for`. 363 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 364 public: 365 using OpConversionPattern::OpConversionPattern; 366 367 LogicalResult 368 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 369 ConversionPatternRewriter &rewriter) const final; 370 }; 371 } // namespace 372 373 LogicalResult 374 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 375 ConversionPatternRewriter &rewriter) const { 376 // For now, this lowering is only defined on `tensor<?xindex>` operands. 377 if (op.shape().getType().isa<ShapeType>()) 378 return failure(); 379 380 auto loc = op.getLoc(); 381 shape::ReduceOp::Adaptor transformed(operands); 382 383 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 384 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 385 Type indexTy = rewriter.getIndexType(); 386 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 387 388 auto loop = rewriter.create<scf::ForOp>( 389 loc, zero, rank, one, op.initVals(), 390 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 391 Value extent = 392 b.create<tensor::ExtractOp>(loc, transformed.shape(), iv); 393 394 SmallVector<Value, 2> mappedValues{iv, extent}; 395 mappedValues.append(args.begin(), args.end()); 396 397 BlockAndValueMapping mapping; 398 Block *reduceBody = op.getBody(); 399 mapping.map(reduceBody->getArguments(), mappedValues); 400 for (auto &nested : reduceBody->without_terminator()) 401 b.clone(nested, mapping); 402 403 SmallVector<Value, 2> mappedResults; 404 for (auto result : reduceBody->getTerminator()->getOperands()) 405 mappedResults.push_back(mapping.lookup(result)); 406 b.create<scf::YieldOp>(loc, mappedResults); 407 }); 408 409 rewriter.replaceOp(op, loop.getResults()); 410 return success(); 411 } 412 413 namespace { 414 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 415 /// only defined on `tensor<?xindex>` operands. The test for equality first 416 /// compares their size and, if equal, checks every extent for equality. 417 /// 418 /// Example: 419 /// 420 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 421 /// 422 /// becomes 423 /// 424 /// %c0 = constant 0 : index 425 /// %0 = dim %arg0, %c0 : tensor<?xindex> 426 /// %1 = dim %arg1, %c0 : tensor<?xindex> 427 /// %2 = cmpi "eq", %0, %1 : index 428 /// %result = scf.if %2 -> (i1) { 429 /// %c1 = constant 1 : index 430 /// %true = constant true 431 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 432 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 433 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 434 /// %7 = cmpi "eq", %5, %6 : index 435 /// %8 = and %arg3, %7 : i1 436 /// scf.yield %8 : i1 437 /// } 438 /// scf.yield %4 : i1 439 /// } else { 440 /// %false = constant false 441 /// scf.yield %false : i1 442 /// } 443 /// 444 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 445 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 446 447 LogicalResult 448 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 449 ConversionPatternRewriter &rewriter) const override; 450 }; 451 } // namespace 452 453 LogicalResult 454 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 455 ConversionPatternRewriter &rewriter) const { 456 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 457 // on shapes. 458 if (op.lhs().getType().isa<ShapeType>() || 459 op.rhs().getType().isa<ShapeType>()) { 460 return failure(); 461 } 462 463 ShapeEqOp::Adaptor transformed(operands); 464 auto loc = op.getLoc(); 465 Type indexTy = rewriter.getIndexType(); 466 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 467 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 468 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 469 Value eqRank = 470 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 471 Type i1Ty = rewriter.getI1Type(); 472 rewriter.replaceOpWithNewOp<IfOp>( 473 op, i1Ty, eqRank, 474 [&](OpBuilder &b, Location loc) { 475 Value one = b.create<ConstantIndexOp>(loc, 1); 476 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 477 auto loop = b.create<scf::ForOp>( 478 loc, zero, lhsRank, one, ValueRange{init}, 479 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 480 Value conj = args[0]; 481 Value lhsExtent = 482 b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv); 483 Value rhsExtent = 484 b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv); 485 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 486 lhsExtent, rhsExtent); 487 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 488 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 489 }); 490 b.create<scf::YieldOp>(loc, loop.getResults()); 491 }, 492 [&](OpBuilder &b, Location loc) { 493 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 494 b.create<scf::YieldOp>(loc, result); 495 }); 496 return success(); 497 } 498 499 namespace { 500 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 501 public: 502 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 503 504 LogicalResult 505 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 506 ConversionPatternRewriter &rewriter) const override; 507 }; 508 } // namespace 509 510 LogicalResult ShapeOfOpConversion::matchAndRewrite( 511 ShapeOfOp op, ArrayRef<Value> operands, 512 ConversionPatternRewriter &rewriter) const { 513 514 // For now, only error-free types are supported by this lowering. 515 if (op.getType().isa<ShapeType>()) 516 return failure(); 517 518 // For ranked tensor arguments, lower to `tensor.from_elements`. 519 auto loc = op.getLoc(); 520 ShapeOfOp::Adaptor transformed(operands); 521 Value tensor = transformed.arg(); 522 Type tensorTy = tensor.getType(); 523 if (tensorTy.isa<RankedTensorType>()) { 524 525 // Build values for individual extents. 526 SmallVector<Value, 8> extentValues; 527 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 528 int64_t rank = rankedTensorTy.getRank(); 529 for (int64_t i = 0; i < rank; i++) { 530 if (rankedTensorTy.isDynamicDim(i)) { 531 Value extent = rewriter.create<DimOp>(loc, tensor, i); 532 extentValues.push_back(extent); 533 } else { 534 Value extent = 535 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 536 extentValues.push_back(extent); 537 } 538 } 539 540 // Materialize extent tensor. 541 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 542 loc, rewriter.getIndexType(), extentValues); 543 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 544 staticExtentTensor); 545 return success(); 546 } 547 548 // Lower to `tensor.generate` otherwise. 549 auto *ctx = rewriter.getContext(); 550 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 551 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 552 op, getExtentTensorType(ctx), ValueRange{rank}, 553 [&](OpBuilder &b, Location loc, ValueRange args) { 554 Value dim = args.front(); 555 Value extent = b.create<DimOp>(loc, tensor, dim); 556 b.create<tensor::YieldOp>(loc, extent); 557 }); 558 559 return success(); 560 } 561 562 namespace { 563 class ToExtentTensorOpConversion 564 : public OpConversionPattern<ToExtentTensorOp> { 565 public: 566 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 567 568 LogicalResult 569 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 570 ConversionPatternRewriter &rewriter) const override { 571 ToExtentTensorOpAdaptor adaptor(operands); 572 573 if (!adaptor.input().getType().isa<RankedTensorType>()) 574 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 575 576 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 577 adaptor.input()); 578 return success(); 579 } 580 }; 581 } // namespace 582 583 namespace { 584 /// Import the Shape Ops to Std Patterns. 585 #include "ShapeToStandard.cpp.inc" 586 } // namespace 587 588 namespace { 589 /// Conversion pass. 590 class ConvertShapeToStandardPass 591 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 592 593 void runOnOperation() override; 594 }; 595 } // namespace 596 597 void ConvertShapeToStandardPass::runOnOperation() { 598 // Setup target legality. 599 MLIRContext &ctx = getContext(); 600 ConversionTarget target(ctx); 601 target 602 .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>(); 603 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>(); 604 605 // Setup conversion patterns. 606 OwningRewritePatternList patterns; 607 populateShapeToStandardConversionPatterns(patterns, &ctx); 608 609 // Apply conversion. 610 auto module = getOperation(); 611 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 612 signalPassFailure(); 613 } 614 615 void mlir::populateShapeToStandardConversionPatterns( 616 OwningRewritePatternList &patterns, MLIRContext *ctx) { 617 // clang-format off 618 populateWithGenerated(ctx, patterns); 619 patterns.insert< 620 AnyOpConversion, 621 BinaryOpConversion<AddOp, AddIOp>, 622 BinaryOpConversion<MulOp, MulIOp>, 623 BroadcastOpConverter, 624 ConstShapeOpConverter, 625 ConstSizeOpConversion, 626 IsBroadcastableOpConverter, 627 GetExtentOpConverter, 628 RankOpConverter, 629 ReduceOpConverter, 630 ShapeEqOpConverter, 631 ShapeOfOpConversion, 632 ToExtentTensorOpConversion>(ctx); 633 // clang-format on 634 } 635 636 std::unique_ptr<OperationPass<ModuleOp>> 637 mlir::createConvertShapeToStandardPass() { 638 return std::make_unique<ConvertShapeToStandardPass>(); 639 } 640