1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Shape/IR/Shape.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/STLExtras.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 using namespace mlir::scf; 25 26 /// Conversion patterns. 27 namespace { 28 class AnyOpConversion : public OpConversionPattern<AnyOp> { 29 public: 30 using OpConversionPattern<AnyOp>::OpConversionPattern; 31 32 LogicalResult 33 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 34 ConversionPatternRewriter &rewriter) const override; 35 }; 36 } // namespace 37 38 LogicalResult 39 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 40 ConversionPatternRewriter &rewriter) const { 41 AnyOp::Adaptor transformed(operands); 42 43 // Replace `any` with its first operand. 44 // Any operand would be a valid substitution. 45 rewriter.replaceOp(op, {transformed.inputs().front()}); 46 return success(); 47 } 48 49 namespace { 50 template <typename SrcOpTy, typename DstOpTy> 51 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 52 public: 53 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 54 55 LogicalResult 56 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 57 ConversionPatternRewriter &rewriter) const override { 58 typename SrcOpTy::Adaptor transformed(operands); 59 60 // For now, only error-free types are supported by this lowering. 61 if (op.getType().template isa<SizeType>()) 62 return failure(); 63 64 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 65 transformed.rhs()); 66 return success(); 67 } 68 }; 69 } // namespace 70 71 namespace { 72 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 73 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 74 75 LogicalResult 76 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 77 ConversionPatternRewriter &rewriter) const override; 78 }; 79 80 // Get the resulting extent in a given dimension. This is computed with any 81 // number of extent tensors and shifted offsets into them. 82 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, 83 ValueRange rankDiffs, Value outputDimension) { 84 Value one = lb.create<ConstantIndexOp>(1); 85 Value broadcastedDim = one; 86 for (auto tup : llvm::zip(extentTensors, rankDiffs)) { 87 Value shape = std::get<0>(tup); 88 Value rankDiff = std::get<1>(tup); 89 Value outOfBounds = 90 lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff); 91 Type indexTy = lb.getIndexType(); 92 broadcastedDim = 93 lb.create<IfOp>( 94 TypeRange{indexTy}, outOfBounds, 95 [&](OpBuilder &b, Location loc) { 96 b.create<scf::YieldOp>(loc, broadcastedDim); 97 }, 98 [&](OpBuilder &b, Location loc) { 99 // The broadcasting logic is: 100 // - if one extent (here we arbitrarily choose the 101 // extent from the greater-rank operand) is equal to 1, 102 // then take the extent from the other operand 103 // - otherwise, take the extent as-is. 104 // Note that this logic remains correct in the presence 105 // of dimensions of zero extent. 106 Value lesserRankOperandDimension = 107 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 108 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( 109 loc, shape, ValueRange{lesserRankOperandDimension}); 110 111 Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 112 lesserRankOperandExtent, one); 113 Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim, 114 lesserRankOperandExtent); 115 b.create<scf::YieldOp>(loc, dim); 116 }) 117 .getResult(0); 118 } 119 return broadcastedDim; 120 } 121 } // namespace 122 123 LogicalResult BroadcastOpConverter::matchAndRewrite( 124 BroadcastOp op, ArrayRef<Value> operands, 125 ConversionPatternRewriter &rewriter) const { 126 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 127 // on shapes. 128 if (op.getType().isa<ShapeType>()) 129 return failure(); 130 131 auto loc = op.getLoc(); 132 ImplicitLocOpBuilder lb(loc, rewriter); 133 BroadcastOp::Adaptor transformed(operands); 134 135 Value zero = lb.create<ConstantIndexOp>(0); 136 Type indexTy = lb.getIndexType(); 137 138 // Save all the ranks for bounds checking. Because this is a tensor 139 // representing the shape extents, the rank is the extent of the only 140 // dimension in the tensor. 141 SmallVector<Value> ranks, rankDiffs; 142 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 143 return lb.create<memref::DimOp>(v, zero); 144 })); 145 146 // Find the maximum rank 147 Value maxRank = ranks.front(); 148 for (Value v : llvm::drop_begin(ranks, 1)) { 149 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 150 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 151 } 152 153 // Calculate the difference of ranks and the maximum rank for later offsets. 154 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 155 return lb.create<SubIOp>(indexTy, maxRank, v); 156 })); 157 158 Value replacement = lb.create<tensor::GenerateOp>( 159 getExtentTensorType(lb.getContext()), ValueRange{maxRank}, 160 [&](OpBuilder &b, Location loc, ValueRange args) { 161 Value broadcastedDim = 162 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), 163 transformed.shapes(), rankDiffs, args[0]); 164 165 b.create<tensor::YieldOp>(loc, broadcastedDim); 166 }); 167 if (replacement.getType() != op.getType()) 168 replacement = lb.create<tensor::CastOp>(op.getType(), replacement); 169 rewriter.replaceOp(op, replacement); 170 return success(); 171 } 172 173 namespace { 174 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 175 public: 176 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 177 178 LogicalResult 179 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 180 ConversionPatternRewriter &rewriter) const override; 181 }; 182 } // namespace 183 184 LogicalResult ConstShapeOpConverter::matchAndRewrite( 185 ConstShapeOp op, ArrayRef<Value> operands, 186 ConversionPatternRewriter &rewriter) const { 187 188 // For now, this lowering supports only extent tensors, not `shape.shape` 189 // types. 190 if (op.getType().isa<ShapeType>()) 191 return failure(); 192 193 auto loc = op.getLoc(); 194 SmallVector<Value, 4> extentOperands; 195 for (auto extent : op.shape()) { 196 extentOperands.push_back( 197 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 198 } 199 Type indexTy = rewriter.getIndexType(); 200 Value tensor = 201 rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands); 202 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 203 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); 204 return success(); 205 } 206 207 namespace { 208 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 209 public: 210 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 211 212 LogicalResult 213 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 214 ConversionPatternRewriter &rewriter) const override; 215 }; 216 } // namespace 217 218 LogicalResult ConstSizeOpConversion::matchAndRewrite( 219 ConstSizeOp op, ArrayRef<Value> operands, 220 ConversionPatternRewriter &rewriter) const { 221 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 222 return success(); 223 } 224 225 namespace { 226 struct IsBroadcastableOpConverter 227 : public OpConversionPattern<IsBroadcastableOp> { 228 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; 229 230 LogicalResult 231 matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands, 232 ConversionPatternRewriter &rewriter) const override; 233 }; 234 } // namespace 235 236 LogicalResult IsBroadcastableOpConverter::matchAndRewrite( 237 IsBroadcastableOp op, ArrayRef<Value> operands, 238 ConversionPatternRewriter &rewriter) const { 239 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 240 // on shapes. 241 IsBroadcastableOp::Adaptor transformed(operands); 242 if (!llvm::all_of(op.shapes(), 243 [](Value v) { return !v.getType().isa<ShapeType>(); })) 244 return failure(); 245 246 auto loc = op.getLoc(); 247 ImplicitLocOpBuilder lb(loc, rewriter); 248 Value zero = lb.create<ConstantIndexOp>(0); 249 Value one = lb.create<ConstantIndexOp>(1); 250 Type indexTy = lb.getIndexType(); 251 252 // Save all the ranks for bounds checking. Because this is a tensor 253 // representing the shape extents, the rank is the extent of the only 254 // dimension in the tensor. 255 SmallVector<Value> ranks, rankDiffs; 256 llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { 257 return lb.create<memref::DimOp>(v, zero); 258 })); 259 260 // Find the maximum rank 261 Value maxRank = ranks.front(); 262 for (Value v : llvm::drop_begin(ranks, 1)) { 263 Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); 264 maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); 265 } 266 267 // Calculate the difference of ranks and the maximum rank for later offsets. 268 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { 269 return lb.create<SubIOp>(indexTy, maxRank, v); 270 })); 271 272 Type i1Ty = rewriter.getI1Type(); 273 Value trueVal = 274 rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); 275 276 auto reduceResult = lb.create<ForOp>( 277 loc, zero, maxRank, one, ValueRange{trueVal}, 278 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { 279 // Find a non-1 dim, if it exists. Note that the first part of this 280 // could reuse the Broadcast lowering entirely, but we redo the work 281 // here to make optimizations easier between the two loops. 282 Value broadcastedDim = getBroadcastedDim( 283 ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); 284 285 Value broadcastable = iterArgs[0]; 286 for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { 287 Value shape, rankDiff; 288 std::tie(shape, rankDiff) = tup; 289 Value outOfBounds = 290 b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff); 291 broadcastable = 292 b.create<IfOp>( 293 loc, TypeRange{i1Ty}, outOfBounds, 294 [&](OpBuilder &b, Location loc) { 295 // Non existent dimensions are always broadcastable 296 b.create<scf::YieldOp>(loc, broadcastable); 297 }, 298 [&](OpBuilder &b, Location loc) { 299 // Every value needs to be either 1, or the same non-1 300 // value to be broadcastable in this dim. 301 Value operandDimension = 302 b.create<SubIOp>(loc, indexTy, iv, rankDiff); 303 Value dimensionExtent = b.create<tensor::ExtractOp>( 304 loc, shape, ValueRange{operandDimension}); 305 306 Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, 307 dimensionExtent, one); 308 Value equalBroadcasted = 309 b.create<CmpIOp>(loc, CmpIPredicate::eq, 310 dimensionExtent, broadcastedDim); 311 Value result = b.create<AndOp>( 312 loc, broadcastable, 313 b.create<OrOp>(loc, equalOne, equalBroadcasted)); 314 b.create<scf::YieldOp>(loc, result); 315 }) 316 .getResult(0); 317 } 318 319 b.create<scf::YieldOp>(loc, broadcastable); 320 }); 321 322 rewriter.replaceOp(op, reduceResult.results().front()); 323 return success(); 324 } 325 326 namespace { 327 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 328 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 329 330 LogicalResult 331 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 332 ConversionPatternRewriter &rewriter) const override; 333 }; 334 } // namespace 335 336 LogicalResult GetExtentOpConverter::matchAndRewrite( 337 GetExtentOp op, ArrayRef<Value> operands, 338 ConversionPatternRewriter &rewriter) const { 339 GetExtentOp::Adaptor transformed(operands); 340 341 // For now, only error-free types are supported by this lowering. 342 if (op.getType().isa<SizeType>()) 343 return failure(); 344 345 // Derive shape extent directly from shape origin if possible. This 346 // circumvents the necessity to materialize the shape in memory. 347 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 348 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 349 rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(), 350 transformed.dim()); 351 return success(); 352 } 353 } 354 355 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), 356 transformed.shape(), 357 ValueRange{transformed.dim()}); 358 return success(); 359 } 360 361 namespace { 362 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 363 public: 364 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 365 366 LogicalResult 367 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 368 ConversionPatternRewriter &rewriter) const override; 369 }; 370 } // namespace 371 372 LogicalResult 373 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 374 ConversionPatternRewriter &rewriter) const { 375 // For now, this lowering supports only error-free types. 376 if (op.getType().isa<SizeType>()) 377 return failure(); 378 379 shape::RankOp::Adaptor transformed(operands); 380 rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0); 381 return success(); 382 } 383 384 namespace { 385 /// Converts `shape.reduce` to `scf.for`. 386 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 387 public: 388 using OpConversionPattern::OpConversionPattern; 389 390 LogicalResult 391 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 392 ConversionPatternRewriter &rewriter) const final; 393 }; 394 } // namespace 395 396 LogicalResult 397 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 398 ConversionPatternRewriter &rewriter) const { 399 // For now, this lowering is only defined on `tensor<?xindex>` operands. 400 if (op.shape().getType().isa<ShapeType>()) 401 return failure(); 402 403 auto loc = op.getLoc(); 404 shape::ReduceOp::Adaptor transformed(operands); 405 406 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 407 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 408 Type indexTy = rewriter.getIndexType(); 409 Value rank = 410 rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero); 411 412 auto loop = rewriter.create<scf::ForOp>( 413 loc, zero, rank, one, op.initVals(), 414 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 415 Value extent = 416 b.create<tensor::ExtractOp>(loc, transformed.shape(), iv); 417 418 SmallVector<Value, 2> mappedValues{iv, extent}; 419 mappedValues.append(args.begin(), args.end()); 420 421 BlockAndValueMapping mapping; 422 Block *reduceBody = op.getBody(); 423 mapping.map(reduceBody->getArguments(), mappedValues); 424 for (auto &nested : reduceBody->without_terminator()) 425 b.clone(nested, mapping); 426 427 SmallVector<Value, 2> mappedResults; 428 for (auto result : reduceBody->getTerminator()->getOperands()) 429 mappedResults.push_back(mapping.lookup(result)); 430 b.create<scf::YieldOp>(loc, mappedResults); 431 }); 432 433 rewriter.replaceOp(op, loop.getResults()); 434 return success(); 435 } 436 437 namespace { 438 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 439 /// only defined on `tensor<?xindex>` operands. The test for equality first 440 /// compares their size and, if equal, checks every extent for equality. 441 /// 442 /// Example: 443 /// 444 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 445 /// 446 /// becomes 447 /// 448 /// %c0 = constant 0 : index 449 /// %0 = dim %arg0, %c0 : tensor<?xindex> 450 /// %1 = dim %arg1, %c0 : tensor<?xindex> 451 /// %2 = cmpi "eq", %0, %1 : index 452 /// %result = scf.if %2 -> (i1) { 453 /// %c1 = constant 1 : index 454 /// %true = constant true 455 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 456 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> 457 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> 458 /// %7 = cmpi "eq", %5, %6 : index 459 /// %8 = and %arg3, %7 : i1 460 /// scf.yield %8 : i1 461 /// } 462 /// scf.yield %4 : i1 463 /// } else { 464 /// %false = constant false 465 /// scf.yield %false : i1 466 /// } 467 /// 468 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 469 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 470 471 LogicalResult 472 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 473 ConversionPatternRewriter &rewriter) const override; 474 }; 475 } // namespace 476 477 LogicalResult 478 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 479 ConversionPatternRewriter &rewriter) const { 480 if (!llvm::all_of(op.shapes(), 481 [](Value v) { return !v.getType().isa<ShapeType>(); })) 482 return failure(); 483 484 Type i1Ty = rewriter.getI1Type(); 485 if (op.shapes().size() <= 1) { 486 rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty, 487 rewriter.getBoolAttr(true)); 488 return success(); 489 } 490 491 ShapeEqOp::Adaptor transformed(operands); 492 auto loc = op.getLoc(); 493 Type indexTy = rewriter.getIndexType(); 494 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 495 Value firstShape = transformed.shapes().front(); 496 Value firstRank = 497 rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero); 498 Value result = nullptr; 499 // Generate a linear sequence of compares, all with firstShape as lhs. 500 for (Value shape : transformed.shapes().drop_front(1)) { 501 Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero); 502 Value eqRank = 503 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank); 504 auto same = rewriter.create<IfOp>( 505 loc, i1Ty, eqRank, 506 [&](OpBuilder &b, Location loc) { 507 Value one = b.create<ConstantIndexOp>(loc, 1); 508 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 509 auto loop = b.create<scf::ForOp>( 510 loc, zero, firstRank, one, ValueRange{init}, 511 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 512 Value conj = args[0]; 513 Value lhsExtent = 514 b.create<tensor::ExtractOp>(loc, firstShape, iv); 515 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); 516 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 517 lhsExtent, rhsExtent); 518 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 519 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 520 }); 521 b.create<scf::YieldOp>(loc, loop.getResults()); 522 }, 523 [&](OpBuilder &b, Location loc) { 524 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 525 b.create<scf::YieldOp>(loc, result); 526 }); 527 result = !result ? same.getResult(0) 528 : rewriter.create<AndOp>(loc, result, same.getResult(0)); 529 } 530 rewriter.replaceOp(op, result); 531 return success(); 532 } 533 534 namespace { 535 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 536 public: 537 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 538 539 LogicalResult 540 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 541 ConversionPatternRewriter &rewriter) const override; 542 }; 543 } // namespace 544 545 LogicalResult ShapeOfOpConversion::matchAndRewrite( 546 ShapeOfOp op, ArrayRef<Value> operands, 547 ConversionPatternRewriter &rewriter) const { 548 549 // For now, only error-free types are supported by this lowering. 550 if (op.getType().isa<ShapeType>()) 551 return failure(); 552 553 // For ranked tensor arguments, lower to `tensor.from_elements`. 554 auto loc = op.getLoc(); 555 ShapeOfOp::Adaptor transformed(operands); 556 Value tensor = transformed.arg(); 557 Type tensorTy = tensor.getType(); 558 if (tensorTy.isa<RankedTensorType>()) { 559 560 // Build values for individual extents. 561 SmallVector<Value, 8> extentValues; 562 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 563 int64_t rank = rankedTensorTy.getRank(); 564 for (int64_t i = 0; i < rank; i++) { 565 if (rankedTensorTy.isDynamicDim(i)) { 566 Value extent = rewriter.create<memref::DimOp>(loc, tensor, i); 567 extentValues.push_back(extent); 568 } else { 569 Value extent = 570 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 571 extentValues.push_back(extent); 572 } 573 } 574 575 // Materialize extent tensor. 576 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( 577 loc, rewriter.getIndexType(), extentValues); 578 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 579 staticExtentTensor); 580 return success(); 581 } 582 583 // Lower to `tensor.generate` otherwise. 584 auto *ctx = rewriter.getContext(); 585 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 586 rewriter.replaceOpWithNewOp<tensor::GenerateOp>( 587 op, getExtentTensorType(ctx), ValueRange{rank}, 588 [&](OpBuilder &b, Location loc, ValueRange args) { 589 Value dim = args.front(); 590 Value extent = b.create<memref::DimOp>(loc, tensor, dim); 591 b.create<tensor::YieldOp>(loc, extent); 592 }); 593 594 return success(); 595 } 596 597 namespace { 598 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { 599 public: 600 using OpConversionPattern<SplitAtOp>::OpConversionPattern; 601 602 LogicalResult 603 matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands, 604 ConversionPatternRewriter &rewriter) const override; 605 }; 606 } // namespace 607 608 LogicalResult SplitAtOpConversion::matchAndRewrite( 609 SplitAtOp op, ArrayRef<Value> operands, 610 ConversionPatternRewriter &rewriter) const { 611 // Error conditions are not implemented, only lower if all operands and 612 // results are extent tensors. 613 if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, 614 [](Value v) { return v.getType().isa<ShapeType>(); })) 615 return failure(); 616 617 SplitAtOp::Adaptor transformed(op); 618 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 619 Value zero = b.create<ConstantIndexOp>(0); 620 Value rank = b.create<memref::DimOp>(transformed.operand(), zero); 621 622 // index < 0 ? index + rank : index 623 Value originalIndex = transformed.index(); 624 Value add = b.create<AddIOp>(originalIndex, rank); 625 Value indexIsNegative = 626 b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero); 627 Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex); 628 629 Value one = b.create<ConstantIndexOp>(1); 630 Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one); 631 Value tailSize = b.create<SubIOp>(rank, index); 632 Value tail = 633 b.create<SubTensorOp>(transformed.operand(), index, tailSize, one); 634 rewriter.replaceOp(op, {head, tail}); 635 return success(); 636 } 637 638 namespace { 639 class ToExtentTensorOpConversion 640 : public OpConversionPattern<ToExtentTensorOp> { 641 public: 642 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 643 644 LogicalResult 645 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 646 ConversionPatternRewriter &rewriter) const override { 647 ToExtentTensorOpAdaptor adaptor(operands); 648 649 if (!adaptor.input().getType().isa<RankedTensorType>()) 650 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 651 652 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), 653 adaptor.input()); 654 return success(); 655 } 656 }; 657 } // namespace 658 659 namespace { 660 /// Import the Shape Ops to Std Patterns. 661 #include "ShapeToStandard.cpp.inc" 662 } // namespace 663 664 namespace { 665 /// Conversion pass. 666 class ConvertShapeToStandardPass 667 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 668 669 void runOnOperation() override; 670 }; 671 } // namespace 672 673 void ConvertShapeToStandardPass::runOnOperation() { 674 // Setup target legality. 675 MLIRContext &ctx = getContext(); 676 ConversionTarget target(ctx); 677 target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect, 678 tensor::TensorDialect>(); 679 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); 680 681 // Setup conversion patterns. 682 RewritePatternSet patterns(&ctx); 683 populateShapeToStandardConversionPatterns(patterns); 684 685 // Apply conversion. 686 auto module = getOperation(); 687 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 688 signalPassFailure(); 689 } 690 691 void mlir::populateShapeToStandardConversionPatterns( 692 RewritePatternSet &patterns) { 693 // clang-format off 694 populateWithGenerated(patterns); 695 patterns.add< 696 AnyOpConversion, 697 BinaryOpConversion<AddOp, AddIOp>, 698 BinaryOpConversion<MulOp, MulIOp>, 699 BroadcastOpConverter, 700 ConstShapeOpConverter, 701 ConstSizeOpConversion, 702 IsBroadcastableOpConverter, 703 GetExtentOpConverter, 704 RankOpConverter, 705 ReduceOpConverter, 706 ShapeEqOpConverter, 707 ShapeOfOpConversion, 708 SplitAtOpConversion, 709 ToExtentTensorOpConversion>(patterns.getContext()); 710 // clang-format on 711 } 712 713 std::unique_ptr<OperationPass<ModuleOp>> 714 mlir::createConvertShapeToStandardPass() { 715 return std::make_unique<ConvertShapeToStandardPass>(); 716 } 717