1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 using namespace mlir::scf; 24 25 /// Conversion patterns. 26 namespace { 27 class AnyOpConversion : public OpConversionPattern<AnyOp> { 28 public: 29 using OpConversionPattern<AnyOp>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(AnyOp op, OpAdaptor adaptor, 33 ConversionPatternRewriter &rewriter) const override; 34 }; 35 } // namespace 36 37 LogicalResult 38 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, 39 ConversionPatternRewriter &rewriter) const { 40 // Replace `any` with its first operand. 41 // Any operand would be a valid substitution. 42 rewriter.replaceOp(op, {adaptor.inputs().front()}); 43 return success(); 44 } 45 46 namespace { 47 template <typename SrcOpTy, typename DstOpTy> 48 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 49 public: 50 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 51 52 LogicalResult 53 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, 54 ConversionPatternRewriter &rewriter) const override { 55 // For now, only error-free types are supported by this lowering. 56 if (op.getType().template isa<SizeType>()) 57 return failure(); 58 59 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs()); 60 return success(); 61 } 62 }; 63 } // namespace 64 65 namespace { 66 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 67 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 68 69 LogicalResult 70 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, 71 ConversionPatternRewriter &rewriter) const override; 72 }; 73 74 // Get the resulting extent in a given dimension. This is computed with any 75 // number of extent tensors and shifted offsets into them. 76 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 77 ValueRange rankDiffs, Value outputDimension) { 78 Value one = lb.create<ConstantIndexOp>(1); 79 Value broadcastedDim = one; 80 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 81 Value shape = std::get<0>(tup); 82 Value rankDiff = std::get<1>(tup); 83 Value outOfBounds = 84 lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff); 85 Type indexTy = lb.getIndexType(); 86 broadcastedDim = 87 lb.create<IfOp>( 88 TypeRange{indexTy}, outOfBounds, 89 [&](OpBuilder &b, Location loc) { 90 b.create<scf::YieldOp>(loc, broadcastedDim); 91 }, 92 [&](OpBuilder &b, Location loc) { 93 // The broadcasting logic is: 94 // - if one extent (here we arbitrarily choose the 95 // extent from the greater-rank operand) is equal to 1, 96 // then take the extent from the other operand 97 // - otherwise, take the extent as-is. 98 // Note that this logic remains correct in the presence 99 // of dimensions of zero extent. 100 Value lesserRankOperandDimension = 101 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 102 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 103 loc, shape, ValueRange{lesserRankOperandDimension}); 104 105 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 106 lesserRankOperandExtent, one); 107 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 108 lesserRankOperandExtent); 109 b.create<scf::YieldOp>(loc, dim); 110 }) 111 .getResult(0); 112 } 113 return broadcastedDim; 114 } 115 } // namespace 116 117 LogicalResult BroadcastOpConverter::matchAndRewrite( 118 BroadcastOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const { 120 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 121 // on shapes. 122 if (op.getType().isa<ShapeType>()) 123 return failure(); 124 125 auto loc = op.getLoc(); 126 ImplicitLocOpBuilder lb(loc, rewriter); 127 128 Value zero = lb.create<ConstantIndexOp>(0); 129 Type indexTy = lb.getIndexType(); 130 131 // Save all the ranks for bounds checking. Because this is a tensor 132 // representing the shape extents, the rank is the extent of the only 133 // dimension in the tensor. 134 SmallVector<Value> ranks, rankDiffs; 135 llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { 136 return lb.create<tensor::DimOp>(v, zero); 137 })); 138 139 // Find the maximum rank 140 Value maxRank = ranks.front(); 141 for (Value v : llvm::drop_begin(ranks, 1)) { 142 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 143 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 144 } 145 146 // Calculate the difference of ranks and the maximum rank for later offsets. 147 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 148 return lb.create<SubIOp>(indexTy, maxRank, v); 149 })); 150 151 Value replacement = lb.create<tensor::GenerateOp>( 152 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 153 [&](OpBuilder &b, Location loc, ValueRange args) { 154 Value broadcastedDim = getBroadcastedDim( 155 ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]); 156 157 b.create<tensor::YieldOp>(loc, broadcastedDim); 158 }); 159 if (replacement.getType() != op.getType()) 160 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 161 rewriter.replaceOp(op, replacement); 162 return success(); 163 } 164 165 namespace { 166 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 167 public: 168 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 169 170 LogicalResult 171 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, 172 ConversionPatternRewriter &rewriter) const override; 173 }; 174 } // namespace 175 176 LogicalResult ConstShapeOpConverter::matchAndRewrite( 177 ConstShapeOp op, OpAdaptor adaptor, 178 ConversionPatternRewriter &rewriter) const { 179 180 // For now, this lowering supports only extent tensors, not `shape.shape` 181 // types. 182 if (op.getType().isa<ShapeType>()) 183 return failure(); 184 185 auto loc = op.getLoc(); 186 SmallVector<Value, 4> extentOperands; 187 for (auto extent : op.shape()) { 188 extentOperands.push_back( 189 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 190 } 191 Type indexTy = rewriter.getIndexType(); 192 Value tensor = 193 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 194 Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy); 195 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 196 return success(); 197 } 198 199 namespace { 200 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 201 public: 202 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 203 204 LogicalResult 205 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, 206 ConversionPatternRewriter &rewriter) const override; 207 }; 208 } // namespace 209 210 LogicalResult ConstSizeOpConversion::matchAndRewrite( 211 ConstSizeOp op, OpAdaptor adaptor, 212 ConversionPatternRewriter &rewriter) const { 213 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 214 return success(); 215 } 216 217 namespace { 218 struct IsBroadcastableOpConverter 219 : public OpConversionPattern<IsBroadcastableOp> { 220 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 221 222 LogicalResult 223 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, 224 ConversionPatternRewriter &rewriter) const override; 225 }; 226 } // namespace 227 228 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 229 IsBroadcastableOp op, OpAdaptor adaptor, 230 ConversionPatternRewriter &rewriter) const { 231 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 232 // on shapes. 233 if (!llvm::all_of(op.shapes(), 234 [](Value v) { return !v.getType().isa<ShapeType>(); })) 235 return failure(); 236 237 auto loc = op.getLoc(); 238 ImplicitLocOpBuilder lb(loc, rewriter); 239 Value zero = lb.create<ConstantIndexOp>(0); 240 Value one = lb.create<ConstantIndexOp>(1); 241 Type indexTy = lb.getIndexType(); 242 243 // Save all the ranks for bounds checking. Because this is a tensor 244 // representing the shape extents, the rank is the extent of the only 245 // dimension in the tensor. 246 SmallVector<Value> ranks, rankDiffs; 247 llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { 248 return lb.create<tensor::DimOp>(v, zero); 249 })); 250 251 // Find the maximum rank 252 Value maxRank = ranks.front(); 253 for (Value v : llvm::drop_begin(ranks, 1)) { 254 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 255 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 256 } 257 258 // Calculate the difference of ranks and the maximum rank for later offsets. 259 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 260 return lb.create<SubIOp>(indexTy, maxRank, v); 261 })); 262 263 Type i1Ty = rewriter.getI1Type(); 264 Value trueVal = 265 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 266 267 auto reduceResult = lb.create<ForOp>( 268 loc, zero, maxRank, one, ValueRange{trueVal}, 269 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 270 // Find a non-1 dim, if it exists. Note that the first part of this 271 // could reuse the Broadcast lowering entirely, but we redo the work 272 // here to make optimizations easier between the two loops. 273 Value broadcastedDim = getBroadcastedDim( 274 ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv); 275 276 Value broadcastable = iterArgs[0]; 277 for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { 278 Value shape, rankDiff; 279 std::tie(shape, rankDiff) = tup; 280 Value outOfBounds = 281 b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff); 282 broadcastable = 283 b.create<IfOp>( 284 loc, TypeRange{i1Ty}, outOfBounds, 285 [&](OpBuilder &b, Location loc) { 286 // Non existent dimensions are always broadcastable 287 b.create<scf::YieldOp>(loc, broadcastable); 288 }, 289 [&](OpBuilder &b, Location loc) { 290 // Every value needs to be either 1, or the same non-1 291 // value to be broadcastable in this dim. 292 Value operandDimension = 293 b.create<SubIOp>(loc, indexTy, iv, rankDiff); 294 Value dimensionExtent = b.create<tensor::ExtractOp>( 295 loc, shape, ValueRange{operandDimension}); 296 297 Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 298 dimensionExtent, one); 299 Value equalBroadcasted = 300 b.create<CmpIOp>(loc, CmpIPredicate::eq, 301 dimensionExtent, broadcastedDim); 302 Value result = b.create<AndOp>( 303 loc, broadcastable, 304 b.create<OrOp>(loc, equalOne, equalBroadcasted)); 305 b.create<scf::YieldOp>(loc, result); 306 }) 307 .getResult(0); 308 } 309 310 b.create<scf::YieldOp>(loc, broadcastable); 311 }); 312 313 rewriter.replaceOp(op, reduceResult.results().front()); 314 return success(); 315 } 316 317 namespace { 318 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 319 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 320 321 LogicalResult 322 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, 323 ConversionPatternRewriter &rewriter) const override; 324 }; 325 } // namespace 326 327 LogicalResult GetExtentOpConverter::matchAndRewrite( 328 GetExtentOp op, OpAdaptor adaptor, 329 ConversionPatternRewriter &rewriter) const { 330 // For now, only error-free types are supported by this lowering. 331 if (op.getType().isa<SizeType>()) 332 return failure(); 333 334 // Derive shape extent directly from shape origin if possible. This 335 // circumvents the necessity to materialize the shape in memory. 336 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 337 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 338 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(), 339 adaptor.dim()); 340 return success(); 341 } 342 } 343 344 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 345 op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()}); 346 return success(); 347 } 348 349 namespace { 350 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 351 public: 352 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 356 ConversionPatternRewriter &rewriter) const override; 357 }; 358 } // namespace 359 360 LogicalResult 361 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 362 ConversionPatternRewriter &rewriter) const { 363 // For now, this lowering supports only error-free types. 364 if (op.getType().isa<SizeType>()) 365 return failure(); 366 367 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.shape(), 0); 368 return success(); 369 } 370 371 namespace { 372 /// Converts `shape.reduce` to `scf.for`. 373 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 374 public: 375 using OpConversionPattern::OpConversionPattern; 376 377 LogicalResult 378 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 379 ConversionPatternRewriter &rewriter) const final; 380 }; 381 } // namespace 382 383 LogicalResult 384 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 385 ConversionPatternRewriter &rewriter) const { 386 // For now, this lowering is only defined on `tensor<?xindex>` operands. 387 if (op.shape().getType().isa<ShapeType>()) 388 return failure(); 389 390 auto loc = op.getLoc(); 391 392 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 393 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 394 Type indexTy = rewriter.getIndexType(); 395 Value rank = 396 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.shape(), zero); 397 398 auto loop = rewriter.create<scf::ForOp>( 399 loc, zero, rank, one, op.initVals(), 400 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 401 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.shape(), iv); 402 403 SmallVector<Value, 2> mappedValues{iv, extent}; 404 mappedValues.append(args.begin(), args.end()); 405 406 BlockAndValueMapping mapping; 407 Block *reduceBody = op.getBody(); 408 mapping.map(reduceBody->getArguments(), mappedValues); 409 for (auto &nested : reduceBody->without_terminator()) 410 b.clone(nested, mapping); 411 412 SmallVector<Value, 2> mappedResults; 413 for (auto result : reduceBody->getTerminator()->getOperands()) 414 mappedResults.push_back(mapping.lookup(result)); 415 b.create<scf::YieldOp>(loc, mappedResults); 416 }); 417 418 rewriter.replaceOp(op, loop.getResults()); 419 return success(); 420 } 421 422 namespace { 423 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 424 /// only defined on `tensor<?xindex>` operands. The test for equality first 425 /// compares their size and, if equal, checks every extent for equality. 426 /// 427 /// Example: 428 /// 429 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 430 /// 431 /// becomes 432 /// 433 /// %c0 = constant 0 : index 434 /// %0 = dim %arg0, %c0 : tensor<?xindex> 435 /// %1 = dim %arg1, %c0 : tensor<?xindex> 436 /// %2 = cmpi "eq", %0, %1 : index 437 /// %result = scf.if %2 -> (i1) { 438 /// %c1 = constant 1 : index 439 /// %true = constant true 440 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 441 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 442 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 443 /// %7 = cmpi "eq", %5, %6 : index 444 /// %8 = and %arg3, %7 : i1 445 /// scf.yield %8 : i1 446 /// } 447 /// scf.yield %4 : i1 448 /// } else { 449 /// %false = constant false 450 /// scf.yield %false : i1 451 /// } 452 /// 453 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 454 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 455 456 LogicalResult 457 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 458 ConversionPatternRewriter &rewriter) const override; 459 }; 460 } // namespace 461 462 LogicalResult 463 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const { 465 if (!llvm::all_of(op.shapes(), 466 [](Value v) { return !v.getType().isa<ShapeType>(); })) 467 return failure(); 468 469 Type i1Ty = rewriter.getI1Type(); 470 if (op.shapes().size() <= 1) { 471 rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty, 472 rewriter.getBoolAttr(true)); 473 return success(); 474 } 475 476 auto loc = op.getLoc(); 477 Type indexTy = rewriter.getIndexType(); 478 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 479 Value firstShape = adaptor.shapes().front(); 480 Value firstRank = 481 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); 482 Value result = nullptr; 483 // Generate a linear sequence of compares, all with firstShape as lhs. 484 for (Value shape : adaptor.shapes().drop_front(1)) { 485 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); 486 Value eqRank = 487 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank); 488 auto same = rewriter.create<IfOp>( 489 loc, i1Ty, eqRank, 490 [&](OpBuilder &b, Location loc) { 491 Value one = b.create<ConstantIndexOp>(loc, 1); 492 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 493 auto loop = b.create<scf::ForOp>( 494 loc, zero, firstRank, one, ValueRange{init}, 495 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 496 Value conj = args[0]; 497 Value lhsExtent = 498 b.create<tensor::ExtractOp>(loc, firstShape, iv); 499 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 500 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 501 lhsExtent, rhsExtent); 502 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 503 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 504 }); 505 b.create<scf::YieldOp>(loc, loop.getResults()); 506 }, 507 [&](OpBuilder &b, Location loc) { 508 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 509 b.create<scf::YieldOp>(loc, result); 510 }); 511 result = !result ? same.getResult(0) 512 : rewriter.create<AndOp>(loc, result, same.getResult(0)); 513 } 514 rewriter.replaceOp(op, result); 515 return success(); 516 } 517 518 namespace { 519 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 520 public: 521 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 522 523 LogicalResult 524 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, 525 ConversionPatternRewriter &rewriter) const override; 526 }; 527 } // namespace 528 529 LogicalResult ShapeOfOpConversion::matchAndRewrite( 530 ShapeOfOp op, OpAdaptor adaptor, 531 ConversionPatternRewriter &rewriter) const { 532 533 // For now, only error-free types are supported by this lowering. 534 if (op.getType().isa<ShapeType>()) 535 return failure(); 536 537 // For ranked tensor arguments, lower to `tensor.from_elements`. 538 auto loc = op.getLoc(); 539 Value tensor = adaptor.arg(); 540 Type tensorTy = tensor.getType(); 541 if (tensorTy.isa<RankedTensorType>()) { 542 543 // Build values for individual extents. 544 SmallVector<Value, 8> extentValues; 545 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 546 int64_t rank = rankedTensorTy.getRank(); 547 for (int64_t i = 0; i < rank; i++) { 548 if (rankedTensorTy.isDynamicDim(i)) { 549 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); 550 extentValues.push_back(extent); 551 } else { 552 Value extent = 553 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 554 extentValues.push_back(extent); 555 } 556 } 557 558 // Materialize extent tensor. 559 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 560 loc, rewriter.getIndexType(), extentValues); 561 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 562 staticExtentTensor); 563 return success(); 564 } 565 566 // Lower to `tensor.generate` otherwise. 567 auto *ctx = rewriter.getContext(); 568 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 569 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 570 op, getExtentTensorType(ctx), ValueRange{rank}, 571 [&](OpBuilder &b, Location loc, ValueRange args) { 572 Value dim = args.front(); 573 Value extent = b.create<tensor::DimOp>(loc, tensor, dim); 574 b.create<tensor::YieldOp>(loc, extent); 575 }); 576 577 return success(); 578 } 579 580 namespace { 581 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 582 public: 583 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 584 585 LogicalResult 586 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, 587 ConversionPatternRewriter &rewriter) const override; 588 }; 589 } // namespace 590 591 LogicalResult SplitAtOpConversion::matchAndRewrite( 592 SplitAtOp op, OpAdaptor adaptor, 593 ConversionPatternRewriter &rewriter) const { 594 // Error conditions are not implemented, only lower if all operands and 595 // results are extent tensors. 596 if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, 597 [](Value v) { return v.getType().isa<ShapeType>(); })) 598 return failure(); 599 600 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 601 Value zero = b.create<ConstantIndexOp>(0); 602 Value rank = b.create<tensor::DimOp>(adaptor.operand(), zero); 603 604 // index < 0 ? index + rank : index 605 Value originalIndex = adaptor.index(); 606 Value add = b.create<AddIOp>(originalIndex, rank); 607 Value indexIsNegative = 608 b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero); 609 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 610 611 Value one = b.create<ConstantIndexOp>(1); 612 Value head = 613 b.create<tensor::ExtractSliceOp>(adaptor.operand(), zero, index, one); 614 Value tailSize = b.create<SubIOp>(rank, index); 615 Value tail = 616 b.create<tensor::ExtractSliceOp>(adaptor.operand(), index, tailSize, one); 617 rewriter.replaceOp(op, {head, tail}); 618 return success(); 619 } 620 621 namespace { 622 class ToExtentTensorOpConversion 623 : public OpConversionPattern<ToExtentTensorOp> { 624 public: 625 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 626 627 LogicalResult 628 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, 629 ConversionPatternRewriter &rewriter) const override { 630 if (!adaptor.input().getType().isa<RankedTensorType>()) 631 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 632 633 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 634 adaptor.input()); 635 return success(); 636 } 637 }; 638 } // namespace 639 640 namespace { 641 /// Import the Shape Ops to Std Patterns. 642 #include "ShapeToStandard.cpp.inc" 643 } // namespace 644 645 namespace { 646 /// Conversion pass. 647 class ConvertShapeToStandardPass 648 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 649 650 void runOnOperation() override; 651 }; 652 } // namespace 653 654 void ConvertShapeToStandardPass::runOnOperation() { 655 // Setup target legality. 656 MLIRContext &ctx = getContext(); 657 ConversionTarget target(ctx); 658 target 659 .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>(); 660 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 661 662 // Setup conversion patterns. 663 RewritePatternSet patterns(&ctx); 664 populateShapeToStandardConversionPatterns(patterns); 665 666 // Apply conversion. 667 auto module = getOperation(); 668 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 669 signalPassFailure(); 670 } 671 672 void mlir::populateShapeToStandardConversionPatterns( 673 RewritePatternSet &patterns) { 674 // clang-format off 675 populateWithGenerated(patterns); 676 patterns.add< 677 AnyOpConversion, 678 BinaryOpConversion<AddOp, AddIOp>, 679 BinaryOpConversion<MulOp, MulIOp>, 680 BroadcastOpConverter, 681 ConstShapeOpConverter, 682 ConstSizeOpConversion, 683 IsBroadcastableOpConverter, 684 GetExtentOpConverter, 685 RankOpConverter, 686 ReduceOpConverter, 687 ShapeEqOpConverter, 688 ShapeOfOpConversion, 689 SplitAtOpConversion, 690 ToExtentTensorOpConversion>(patterns.getContext()); 691 // clang-format on 692 } 693 694 std::unique_ptr<OperationPass<ModuleOp>> 695 mlir::createConvertShapeToStandardPass() { 696 return std::make_unique<ConvertShapeToStandardPass>(); 697 } 698