1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::shape; 20 using namespace mlir::scf; 21 22 /// Conversion patterns. 23 namespace { 24 class AnyOpConversion : public OpConversionPattern<AnyOp> { 25 public: 26 using OpConversionPattern<AnyOp>::OpConversionPattern; 27 28 LogicalResult 29 matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override; 31 }; 32 } // namespace 33 34 LogicalResult 35 AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands, 36 ConversionPatternRewriter &rewriter) const { 37 AnyOp::Adaptor transformed(operands); 38 39 // Replace `any` with its first operand. 40 // Any operand would be a valid substitution. 41 rewriter.replaceOp(op, {transformed.inputs().front()}); 42 return success(); 43 } 44 45 namespace { 46 template <typename SrcOpTy, typename DstOpTy> 47 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { 48 public: 49 using OpConversionPattern<SrcOpTy>::OpConversionPattern; 50 51 LogicalResult 52 matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands, 53 ConversionPatternRewriter &rewriter) const override { 54 typename SrcOpTy::Adaptor transformed(operands); 55 56 // For now, only error-free types are supported by this lowering. 57 if (op.getType().template isa<SizeType>()) 58 return failure(); 59 60 rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(), 61 transformed.rhs()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { 69 using OpConversionPattern<BroadcastOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 } // namespace 76 77 LogicalResult BroadcastOpConverter::matchAndRewrite( 78 BroadcastOp op, ArrayRef<Value> operands, 79 ConversionPatternRewriter &rewriter) const { 80 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 81 // on shapes. 82 if (op.getType().isa<ShapeType>()) 83 return failure(); 84 85 assert(!op.lhs().getType().isa<ShapeType>() && 86 !op.rhs().getType().isa<ShapeType>()); 87 auto loc = op.getLoc(); 88 BroadcastOp::Adaptor transformed(operands); 89 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 90 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 91 92 // Find smaller and greater rank and extent tensor. 93 Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero); 94 Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero); 95 Value lhsRankULE = 96 rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); 97 Type indexTy = rewriter.getIndexType(); 98 Value lesserRank = 99 rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); 100 Value greaterRank = 101 rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); 102 Value lesserRankOperand = 103 rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs()); 104 Value greaterRankOperand = 105 rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs()); 106 107 // Allocate stack memory for the broadcasted extent tensor. 108 Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); 109 Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank}); 110 111 // Copy extents from greater operand that are not challenged. 112 Value rankDiff = 113 rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); 114 rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None, 115 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 116 Value extent = b.create<ExtractElementOp>( 117 loc, greaterRankOperand, ValueRange{iv}); 118 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 119 b.create<scf::YieldOp>(loc); 120 }); 121 122 // Determine remaining broadcasted extents. 123 rewriter.create<ForOp>( 124 loc, rankDiff, greaterRank, one, llvm::None, 125 [&](OpBuilder &b, Location loc, Value iv, ValueRange) { 126 Value greaterOperandExtent = 127 b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv}); 128 Value greaterOperandExtentIsOne = 129 b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one); 130 auto ifOp = b.create<IfOp>( 131 loc, TypeRange{indexTy}, greaterOperandExtentIsOne, 132 [&](OpBuilder &b, Location loc) { 133 Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); 134 Value lesserRankOperandExtent = b.create<ExtractElementOp>( 135 loc, lesserRankOperand, ValueRange{ivShifted}); 136 b.create<scf::YieldOp>(loc, lesserRankOperandExtent); 137 }, 138 [&](OpBuilder &b, Location loc) { 139 b.create<scf::YieldOp>(loc, greaterOperandExtent); 140 }); 141 Value extent = ifOp.getResult(0); 142 b.create<StoreOp>(loc, extent, mem, ValueRange{iv}); 143 b.create<scf::YieldOp>(loc); 144 }); 145 146 // Load broadcasted shape as an extent tensor. 147 rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem); 148 return success(); 149 } 150 151 namespace { 152 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { 153 public: 154 using OpConversionPattern<ConstShapeOp>::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const override; 159 }; 160 } // namespace 161 162 LogicalResult ConstShapeOpConverter::matchAndRewrite( 163 ConstShapeOp op, ArrayRef<Value> operands, 164 ConversionPatternRewriter &rewriter) const { 165 166 // For now, this lowering supports only extent tensors, not `shape.shape` 167 // types. 168 if (op.getType().isa<ShapeType>()) 169 return failure(); 170 171 auto loc = op.getLoc(); 172 SmallVector<Value, 4> extentOperands; 173 for (auto extent : op.shape()) { 174 extentOperands.push_back( 175 rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue())); 176 } 177 Type indexTy = rewriter.getIndexType(); 178 Value tensor = 179 rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands); 180 Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); 181 rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy); 182 return success(); 183 } 184 185 namespace { 186 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { 187 public: 188 using OpConversionPattern<ConstSizeOp>::OpConversionPattern; 189 190 LogicalResult 191 matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands, 192 ConversionPatternRewriter &rewriter) const override; 193 }; 194 } // namespace 195 196 LogicalResult ConstSizeOpConversion::matchAndRewrite( 197 ConstSizeOp op, ArrayRef<Value> operands, 198 ConversionPatternRewriter &rewriter) const { 199 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue()); 200 return success(); 201 } 202 203 namespace { 204 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { 205 using OpConversionPattern<GetExtentOp>::OpConversionPattern; 206 207 LogicalResult 208 matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands, 209 ConversionPatternRewriter &rewriter) const override; 210 }; 211 } // namespace 212 213 LogicalResult GetExtentOpConverter::matchAndRewrite( 214 GetExtentOp op, ArrayRef<Value> operands, 215 ConversionPatternRewriter &rewriter) const { 216 GetExtentOp::Adaptor transformed(operands); 217 218 // For now, only error-free types are supported by this lowering. 219 if (op.getType().isa<SizeType>()) 220 return failure(); 221 222 // Derive shape extent directly from shape origin if possible. This 223 // circumvents the necessity to materialize the shape in memory. 224 if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) { 225 if (shapeOfOp.arg().getType().isa<ShapedType>()) { 226 rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), 227 transformed.dim()); 228 return success(); 229 } 230 } 231 232 rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(), 233 transformed.shape(), 234 ValueRange{transformed.dim()}); 235 return success(); 236 } 237 238 namespace { 239 class RankOpConverter : public OpConversionPattern<shape::RankOp> { 240 public: 241 using OpConversionPattern<shape::RankOp>::OpConversionPattern; 242 243 LogicalResult 244 matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 245 ConversionPatternRewriter &rewriter) const override; 246 }; 247 } // namespace 248 249 LogicalResult 250 RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands, 251 ConversionPatternRewriter &rewriter) const { 252 // For now, this lowering supports only error-free types. 253 if (op.getType().isa<SizeType>()) 254 return failure(); 255 256 shape::RankOp::Adaptor transformed(operands); 257 rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0); 258 return success(); 259 } 260 261 namespace { 262 /// Converts `shape.reduce` to `scf.for`. 263 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { 264 public: 265 using OpConversionPattern::OpConversionPattern; 266 267 LogicalResult 268 matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 269 ConversionPatternRewriter &rewriter) const final; 270 }; 271 } // namespace 272 273 LogicalResult 274 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands, 275 ConversionPatternRewriter &rewriter) const { 276 // For now, this lowering is only defined on `tensor<?xindex>` operands. 277 if (op.shape().getType().isa<ShapeType>()) 278 return failure(); 279 280 auto loc = op.getLoc(); 281 shape::ReduceOp::Adaptor transformed(operands); 282 283 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 284 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 285 Type indexTy = rewriter.getIndexType(); 286 Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero); 287 288 auto loop = rewriter.create<scf::ForOp>( 289 loc, zero, rank, one, op.initVals(), 290 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 291 Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv); 292 293 SmallVector<Value, 2> mappedValues{iv, extent}; 294 mappedValues.append(args.begin(), args.end()); 295 296 BlockAndValueMapping mapping; 297 Block *reduceBody = op.getBody(); 298 mapping.map(reduceBody->getArguments(), mappedValues); 299 for (auto &nested : reduceBody->without_terminator()) 300 b.clone(nested, mapping); 301 302 SmallVector<Value, 2> mappedResults; 303 for (auto result : reduceBody->getTerminator()->getOperands()) 304 mappedResults.push_back(mapping.lookup(result)); 305 b.create<scf::YieldOp>(loc, mappedResults); 306 }); 307 308 rewriter.replaceOp(op, loop.getResults()); 309 return success(); 310 } 311 312 namespace { 313 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is 314 /// only defined on `tensor<?xindex>` operands. The test for equality first 315 /// compares their size and, if equal, checks every extent for equality. 316 /// 317 /// Example: 318 /// 319 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 320 /// 321 /// becomes 322 /// 323 /// %c0 = constant 0 : index 324 /// %0 = dim %arg0, %c0 : tensor<?xindex> 325 /// %1 = dim %arg1, %c0 : tensor<?xindex> 326 /// %2 = cmpi "eq", %0, %1 : index 327 /// %result = scf.if %2 -> (i1) { 328 /// %c1 = constant 1 : index 329 /// %true = constant true 330 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { 331 /// %5 = extract_element %arg0[%arg2] : tensor<?xindex> 332 /// %6 = extract_element %arg1[%arg2] : tensor<?xindex> 333 /// %7 = cmpi "eq", %5, %6 : index 334 /// %8 = and %arg3, %7 : i1 335 /// scf.yield %8 : i1 336 /// } 337 /// scf.yield %4 : i1 338 /// } else { 339 /// %false = constant false 340 /// scf.yield %false : i1 341 /// } 342 /// 343 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { 344 using OpConversionPattern<ShapeEqOp>::OpConversionPattern; 345 346 LogicalResult 347 matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 348 ConversionPatternRewriter &rewriter) const override; 349 }; 350 } // namespace 351 352 LogicalResult 353 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands, 354 ConversionPatternRewriter &rewriter) const { 355 // For now, this lowering is only defined on `tensor<?xindex>` operands, not 356 // on shapes. 357 if (op.lhs().getType().isa<ShapeType>() || 358 op.rhs().getType().isa<ShapeType>()) { 359 return failure(); 360 } 361 362 ShapeEqOp::Adaptor transformed(operands); 363 auto loc = op.getLoc(); 364 Type indexTy = rewriter.getIndexType(); 365 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 366 Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero); 367 Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero); 368 Value eqRank = 369 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank); 370 Type i1Ty = rewriter.getI1Type(); 371 rewriter.replaceOpWithNewOp<IfOp>( 372 op, i1Ty, eqRank, 373 [&](OpBuilder &b, Location loc) { 374 Value one = b.create<ConstantIndexOp>(loc, 1); 375 Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); 376 auto loop = b.create<scf::ForOp>( 377 loc, zero, lhsRank, one, ValueRange{init}, 378 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { 379 Value conj = args[0]; 380 Value lhsExtent = 381 b.create<ExtractElementOp>(loc, transformed.lhs(), iv); 382 Value rhsExtent = 383 b.create<ExtractElementOp>(loc, transformed.rhs(), iv); 384 Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq, 385 lhsExtent, rhsExtent); 386 Value conjNext = b.create<AndOp>(loc, conj, eqExtent); 387 b.create<scf::YieldOp>(loc, ValueRange({conjNext})); 388 }); 389 b.create<scf::YieldOp>(loc, loop.getResults()); 390 }, 391 [&](OpBuilder &b, Location loc) { 392 Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); 393 b.create<scf::YieldOp>(loc, result); 394 }); 395 return success(); 396 } 397 398 namespace { 399 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { 400 public: 401 using OpConversionPattern<ShapeOfOp>::OpConversionPattern; 402 403 LogicalResult 404 matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands, 405 ConversionPatternRewriter &rewriter) const override; 406 }; 407 } // namespace 408 409 LogicalResult ShapeOfOpConversion::matchAndRewrite( 410 ShapeOfOp op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const { 412 413 // For now, only error-free types are supported by this lowering. 414 if (op.getType().isa<ShapeType>()) 415 return failure(); 416 417 // For ranked tensor arguments, lower to `tensor_from_elements`. 418 auto loc = op.getLoc(); 419 ShapeOfOp::Adaptor transformed(operands); 420 Value tensor = transformed.arg(); 421 Type tensorTy = tensor.getType(); 422 if (tensorTy.isa<RankedTensorType>()) { 423 424 // Build values for individual extents. 425 SmallVector<Value, 8> extentValues; 426 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); 427 int64_t rank = rankedTensorTy.getRank(); 428 for (int64_t i = 0; i < rank; i++) { 429 if (rankedTensorTy.isDynamicDim(i)) { 430 Value extent = rewriter.create<DimOp>(loc, tensor, i); 431 extentValues.push_back(extent); 432 } else { 433 Value extent = 434 rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i)); 435 extentValues.push_back(extent); 436 } 437 } 438 439 // Materialize extent tensor. 440 Value staticExtentTensor = rewriter.create<TensorFromElementsOp>( 441 loc, rewriter.getIndexType(), extentValues); 442 rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor, 443 op.getType()); 444 return success(); 445 } 446 447 // Lower to `dynamic_tensor_from_elements` otherwise. 448 auto *ctx = rewriter.getContext(); 449 Value rank = rewriter.create<mlir::RankOp>(loc, tensor); 450 rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>( 451 op, getExtentTensorType(ctx), ValueRange{rank}, 452 [&](OpBuilder &b, Location loc, ValueRange args) { 453 Value dim = args.front(); 454 Value extent = b.create<DimOp>(loc, tensor, dim); 455 b.create<mlir::YieldOp>(loc, extent); 456 }); 457 458 return success(); 459 } 460 461 namespace { 462 class ToExtentTensorOpConversion 463 : public OpConversionPattern<ToExtentTensorOp> { 464 public: 465 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; 466 467 LogicalResult 468 matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands, 469 ConversionPatternRewriter &rewriter) const override { 470 ToExtentTensorOpAdaptor adaptor(operands); 471 472 if (!adaptor.input().getType().isa<RankedTensorType>()) 473 return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); 474 475 rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(), 476 op.getType()); 477 return success(); 478 } 479 }; 480 } // namespace 481 482 namespace { 483 /// Conversion pass. 484 class ConvertShapeToStandardPass 485 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { 486 487 void runOnOperation() override; 488 }; 489 } // namespace 490 491 void ConvertShapeToStandardPass::runOnOperation() { 492 // Setup target legality. 493 MLIRContext &ctx = getContext(); 494 ConversionTarget target(ctx); 495 target.addLegalDialect<StandardOpsDialect, SCFDialect>(); 496 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>(); 497 498 // Setup conversion patterns. 499 OwningRewritePatternList patterns; 500 populateShapeToStandardConversionPatterns(patterns, &ctx); 501 502 // Apply conversion. 503 auto module = getOperation(); 504 if (failed(applyPartialConversion(module, target, patterns))) 505 signalPassFailure(); 506 } 507 508 void mlir::populateShapeToStandardConversionPatterns( 509 OwningRewritePatternList &patterns, MLIRContext *ctx) { 510 // clang-format off 511 patterns.insert< 512 AnyOpConversion, 513 BinaryOpConversion<AddOp, AddIOp>, 514 BinaryOpConversion<MulOp, MulIOp>, 515 BroadcastOpConverter, 516 ConstShapeOpConverter, 517 ConstSizeOpConversion, 518 GetExtentOpConverter, 519 RankOpConverter, 520 ReduceOpConverter, 521 ShapeEqOpConverter, 522 ShapeOfOpConversion, 523 ToExtentTensorOpConversion>(ctx); 524 // clang-format on 525 } 526 527 std::unique_ptr<OperationPass<ModuleOp>> 528 mlir::createConvertShapeToStandardPass() { 529 return std::make_unique<ConvertShapeToStandardPass>(); 530 } 531