1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 95 Value lhsRankULE = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Value lesserRank = 99 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 100 Value greaterRank = 101 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 102 auto erasedRankType = 103 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 104 Value rankErasedLhs = 105 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 106 Value rankErasedRhs = 107 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 108 Value lesserRankOperand = 109 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 110 Value greaterRankOperand = 111 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 112 113 Value rankDiff = 114 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 115 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 116 op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, 117 [&](OpBuilder &b, Location loc, ValueRange args) { 118 Value outputDimension = args[0]; 119 Value isUnchallengedDimension = b.create<CmpIOp>( 120 loc, CmpIPredicate::ult, outputDimension, rankDiff); 121 Value greaterRankOperandExtent = b.create<ExtractElementOp>( 122 loc, greaterRankOperand, outputDimension); 123 // The initial dimensions of the greater-rank operand are unchallenged, 124 // so we can take them as-is. Otherwise, we need to do a comparison. 125 // We need an actual branch here (instead of a select) because the 126 // lesser-rank operand might be rank 0, so any extract_element would be 127 // invalid. 128 auto ifOp = b.create<IfOp>( 129 loc, TypeRange{indexTy}, isUnchallengedDimension, 130 [&](OpBuilder &b, Location loc) { 131 b.create<scf::YieldOp>(loc, greaterRankOperandExtent); 132 }, 133 [&](OpBuilder &b, Location loc) { 134 // The broadcasting logic is: 135 // - if one extent (here we arbitrarily choose the extent from 136 // the greater-rank operand) is equal to 1, then take the extent 137 // from the other operand 138 // - otherwise, take the extent as-is. 139 // Note that this logic remains correct in the presence of 140 // dimensions of zero extent. 141 Value lesserRankOperandDimension = 142 b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff); 143 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 144 loc, lesserRankOperand, 145 ValueRange{lesserRankOperandDimension}); 146 Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( 147 loc, CmpIPredicate::eq, greaterRankOperandExtent, one); 148 Value broadcastedExtent = b.create<SelectOp>( 149 loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, 150 greaterRankOperandExtent); 151 b.create<scf::YieldOp>(loc, broadcastedExtent); 152 }); 153 b.create<mlir::YieldOp>(loc, ifOp.getResult(0)); 154 }); 155 return success(); 156 } 157 158 namespace { 159 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 160 public: 161 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 162 163 LogicalResult 164 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 165 ConversionPatternRewriter &rewriter) const override; 166 }; 167 } // namespace 168 169 LogicalResult ConstShapeOpConverter::matchAndRewrite( 170 ConstShapeOp op, ArrayRef<Value> operands, 171 ConversionPatternRewriter &rewriter) const { 172 173 // For now, this lowering supports only extent tensors, not `shape.shape` 174 // types. 175 if (op.getType().isa<ShapeType>()) 176 return failure(); 177 178 auto loc = op.getLoc(); 179 SmallVector<Value, 4> extentOperands; 180 for (auto extent : op.shape()) { 181 extentOperands.push_back( 182 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 183 } 184 Type indexTy = rewriter.getIndexType(); 185 Value tensor = 186 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 187 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 188 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 189 return success(); 190 } 191 192 namespace { 193 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 194 public: 195 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 196 197 LogicalResult 198 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 199 ConversionPatternRewriter &rewriter) const override; 200 }; 201 } // namespace 202 203 LogicalResult ConstSizeOpConversion::matchAndRewrite( 204 ConstSizeOp op, ArrayRef<Value> operands, 205 ConversionPatternRewriter &rewriter) const { 206 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 207 return success(); 208 } 209 210 namespace { 211 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 212 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 213 214 LogicalResult 215 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 216 ConversionPatternRewriter &rewriter) const override; 217 }; 218 } // namespace 219 220 LogicalResult GetExtentOpConverter::matchAndRewrite( 221 GetExtentOp op, ArrayRef<Value> operands, 222 ConversionPatternRewriter &rewriter) const { 223 GetExtentOp::Adaptor transformed(operands); 224 225 // For now, only error-free types are supported by this lowering. 226 if (op.getType().isa<SizeType>()) 227 return failure(); 228 229 // Derive shape extent directly from shape origin if possible. This 230 // circumvents the necessity to materialize the shape in memory. 231 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 232 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 233 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 234 transformed.dim()); 235 return success(); 236 } 237 } 238 239 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 240 transformed.shape(), 241 ValueRange{transformed.dim()}); 242 return success(); 243 } 244 245 namespace { 246 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 247 public: 248 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 249 250 LogicalResult 251 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 252 ConversionPatternRewriter &rewriter) const override; 253 }; 254 } // namespace 255 256 LogicalResult 257 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 258 ConversionPatternRewriter &rewriter) const { 259 // For now, this lowering supports only error-free types. 260 if (op.getType().isa<SizeType>()) 261 return failure(); 262 263 shape::RankOp::Adaptor transformed(operands); 264 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 265 return success(); 266 } 267 268 namespace { 269 /// Converts `shape.reduce` to `scf.for`. 270 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 271 public: 272 using OpConversionPattern::OpConversionPattern; 273 274 LogicalResult 275 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 276 ConversionPatternRewriter &rewriter) const final; 277 }; 278 } // namespace 279 280 LogicalResult 281 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 282 ConversionPatternRewriter &rewriter) const { 283 // For now, this lowering is only defined on `tensor<?xindex>` operands. 284 if (op.shape().getType().isa<ShapeType>()) 285 return failure(); 286 287 auto loc = op.getLoc(); 288 shape::ReduceOp::Adaptor transformed(operands); 289 290 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 291 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 292 Type indexTy = rewriter.getIndexType(); 293 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 294 295 auto loop = rewriter.create<scf::ForOp>( 296 loc, zero, rank, one, op.initVals(), 297 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 298 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 299 300 SmallVector<Value, 2> mappedValues{iv, extent}; 301 mappedValues.append(args.begin(), args.end()); 302 303 BlockAndValueMapping mapping; 304 Block *reduceBody = op.getBody(); 305 mapping.map(reduceBody->getArguments(), mappedValues); 306 for (auto &nested : reduceBody->without_terminator()) 307 b.clone(nested, mapping); 308 309 SmallVector<Value, 2> mappedResults; 310 for (auto result : reduceBody->getTerminator()->getOperands()) 311 mappedResults.push_back(mapping.lookup(result)); 312 b.create<scf::YieldOp>(loc, mappedResults); 313 }); 314 315 rewriter.replaceOp(op, loop.getResults()); 316 return success(); 317 } 318 319 namespace { 320 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 321 /// only defined on `tensor<?xindex>` operands. The test for equality first 322 /// compares their size and, if equal, checks every extent for equality. 323 /// 324 /// Example: 325 /// 326 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 327 /// 328 /// becomes 329 /// 330 /// %c0 = constant 0 : index 331 /// %0 = dim %arg0, %c0 : tensor<?xindex> 332 /// %1 = dim %arg1, %c0 : tensor<?xindex> 333 /// %2 = cmpi "eq", %0, %1 : index 334 /// %result = scf.if %2 -> (i1) { 335 /// %c1 = constant 1 : index 336 /// %true = constant true 337 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 338 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 339 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 340 /// %7 = cmpi "eq", %5, %6 : index 341 /// %8 = and %arg3, %7 : i1 342 /// scf.yield %8 : i1 343 /// } 344 /// scf.yield %4 : i1 345 /// } else { 346 /// %false = constant false 347 /// scf.yield %false : i1 348 /// } 349 /// 350 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 351 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 352 353 LogicalResult 354 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const override; 356 }; 357 } // namespace 358 359 LogicalResult 360 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 361 ConversionPatternRewriter &rewriter) const { 362 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 363 // on shapes. 364 if (op.lhs().getType().isa<ShapeType>() || 365 op.rhs().getType().isa<ShapeType>()) { 366 return failure(); 367 } 368 369 ShapeEqOp::Adaptor transformed(operands); 370 auto loc = op.getLoc(); 371 Type indexTy = rewriter.getIndexType(); 372 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 373 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 374 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 375 Value eqRank = 376 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 377 Type i1Ty = rewriter.getI1Type(); 378 rewriter.replaceOpWithNewOp<IfOp>( 379 op, i1Ty, eqRank, 380 [&](OpBuilder &b, Location loc) { 381 Value one = b.create<ConstantIndexOp>(loc, 1); 382 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 383 auto loop = b.create<scf::ForOp>( 384 loc, zero, lhsRank, one, ValueRange{init}, 385 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 386 Value conj = args[0]; 387 Value lhsExtent = 388 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 389 Value rhsExtent = 390 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 391 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 392 lhsExtent, rhsExtent); 393 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 394 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 395 }); 396 b.create<scf::YieldOp>(loc, loop.getResults()); 397 }, 398 [&](OpBuilder &b, Location loc) { 399 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 400 b.create<scf::YieldOp>(loc, result); 401 }); 402 return success(); 403 } 404 405 namespace { 406 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 407 public: 408 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 409 410 LogicalResult 411 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 412 ConversionPatternRewriter &rewriter) const override; 413 }; 414 } // namespace 415 416 LogicalResult ShapeOfOpConversion::matchAndRewrite( 417 ShapeOfOp op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const { 419 420 // For now, only error-free types are supported by this lowering. 421 if (op.getType().isa<ShapeType>()) 422 return failure(); 423 424 // For ranked tensor arguments, lower to `tensor_from_elements`. 425 auto loc = op.getLoc(); 426 ShapeOfOp::Adaptor transformed(operands); 427 Value tensor = transformed.arg(); 428 Type tensorTy = tensor.getType(); 429 if (tensorTy.isa<RankedTensorType>()) { 430 431 // Build values for individual extents. 432 SmallVector<Value, 8> extentValues; 433 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 434 int64_t rank = rankedTensorTy.getRank(); 435 for (int64_t i = 0; i < rank; i++) { 436 if (rankedTensorTy.isDynamicDim(i)) { 437 Value extent = rewriter.create<DimOp>(loc, tensor, i); 438 extentValues.push_back(extent); 439 } else { 440 Value extent = 441 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 442 extentValues.push_back(extent); 443 } 444 } 445 446 // Materialize extent tensor. 447 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 448 loc, rewriter.getIndexType(), extentValues); 449 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 450 op.getType()); 451 return success(); 452 } 453 454 // Lower to `dynamic_tensor_from_elements` otherwise. 455 auto *ctx = rewriter.getContext(); 456 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 457 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 458 op, getExtentTensorType(ctx), ValueRange{rank}, 459 [&](OpBuilder &b, Location loc, ValueRange args) { 460 Value dim = args.front(); 461 Value extent = b.create<DimOp>(loc, tensor, dim); 462 b.create<mlir::YieldOp>(loc, extent); 463 }); 464 465 return success(); 466 } 467 468 namespace { 469 class ToExtentTensorOpConversion 470 : public OpConversionPattern<ToExtentTensorOp> { 471 public: 472 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 473 474 LogicalResult 475 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const override { 477 ToExtentTensorOpAdaptor adaptor(operands); 478 479 if (!adaptor.input().getType().isa<RankedTensorType>()) 480 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 481 482 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 483 op.getType()); 484 return success(); 485 } 486 }; 487 } // namespace 488 489 namespace { 490 /// Conversion pass. 491 class ConvertShapeToStandardPass 492 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 493 494 void runOnOperation() override; 495 }; 496 } // namespace 497 498 void ConvertShapeToStandardPass::runOnOperation() { 499 // Setup target legality. 500 MLIRContext &ctx = getContext(); 501 ConversionTarget target(ctx); 502 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 503 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 504 505 // Setup conversion patterns. 506 OwningRewritePatternList patterns; 507 populateShapeToStandardConversionPatterns(patterns, &ctx); 508 509 // Apply conversion. 510 auto module = getOperation(); 511 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 512 signalPassFailure(); 513 } 514 515 void mlir::populateShapeToStandardConversionPatterns( 516 OwningRewritePatternList &patterns, MLIRContext *ctx) { 517 // clang-format off 518 patterns.insert< 519 AnyOpConversion, 520 BinaryOpConversion<AddOp, AddIOp>, 521 BinaryOpConversion<MulOp, MulIOp>, 522 BroadcastOpConverter, 523 ConstShapeOpConverter, 524 ConstSizeOpConversion, 525 GetExtentOpConverter, 526 RankOpConverter, 527 ReduceOpConverter, 528 ShapeEqOpConverter, 529 ShapeOfOpConversion, 530 ToExtentTensorOpConversion>(ctx); 531 // clang-format on 532 } 533 534 std::unique_ptr<OperationPass<ModuleOp>> 535 mlir::createConvertShapeToStandardPass() { 536 return std::make_unique<ConvertShapeToStandardPass>(); 537 } 538