1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); 95 Value lhsSmaller = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Type extentTensorTy = op.getType(); 99 auto ifOp = rewriter.create<IfOp>( 100 loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, 101 lhsSmaller, 102 [&](OpBuilder &b, Location loc) { 103 b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(), 104 rhsRank, transformed.rhs()}); 105 }, 106 [&](OpBuilder &b, Location loc) { 107 b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(), 108 lhsRank, transformed.lhs()}); 109 }); 110 Value smallerRank = ifOp.getResult(0); 111 Value smallerOperand = ifOp.getResult(1); 112 Value greaterRank = ifOp.getResult(2); 113 Value greaterOperand = ifOp.getResult(3); 114 115 // Allocate stack memory for the broadcasted extent tensor. 116 Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); 117 Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank}); 118 119 // Copy extents from greater operand that are not challenged. 120 Value rankDiff = 121 rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank); 122 rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None, 123 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 124 Value extent = b.create<ExtractElementOp>( 125 loc, greaterOperand, ValueRange{iv}); 126 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 127 b.create<scf::YieldOp>(loc); 128 }); 129 130 // Determine remaining broadcasted extents. 131 rewriter.create<ForOp>( 132 loc, rankDiff, greaterRank, one, llvm::None, 133 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 134 Value greaterOperandExtent = 135 b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv}); 136 Value greaterOperandExtentIsOne = 137 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one); 138 auto ifOp = b.create<IfOp>( 139 loc, TypeRange{indexTy}, greaterOperandExtentIsOne, 140 [&](OpBuilder &b, Location loc) { 141 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 142 Value smallerOperandExtent = b.create<ExtractElementOp>( 143 loc, smallerOperand, ValueRange{ivShifted}); 144 b.create<scf::YieldOp>(loc, smallerOperandExtent); 145 }, 146 [&](OpBuilder &b, Location loc) { 147 b.create<scf::YieldOp>(loc, greaterOperandExtent); 148 }); 149 Value extent = ifOp.getResult(0); 150 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 151 b.create<scf::YieldOp>(loc); 152 }); 153 154 // Load broadcasted shape as an extent tensor. 155 rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem); 156 return success(); 157 } 158 159 namespace { 160 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 161 public: 162 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 163 164 LogicalResult 165 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const override; 167 }; 168 } // namespace 169 170 LogicalResult ConstShapeOpConverter::matchAndRewrite( 171 ConstShapeOp op, ArrayRef<Value> operands, 172 ConversionPatternRewriter &rewriter) const { 173 174 // For now, this lowering supports only extent tensors, not `shape.shape` 175 // types. 176 if (op.getType().isa<ShapeType>()) 177 return failure(); 178 179 auto loc = op.getLoc(); 180 SmallVector<Value, 4> extentOperands; 181 for (auto extent : op.shape()) { 182 extentOperands.push_back( 183 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 184 } 185 Type indexTy = rewriter.getIndexType(); 186 Value tensor = 187 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 188 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 189 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 190 return success(); 191 } 192 193 namespace { 194 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 195 public: 196 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 197 198 LogicalResult 199 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 200 ConversionPatternRewriter &rewriter) const override; 201 }; 202 } // namespace 203 204 LogicalResult ConstSizeOpConversion::matchAndRewrite( 205 ConstSizeOp op, ArrayRef<Value> operands, 206 ConversionPatternRewriter &rewriter) const { 207 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 208 return success(); 209 } 210 211 namespace { 212 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 213 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 214 215 LogicalResult 216 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 217 ConversionPatternRewriter &rewriter) const override; 218 }; 219 } // namespace 220 221 LogicalResult GetExtentOpConverter::matchAndRewrite( 222 GetExtentOp op, ArrayRef<Value> operands, 223 ConversionPatternRewriter &rewriter) const { 224 GetExtentOp::Adaptor transformed(operands); 225 226 // For now, only error-free types are supported by this lowering. 227 if (op.getType().isa<SizeType>()) 228 return failure(); 229 230 // Derive shape extent directly from shape origin if possible. This 231 // circumvents the necessity to materialize the shape in memory. 232 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 233 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 234 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 235 transformed.dim()); 236 return success(); 237 } 238 } 239 240 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 241 transformed.shape(), 242 ValueRange{transformed.dim()}); 243 return success(); 244 } 245 246 namespace { 247 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 248 public: 249 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 250 251 LogicalResult 252 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 253 ConversionPatternRewriter &rewriter) const override; 254 }; 255 } // namespace 256 257 LogicalResult 258 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 259 ConversionPatternRewriter &rewriter) const { 260 // For now, this lowering supports only error-free types. 261 if (op.getType().isa<SizeType>()) 262 return failure(); 263 264 shape::RankOp::Adaptor transformed(operands); 265 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 266 return success(); 267 } 268 269 namespace { 270 /// Converts `shape.reduce` to `scf.for`. 271 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 272 public: 273 using OpConversionPattern::OpConversionPattern; 274 275 LogicalResult 276 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 277 ConversionPatternRewriter &rewriter) const final; 278 }; 279 } // namespace 280 281 LogicalResult 282 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 283 ConversionPatternRewriter &rewriter) const { 284 // For now, this lowering is only defined on `tensor<?xindex>` operands. 285 if (op.shape().getType().isa<ShapeType>()) 286 return failure(); 287 288 auto loc = op.getLoc(); 289 shape::ReduceOp::Adaptor transformed(operands); 290 291 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 292 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 293 Type indexTy = rewriter.getIndexType(); 294 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 295 296 auto loop = rewriter.create<scf::ForOp>( 297 loc, zero, rank, one, op.initVals(), 298 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 299 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 300 301 SmallVector<Value, 2> mappedValues{iv, extent}; 302 mappedValues.append(args.begin(), args.end()); 303 304 BlockAndValueMapping mapping; 305 Block *reduceBody = op.getBody(); 306 mapping.map(reduceBody->getArguments(), mappedValues); 307 for (auto &nested : reduceBody->without_terminator()) 308 b.clone(nested, mapping); 309 310 SmallVector<Value, 2> mappedResults; 311 for (auto result : reduceBody->getTerminator()->getOperands()) 312 mappedResults.push_back(mapping.lookup(result)); 313 b.create<scf::YieldOp>(loc, mappedResults); 314 }); 315 316 rewriter.replaceOp(op, loop.getResults()); 317 return success(); 318 } 319 320 namespace { 321 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 322 /// only defined on `tensor<?xindex>` operands. The test for equality first 323 /// compares their size and, if equal, checks every extent for equality. 324 /// 325 /// Example: 326 /// 327 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 328 /// 329 /// becomes 330 /// 331 /// %c0 = constant 0 : index 332 /// %0 = dim %arg0, %c0 : tensor<?xindex> 333 /// %1 = dim %arg1, %c0 : tensor<?xindex> 334 /// %2 = cmpi "eq", %0, %1 : index 335 /// %result = scf.if %2 -> (i1) { 336 /// %c1 = constant 1 : index 337 /// %true = constant true 338 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 339 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 340 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 341 /// %7 = cmpi "eq", %5, %6 : index 342 /// %8 = and %arg3, %7 : i1 343 /// scf.yield %8 : i1 344 /// } 345 /// scf.yield %4 : i1 346 /// } else { 347 /// %false = constant false 348 /// scf.yield %false : i1 349 /// } 350 /// 351 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 352 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const override; 357 }; 358 } // namespace 359 360 LogicalResult 361 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 362 ConversionPatternRewriter &rewriter) const { 363 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 364 // on shapes. 365 if (op.lhs().getType().isa<ShapeType>() || 366 op.rhs().getType().isa<ShapeType>()) { 367 return failure(); 368 } 369 370 ShapeEqOp::Adaptor transformed(operands); 371 auto loc = op.getLoc(); 372 Type indexTy = rewriter.getIndexType(); 373 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 374 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 375 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 376 Value eqRank = 377 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 378 Type i1Ty = rewriter.getI1Type(); 379 rewriter.replaceOpWithNewOp<IfOp>( 380 op, i1Ty, eqRank, 381 [&](OpBuilder &b, Location loc) { 382 Value one = b.create<ConstantIndexOp>(loc, 1); 383 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 384 auto loop = b.create<scf::ForOp>( 385 loc, zero, lhsRank, one, ValueRange{init}, 386 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 387 Value conj = args[0]; 388 Value lhsExtent = 389 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 390 Value rhsExtent = 391 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 392 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 393 lhsExtent, rhsExtent); 394 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 395 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 396 }); 397 b.create<scf::YieldOp>(loc, loop.getResults()); 398 }, 399 [&](OpBuilder &b, Location loc) { 400 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 401 b.create<scf::YieldOp>(loc, result); 402 }); 403 return success(); 404 } 405 406 namespace { 407 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 408 public: 409 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 410 411 LogicalResult 412 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 413 ConversionPatternRewriter &rewriter) const override; 414 }; 415 } // namespace 416 417 LogicalResult ShapeOfOpConversion::matchAndRewrite( 418 ShapeOfOp op, ArrayRef<Value> operands, 419 ConversionPatternRewriter &rewriter) const { 420 421 // For now, only error-free types are supported by this lowering. 422 if (op.getType().isa<ShapeType>()) 423 return failure(); 424 425 // For ranked tensor arguments, lower to `tensor_from_elements`. 426 auto loc = op.getLoc(); 427 ShapeOfOp::Adaptor transformed(operands); 428 Value tensor = transformed.arg(); 429 Type tensorTy = tensor.getType(); 430 if (tensorTy.isa<RankedTensorType>()) { 431 432 // Build values for individual extents. 433 SmallVector<Value, 8> extentValues; 434 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 435 int64_t rank = rankedTensorTy.getRank(); 436 for (int64_t i = 0; i < rank; i++) { 437 if (rankedTensorTy.isDynamicDim(i)) { 438 Value extent = rewriter.create<DimOp>(loc, tensor, i); 439 extentValues.push_back(extent); 440 } else { 441 Value extent = 442 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 443 extentValues.push_back(extent); 444 } 445 } 446 447 // Materialize extent tensor. 448 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 449 loc, rewriter.getIndexType(), extentValues); 450 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 451 op.getType()); 452 return success(); 453 } 454 455 // Lower to `dynamic_tensor_from_elements` otherwise. 456 auto *ctx = rewriter.getContext(); 457 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 458 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 459 op, getExtentTensorType(ctx), ValueRange{rank}, 460 [&](OpBuilder &b, Location loc, ValueRange args) { 461 Value dim = args.front(); 462 Value extent = b.create<DimOp>(loc, tensor, dim); 463 b.create<mlir::YieldOp>(loc, extent); 464 }); 465 466 return success(); 467 } 468 469 namespace { 470 class ToExtentTensorOpConversion 471 : public OpConversionPattern<ToExtentTensorOp> { 472 public: 473 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 474 475 LogicalResult 476 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 477 ConversionPatternRewriter &rewriter) const override { 478 ToExtentTensorOpAdaptor adaptor(operands); 479 480 if (!adaptor.input().getType().isa<RankedTensorType>()) 481 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 482 483 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 484 op.getType()); 485 return success(); 486 } 487 }; 488 } // namespace 489 490 namespace { 491 /// Conversion pass. 492 class ConvertShapeToStandardPass 493 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 494 495 void runOnOperation() override; 496 }; 497 } // namespace 498 499 void ConvertShapeToStandardPass::runOnOperation() { 500 // Setup target legality. 501 MLIRContext &ctx = getContext(); 502 ConversionTarget target(ctx); 503 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 504 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 505 506 // Setup conversion patterns. 507 OwningRewritePatternList patterns; 508 populateShapeToStandardConversionPatterns(patterns, &ctx); 509 510 // Apply conversion. 511 auto module = getOperation(); 512 if (failed(applyPartialConversion(module, target, patterns))) 513 signalPassFailure(); 514 } 515 516 void mlir::populateShapeToStandardConversionPatterns( 517 OwningRewritePatternList &patterns, MLIRContext *ctx) { 518 // clang-format off 519 patterns.insert< 520 AnyOpConversion, 521 BinaryOpConversion<AddOp, AddIOp>, 522 BinaryOpConversion<MulOp, MulIOp>, 523 BroadcastOpConverter, 524 ConstShapeOpConverter, 525 ConstSizeOpConversion, 526 GetExtentOpConverter, 527 RankOpConverter, 528 ReduceOpConverter, 529 ShapeEqOpConverter, 530 ShapeOfOpConversion, 531 ToExtentTensorOpConversion>(ctx); 532 // clang-format on 533 } 534 535 std::unique_ptr<OperationPass<ModuleOp>> 536 mlir::createConvertShapeToStandardPass() { 537 return std::make_unique<ConvertShapeToStandardPass>(); 538 } 539