1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/STLExtras.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 using namespace mlir::scf; 25 26 /// Conversion patterns. 27 namespace { 28 class AnyOpConversion : public OpConversionPattern<AnyOp> { 29 public: 30 using OpConversionPattern<AnyOp>::OpConversionPattern; 31 32 LogicalResult 33 matchAndRewrite(AnyOp op, OpAdaptor adaptor, 34 ConversionPatternRewriter &rewriter) const override; 35 }; 36 } // namespace 37 38 LogicalResult 39 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, 40 ConversionPatternRewriter &rewriter) const { 41 // Replace `any` with its first operand. 42 // Any operand would be a valid substitution. 43 rewriter.replaceOp(op, {adaptor.inputs().front()}); 44 return success(); 45 } 46 47 namespace { 48 template <typename SrcOpTy, typename DstOpTy> 49 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 50 public: 51 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 52 53 LogicalResult 54 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs()); 61 return success(); 62 } 63 }; 64 } // namespace 65 66 namespace { 67 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 68 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override; 73 }; 74 75 // Get the resulting extent in a given dimension. This is computed with any 76 // number of extent tensors and shifted offsets into them. 77 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 78 ValueRange rankDiffs, Value outputDimension) { 79 Value one = lb.create<arith::ConstantIndexOp>(1); 80 Value broadcastedDim = one; 81 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 82 Value shape = std::get<0>(tup); 83 Value rankDiff = std::get<1>(tup); 84 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, 85 outputDimension, rankDiff); 86 Type indexTy = lb.getIndexType(); 87 broadcastedDim = 88 lb.create<IfOp>( 89 TypeRange{indexTy}, outOfBounds, 90 [&](OpBuilder &b, Location loc) { 91 b.create<scf::YieldOp>(loc, broadcastedDim); 92 }, 93 [&](OpBuilder &b, Location loc) { 94 // The broadcasting logic is: 95 // - if one extent (here we arbitrarily choose the 96 // extent from the greater-rank operand) is equal to 1, 97 // then take the extent from the other operand 98 // - otherwise, take the extent as-is. 99 // Note that this logic remains correct in the presence 100 // of dimensions of zero extent. 101 Value lesserRankOperandDimension = b.create<arith::SubIOp>( 102 loc, indexTy, outputDimension, rankDiff); 103 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 104 loc, shape, ValueRange{lesserRankOperandDimension}); 105 106 Value dimIsOne = 107 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 108 lesserRankOperandExtent, one); 109 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 110 lesserRankOperandExtent); 111 b.create<scf::YieldOp>(loc, dim); 112 }) 113 .getResult(0); 114 } 115 return broadcastedDim; 116 } 117 } // namespace 118 119 LogicalResult BroadcastOpConverter::matchAndRewrite( 120 BroadcastOp op, OpAdaptor adaptor, 121 ConversionPatternRewriter &rewriter) const { 122 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 123 // on shapes. 124 if (op.getType().isa<ShapeType>()) 125 return failure(); 126 127 auto loc = op.getLoc(); 128 ImplicitLocOpBuilder lb(loc, rewriter); 129 130 Value zero = lb.create<arith::ConstantIndexOp>(0); 131 Type indexTy = lb.getIndexType(); 132 133 // Save all the ranks for bounds checking. Because this is a tensor 134 // representing the shape extents, the rank is the extent of the only 135 // dimension in the tensor. 136 SmallVector<Value> ranks, rankDiffs; 137 llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { 138 return lb.create<tensor::DimOp>(v, zero); 139 })); 140 141 // Find the maximum rank 142 Value maxRank = ranks.front(); 143 for (Value v : llvm::drop_begin(ranks, 1)) { 144 Value rankIsGreater = 145 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 146 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 147 } 148 149 // Calculate the difference of ranks and the maximum rank for later offsets. 150 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 151 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 152 })); 153 154 Value replacement = lb.create<tensor::GenerateOp>( 155 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 156 [&](OpBuilder &b, Location loc, ValueRange args) { 157 Value broadcastedDim = getBroadcastedDim( 158 ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]); 159 160 b.create<tensor::YieldOp>(loc, broadcastedDim); 161 }); 162 if (replacement.getType() != op.getType()) 163 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 164 rewriter.replaceOp(op, replacement); 165 return success(); 166 } 167 168 namespace { 169 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 170 public: 171 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 172 173 LogicalResult 174 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override; 176 }; 177 } // namespace 178 179 LogicalResult ConstShapeOpConverter::matchAndRewrite( 180 ConstShapeOp op, OpAdaptor adaptor, 181 ConversionPatternRewriter &rewriter) const { 182 183 // For now, this lowering supports only extent tensors, not `shape.shape` 184 // types. 185 if (op.getType().isa<ShapeType>()) 186 return failure(); 187 188 auto loc = op.getLoc(); 189 SmallVector<Value, 4> extentOperands; 190 for (auto extent : op.shape()) { 191 extentOperands.push_back( 192 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); 193 } 194 Type indexTy = rewriter.getIndexType(); 195 Value tensor = 196 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 197 Type resultTy = RankedTensorType::get({op.shape().size()}, indexTy); 198 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 199 return success(); 200 } 201 202 namespace { 203 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 204 public: 205 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 206 207 LogicalResult 208 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, 209 ConversionPatternRewriter &rewriter) const override; 210 }; 211 } // namespace 212 213 LogicalResult ConstSizeOpConversion::matchAndRewrite( 214 ConstSizeOp op, OpAdaptor adaptor, 215 ConversionPatternRewriter &rewriter) const { 216 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( 217 op, op.value().getSExtValue()); 218 return success(); 219 } 220 221 namespace { 222 struct IsBroadcastableOpConverter 223 : public OpConversionPattern<IsBroadcastableOp> { 224 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 225 226 LogicalResult 227 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, 228 ConversionPatternRewriter &rewriter) const override; 229 }; 230 } // namespace 231 232 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 233 IsBroadcastableOp op, OpAdaptor adaptor, 234 ConversionPatternRewriter &rewriter) const { 235 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 236 // on shapes. 237 if (!llvm::all_of(op.shapes(), 238 [](Value v) { return !v.getType().isa<ShapeType>(); })) 239 return failure(); 240 241 auto loc = op.getLoc(); 242 ImplicitLocOpBuilder lb(loc, rewriter); 243 Value zero = lb.create<arith::ConstantIndexOp>(0); 244 Value one = lb.create<arith::ConstantIndexOp>(1); 245 Type indexTy = lb.getIndexType(); 246 247 // Save all the ranks for bounds checking. Because this is a tensor 248 // representing the shape extents, the rank is the extent of the only 249 // dimension in the tensor. 250 SmallVector<Value> ranks, rankDiffs; 251 llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { 252 return lb.create<tensor::DimOp>(v, zero); 253 })); 254 255 // Find the maximum rank 256 Value maxRank = ranks.front(); 257 for (Value v : llvm::drop_begin(ranks, 1)) { 258 Value rankIsGreater = 259 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); 260 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 261 } 262 263 // Calculate the difference of ranks and the maximum rank for later offsets. 264 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 265 return lb.create<arith::SubIOp>(indexTy, maxRank, v); 266 })); 267 268 Type i1Ty = rewriter.getI1Type(); 269 Value trueVal = 270 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 271 272 auto reduceResult = lb.create<ForOp>( 273 loc, zero, maxRank, one, ValueRange{trueVal}, 274 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 275 // Find a non-1 dim, if it exists. Note that the first part of this 276 // could reuse the Broadcast lowering entirely, but we redo the work 277 // here to make optimizations easier between the two loops. 278 Value broadcastedDim = getBroadcastedDim( 279 ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv); 280 281 Value broadcastable = iterArgs[0]; 282 for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { 283 Value shape, rankDiff; 284 std::tie(shape, rankDiff) = tup; 285 Value outOfBounds = b.create<arith::CmpIOp>( 286 loc, arith::CmpIPredicate::ult, iv, rankDiff); 287 broadcastable = 288 b.create<IfOp>( 289 loc, TypeRange{i1Ty}, outOfBounds, 290 [&](OpBuilder &b, Location loc) { 291 // Non existent dimensions are always broadcastable 292 b.create<scf::YieldOp>(loc, broadcastable); 293 }, 294 [&](OpBuilder &b, Location loc) { 295 // Every value needs to be either 1, or the same non-1 296 // value to be broadcastable in this dim. 297 Value operandDimension = 298 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); 299 Value dimensionExtent = b.create<tensor::ExtractOp>( 300 loc, shape, ValueRange{operandDimension}); 301 302 Value equalOne = b.create<arith::CmpIOp>( 303 loc, arith::CmpIPredicate::eq, dimensionExtent, one); 304 Value equalBroadcasted = b.create<arith::CmpIOp>( 305 loc, arith::CmpIPredicate::eq, dimensionExtent, 306 broadcastedDim); 307 Value result = b.create<arith::AndIOp>( 308 loc, broadcastable, 309 b.create<arith::OrIOp>(loc, equalOne, 310 equalBroadcasted)); 311 b.create<scf::YieldOp>(loc, result); 312 }) 313 .getResult(0); 314 } 315 316 b.create<scf::YieldOp>(loc, broadcastable); 317 }); 318 319 rewriter.replaceOp(op, reduceResult.results().front()); 320 return success(); 321 } 322 323 namespace { 324 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 325 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 326 327 LogicalResult 328 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, 329 ConversionPatternRewriter &rewriter) const override; 330 }; 331 } // namespace 332 333 LogicalResult GetExtentOpConverter::matchAndRewrite( 334 GetExtentOp op, OpAdaptor adaptor, 335 ConversionPatternRewriter &rewriter) const { 336 // For now, only error-free types are supported by this lowering. 337 if (op.getType().isa<SizeType>()) 338 return failure(); 339 340 // Derive shape extent directly from shape origin if possible. This 341 // circumvents the necessity to materialize the shape in memory. 342 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 343 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 344 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(), 345 adaptor.dim()); 346 return success(); 347 } 348 } 349 350 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 351 op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()}); 352 return success(); 353 } 354 355 namespace { 356 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 357 public: 358 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 359 360 LogicalResult 361 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 362 ConversionPatternRewriter &rewriter) const override; 363 }; 364 } // namespace 365 366 LogicalResult 367 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const { 369 // For now, this lowering supports only error-free types. 370 if (op.getType().isa<SizeType>()) 371 return failure(); 372 373 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.shape(), 0); 374 return success(); 375 } 376 377 namespace { 378 /// Converts `shape.reduce` to `scf.for`. 379 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 380 public: 381 using OpConversionPattern::OpConversionPattern; 382 383 LogicalResult 384 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 385 ConversionPatternRewriter &rewriter) const final; 386 }; 387 } // namespace 388 389 LogicalResult 390 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, 391 ConversionPatternRewriter &rewriter) const { 392 // For now, this lowering is only defined on `tensor<?xindex>` operands. 393 if (op.shape().getType().isa<ShapeType>()) 394 return failure(); 395 396 auto loc = op.getLoc(); 397 398 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 399 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 400 Type indexTy = rewriter.getIndexType(); 401 Value rank = 402 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.shape(), zero); 403 404 auto loop = rewriter.create<scf::ForOp>( 405 loc, zero, rank, one, op.initVals(), 406 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 407 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.shape(), iv); 408 409 SmallVector<Value, 2> mappedValues{iv, extent}; 410 mappedValues.append(args.begin(), args.end()); 411 412 BlockAndValueMapping mapping; 413 Block *reduceBody = op.getBody(); 414 mapping.map(reduceBody->getArguments(), mappedValues); 415 for (auto &nested : reduceBody->without_terminator()) 416 b.clone(nested, mapping); 417 418 SmallVector<Value, 2> mappedResults; 419 for (auto result : reduceBody->getTerminator()->getOperands()) 420 mappedResults.push_back(mapping.lookup(result)); 421 b.create<scf::YieldOp>(loc, mappedResults); 422 }); 423 424 rewriter.replaceOp(op, loop.getResults()); 425 return success(); 426 } 427 428 namespace { 429 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 430 /// only defined on `tensor<?xindex>` operands. The test for equality first 431 /// compares their size and, if equal, checks every extent for equality. 432 /// 433 /// Example: 434 /// 435 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 436 /// 437 /// becomes 438 /// 439 /// %c0 = constant 0 : index 440 /// %0 = dim %arg0, %c0 : tensor<?xindex> 441 /// %1 = dim %arg1, %c0 : tensor<?xindex> 442 /// %2 = arith.cmpi "eq", %0, %1 : index 443 /// %result = scf.if %2 -> (i1) { 444 /// %c1 = arith.constant 1 : index 445 /// %true = arith.constant true 446 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 447 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 448 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 449 /// %7 = arith.cmpi "eq", %5, %6 : index 450 /// %8 = arith.andi %arg3, %7 : i1 451 /// scf.yield %8 : i1 452 /// } 453 /// scf.yield %4 : i1 454 /// } else { 455 /// %false = arith.constant false 456 /// scf.yield %false : i1 457 /// } 458 /// 459 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 460 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 461 462 LogicalResult 463 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 464 ConversionPatternRewriter &rewriter) const override; 465 }; 466 } // namespace 467 468 LogicalResult 469 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, 470 ConversionPatternRewriter &rewriter) const { 471 if (!llvm::all_of(op.shapes(), 472 [](Value v) { return !v.getType().isa<ShapeType>(); })) 473 return failure(); 474 475 Type i1Ty = rewriter.getI1Type(); 476 if (op.shapes().size() <= 1) { 477 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty, 478 rewriter.getBoolAttr(true)); 479 return success(); 480 } 481 482 auto loc = op.getLoc(); 483 Type indexTy = rewriter.getIndexType(); 484 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 485 Value firstShape = adaptor.shapes().front(); 486 Value firstRank = 487 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); 488 Value result = nullptr; 489 // Generate a linear sequence of compares, all with firstShape as lhs. 490 for (Value shape : adaptor.shapes().drop_front(1)) { 491 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); 492 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 493 firstRank, rank); 494 auto same = rewriter.create<IfOp>( 495 loc, i1Ty, eqRank, 496 [&](OpBuilder &b, Location loc) { 497 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 498 Value init = 499 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 500 auto loop = b.create<scf::ForOp>( 501 loc, zero, firstRank, one, ValueRange{init}, 502 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 503 Value conj = args[0]; 504 Value lhsExtent = 505 b.create<tensor::ExtractOp>(loc, firstShape, iv); 506 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 507 Value eqExtent = b.create<arith::CmpIOp>( 508 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); 509 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); 510 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 511 }); 512 b.create<scf::YieldOp>(loc, loop.getResults()); 513 }, 514 [&](OpBuilder &b, Location loc) { 515 Value result = 516 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 517 b.create<scf::YieldOp>(loc, result); 518 }); 519 result = !result ? same.getResult(0) 520 : rewriter.create<arith::AndIOp>(loc, result, 521 same.getResult(0)); 522 } 523 rewriter.replaceOp(op, result); 524 return success(); 525 } 526 527 namespace { 528 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 529 public: 530 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 531 532 LogicalResult 533 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, 534 ConversionPatternRewriter &rewriter) const override; 535 }; 536 } // namespace 537 538 LogicalResult ShapeOfOpConversion::matchAndRewrite( 539 ShapeOfOp op, OpAdaptor adaptor, 540 ConversionPatternRewriter &rewriter) const { 541 542 // For now, only error-free types are supported by this lowering. 543 if (op.getType().isa<ShapeType>()) 544 return failure(); 545 546 // For ranked tensor arguments, lower to `tensor.from_elements`. 547 auto loc = op.getLoc(); 548 Value tensor = adaptor.arg(); 549 Type tensorTy = tensor.getType(); 550 if (tensorTy.isa<RankedTensorType>()) { 551 552 // Build values for individual extents. 553 SmallVector<Value, 8> extentValues; 554 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 555 int64_t rank = rankedTensorTy.getRank(); 556 for (int64_t i = 0; i < rank; i++) { 557 if (rankedTensorTy.isDynamicDim(i)) { 558 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); 559 extentValues.push_back(extent); 560 } else { 561 Value extent = rewriter.create<arith::ConstantIndexOp>( 562 loc, rankedTensorTy.getDimSize(i)); 563 extentValues.push_back(extent); 564 } 565 } 566 567 // Materialize extent tensor. 568 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 569 loc, rewriter.getIndexType(), extentValues); 570 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 571 staticExtentTensor); 572 return success(); 573 } 574 575 // Lower to `tensor.generate` otherwise. 576 auto *ctx = rewriter.getContext(); 577 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 578 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 579 op, getExtentTensorType(ctx), ValueRange{rank}, 580 [&](OpBuilder &b, Location loc, ValueRange args) { 581 Value dim = args.front(); 582 Value extent = b.create<tensor::DimOp>(loc, tensor, dim); 583 b.create<tensor::YieldOp>(loc, extent); 584 }); 585 586 return success(); 587 } 588 589 namespace { 590 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 591 public: 592 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 593 594 LogicalResult 595 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, 596 ConversionPatternRewriter &rewriter) const override; 597 }; 598 } // namespace 599 600 LogicalResult SplitAtOpConversion::matchAndRewrite( 601 SplitAtOp op, OpAdaptor adaptor, 602 ConversionPatternRewriter &rewriter) const { 603 // Error conditions are not implemented, only lower if all operands and 604 // results are extent tensors. 605 if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, 606 [](Value v) { return v.getType().isa<ShapeType>(); })) 607 return failure(); 608 609 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 610 Value zero = b.create<arith::ConstantIndexOp>(0); 611 Value rank = b.create<tensor::DimOp>(adaptor.operand(), zero); 612 613 // index < 0 ? index + rank : index 614 Value originalIndex = adaptor.index(); 615 Value add = b.create<arith::AddIOp>(originalIndex, rank); 616 Value indexIsNegative = 617 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); 618 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 619 620 Value one = b.create<arith::ConstantIndexOp>(1); 621 Value head = 622 b.create<tensor::ExtractSliceOp>(adaptor.operand(), zero, index, one); 623 Value tailSize = b.create<arith::SubIOp>(rank, index); 624 Value tail = 625 b.create<tensor::ExtractSliceOp>(adaptor.operand(), index, tailSize, one); 626 rewriter.replaceOp(op, {head, tail}); 627 return success(); 628 } 629 630 namespace { 631 class ToExtentTensorOpConversion 632 : public OpConversionPattern<ToExtentTensorOp> { 633 public: 634 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 635 636 LogicalResult 637 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, 638 ConversionPatternRewriter &rewriter) const override { 639 if (!adaptor.input().getType().isa<RankedTensorType>()) 640 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 641 642 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 643 adaptor.input()); 644 return success(); 645 } 646 }; 647 } // namespace 648 649 namespace { 650 /// Import the Shape Ops to Std Patterns. 651 #include "ShapeToStandard.cpp.inc" 652 } // namespace 653 654 namespace { 655 /// Conversion pass. 656 class ConvertShapeToStandardPass 657 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 658 659 void runOnOperation() override; 660 }; 661 } // namespace 662 663 void ConvertShapeToStandardPass::runOnOperation() { 664 // Setup target legality. 665 MLIRContext &ctx = getContext(); 666 ConversionTarget target(ctx); 667 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect, 668 SCFDialect, tensor::TensorDialect>(); 669 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 670 671 // Setup conversion patterns. 672 RewritePatternSet patterns(&ctx); 673 populateShapeToStandardConversionPatterns(patterns); 674 675 // Apply conversion. 676 auto module = getOperation(); 677 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 678 signalPassFailure(); 679 } 680 681 void mlir::populateShapeToStandardConversionPatterns( 682 RewritePatternSet &patterns) { 683 // clang-format off 684 populateWithGenerated(patterns); 685 patterns.add< 686 AnyOpConversion, 687 BinaryOpConversion<AddOp, arith::AddIOp>, 688 BinaryOpConversion<MulOp, arith::MulIOp>, 689 BroadcastOpConverter, 690 ConstShapeOpConverter, 691 ConstSizeOpConversion, 692 IsBroadcastableOpConverter, 693 GetExtentOpConverter, 694 RankOpConverter, 695 ReduceOpConverter, 696 ShapeEqOpConverter, 697 ShapeOfOpConversion, 698 SplitAtOpConversion, 699 ToExtentTensorOpConversion>(patterns.getContext()); 700 // clang-format on 701 } 702 703 std::unique_ptr<OperationPass<ModuleOp>> 704 mlir::createConvertShapeToStandardPass() { 705 return std::make_unique<ConvertShapeToStandardPass>(); 706 } 707