1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 95 Value lhsRankULE = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Value lesserRank = 99 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 100 Value greaterRank = 101 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 102 auto erasedRankType = 103 RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 104 Value rankErasedLhs = 105 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs()); 106 Value rankErasedRhs = 107 rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs()); 108 Value lesserRankOperand = 109 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); 110 Value greaterRankOperand = 111 rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); 112 113 // Allocate stack memory for the broadcasted extent tensor. 114 Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); 115 Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank}); 116 117 // Copy extents from greater operand that are not challenged. 118 Value rankDiff = 119 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 120 rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None, 121 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 122 Value extent = b.create<ExtractElementOp>( 123 loc, greaterRankOperand, ValueRange{iv}); 124 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 125 b.create<scf::YieldOp>(loc); 126 }); 127 128 // Determine remaining broadcasted extents. 129 rewriter.create<ForOp>( 130 loc, rankDiff, greaterRank, one, llvm::None, 131 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 132 Value greaterOperandExtent = 133 b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv}); 134 Value greaterOperandExtentIsOne = 135 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one); 136 auto ifOp = b.create<IfOp>( 137 loc, TypeRange{indexTy}, greaterOperandExtentIsOne, 138 [&](OpBuilder &b, Location loc) { 139 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 140 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 141 loc, lesserRankOperand, ValueRange{ivShifted}); 142 b.create<scf::YieldOp>(loc, lesserRankOperandExtent); 143 }, 144 [&](OpBuilder &b, Location loc) { 145 b.create<scf::YieldOp>(loc, greaterOperandExtent); 146 }); 147 Value extent = ifOp.getResult(0); 148 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 149 b.create<scf::YieldOp>(loc); 150 }); 151 152 // Load broadcasted shape as an extent tensor. 153 rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem); 154 return success(); 155 } 156 157 namespace { 158 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 159 public: 160 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 161 162 LogicalResult 163 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 164 ConversionPatternRewriter &rewriter) const override; 165 }; 166 } // namespace 167 168 LogicalResult ConstShapeOpConverter::matchAndRewrite( 169 ConstShapeOp op, ArrayRef<Value> operands, 170 ConversionPatternRewriter &rewriter) const { 171 172 // For now, this lowering supports only extent tensors, not `shape.shape` 173 // types. 174 if (op.getType().isa<ShapeType>()) 175 return failure(); 176 177 auto loc = op.getLoc(); 178 SmallVector<Value, 4> extentOperands; 179 for (auto extent : op.shape()) { 180 extentOperands.push_back( 181 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 182 } 183 Type indexTy = rewriter.getIndexType(); 184 Value tensor = 185 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 186 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 187 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 188 return success(); 189 } 190 191 namespace { 192 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 193 public: 194 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 195 196 LogicalResult 197 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 198 ConversionPatternRewriter &rewriter) const override; 199 }; 200 } // namespace 201 202 LogicalResult ConstSizeOpConversion::matchAndRewrite( 203 ConstSizeOp op, ArrayRef<Value> operands, 204 ConversionPatternRewriter &rewriter) const { 205 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 206 return success(); 207 } 208 209 namespace { 210 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 211 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 212 213 LogicalResult 214 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 215 ConversionPatternRewriter &rewriter) const override; 216 }; 217 } // namespace 218 219 LogicalResult GetExtentOpConverter::matchAndRewrite( 220 GetExtentOp op, ArrayRef<Value> operands, 221 ConversionPatternRewriter &rewriter) const { 222 GetExtentOp::Adaptor transformed(operands); 223 224 // For now, only error-free types are supported by this lowering. 225 if (op.getType().isa<SizeType>()) 226 return failure(); 227 228 // Derive shape extent directly from shape origin if possible. This 229 // circumvents the necessity to materialize the shape in memory. 230 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 231 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 232 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 233 transformed.dim()); 234 return success(); 235 } 236 } 237 238 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 239 transformed.shape(), 240 ValueRange{transformed.dim()}); 241 return success(); 242 } 243 244 namespace { 245 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 246 public: 247 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 248 249 LogicalResult 250 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 251 ConversionPatternRewriter &rewriter) const override; 252 }; 253 } // namespace 254 255 LogicalResult 256 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 257 ConversionPatternRewriter &rewriter) const { 258 // For now, this lowering supports only error-free types. 259 if (op.getType().isa<SizeType>()) 260 return failure(); 261 262 shape::RankOp::Adaptor transformed(operands); 263 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 264 return success(); 265 } 266 267 namespace { 268 /// Converts `shape.reduce` to `scf.for`. 269 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 270 public: 271 using OpConversionPattern::OpConversionPattern; 272 273 LogicalResult 274 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 275 ConversionPatternRewriter &rewriter) const final; 276 }; 277 } // namespace 278 279 LogicalResult 280 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 281 ConversionPatternRewriter &rewriter) const { 282 // For now, this lowering is only defined on `tensor<?xindex>` operands. 283 if (op.shape().getType().isa<ShapeType>()) 284 return failure(); 285 286 auto loc = op.getLoc(); 287 shape::ReduceOp::Adaptor transformed(operands); 288 289 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 290 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 291 Type indexTy = rewriter.getIndexType(); 292 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 293 294 auto loop = rewriter.create<scf::ForOp>( 295 loc, zero, rank, one, op.initVals(), 296 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 297 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 298 299 SmallVector<Value, 2> mappedValues{iv, extent}; 300 mappedValues.append(args.begin(), args.end()); 301 302 BlockAndValueMapping mapping; 303 Block *reduceBody = op.getBody(); 304 mapping.map(reduceBody->getArguments(), mappedValues); 305 for (auto &nested : reduceBody->without_terminator()) 306 b.clone(nested, mapping); 307 308 SmallVector<Value, 2> mappedResults; 309 for (auto result : reduceBody->getTerminator()->getOperands()) 310 mappedResults.push_back(mapping.lookup(result)); 311 b.create<scf::YieldOp>(loc, mappedResults); 312 }); 313 314 rewriter.replaceOp(op, loop.getResults()); 315 return success(); 316 } 317 318 namespace { 319 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 320 /// only defined on `tensor<?xindex>` operands. The test for equality first 321 /// compares their size and, if equal, checks every extent for equality. 322 /// 323 /// Example: 324 /// 325 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 326 /// 327 /// becomes 328 /// 329 /// %c0 = constant 0 : index 330 /// %0 = dim %arg0, %c0 : tensor<?xindex> 331 /// %1 = dim %arg1, %c0 : tensor<?xindex> 332 /// %2 = cmpi "eq", %0, %1 : index 333 /// %result = scf.if %2 -> (i1) { 334 /// %c1 = constant 1 : index 335 /// %true = constant true 336 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 337 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 338 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 339 /// %7 = cmpi "eq", %5, %6 : index 340 /// %8 = and %arg3, %7 : i1 341 /// scf.yield %8 : i1 342 /// } 343 /// scf.yield %4 : i1 344 /// } else { 345 /// %false = constant false 346 /// scf.yield %false : i1 347 /// } 348 /// 349 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 350 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 351 352 LogicalResult 353 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 354 ConversionPatternRewriter &rewriter) const override; 355 }; 356 } // namespace 357 358 LogicalResult 359 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 360 ConversionPatternRewriter &rewriter) const { 361 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 362 // on shapes. 363 if (op.lhs().getType().isa<ShapeType>() || 364 op.rhs().getType().isa<ShapeType>()) { 365 return failure(); 366 } 367 368 ShapeEqOp::Adaptor transformed(operands); 369 auto loc = op.getLoc(); 370 Type indexTy = rewriter.getIndexType(); 371 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 372 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 373 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 374 Value eqRank = 375 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 376 Type i1Ty = rewriter.getI1Type(); 377 rewriter.replaceOpWithNewOp<IfOp>( 378 op, i1Ty, eqRank, 379 [&](OpBuilder &b, Location loc) { 380 Value one = b.create<ConstantIndexOp>(loc, 1); 381 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 382 auto loop = b.create<scf::ForOp>( 383 loc, zero, lhsRank, one, ValueRange{init}, 384 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 385 Value conj = args[0]; 386 Value lhsExtent = 387 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 388 Value rhsExtent = 389 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 390 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 391 lhsExtent, rhsExtent); 392 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 393 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 394 }); 395 b.create<scf::YieldOp>(loc, loop.getResults()); 396 }, 397 [&](OpBuilder &b, Location loc) { 398 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 399 b.create<scf::YieldOp>(loc, result); 400 }); 401 return success(); 402 } 403 404 namespace { 405 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 406 public: 407 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 408 409 LogicalResult 410 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const override; 412 }; 413 } // namespace 414 415 LogicalResult ShapeOfOpConversion::matchAndRewrite( 416 ShapeOfOp op, ArrayRef<Value> operands, 417 ConversionPatternRewriter &rewriter) const { 418 419 // For now, only error-free types are supported by this lowering. 420 if (op.getType().isa<ShapeType>()) 421 return failure(); 422 423 // For ranked tensor arguments, lower to `tensor_from_elements`. 424 auto loc = op.getLoc(); 425 ShapeOfOp::Adaptor transformed(operands); 426 Value tensor = transformed.arg(); 427 Type tensorTy = tensor.getType(); 428 if (tensorTy.isa<RankedTensorType>()) { 429 430 // Build values for individual extents. 431 SmallVector<Value, 8> extentValues; 432 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 433 int64_t rank = rankedTensorTy.getRank(); 434 for (int64_t i = 0; i < rank; i++) { 435 if (rankedTensorTy.isDynamicDim(i)) { 436 Value extent = rewriter.create<DimOp>(loc, tensor, i); 437 extentValues.push_back(extent); 438 } else { 439 Value extent = 440 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 441 extentValues.push_back(extent); 442 } 443 } 444 445 // Materialize extent tensor. 446 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 447 loc, rewriter.getIndexType(), extentValues); 448 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 449 op.getType()); 450 return success(); 451 } 452 453 // Lower to `dynamic_tensor_from_elements` otherwise. 454 auto *ctx = rewriter.getContext(); 455 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 456 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 457 op, getExtentTensorType(ctx), ValueRange{rank}, 458 [&](OpBuilder &b, Location loc, ValueRange args) { 459 Value dim = args.front(); 460 Value extent = b.create<DimOp>(loc, tensor, dim); 461 b.create<mlir::YieldOp>(loc, extent); 462 }); 463 464 return success(); 465 } 466 467 namespace { 468 class ToExtentTensorOpConversion 469 : public OpConversionPattern<ToExtentTensorOp> { 470 public: 471 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 472 473 LogicalResult 474 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 475 ConversionPatternRewriter &rewriter) const override { 476 ToExtentTensorOpAdaptor adaptor(operands); 477 478 if (!adaptor.input().getType().isa<RankedTensorType>()) 479 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 480 481 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 482 op.getType()); 483 return success(); 484 } 485 }; 486 } // namespace 487 488 namespace { 489 /// Conversion pass. 490 class ConvertShapeToStandardPass 491 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 492 493 void runOnOperation() override; 494 }; 495 } // namespace 496 497 void ConvertShapeToStandardPass::runOnOperation() { 498 // Setup target legality. 499 MLIRContext &ctx = getContext(); 500 ConversionTarget target(ctx); 501 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 502 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 503 504 // Setup conversion patterns. 505 OwningRewritePatternList patterns; 506 populateShapeToStandardConversionPatterns(patterns, &ctx); 507 508 // Apply conversion. 509 auto module = getOperation(); 510 if (failed(applyPartialConversion(module, target, patterns))) 511 signalPassFailure(); 512 } 513 514 void mlir::populateShapeToStandardConversionPatterns( 515 OwningRewritePatternList &patterns, MLIRContext *ctx) { 516 // clang-format off 517 patterns.insert< 518 AnyOpConversion, 519 BinaryOpConversion<AddOp, AddIOp>, 520 BinaryOpConversion<MulOp, MulIOp>, 521 BroadcastOpConverter, 522 ConstShapeOpConverter, 523 ConstSizeOpConversion, 524 GetExtentOpConverter, 525 RankOpConverter, 526 ReduceOpConverter, 527 ShapeEqOpConverter, 528 ShapeOfOpConversion, 529 ToExtentTensorOpConversion>(ctx); 530 // clang-format on 531 } 532 533 std::unique_ptr<OperationPass<ModuleOp>> 534 mlir::createConvertShapeToStandardPass() { 535 return std::make_unique<ConvertShapeToStandardPass>(); 536 } 537