1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "llvm/ADT/SmallString.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 using namespace mlir::shape; 22 23 namespace { 24 #include "ShapeCanonicalization.inc" 25 } 26 27 static RankedTensorType getExtentTensorType(MLIRContext *ctx) { 28 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 29 } 30 31 static bool isErrorPropagationPossible(TypeRange operandTypes) { 32 for (Type ty : operandTypes) 33 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>()) 34 return true; 35 return false; 36 } 37 38 static LogicalResult verifySizeOrIndexOp(Operation *op) { 39 assert(op != nullptr && op->getNumResults() == 1); 40 Type resultTy = op->getResultTypes().front(); 41 if (isErrorPropagationPossible(op->getOperandTypes())) { 42 if (!resultTy.isa<SizeType>()) 43 return op->emitOpError() 44 << "if at least one of the operands can hold error values then " 45 "the result must be of type `size` to propagate them"; 46 } 47 return success(); 48 } 49 50 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 51 assert(op != nullptr && op->getNumResults() == 1); 52 Type resultTy = op->getResultTypes().front(); 53 if (isErrorPropagationPossible(op->getOperandTypes())) { 54 if (!resultTy.isa<ShapeType>()) 55 return op->emitOpError() 56 << "if at least one of the operands can hold error values then " 57 "the result must be of type `shape` to propagate them"; 58 } 59 return success(); 60 } 61 62 ShapeDialect::ShapeDialect(MLIRContext *context) 63 : Dialect(getDialectNamespace(), context) { 64 addOperations< 65 #define GET_OP_LIST 66 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 67 >(); 68 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 69 WitnessType>(); 70 // Allow unknown operations during prototyping and testing. As the dialect is 71 // still evolving it makes it simple to start with an unregistered ops and 72 // try different variants before actually defining the op. 73 allowUnknownOperations(); 74 } 75 76 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 77 Attribute value, Type type, 78 Location loc) { 79 if (type.isa<ShapeType>() || 80 type == getExtentTensorType(builder.getContext())) 81 return builder.create<ConstShapeOp>(loc, type, 82 value.cast<DenseIntElementsAttr>()); 83 if (type.isa<SizeType>()) 84 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 85 if (type.isa<WitnessType>()) 86 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 87 if (type.isa<IndexType>()) 88 return builder.create<ConstantOp>(loc, type, value); 89 return nullptr; 90 } 91 92 /// Parse a type registered to this dialect. 93 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 94 StringRef keyword; 95 if (parser.parseKeyword(&keyword)) 96 return Type(); 97 98 if (keyword == "component") 99 return ComponentType::get(getContext()); 100 if (keyword == "element") 101 return ElementType::get(getContext()); 102 if (keyword == "shape") 103 return ShapeType::get(getContext()); 104 if (keyword == "size") 105 return SizeType::get(getContext()); 106 if (keyword == "value_shape") 107 return ValueShapeType::get(getContext()); 108 if (keyword == "witness") 109 return WitnessType::get(getContext()); 110 111 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 112 return Type(); 113 } 114 115 /// Print a type registered to this dialect. 116 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 117 switch (type.getKind()) { 118 case ShapeTypes::Component: 119 os << "component"; 120 return; 121 case ShapeTypes::Element: 122 os << "element"; 123 return; 124 case ShapeTypes::Size: 125 os << "size"; 126 return; 127 case ShapeTypes::Shape: 128 os << "shape"; 129 return; 130 case ShapeTypes::ValueShape: 131 os << "value_shape"; 132 return; 133 case ShapeTypes::Witness: 134 os << "witness"; 135 return; 136 default: 137 llvm_unreachable("unexpected 'shape' type kind"); 138 } 139 } 140 141 //===----------------------------------------------------------------------===// 142 // AnyOp 143 //===----------------------------------------------------------------------===// 144 145 // TODO: Canonicalization should be implemented for shapes that can be 146 // determined through mixtures of the known dimensions of the inputs. 147 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 148 // Only the last operand is checked because AnyOp is commutative. 149 if (operands.back()) 150 return operands.back(); 151 152 return nullptr; 153 } 154 155 //===----------------------------------------------------------------------===// 156 // AssumingOp 157 //===----------------------------------------------------------------------===// 158 159 static ParseResult parseAssumingOp(OpAsmParser &parser, 160 OperationState &result) { 161 result.regions.reserve(1); 162 Region *doRegion = result.addRegion(); 163 164 auto &builder = parser.getBuilder(); 165 OpAsmParser::OperandType cond; 166 if (parser.parseOperand(cond) || 167 parser.resolveOperand(cond, builder.getType<WitnessType>(), 168 result.operands)) 169 return failure(); 170 171 // Parse optional results type list. 172 if (parser.parseOptionalArrowTypeList(result.types)) 173 return failure(); 174 175 // Parse the region and add a terminator if elided. 176 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 177 return failure(); 178 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 179 180 // Parse the optional attribute list. 181 if (parser.parseOptionalAttrDict(result.attributes)) 182 return failure(); 183 return success(); 184 } 185 186 static void print(OpAsmPrinter &p, AssumingOp op) { 187 bool yieldsResults = !op.results().empty(); 188 189 p << AssumingOp::getOperationName() << " " << op.witness(); 190 if (yieldsResults) { 191 p << " -> (" << op.getResultTypes() << ")"; 192 } 193 p.printRegion(op.doRegion(), 194 /*printEntryBlockArgs=*/false, 195 /*printBlockTerminators=*/yieldsResults); 196 p.printOptionalAttrDict(op.getAttrs()); 197 } 198 199 namespace { 200 // Removes AssumingOp with a passing witness and inlines the region. 201 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 202 using OpRewritePattern<AssumingOp>::OpRewritePattern; 203 204 LogicalResult matchAndRewrite(AssumingOp op, 205 PatternRewriter &rewriter) const override { 206 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 207 if (!witness || !witness.passingAttr()) 208 return failure(); 209 210 AssumingOp::inlineRegionIntoParent(op, rewriter); 211 return success(); 212 } 213 }; 214 } // namespace 215 216 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 217 MLIRContext *context) { 218 // If taking a passing witness, inline region. 219 patterns.insert<AssumingWithTrue>(context); 220 } 221 222 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 223 PatternRewriter &rewriter) { 224 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 225 auto *assumingBlock = op.getBody(); 226 auto initPosition = rewriter.getInsertionPoint(); 227 auto *blockAfterAssuming = 228 rewriter.splitBlock(blockBeforeAssuming, initPosition); 229 230 // Remove the AssumingOp and AssumingYieldOp. 231 auto &yieldOp = assumingBlock->back(); 232 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 233 rewriter.replaceOp(op, yieldOp.getOperands()); 234 rewriter.eraseOp(&yieldOp); 235 236 // Merge blocks together as there was no branching behavior from the 237 // AssumingOp. 238 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 239 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // AssumingAllOp 244 //===----------------------------------------------------------------------===// 245 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 246 // Iterate in reverse to first handle all constant operands. They are 247 // guaranteed to be the tail of the inputs because this is commutative. 248 for (int idx = operands.size() - 1; idx >= 0; idx--) { 249 Attribute a = operands[idx]; 250 // Cannot fold if any inputs are not constant; 251 if (!a) 252 return nullptr; 253 254 // We do not need to keep statically known values after handling them in 255 // this method. 256 getOperation()->eraseOperand(idx); 257 258 // Always false if any input is statically known false 259 if (!a.cast<BoolAttr>().getValue()) 260 return a; 261 } 262 // If this is reached, all inputs were statically known passing. 263 return BoolAttr::get(true, getContext()); 264 } 265 266 static LogicalResult verify(AssumingAllOp op) { 267 // Ensure that AssumingAllOp contains at least one operand 268 if (op.getNumOperands() == 0) 269 return op.emitOpError("no operands specified"); 270 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BroadcastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 279 if (!operands[1]) 280 return nullptr; 281 282 auto rhsShape = llvm::to_vector<6>( 283 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 284 if (rhsShape.empty()) 285 return lhs(); 286 287 if (!operands[0]) 288 return nullptr; 289 290 auto lhsShape = llvm::to_vector<6>( 291 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 292 if (lhsShape.empty()) 293 return rhs(); 294 295 SmallVector<int64_t, 6> resultShape; 296 // If the shapes are not compatible, we can't fold it. 297 // TODO: Fold to an "error". 298 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 299 return nullptr; 300 Builder builder(getContext()); 301 return builder.getIndexTensorAttr(resultShape); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // ConcatOp 306 //===----------------------------------------------------------------------===// 307 308 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 309 if (!operands[0] || !operands[1]) 310 return nullptr; 311 auto lhsShape = llvm::to_vector<6>( 312 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 313 auto rhsShape = llvm::to_vector<6>( 314 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 315 SmallVector<int64_t, 6> resultShape; 316 resultShape.append(lhsShape.begin(), lhsShape.end()); 317 resultShape.append(rhsShape.begin(), rhsShape.end()); 318 Builder builder(getContext()); 319 return builder.getIndexTensorAttr(resultShape); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // ConstShapeOp 324 //===----------------------------------------------------------------------===// 325 326 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 327 p << "shape.const_shape "; 328 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 329 p << "["; 330 interleaveComma(op.shape().getValues<int64_t>(), p, 331 [&](int64_t i) { p << i; }); 332 p << "] : "; 333 p.printType(op.getType()); 334 } 335 336 static ParseResult parseConstShapeOp(OpAsmParser &parser, 337 OperationState &result) { 338 if (parser.parseOptionalAttrDict(result.attributes)) 339 return failure(); 340 // We piggy-back on ArrayAttr parsing, though we don't internally store the 341 // shape as an ArrayAttr. 342 // TODO: Implement custom parser and maybe make syntax a bit more concise. 343 Attribute extentsRaw; 344 NamedAttrList dummy; 345 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 346 return failure(); 347 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 348 if (!extentsArray) 349 return failure(); 350 SmallVector<int64_t, 6> ints; 351 for (Attribute extent : extentsArray) { 352 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 353 if (!attr) 354 return failure(); 355 ints.push_back(attr.getInt()); 356 } 357 Builder &builder = parser.getBuilder(); 358 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 359 Type resultTy; 360 if (parser.parseColonType(resultTy)) 361 return failure(); 362 result.types.push_back(resultTy); 363 return success(); 364 } 365 366 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 367 368 //===----------------------------------------------------------------------===// 369 // CstrBroadcastableOp 370 //===----------------------------------------------------------------------===// 371 372 namespace { 373 // Given an input shape Value, try to obtain the shape's values. 374 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 375 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 376 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 377 if (!type.hasRank()) 378 return failure(); 379 shapeValues = llvm::to_vector<6>(type.getShape()); 380 return success(); 381 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 382 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 383 return success(); 384 } else { 385 return failure(); 386 } 387 } 388 389 // For shapes that were created by some operations, we can obtain partial 390 // information on the shapes and sometimes determine if they will be 391 // broadcastable with that. 392 struct CstrBroadcastablePartialInfo 393 : public OpRewritePattern<CstrBroadcastableOp> { 394 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 395 396 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 397 PatternRewriter &rewriter) const override { 398 SmallVector<int64_t, 6> lhsShape, rhsShape; 399 if (failed(getShapeVec(op.lhs(), lhsShape))) 400 return failure(); 401 if (failed(getShapeVec(op.rhs(), rhsShape))) 402 return failure(); 403 if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 404 return failure(); 405 406 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 407 return success(); 408 } 409 }; 410 411 // Scalars are always broadcastable. 412 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> { 413 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 414 415 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 416 PatternRewriter &rewriter) const override { 417 SmallVector<int64_t, 6> shape; 418 if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0) 419 return failure(); 420 if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0) 421 return failure(); 422 423 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 424 return success(); 425 } 426 }; 427 428 } // namespace 429 430 void CstrBroadcastableOp::getCanonicalizationPatterns( 431 OwningRewritePatternList &patterns, MLIRContext *context) { 432 // Canonicalization patterns have overlap with the considerations during 433 // folding in case additional shape information is inferred at some point that 434 // does not result in folding. 435 patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo, 436 CstrBroadcastableScalar>(context); 437 } 438 439 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 440 // Both operands are not needed if one is a scalar. 441 if (operands[0] && 442 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 443 return BoolAttr::get(true, getContext()); 444 if (operands[1] && 445 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 446 return BoolAttr::get(true, getContext()); 447 448 if (operands[0] && operands[1]) { 449 auto lhsShape = llvm::to_vector<6>( 450 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 451 auto rhsShape = llvm::to_vector<6>( 452 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 453 SmallVector<int64_t, 6> resultShape; 454 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 455 return BoolAttr::get(true, getContext()); 456 } 457 458 // Lastly, see if folding can be completed based on what constraints are known 459 // on the input shapes. 460 SmallVector<int64_t, 6> lhsShape, rhsShape; 461 if (failed(getShapeVec(lhs(), lhsShape))) 462 return nullptr; 463 if (failed(getShapeVec(rhs(), rhsShape))) 464 return nullptr; 465 466 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 467 return BoolAttr::get(true, getContext()); 468 469 // Because a failing witness result here represents an eventual assertion 470 // failure, we do not replace it with a constant witness. 471 return nullptr; 472 } 473 474 //===----------------------------------------------------------------------===// 475 // CstrEqOp 476 //===----------------------------------------------------------------------===// 477 478 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 479 MLIRContext *context) { 480 // If inputs are equal, return passing witness 481 patterns.insert<CstrEqEqOps>(context); 482 } 483 484 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 485 if (llvm::all_of(operands, 486 [&](Attribute a) { return a && a == operands[0]; })) 487 return BoolAttr::get(true, getContext()); 488 489 // Because a failing witness result here represents an eventual assertion 490 // failure, we do not try to replace it with a constant witness. Similarly, we 491 // cannot if there are any non-const inputs. 492 return nullptr; 493 } 494 495 //===----------------------------------------------------------------------===// 496 // ConstSizeOp 497 //===----------------------------------------------------------------------===// 498 499 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 500 int64_t value) { 501 build(builder, result, builder.getIndexAttr(value)); 502 } 503 504 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 505 506 void ConstSizeOp::getAsmResultNames( 507 llvm::function_ref<void(Value, StringRef)> setNameFn) { 508 SmallString<4> buffer; 509 llvm::raw_svector_ostream os(buffer); 510 os << "c" << value(); 511 setNameFn(getResult(), os.str()); 512 } 513 514 //===----------------------------------------------------------------------===// 515 // ConstWitnessOp 516 //===----------------------------------------------------------------------===// 517 518 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 519 520 //===----------------------------------------------------------------------===// 521 // ShapeEqOp 522 //===----------------------------------------------------------------------===// 523 524 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 525 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 526 if (lhs == nullptr) 527 return {}; 528 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 529 if (rhs == nullptr) 530 return {}; 531 return BoolAttr::get(lhs == rhs, getContext()); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // IndexToSizeOp 536 //===----------------------------------------------------------------------===// 537 538 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 539 // Constant values of both types, `shape.size` and `index`, are represented as 540 // `IntegerAttr`s which makes constant folding simple. 541 if (Attribute arg = operands[0]) 542 return arg; 543 return {}; 544 } 545 546 void IndexToSizeOp::getCanonicalizationPatterns( 547 OwningRewritePatternList &patterns, MLIRContext *context) { 548 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // FromExtentsOp 553 //===----------------------------------------------------------------------===// 554 555 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 556 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 557 return nullptr; 558 SmallVector<int64_t, 6> extents; 559 for (auto attr : operands) 560 extents.push_back(attr.cast<IntegerAttr>().getInt()); 561 Builder builder(getContext()); 562 return builder.getIndexTensorAttr(extents); 563 } 564 565 //===----------------------------------------------------------------------===// 566 // GetExtentOp 567 //===----------------------------------------------------------------------===// 568 569 Optional<int64_t> GetExtentOp::getConstantDim() { 570 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 571 return constSizeOp.value().getLimitedValue(); 572 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 573 return constantOp.value().cast<IntegerAttr>().getInt(); 574 return llvm::None; 575 } 576 577 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 578 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 579 if (!elements) 580 return nullptr; 581 Optional<int64_t> dim = getConstantDim(); 582 if (!dim.hasValue()) 583 return nullptr; 584 if (dim.getValue() >= elements.getNumElements()) 585 return nullptr; 586 return elements.getValue({(uint64_t)dim.getValue()}); 587 } 588 589 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 590 int64_t dim) { 591 auto loc = result.location; 592 auto dimAttr = builder.getIndexAttr(dim); 593 if (shape.getType().isa<ShapeType>()) { 594 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 595 build(builder, result, builder.getType<SizeType>(), shape, dim); 596 } else { 597 Value dim = 598 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 599 build(builder, result, builder.getIndexType(), shape, dim); 600 } 601 } 602 603 //===----------------------------------------------------------------------===// 604 // RankOp 605 //===----------------------------------------------------------------------===// 606 607 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 608 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 609 if (!shape) 610 return {}; 611 int64_t rank = shape.getNumElements(); 612 Builder builder(getContext()); 613 return builder.getIndexAttr(rank); 614 } 615 616 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 617 /// Constant folding fails in cases where only the rank is constant, not the 618 /// shape itself. 619 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 620 /// 621 /// Example: 622 /// 623 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 624 /// %rank = shape.rank %shape 625 /// 626 /// becomes 627 /// 628 /// %rank = shape.const_size 3 629 630 namespace { 631 struct RankShapeOfCanonicalizationPattern 632 : public OpRewritePattern<shape::RankOp> { 633 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 634 635 LogicalResult matchAndRewrite(shape::RankOp op, 636 PatternRewriter &rewriter) const override { 637 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 638 if (!shapeOfOp) 639 return failure(); 640 auto rankedTensorType = 641 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 642 if (!rankedTensorType) 643 return failure(); 644 assert(op.getType().isa<IndexType>() && 645 "expected `rank(shape_of( ... )]` based on a shaped argument to " 646 "yield an index type"); 647 int64_t rank = rankedTensorType.getRank(); 648 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 649 return success(); 650 } 651 }; 652 } // namespace 653 654 void shape::RankOp::getCanonicalizationPatterns( 655 OwningRewritePatternList &patterns, MLIRContext *context) { 656 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // NumElementsOp 661 //===----------------------------------------------------------------------===// 662 663 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 664 665 // Fold only when argument constant. 666 Attribute shape = operands[0]; 667 if (!shape) 668 return {}; 669 670 APInt product(64, 1); 671 for (auto value : shape.cast<DenseIntElementsAttr>()) 672 product *= value; 673 Builder builder(getContext()); 674 return builder.getIndexAttr(product.getLimitedValue()); 675 } 676 677 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 678 Value shape) { 679 if (shape.getType().isa<ShapedType>()) { 680 auto type = builder.getIndexType(); 681 return build(builder, result, type, shape); 682 } 683 auto type = SizeType::get(builder.getContext()); 684 return build(builder, result, type, shape); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // MulOp 689 //===----------------------------------------------------------------------===// 690 691 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 692 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 693 if (!lhs) 694 return nullptr; 695 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 696 if (!rhs) 697 return nullptr; 698 APInt folded = lhs.getValue() * rhs.getValue(); 699 Type indexTy = IndexType::get(getContext()); 700 return IntegerAttr::get(indexTy, folded); 701 } 702 703 //===----------------------------------------------------------------------===// 704 // ShapeOfOp 705 //===----------------------------------------------------------------------===// 706 707 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 708 auto type = getOperand().getType().dyn_cast<ShapedType>(); 709 if (!type || !type.hasStaticShape()) 710 return nullptr; 711 Builder builder(getContext()); 712 return builder.getIndexTensorAttr(type.getShape()); 713 } 714 715 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 716 if (arg.getType().isa<ShapedType>()) { 717 auto type = RankedTensorType::get({ShapedType::kDynamicSize}, 718 builder.getIndexType()); 719 return ShapeOfOp::build(builder, result, type, arg); 720 } 721 auto type = ShapeType::get(builder.getContext()); 722 return ShapeOfOp::build(builder, result, type, arg); 723 } 724 725 namespace { 726 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 727 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 728 729 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 730 PatternRewriter &rewriter) const override { 731 if (!op.arg().getType().isa<ShapedType>()) 732 return failure(); 733 if (op.getType().isa<ShapedType>()) 734 return failure(); 735 736 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 737 return success(); 738 } 739 }; 740 } // namespace 741 742 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 743 MLIRContext *context) { 744 patterns.insert<ShapeOfWithTensor>(context); 745 } 746 747 //===----------------------------------------------------------------------===// 748 // SizeToIndexOp 749 //===----------------------------------------------------------------------===// 750 751 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 752 // Constant values of both types, `shape.size` and `index`, are represented as 753 // `IntegerAttr`s which makes constant folding simple. 754 if (Attribute arg = operands[0]) 755 return arg; 756 return impl::foldCastOp(*this); 757 } 758 759 void SizeToIndexOp::getCanonicalizationPatterns( 760 OwningRewritePatternList &patterns, MLIRContext *context) { 761 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // YieldOp 766 //===----------------------------------------------------------------------===// 767 768 static LogicalResult verify(YieldOp op) { 769 auto *parentOp = op.getParentOp(); 770 auto results = parentOp->getResults(); 771 auto operands = op.getOperands(); 772 773 if (parentOp->getNumResults() != op.getNumOperands()) 774 return op.emitOpError() << "number of operands does not match number of " 775 "results of its parent"; 776 for (auto e : llvm::zip(results, operands)) 777 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 778 return op.emitOpError() 779 << "types mismatch between yield op and its parent"; 780 781 return success(); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // SplitAtOp 786 //===----------------------------------------------------------------------===// 787 788 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 789 SmallVectorImpl<OpFoldResult> &results) { 790 if (!operands[0] || !operands[1]) 791 return failure(); 792 auto shapeVec = llvm::to_vector<6>( 793 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 794 auto shape = llvm::makeArrayRef(shapeVec); 795 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 796 // Verify that the split point is in the correct range. 797 // TODO: Constant fold to an "error". 798 int64_t rank = shape.size(); 799 if (!(-rank <= splitPoint && splitPoint <= rank)) 800 return failure(); 801 if (splitPoint < 0) 802 splitPoint += shape.size(); 803 Builder builder(operands[0].getContext()); 804 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 805 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 806 return success(); 807 } 808 809 //===----------------------------------------------------------------------===// 810 // ToExtentTensorOp 811 //===----------------------------------------------------------------------===// 812 813 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 814 if (!operands[0]) 815 return impl::foldCastOp(*this); 816 Builder builder(getContext()); 817 auto shape = llvm::to_vector<6>( 818 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 819 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 820 builder.getIndexType()); 821 return DenseIntElementsAttr::get(type, shape); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // ReduceOp 826 //===----------------------------------------------------------------------===// 827 828 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 829 ValueRange initVals) { 830 result.addOperands(shape); 831 result.addOperands(initVals); 832 833 Region *bodyRegion = result.addRegion(); 834 bodyRegion->push_back(new Block); 835 Block &bodyBlock = bodyRegion->front(); 836 bodyBlock.addArgument(builder.getIndexType()); 837 838 Type elementType; 839 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 840 elementType = tensorType.getElementType(); 841 else 842 elementType = SizeType::get(builder.getContext()); 843 bodyBlock.addArgument(elementType); 844 845 for (Type initValType : initVals.getTypes()) { 846 bodyBlock.addArgument(initValType); 847 result.addTypes(initValType); 848 } 849 } 850 851 static LogicalResult verify(ReduceOp op) { 852 // Verify block arg types. 853 Block &block = op.region().front(); 854 855 // The block takes index, extent, and aggregated values as arguments. 856 auto blockArgsCount = op.initVals().size() + 2; 857 if (block.getNumArguments() != blockArgsCount) 858 return op.emitOpError() << "ReduceOp body is expected to have " 859 << blockArgsCount << " arguments"; 860 861 // The first block argument is the index and must always be of type `index`. 862 if (!block.getArgument(0).getType().isa<IndexType>()) 863 return op.emitOpError( 864 "argument 0 of ReduceOp body is expected to be of IndexType"); 865 866 // The second block argument is the extent and must be of type `size` or 867 // `index`, depending on whether the reduce operation is applied to a shape or 868 // to an extent tensor. 869 Type extentTy = block.getArgument(1).getType(); 870 if (op.shape().getType().isa<ShapeType>()) { 871 if (!extentTy.isa<SizeType>()) 872 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 873 "SizeType if the ReduceOp operates on a ShapeType"); 874 } else { 875 if (!extentTy.isa<IndexType>()) 876 return op.emitOpError( 877 "argument 1 of ReduceOp body is expected to be of IndexType if the " 878 "ReduceOp operates on an extent tensor"); 879 } 880 881 for (auto type : llvm::enumerate(op.initVals())) 882 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 883 return op.emitOpError() 884 << "type mismatch between argument " << type.index() + 2 885 << " of ReduceOp body and initial value " << type.index(); 886 return success(); 887 } 888 889 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 890 // Parse operands. 891 SmallVector<OpAsmParser::OperandType, 3> operands; 892 Type shapeOrExtentTensorType; 893 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 894 OpAsmParser::Delimiter::Paren) || 895 parser.parseColonType(shapeOrExtentTensorType) || 896 parser.parseOptionalArrowTypeList(result.types)) 897 return failure(); 898 899 // Resolve operands. 900 auto initVals = llvm::makeArrayRef(operands).drop_front(); 901 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 902 result.operands) || 903 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 904 result.operands)) 905 return failure(); 906 907 // Parse the body. 908 Region *body = result.addRegion(); 909 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 910 return failure(); 911 912 // Parse attributes. 913 if (parser.parseOptionalAttrDict(result.attributes)) 914 return failure(); 915 916 return success(); 917 } 918 919 static void print(OpAsmPrinter &p, ReduceOp op) { 920 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 921 << ") : " << op.shape().getType(); 922 p.printOptionalArrowTypeList(op.getResultTypes()); 923 p.printRegion(op.region()); 924 p.printOptionalAttrDict(op.getAttrs()); 925 } 926 927 namespace mlir { 928 namespace shape { 929 930 #define GET_OP_CLASSES 931 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 932 933 } // namespace shape 934 } // namespace mlir 935