1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 using namespace mlir::scf; 24 25 /// Conversion patterns. 26 namespace { 27 class AnyOpConversion : public OpConversionPattern<AnyOp> { 28 public: 29 using OpConversionPattern<AnyOp>::OpConversionPattern; 30 31 LogicalResult 32 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override; 34 }; 35 } // namespace 36 37 LogicalResult 38 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 39 ConversionPatternRewriter &rewriter) const { 40 AnyOp::Adaptor transformed(operands); 41 42 // Replace `any` with its first operand. 43 // Any operand would be a valid substitution. 44 rewriter.replaceOp(op, {transformed.inputs().front()}); 45 return success(); 46 } 47 48 namespace { 49 template <typename SrcOpTy, typename DstOpTy> 50 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 51 public: 52 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 53 54 LogicalResult 55 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 56 ConversionPatternRewriter &rewriter) const override { 57 typename SrcOpTy::Adaptor transformed(operands); 58 59 // For now, only error-free types are supported by this lowering. 60 if (op.getType().template isa<SizeType>()) 61 return failure(); 62 63 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 64 transformed.rhs()); 65 return success(); 66 } 67 }; 68 } // namespace 69 70 namespace { 71 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 72 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 73 74 LogicalResult 75 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 76 ConversionPatternRewriter &rewriter) const override; 77 }; 78 79 // Get the resulting extent in a given dimension. This is computed with any 80 // number of extent tensors and shifted offsets into them. 81 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 82 ValueRange rankDiffs, Value outputDimension) { 83 Value one = lb.create<ConstantIndexOp>(1); 84 Value broadcastedDim = one; 85 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 86 Value shape = std::get<0>(tup); 87 Value rankDiff = std::get<1>(tup); 88 Value outOfBounds = 89 lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff); 90 Type indexTy = lb.getIndexType(); 91 broadcastedDim = 92 lb.create<IfOp>( 93 TypeRange{indexTy}, outOfBounds, 94 [&](OpBuilder &b, Location loc) { 95 b.create<scf::YieldOp>(loc, broadcastedDim); 96 }, 97 [&](OpBuilder &b, Location loc) { 98 // The broadcasting logic is: 99 // - if one extent (here we arbitrarily choose the 100 // extent from the greater-rank operand) is equal to 1, 101 // then take the extent from the other operand 102 // - otherwise, take the extent as-is. 103 // Note that this logic remains correct in the presence 104 // of dimensions of zero extent. 105 Value lesserRankOperandDimension = 106 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 107 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 108 loc, shape, ValueRange{lesserRankOperandDimension}); 109 110 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 111 lesserRankOperandExtent, one); 112 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 113 lesserRankOperandExtent); 114 b.create<scf::YieldOp>(loc, dim); 115 }) 116 .getResult(0); 117 } 118 return broadcastedDim; 119 } 120 } // namespace 121 122 LogicalResult BroadcastOpConverter::matchAndRewrite( 123 BroadcastOp op, ArrayRef<Value> operands, 124 ConversionPatternRewriter &rewriter) const { 125 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 126 // on shapes. 127 if (op.getType().isa<ShapeType>()) 128 return failure(); 129 130 auto loc = op.getLoc(); 131 ImplicitLocOpBuilder lb(loc, rewriter); 132 BroadcastOp::Adaptor transformed(operands); 133 134 Value zero = lb.create<ConstantIndexOp>(0); 135 Type indexTy = lb.getIndexType(); 136 137 // Save all the ranks for bounds checking. Because this is a tensor 138 // representing the shape extents, the rank is the extent of the only 139 // dimension in the tensor. 140 SmallVector<Value> ranks, rankDiffs; 141 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 142 return lb.create<tensor::DimOp>(v, zero); 143 })); 144 145 // Find the maximum rank 146 Value maxRank = ranks.front(); 147 for (Value v : llvm::drop_begin(ranks, 1)) { 148 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 149 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 150 } 151 152 // Calculate the difference of ranks and the maximum rank for later offsets. 153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 154 return lb.create<SubIOp>(indexTy, maxRank, v); 155 })); 156 157 Value replacement = lb.create<tensor::GenerateOp>( 158 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 159 [&](OpBuilder &b, Location loc, ValueRange args) { 160 Value broadcastedDim = 161 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), 162 transformed.shapes(), rankDiffs, args[0]); 163 164 b.create<tensor::YieldOp>(loc, broadcastedDim); 165 }); 166 if (replacement.getType() != op.getType()) 167 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 168 rewriter.replaceOp(op, replacement); 169 return success(); 170 } 171 172 namespace { 173 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 174 public: 175 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 176 177 LogicalResult 178 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 179 ConversionPatternRewriter &rewriter) const override; 180 }; 181 } // namespace 182 183 LogicalResult ConstShapeOpConverter::matchAndRewrite( 184 ConstShapeOp op, ArrayRef<Value> operands, 185 ConversionPatternRewriter &rewriter) const { 186 187 // For now, this lowering supports only extent tensors, not `shape.shape` 188 // types. 189 if (op.getType().isa<ShapeType>()) 190 return failure(); 191 192 auto loc = op.getLoc(); 193 SmallVector<Value, 4> extentOperands; 194 for (auto extent : op.shape()) { 195 extentOperands.push_back( 196 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 197 } 198 Type indexTy = rewriter.getIndexType(); 199 Value tensor = 200 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 201 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 202 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 203 return success(); 204 } 205 206 namespace { 207 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 208 public: 209 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 210 211 LogicalResult 212 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 213 ConversionPatternRewriter &rewriter) const override; 214 }; 215 } // namespace 216 217 LogicalResult ConstSizeOpConversion::matchAndRewrite( 218 ConstSizeOp op, ArrayRef<Value> operands, 219 ConversionPatternRewriter &rewriter) const { 220 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 221 return success(); 222 } 223 224 namespace { 225 struct IsBroadcastableOpConverter 226 : public OpConversionPattern<IsBroadcastableOp> { 227 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 228 229 LogicalResult 230 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 231 ConversionPatternRewriter &rewriter) const override; 232 }; 233 } // namespace 234 235 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 236 IsBroadcastableOp op, ArrayRef<Value> operands, 237 ConversionPatternRewriter &rewriter) const { 238 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 239 // on shapes. 240 IsBroadcastableOp::Adaptor transformed(operands); 241 if (!llvm::all_of(op.shapes(), 242 [](Value v) { return !v.getType().isa<ShapeType>(); })) 243 return failure(); 244 245 auto loc = op.getLoc(); 246 ImplicitLocOpBuilder lb(loc, rewriter); 247 Value zero = lb.create<ConstantIndexOp>(0); 248 Value one = lb.create<ConstantIndexOp>(1); 249 Type indexTy = lb.getIndexType(); 250 251 // Save all the ranks for bounds checking. Because this is a tensor 252 // representing the shape extents, the rank is the extent of the only 253 // dimension in the tensor. 254 SmallVector<Value> ranks, rankDiffs; 255 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 256 return lb.create<tensor::DimOp>(v, zero); 257 })); 258 259 // Find the maximum rank 260 Value maxRank = ranks.front(); 261 for (Value v : llvm::drop_begin(ranks, 1)) { 262 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 263 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 264 } 265 266 // Calculate the difference of ranks and the maximum rank for later offsets. 267 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 268 return lb.create<SubIOp>(indexTy, maxRank, v); 269 })); 270 271 Type i1Ty = rewriter.getI1Type(); 272 Value trueVal = 273 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 274 275 auto reduceResult = lb.create<ForOp>( 276 loc, zero, maxRank, one, ValueRange{trueVal}, 277 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 278 // Find a non-1 dim, if it exists. Note that the first part of this 279 // could reuse the Broadcast lowering entirely, but we redo the work 280 // here to make optimizations easier between the two loops. 281 Value broadcastedDim = getBroadcastedDim( 282 ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); 283 284 Value broadcastable = iterArgs[0]; 285 for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { 286 Value shape, rankDiff; 287 std::tie(shape, rankDiff) = tup; 288 Value outOfBounds = 289 b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff); 290 broadcastable = 291 b.create<IfOp>( 292 loc, TypeRange{i1Ty}, outOfBounds, 293 [&](OpBuilder &b, Location loc) { 294 // Non existent dimensions are always broadcastable 295 b.create<scf::YieldOp>(loc, broadcastable); 296 }, 297 [&](OpBuilder &b, Location loc) { 298 // Every value needs to be either 1, or the same non-1 299 // value to be broadcastable in this dim. 300 Value operandDimension = 301 b.create<SubIOp>(loc, indexTy, iv, rankDiff); 302 Value dimensionExtent = b.create<tensor::ExtractOp>( 303 loc, shape, ValueRange{operandDimension}); 304 305 Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 306 dimensionExtent, one); 307 Value equalBroadcasted = 308 b.create<CmpIOp>(loc, CmpIPredicate::eq, 309 dimensionExtent, broadcastedDim); 310 Value result = b.create<AndOp>( 311 loc, broadcastable, 312 b.create<OrOp>(loc, equalOne, equalBroadcasted)); 313 b.create<scf::YieldOp>(loc, result); 314 }) 315 .getResult(0); 316 } 317 318 b.create<scf::YieldOp>(loc, broadcastable); 319 }); 320 321 rewriter.replaceOp(op, reduceResult.results().front()); 322 return success(); 323 } 324 325 namespace { 326 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 327 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 331 ConversionPatternRewriter &rewriter) const override; 332 }; 333 } // namespace 334 335 LogicalResult GetExtentOpConverter::matchAndRewrite( 336 GetExtentOp op, ArrayRef<Value> operands, 337 ConversionPatternRewriter &rewriter) const { 338 GetExtentOp::Adaptor transformed(operands); 339 340 // For now, only error-free types are supported by this lowering. 341 if (op.getType().isa<SizeType>()) 342 return failure(); 343 344 // Derive shape extent directly from shape origin if possible. This 345 // circumvents the necessity to materialize the shape in memory. 346 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 347 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 348 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(), 349 transformed.dim()); 350 return success(); 351 } 352 } 353 354 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 355 transformed.shape(), 356 ValueRange{transformed.dim()}); 357 return success(); 358 } 359 360 namespace { 361 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 362 public: 363 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 364 365 LogicalResult 366 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 367 ConversionPatternRewriter &rewriter) const override; 368 }; 369 } // namespace 370 371 LogicalResult 372 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 373 ConversionPatternRewriter &rewriter) const { 374 // For now, this lowering supports only error-free types. 375 if (op.getType().isa<SizeType>()) 376 return failure(); 377 378 shape::RankOp::Adaptor transformed(operands); 379 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, transformed.shape(), 0); 380 return success(); 381 } 382 383 namespace { 384 /// Converts `shape.reduce` to `scf.for`. 385 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 386 public: 387 using OpConversionPattern::OpConversionPattern; 388 389 LogicalResult 390 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 391 ConversionPatternRewriter &rewriter) const final; 392 }; 393 } // namespace 394 395 LogicalResult 396 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 397 ConversionPatternRewriter &rewriter) const { 398 // For now, this lowering is only defined on `tensor<?xindex>` operands. 399 if (op.shape().getType().isa<ShapeType>()) 400 return failure(); 401 402 auto loc = op.getLoc(); 403 shape::ReduceOp::Adaptor transformed(operands); 404 405 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 406 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 407 Type indexTy = rewriter.getIndexType(); 408 Value rank = 409 rewriter.create<tensor::DimOp>(loc, indexTy, transformed.shape(), zero); 410 411 auto loop = rewriter.create<scf::ForOp>( 412 loc, zero, rank, one, op.initVals(), 413 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 414 Value extent = 415 b.create<tensor::ExtractOp>(loc, transformed.shape(), iv); 416 417 SmallVector<Value, 2> mappedValues{iv, extent}; 418 mappedValues.append(args.begin(), args.end()); 419 420 BlockAndValueMapping mapping; 421 Block *reduceBody = op.getBody(); 422 mapping.map(reduceBody->getArguments(), mappedValues); 423 for (auto &nested : reduceBody->without_terminator()) 424 b.clone(nested, mapping); 425 426 SmallVector<Value, 2> mappedResults; 427 for (auto result : reduceBody->getTerminator()->getOperands()) 428 mappedResults.push_back(mapping.lookup(result)); 429 b.create<scf::YieldOp>(loc, mappedResults); 430 }); 431 432 rewriter.replaceOp(op, loop.getResults()); 433 return success(); 434 } 435 436 namespace { 437 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 438 /// only defined on `tensor<?xindex>` operands. The test for equality first 439 /// compares their size and, if equal, checks every extent for equality. 440 /// 441 /// Example: 442 /// 443 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 444 /// 445 /// becomes 446 /// 447 /// %c0 = constant 0 : index 448 /// %0 = dim %arg0, %c0 : tensor<?xindex> 449 /// %1 = dim %arg1, %c0 : tensor<?xindex> 450 /// %2 = cmpi "eq", %0, %1 : index 451 /// %result = scf.if %2 -> (i1) { 452 /// %c1 = constant 1 : index 453 /// %true = constant true 454 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 455 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 456 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 457 /// %7 = cmpi "eq", %5, %6 : index 458 /// %8 = and %arg3, %7 : i1 459 /// scf.yield %8 : i1 460 /// } 461 /// scf.yield %4 : i1 462 /// } else { 463 /// %false = constant false 464 /// scf.yield %false : i1 465 /// } 466 /// 467 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 468 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 469 470 LogicalResult 471 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 472 ConversionPatternRewriter &rewriter) const override; 473 }; 474 } // namespace 475 476 LogicalResult 477 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 478 ConversionPatternRewriter &rewriter) const { 479 if (!llvm::all_of(op.shapes(), 480 [](Value v) { return !v.getType().isa<ShapeType>(); })) 481 return failure(); 482 483 Type i1Ty = rewriter.getI1Type(); 484 if (op.shapes().size() <= 1) { 485 rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty, 486 rewriter.getBoolAttr(true)); 487 return success(); 488 } 489 490 ShapeEqOp::Adaptor transformed(operands); 491 auto loc = op.getLoc(); 492 Type indexTy = rewriter.getIndexType(); 493 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 494 Value firstShape = transformed.shapes().front(); 495 Value firstRank = 496 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); 497 Value result = nullptr; 498 // Generate a linear sequence of compares, all with firstShape as lhs. 499 for (Value shape : transformed.shapes().drop_front(1)) { 500 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); 501 Value eqRank = 502 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank); 503 auto same = rewriter.create<IfOp>( 504 loc, i1Ty, eqRank, 505 [&](OpBuilder &b, Location loc) { 506 Value one = b.create<ConstantIndexOp>(loc, 1); 507 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 508 auto loop = b.create<scf::ForOp>( 509 loc, zero, firstRank, one, ValueRange{init}, 510 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 511 Value conj = args[0]; 512 Value lhsExtent = 513 b.create<tensor::ExtractOp>(loc, firstShape, iv); 514 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 515 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 516 lhsExtent, rhsExtent); 517 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 518 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 519 }); 520 b.create<scf::YieldOp>(loc, loop.getResults()); 521 }, 522 [&](OpBuilder &b, Location loc) { 523 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 524 b.create<scf::YieldOp>(loc, result); 525 }); 526 result = !result ? same.getResult(0) 527 : rewriter.create<AndOp>(loc, result, same.getResult(0)); 528 } 529 rewriter.replaceOp(op, result); 530 return success(); 531 } 532 533 namespace { 534 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 535 public: 536 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 537 538 LogicalResult 539 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 540 ConversionPatternRewriter &rewriter) const override; 541 }; 542 } // namespace 543 544 LogicalResult ShapeOfOpConversion::matchAndRewrite( 545 ShapeOfOp op, ArrayRef<Value> operands, 546 ConversionPatternRewriter &rewriter) const { 547 548 // For now, only error-free types are supported by this lowering. 549 if (op.getType().isa<ShapeType>()) 550 return failure(); 551 552 // For ranked tensor arguments, lower to `tensor.from_elements`. 553 auto loc = op.getLoc(); 554 ShapeOfOp::Adaptor transformed(operands); 555 Value tensor = transformed.arg(); 556 Type tensorTy = tensor.getType(); 557 if (tensorTy.isa<RankedTensorType>()) { 558 559 // Build values for individual extents. 560 SmallVector<Value, 8> extentValues; 561 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 562 int64_t rank = rankedTensorTy.getRank(); 563 for (int64_t i = 0; i < rank; i++) { 564 if (rankedTensorTy.isDynamicDim(i)) { 565 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); 566 extentValues.push_back(extent); 567 } else { 568 Value extent = 569 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 570 extentValues.push_back(extent); 571 } 572 } 573 574 // Materialize extent tensor. 575 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 576 loc, rewriter.getIndexType(), extentValues); 577 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 578 staticExtentTensor); 579 return success(); 580 } 581 582 // Lower to `tensor.generate` otherwise. 583 auto *ctx = rewriter.getContext(); 584 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 585 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 586 op, getExtentTensorType(ctx), ValueRange{rank}, 587 [&](OpBuilder &b, Location loc, ValueRange args) { 588 Value dim = args.front(); 589 Value extent = b.create<tensor::DimOp>(loc, tensor, dim); 590 b.create<tensor::YieldOp>(loc, extent); 591 }); 592 593 return success(); 594 } 595 596 namespace { 597 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 598 public: 599 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 600 601 LogicalResult 602 matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands, 603 ConversionPatternRewriter &rewriter) const override; 604 }; 605 } // namespace 606 607 LogicalResult SplitAtOpConversion::matchAndRewrite( 608 SplitAtOp op, ArrayRef<Value> operands, 609 ConversionPatternRewriter &rewriter) const { 610 // Error conditions are not implemented, only lower if all operands and 611 // results are extent tensors. 612 if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, 613 [](Value v) { return v.getType().isa<ShapeType>(); })) 614 return failure(); 615 616 SplitAtOp::Adaptor transformed(op); 617 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 618 Value zero = b.create<ConstantIndexOp>(0); 619 Value rank = b.create<tensor::DimOp>(transformed.operand(), zero); 620 621 // index < 0 ? index + rank : index 622 Value originalIndex = transformed.index(); 623 Value add = b.create<AddIOp>(originalIndex, rank); 624 Value indexIsNegative = 625 b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero); 626 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 627 628 Value one = b.create<ConstantIndexOp>(1); 629 Value head = 630 b.create<tensor::ExtractSliceOp>(transformed.operand(), zero, index, one); 631 Value tailSize = b.create<SubIOp>(rank, index); 632 Value tail = b.create<tensor::ExtractSliceOp>(transformed.operand(), index, 633 tailSize, one); 634 rewriter.replaceOp(op, {head, tail}); 635 return success(); 636 } 637 638 namespace { 639 class ToExtentTensorOpConversion 640 : public OpConversionPattern<ToExtentTensorOp> { 641 public: 642 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 643 644 LogicalResult 645 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 646 ConversionPatternRewriter &rewriter) const override { 647 ToExtentTensorOpAdaptor adaptor(operands); 648 649 if (!adaptor.input().getType().isa<RankedTensorType>()) 650 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 651 652 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 653 adaptor.input()); 654 return success(); 655 } 656 }; 657 } // namespace 658 659 namespace { 660 /// Import the Shape Ops to Std Patterns. 661 #include "ShapeToStandard.cpp.inc" 662 } // namespace 663 664 namespace { 665 /// Conversion pass. 666 class ConvertShapeToStandardPass 667 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 668 669 void runOnOperation() override; 670 }; 671 } // namespace 672 673 void ConvertShapeToStandardPass::runOnOperation() { 674 // Setup target legality. 675 MLIRContext &ctx = getContext(); 676 ConversionTarget target(ctx); 677 target 678 .addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>(); 679 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 680 681 // Setup conversion patterns. 682 RewritePatternSet patterns(&ctx); 683 populateShapeToStandardConversionPatterns(patterns); 684 685 // Apply conversion. 686 auto module = getOperation(); 687 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 688 signalPassFailure(); 689 } 690 691 void mlir::populateShapeToStandardConversionPatterns( 692 RewritePatternSet &patterns) { 693 // clang-format off 694 populateWithGenerated(patterns); 695 patterns.add< 696 AnyOpConversion, 697 BinaryOpConversion<AddOp, AddIOp>, 698 BinaryOpConversion<MulOp, MulIOp>, 699 BroadcastOpConverter, 700 ConstShapeOpConverter, 701 ConstSizeOpConversion, 702 IsBroadcastableOpConverter, 703 GetExtentOpConverter, 704 RankOpConverter, 705 ReduceOpConverter, 706 ShapeEqOpConverter, 707 ShapeOfOpConversion, 708 SplitAtOpConversion, 709 ToExtentTensorOpConversion>(patterns.getContext()); 710 // clang-format on 711 } 712 713 std::unique_ptr<OperationPass<ModuleOp>> 714 mlir::createConvertShapeToStandardPass() { 715 return std::make_unique<ConvertShapeToStandardPass>(); 716 } 717