1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); 95 Value lhsSmaller = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Type extentTensorTy = op.getType(); 99 auto ifOp = rewriter.create<IfOp>( 100 loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, 101 lhsSmaller, 102 [&](OpBuilder &b, Location loc) { 103 b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(), 104 rhsRank, transformed.rhs()}); 105 }, 106 [&](OpBuilder &b, Location loc) { 107 b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(), 108 lhsRank, transformed.lhs()}); 109 }); 110 Value smallerRank = ifOp.getResult(0); 111 Value smallerOperand = ifOp.getResult(1); 112 Value greaterRank = ifOp.getResult(2); 113 Value greaterOperand = ifOp.getResult(3); 114 115 // Allocate stack memory for the broadcasted extent tensor. 116 Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); 117 Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank}); 118 119 // Copy extents from greater operand that are not challenged. 120 Value rankDiff = 121 rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank); 122 rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None, 123 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 124 Value extent = b.create<ExtractElementOp>( 125 loc, greaterOperand, ValueRange{iv}); 126 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 127 b.create<scf::YieldOp>(loc); 128 }); 129 130 // Determine remaining broadcasted extents. 131 rewriter.create<ForOp>( 132 loc, rankDiff, greaterRank, one, llvm::None, 133 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 134 Value greaterOperandExtent = 135 b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv}); 136 Value greaterOperandExtentIsOne = 137 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one); 138 auto ifOp = b.create<IfOp>( 139 loc, TypeRange{indexTy}, greaterOperandExtentIsOne, 140 [&](OpBuilder &b, Location loc) { 141 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 142 Value smallerOperandExtent = b.create<ExtractElementOp>( 143 loc, smallerOperand, ValueRange{ivShifted}); 144 b.create<scf::YieldOp>(loc, smallerOperandExtent); 145 }, 146 [&](OpBuilder &b, Location loc) { 147 b.create<scf::YieldOp>(loc, greaterOperandExtent); 148 }); 149 Value extent = ifOp.getResult(0); 150 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 151 b.create<scf::YieldOp>(loc); 152 }); 153 154 // Load broadcasted shape as an extent tensor. 155 rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem); 156 return success(); 157 } 158 159 namespace { 160 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 161 public: 162 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 163 164 LogicalResult 165 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 166 ConversionPatternRewriter &rewriter) const override; 167 }; 168 } // namespace 169 170 LogicalResult ConstShapeOpConverter::matchAndRewrite( 171 ConstShapeOp op, ArrayRef<Value> operands, 172 ConversionPatternRewriter &rewriter) const { 173 174 // For now, this lowering supports only extent tensors, not `shape.shape` 175 // types. 176 if (op.getType().isa<ShapeType>()) 177 return failure(); 178 179 auto loc = op.getLoc(); 180 SmallVector<Value, 4> extentOperands; 181 for (auto extent : op.shape()) { 182 extentOperands.push_back( 183 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 184 } 185 Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands); 186 Type indexTy = rewriter.getIndexType(); 187 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 188 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 189 return success(); 190 } 191 192 namespace { 193 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 194 public: 195 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 196 197 LogicalResult 198 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 199 ConversionPatternRewriter &rewriter) const override; 200 }; 201 } // namespace 202 203 LogicalResult ConstSizeOpConversion::matchAndRewrite( 204 ConstSizeOp op, ArrayRef<Value> operands, 205 ConversionPatternRewriter &rewriter) const { 206 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 207 return success(); 208 } 209 210 namespace { 211 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 212 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 213 214 LogicalResult 215 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 216 ConversionPatternRewriter &rewriter) const override; 217 }; 218 } // namespace 219 220 LogicalResult GetExtentOpConverter::matchAndRewrite( 221 GetExtentOp op, ArrayRef<Value> operands, 222 ConversionPatternRewriter &rewriter) const { 223 GetExtentOp::Adaptor transformed(operands); 224 225 // For now, only error-free types are supported by this lowering. 226 if (op.getType().isa<SizeType>()) 227 return failure(); 228 229 // Derive shape extent directly from shape origin if possible. This 230 // circumvents the necessity to materialize the shape in memory. 231 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 232 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 233 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 234 transformed.dim()); 235 return success(); 236 } 237 } 238 239 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 240 transformed.shape(), 241 ValueRange{transformed.dim()}); 242 return success(); 243 } 244 245 namespace { 246 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 247 public: 248 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 249 250 LogicalResult 251 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 252 ConversionPatternRewriter &rewriter) const override; 253 }; 254 } // namespace 255 256 LogicalResult 257 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 258 ConversionPatternRewriter &rewriter) const { 259 // For now, this lowering supports only error-free types. 260 if (op.getType().isa<SizeType>()) 261 return failure(); 262 263 shape::RankOp::Adaptor transformed(operands); 264 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 265 return success(); 266 } 267 268 namespace { 269 /// Converts `shape.reduce` to `scf.for`. 270 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 271 public: 272 using OpConversionPattern::OpConversionPattern; 273 274 LogicalResult 275 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 276 ConversionPatternRewriter &rewriter) const final; 277 }; 278 } // namespace 279 280 LogicalResult 281 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 282 ConversionPatternRewriter &rewriter) const { 283 // For now, this lowering is only defined on `tensor<?xindex>` operands. 284 if (op.shape().getType().isa<ShapeType>()) 285 return failure(); 286 287 auto loc = op.getLoc(); 288 shape::ReduceOp::Adaptor transformed(operands); 289 290 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 291 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 292 Type indexTy = rewriter.getIndexType(); 293 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 294 295 auto loop = rewriter.create<scf::ForOp>( 296 loc, zero, rank, one, op.initVals(), 297 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 298 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 299 300 SmallVector<Value, 2> mappedValues{iv, extent}; 301 mappedValues.append(args.begin(), args.end()); 302 303 BlockAndValueMapping mapping; 304 Block *reduceBody = op.getBody(); 305 mapping.map(reduceBody->getArguments(), mappedValues); 306 for (auto &nested : reduceBody->without_terminator()) 307 b.clone(nested, mapping); 308 309 SmallVector<Value, 2> mappedResults; 310 for (auto result : reduceBody->getTerminator()->getOperands()) 311 mappedResults.push_back(mapping.lookup(result)); 312 b.create<scf::YieldOp>(loc, mappedResults); 313 }); 314 315 rewriter.replaceOp(op, loop.getResults()); 316 return success(); 317 } 318 319 namespace { 320 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 321 /// only defined on `tensor<?xindex>` operands. The test for equality first 322 /// compares their size and, if equal, checks every extent for equality. 323 /// 324 /// Example: 325 /// 326 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 327 /// 328 /// becomes 329 /// 330 /// %c0 = constant 0 : index 331 /// %0 = dim %arg0, %c0 : tensor<?xindex> 332 /// %1 = dim %arg1, %c0 : tensor<?xindex> 333 /// %2 = cmpi "eq", %0, %1 : index 334 /// %result = scf.if %2 -> (i1) { 335 /// %c1 = constant 1 : index 336 /// %true = constant true 337 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 338 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 339 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 340 /// %7 = cmpi "eq", %5, %6 : index 341 /// %8 = and %arg3, %7 : i1 342 /// scf.yield %8 : i1 343 /// } 344 /// scf.yield %4 : i1 345 /// } else { 346 /// %false = constant false 347 /// scf.yield %false : i1 348 /// } 349 /// 350 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 351 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 352 353 LogicalResult 354 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const override; 356 }; 357 } // namespace 358 359 LogicalResult 360 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 361 ConversionPatternRewriter &rewriter) const { 362 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 363 // on shapes. 364 if (op.lhs().getType().isa<ShapeType>() || 365 op.rhs().getType().isa<ShapeType>()) { 366 return failure(); 367 } 368 369 ShapeEqOp::Adaptor transformed(operands); 370 auto loc = op.getLoc(); 371 Type indexTy = rewriter.getIndexType(); 372 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 373 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 374 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 375 Value eqRank = 376 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 377 Type i1Ty = rewriter.getI1Type(); 378 rewriter.replaceOpWithNewOp<IfOp>( 379 op, i1Ty, eqRank, 380 [&](OpBuilder &b, Location loc) { 381 Value one = b.create<ConstantIndexOp>(loc, 1); 382 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 383 auto loop = b.create<scf::ForOp>( 384 loc, zero, lhsRank, one, ValueRange{init}, 385 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 386 Value conj = args[0]; 387 Value lhsExtent = 388 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 389 Value rhsExtent = 390 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 391 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 392 lhsExtent, rhsExtent); 393 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 394 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 395 }); 396 b.create<scf::YieldOp>(loc, loop.getResults()); 397 }, 398 [&](OpBuilder &b, Location loc) { 399 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 400 b.create<scf::YieldOp>(loc, result); 401 }); 402 return success(); 403 } 404 405 namespace { 406 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 407 public: 408 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 409 410 LogicalResult 411 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 412 ConversionPatternRewriter &rewriter) const override; 413 }; 414 } // namespace 415 416 LogicalResult ShapeOfOpConversion::matchAndRewrite( 417 ShapeOfOp op, ArrayRef<Value> operands, 418 ConversionPatternRewriter &rewriter) const { 419 420 // For now, only error-free types are supported by this lowering. 421 if (op.getType().isa<ShapeType>()) 422 return failure(); 423 424 // For ranked tensor arguments, lower to `tensor_from_elements`. 425 auto loc = op.getLoc(); 426 ShapeOfOp::Adaptor transformed(operands); 427 Value tensor = transformed.arg(); 428 Type tensorTy = tensor.getType(); 429 if (tensorTy.isa<RankedTensorType>()) { 430 431 // Build values for individual extents. 432 SmallVector<Value, 8> extentValues; 433 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 434 int64_t rank = rankedTensorTy.getRank(); 435 for (int64_t i = 0; i < rank; i++) { 436 if (rankedTensorTy.isDynamicDim(i)) { 437 Value extent = rewriter.create<DimOp>(loc, tensor, i); 438 extentValues.push_back(extent); 439 } else { 440 Value extent = 441 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 442 extentValues.push_back(extent); 443 } 444 } 445 446 // Materialize extent tensor. 447 Value staticExtentTensor = 448 rewriter.create<TensorFromElementsOp>(loc, extentValues); 449 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 450 op.getType()); 451 return success(); 452 } 453 454 // Lower to `dynamic_tensor_from_elements` otherwise. 455 auto *ctx = rewriter.getContext(); 456 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 457 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 458 op, getExtentTensorType(ctx), ValueRange{rank}, 459 [&](OpBuilder &b, Location loc, ValueRange args) { 460 Value dim = args.front(); 461 Value extent = b.create<DimOp>(loc, tensor, dim); 462 b.create<mlir::YieldOp>(loc, extent); 463 }); 464 465 return success(); 466 } 467 468 namespace { 469 class ToExtentTensorOpConversion 470 : public OpConversionPattern<ToExtentTensorOp> { 471 public: 472 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 473 474 LogicalResult 475 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const override { 477 ToExtentTensorOpAdaptor adaptor(operands); 478 479 if (!adaptor.input().getType().isa<RankedTensorType>()) 480 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 481 482 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 483 op.getType()); 484 return success(); 485 } 486 }; 487 } // namespace 488 489 namespace { 490 /// Conversion pass. 491 class ConvertShapeToStandardPass 492 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 493 494 void runOnOperation() override; 495 }; 496 } // namespace 497 498 void ConvertShapeToStandardPass::runOnOperation() { 499 // Setup target legality. 500 MLIRContext &ctx = getContext(); 501 ConversionTarget target(ctx); 502 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 503 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 504 505 // Setup conversion patterns. 506 OwningRewritePatternList patterns; 507 populateShapeToStandardConversionPatterns(patterns, &ctx); 508 509 // Apply conversion. 510 auto module = getOperation(); 511 if (failed(applyPartialConversion(module, target, patterns))) 512 signalPassFailure(); 513 } 514 515 void mlir::populateShapeToStandardConversionPatterns( 516 OwningRewritePatternList &patterns, MLIRContext *ctx) { 517 // clang-format off 518 patterns.insert< 519 AnyOpConversion, 520 BinaryOpConversion<AddOp, AddIOp>, 521 BinaryOpConversion<MulOp, MulIOp>, 522 BroadcastOpConverter, 523 ConstShapeOpConverter, 524 ConstSizeOpConversion, 525 GetExtentOpConverter, 526 RankOpConverter, 527 ReduceOpConverter, 528 ShapeEqOpConverter, 529 ShapeOfOpConversion, 530 ToExtentTensorOpConversion>(ctx); 531 // clang-format on 532 } 533 534 std::unique_ptr<OperationPass<ModuleOp>> 535 mlir::createConvertShapeToStandardPass() { 536 return std::make_unique<ConvertShapeToStandardPass>(); 537 } 538