1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SmallString.h" 19 #include "llvm/Support/raw_ostream.h" 20 21 using namespace mlir; 22 using namespace mlir::shape; 23 24 namespace { 25 #include "ShapeCanonicalization.inc" 26 } 27 28 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 29 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 30 } 31 32 static bool isErrorPropagationPossible(TypeRange operandTypes) { 33 for (Type ty : operandTypes) 34 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>()) 35 return true; 36 return false; 37 } 38 39 static LogicalResult verifySizeOrIndexOp(Operation *op) { 40 assert(op != nullptr && op->getNumResults() == 1); 41 Type resultTy = op->getResultTypes().front(); 42 if (isErrorPropagationPossible(op->getOperandTypes())) { 43 if (!resultTy.isa<SizeType>()) 44 return op->emitOpError() 45 << "if at least one of the operands can hold error values then " 46 "the result must be of type `size` to propagate them"; 47 } 48 return success(); 49 } 50 51 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 52 assert(op != nullptr && op->getNumResults() == 1); 53 Type resultTy = op->getResultTypes().front(); 54 if (isErrorPropagationPossible(op->getOperandTypes())) { 55 if (!resultTy.isa<ShapeType>()) 56 return op->emitOpError() 57 << "if at least one of the operands can hold error values then " 58 "the result must be of type `shape` to propagate them"; 59 } 60 return success(); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // InlinerInterface 65 //===----------------------------------------------------------------------===// 66 67 namespace { 68 /// This class defines the interface for inlining shape dialect ops. 69 struct ShapeInlinerInterface : public DialectInlinerInterface { 70 using DialectInlinerInterface::DialectInlinerInterface; 71 72 // Returns true if the given region 'src' can be inlined into the region 73 // 'dest' that is attached to an operation registered to the current dialect. 74 bool isLegalToInline(Region *dest, Region *src, 75 BlockAndValueMapping &) const final { 76 return true; 77 } 78 79 // Returns true if the given operation 'op', that is registered to this 80 // dialect, can be inlined into the region 'dest' that is attached to an 81 // operation registered to the current dialect. 82 bool isLegalToInline(Operation *op, Region *dest, 83 BlockAndValueMapping &) const final { 84 return true; 85 } 86 }; 87 } // namespace 88 89 void ShapeDialect::initialize() { 90 addOperations< 91 #define GET_OP_LIST 92 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 93 >(); 94 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 95 WitnessType>(); 96 addInterfaces<ShapeInlinerInterface>(); 97 // Allow unknown operations during prototyping and testing. As the dialect is 98 // still evolving it makes it simple to start with an unregistered ops and 99 // try different variants before actually defining the op. 100 allowUnknownOperations(); 101 } 102 103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 104 Attribute value, Type type, 105 Location loc) { 106 if (type.isa<ShapeType>() || 107 type == getExtentTensorType(builder.getContext())) 108 return builder.create<ConstShapeOp>(loc, type, 109 value.cast<DenseIntElementsAttr>()); 110 if (type.isa<SizeType>()) 111 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 112 if (type.isa<WitnessType>()) 113 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 114 if (type.isa<IndexType>()) 115 return builder.create<ConstantOp>(loc, type, value); 116 return nullptr; 117 } 118 119 /// Parse a type registered to this dialect. 120 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 121 StringRef keyword; 122 if (parser.parseKeyword(&keyword)) 123 return Type(); 124 125 if (keyword == "component") 126 return ComponentType::get(getContext()); 127 if (keyword == "element") 128 return ElementType::get(getContext()); 129 if (keyword == "shape") 130 return ShapeType::get(getContext()); 131 if (keyword == "size") 132 return SizeType::get(getContext()); 133 if (keyword == "value_shape") 134 return ValueShapeType::get(getContext()); 135 if (keyword == "witness") 136 return WitnessType::get(getContext()); 137 138 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 139 return Type(); 140 } 141 142 /// Print a type registered to this dialect. 143 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 144 switch (type.getKind()) { 145 case ShapeTypes::Component: 146 os << "component"; 147 return; 148 case ShapeTypes::Element: 149 os << "element"; 150 return; 151 case ShapeTypes::Size: 152 os << "size"; 153 return; 154 case ShapeTypes::Shape: 155 os << "shape"; 156 return; 157 case ShapeTypes::ValueShape: 158 os << "value_shape"; 159 return; 160 case ShapeTypes::Witness: 161 os << "witness"; 162 return; 163 default: 164 llvm_unreachable("unexpected 'shape' type kind"); 165 } 166 } 167 168 //===----------------------------------------------------------------------===// 169 // AnyOp 170 //===----------------------------------------------------------------------===// 171 172 // TODO: Canonicalization should be implemented for shapes that can be 173 // determined through mixtures of the known dimensions of the inputs. 174 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 175 // Only the last operand is checked because AnyOp is commutative. 176 if (operands.back()) 177 return operands.back(); 178 179 return nullptr; 180 } 181 182 //===----------------------------------------------------------------------===// 183 // AssumingOp 184 //===----------------------------------------------------------------------===// 185 186 static ParseResult parseAssumingOp(OpAsmParser &parser, 187 OperationState &result) { 188 result.regions.reserve(1); 189 Region *doRegion = result.addRegion(); 190 191 auto &builder = parser.getBuilder(); 192 OpAsmParser::OperandType cond; 193 if (parser.parseOperand(cond) || 194 parser.resolveOperand(cond, builder.getType<WitnessType>(), 195 result.operands)) 196 return failure(); 197 198 // Parse optional results type list. 199 if (parser.parseOptionalArrowTypeList(result.types)) 200 return failure(); 201 202 // Parse the region and add a terminator if elided. 203 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 204 return failure(); 205 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 206 207 // Parse the optional attribute list. 208 if (parser.parseOptionalAttrDict(result.attributes)) 209 return failure(); 210 return success(); 211 } 212 213 static void print(OpAsmPrinter &p, AssumingOp op) { 214 bool yieldsResults = !op.results().empty(); 215 216 p << AssumingOp::getOperationName() << " " << op.witness(); 217 if (yieldsResults) { 218 p << " -> (" << op.getResultTypes() << ")"; 219 } 220 p.printRegion(op.doRegion(), 221 /*printEntryBlockArgs=*/false, 222 /*printBlockTerminators=*/yieldsResults); 223 p.printOptionalAttrDict(op.getAttrs()); 224 } 225 226 namespace { 227 // Removes AssumingOp with a passing witness and inlines the region. 228 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 229 using OpRewritePattern<AssumingOp>::OpRewritePattern; 230 231 LogicalResult matchAndRewrite(AssumingOp op, 232 PatternRewriter &rewriter) const override { 233 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 234 if (!witness || !witness.passingAttr()) 235 return failure(); 236 237 AssumingOp::inlineRegionIntoParent(op, rewriter); 238 return success(); 239 } 240 }; 241 } // namespace 242 243 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 244 MLIRContext *context) { 245 // If taking a passing witness, inline region. 246 patterns.insert<AssumingWithTrue>(context); 247 } 248 249 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 250 PatternRewriter &rewriter) { 251 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 252 auto *assumingBlock = op.getBody(); 253 auto initPosition = rewriter.getInsertionPoint(); 254 auto *blockAfterAssuming = 255 rewriter.splitBlock(blockBeforeAssuming, initPosition); 256 257 // Remove the AssumingOp and AssumingYieldOp. 258 auto &yieldOp = assumingBlock->back(); 259 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 260 rewriter.replaceOp(op, yieldOp.getOperands()); 261 rewriter.eraseOp(&yieldOp); 262 263 // Merge blocks together as there was no branching behavior from the 264 // AssumingOp. 265 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 266 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // AssumingAllOp 271 //===----------------------------------------------------------------------===// 272 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 273 // Iterate in reverse to first handle all constant operands. They are 274 // guaranteed to be the tail of the inputs because this is commutative. 275 for (int idx = operands.size() - 1; idx >= 0; idx--) { 276 Attribute a = operands[idx]; 277 // Cannot fold if any inputs are not constant; 278 if (!a) 279 return nullptr; 280 281 // We do not need to keep statically known values after handling them in 282 // this method. 283 getOperation()->eraseOperand(idx); 284 285 // Always false if any input is statically known false 286 if (!a.cast<BoolAttr>().getValue()) 287 return a; 288 } 289 // If this is reached, all inputs were statically known passing. 290 return BoolAttr::get(true, getContext()); 291 } 292 293 static LogicalResult verify(AssumingAllOp op) { 294 // Ensure that AssumingAllOp contains at least one operand 295 if (op.getNumOperands() == 0) 296 return op.emitOpError("no operands specified"); 297 298 return success(); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // BroadcastOp 303 //===----------------------------------------------------------------------===// 304 305 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 306 if (!operands[1]) 307 return nullptr; 308 309 auto rhsShape = llvm::to_vector<6>( 310 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 311 if (rhsShape.empty()) 312 return lhs(); 313 314 if (!operands[0]) 315 return nullptr; 316 317 auto lhsShape = llvm::to_vector<6>( 318 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 319 if (lhsShape.empty()) 320 return rhs(); 321 322 SmallVector<int64_t, 6> resultShape; 323 // If the shapes are not compatible, we can't fold it. 324 // TODO: Fold to an "error". 325 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 326 return nullptr; 327 Builder builder(getContext()); 328 return builder.getIndexTensorAttr(resultShape); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // ConcatOp 333 //===----------------------------------------------------------------------===// 334 335 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 336 if (!operands[0] || !operands[1]) 337 return nullptr; 338 auto lhsShape = llvm::to_vector<6>( 339 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 340 auto rhsShape = llvm::to_vector<6>( 341 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 342 SmallVector<int64_t, 6> resultShape; 343 resultShape.append(lhsShape.begin(), lhsShape.end()); 344 resultShape.append(rhsShape.begin(), rhsShape.end()); 345 Builder builder(getContext()); 346 return builder.getIndexTensorAttr(resultShape); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // ConstShapeOp 351 //===----------------------------------------------------------------------===// 352 353 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 354 p << "shape.const_shape "; 355 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 356 p << "["; 357 interleaveComma(op.shape().getValues<int64_t>(), p, 358 [&](int64_t i) { p << i; }); 359 p << "] : "; 360 p.printType(op.getType()); 361 } 362 363 static ParseResult parseConstShapeOp(OpAsmParser &parser, 364 OperationState &result) { 365 if (parser.parseOptionalAttrDict(result.attributes)) 366 return failure(); 367 // We piggy-back on ArrayAttr parsing, though we don't internally store the 368 // shape as an ArrayAttr. 369 // TODO: Implement custom parser and maybe make syntax a bit more concise. 370 Attribute extentsRaw; 371 NamedAttrList dummy; 372 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 373 return failure(); 374 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 375 if (!extentsArray) 376 return failure(); 377 SmallVector<int64_t, 6> ints; 378 for (Attribute extent : extentsArray) { 379 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 380 if (!attr) 381 return failure(); 382 ints.push_back(attr.getInt()); 383 } 384 Builder &builder = parser.getBuilder(); 385 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 386 Type resultTy; 387 if (parser.parseColonType(resultTy)) 388 return failure(); 389 result.types.push_back(resultTy); 390 return success(); 391 } 392 393 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 394 395 //===----------------------------------------------------------------------===// 396 // CstrBroadcastableOp 397 //===----------------------------------------------------------------------===// 398 399 namespace { 400 // Given an input shape Value, try to obtain the shape's values. 401 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 402 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 403 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 404 if (!type.hasRank()) 405 return failure(); 406 shapeValues = llvm::to_vector<6>(type.getShape()); 407 return success(); 408 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 409 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 410 return success(); 411 } else { 412 return failure(); 413 } 414 } 415 416 // For shapes that were created by some operations, we can obtain partial 417 // information on the shapes and sometimes determine if they will be 418 // broadcastable with that. 419 struct CstrBroadcastablePartialInfo 420 : public OpRewritePattern<CstrBroadcastableOp> { 421 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 422 423 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 424 PatternRewriter &rewriter) const override { 425 SmallVector<int64_t, 6> lhsShape, rhsShape; 426 if (failed(getShapeVec(op.lhs(), lhsShape))) 427 return failure(); 428 if (failed(getShapeVec(op.rhs(), rhsShape))) 429 return failure(); 430 if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 431 return failure(); 432 433 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 434 return success(); 435 } 436 }; 437 438 // Scalars are always broadcastable. 439 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> { 440 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 441 442 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 443 PatternRewriter &rewriter) const override { 444 SmallVector<int64_t, 6> shape; 445 if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0) 446 return failure(); 447 if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0) 448 return failure(); 449 450 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 451 return success(); 452 } 453 }; 454 455 } // namespace 456 457 void CstrBroadcastableOp::getCanonicalizationPatterns( 458 OwningRewritePatternList &patterns, MLIRContext *context) { 459 // Canonicalization patterns have overlap with the considerations during 460 // folding in case additional shape information is inferred at some point that 461 // does not result in folding. 462 patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo, 463 CstrBroadcastableScalar>(context); 464 } 465 466 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 467 // Both operands are not needed if one is a scalar. 468 if (operands[0] && 469 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 470 return BoolAttr::get(true, getContext()); 471 if (operands[1] && 472 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 473 return BoolAttr::get(true, getContext()); 474 475 if (operands[0] && operands[1]) { 476 auto lhsShape = llvm::to_vector<6>( 477 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 478 auto rhsShape = llvm::to_vector<6>( 479 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 480 SmallVector<int64_t, 6> resultShape; 481 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 482 return BoolAttr::get(true, getContext()); 483 } 484 485 // Lastly, see if folding can be completed based on what constraints are known 486 // on the input shapes. 487 SmallVector<int64_t, 6> lhsShape, rhsShape; 488 if (failed(getShapeVec(lhs(), lhsShape))) 489 return nullptr; 490 if (failed(getShapeVec(rhs(), rhsShape))) 491 return nullptr; 492 493 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 494 return BoolAttr::get(true, getContext()); 495 496 // Because a failing witness result here represents an eventual assertion 497 // failure, we do not replace it with a constant witness. 498 return nullptr; 499 } 500 501 //===----------------------------------------------------------------------===// 502 // CstrEqOp 503 //===----------------------------------------------------------------------===// 504 505 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 506 MLIRContext *context) { 507 // If inputs are equal, return passing witness 508 patterns.insert<CstrEqEqOps>(context); 509 } 510 511 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 512 if (llvm::all_of(operands, 513 [&](Attribute a) { return a && a == operands[0]; })) 514 return BoolAttr::get(true, getContext()); 515 516 // Because a failing witness result here represents an eventual assertion 517 // failure, we do not try to replace it with a constant witness. Similarly, we 518 // cannot if there are any non-const inputs. 519 return nullptr; 520 } 521 522 //===----------------------------------------------------------------------===// 523 // ConstSizeOp 524 //===----------------------------------------------------------------------===// 525 526 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 527 int64_t value) { 528 build(builder, result, builder.getIndexAttr(value)); 529 } 530 531 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 532 533 void ConstSizeOp::getAsmResultNames( 534 llvm::function_ref<void(Value, StringRef)> setNameFn) { 535 SmallString<4> buffer; 536 llvm::raw_svector_ostream os(buffer); 537 os << "c" << value(); 538 setNameFn(getResult(), os.str()); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // ConstWitnessOp 543 //===----------------------------------------------------------------------===// 544 545 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 546 547 //===----------------------------------------------------------------------===// 548 // ShapeEqOp 549 //===----------------------------------------------------------------------===// 550 551 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 552 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 553 if (lhs == nullptr) 554 return {}; 555 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 556 if (rhs == nullptr) 557 return {}; 558 return BoolAttr::get(lhs == rhs, getContext()); 559 } 560 561 //===----------------------------------------------------------------------===// 562 // IndexToSizeOp 563 //===----------------------------------------------------------------------===// 564 565 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 566 // Constant values of both types, `shape.size` and `index`, are represented as 567 // `IntegerAttr`s which makes constant folding simple. 568 if (Attribute arg = operands[0]) 569 return arg; 570 return {}; 571 } 572 573 void IndexToSizeOp::getCanonicalizationPatterns( 574 OwningRewritePatternList &patterns, MLIRContext *context) { 575 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // FromExtentsOp 580 //===----------------------------------------------------------------------===// 581 582 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 583 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 584 return nullptr; 585 SmallVector<int64_t, 6> extents; 586 for (auto attr : operands) 587 extents.push_back(attr.cast<IntegerAttr>().getInt()); 588 Builder builder(getContext()); 589 return builder.getIndexTensorAttr(extents); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // GetExtentOp 594 //===----------------------------------------------------------------------===// 595 596 Optional<int64_t> GetExtentOp::getConstantDim() { 597 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 598 return constSizeOp.value().getLimitedValue(); 599 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 600 return constantOp.value().cast<IntegerAttr>().getInt(); 601 return llvm::None; 602 } 603 604 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 605 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 606 if (!elements) 607 return nullptr; 608 Optional<int64_t> dim = getConstantDim(); 609 if (!dim.hasValue()) 610 return nullptr; 611 if (dim.getValue() >= elements.getNumElements()) 612 return nullptr; 613 return elements.getValue({(uint64_t)dim.getValue()}); 614 } 615 616 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 617 int64_t dim) { 618 auto loc = result.location; 619 auto dimAttr = builder.getIndexAttr(dim); 620 if (shape.getType().isa<ShapeType>()) { 621 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 622 build(builder, result, builder.getType<SizeType>(), shape, dim); 623 } else { 624 Value dim = 625 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 626 build(builder, result, builder.getIndexType(), shape, dim); 627 } 628 } 629 630 //===----------------------------------------------------------------------===// 631 // RankOp 632 //===----------------------------------------------------------------------===// 633 634 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 635 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 636 if (!shape) 637 return {}; 638 int64_t rank = shape.getNumElements(); 639 Builder builder(getContext()); 640 return builder.getIndexAttr(rank); 641 } 642 643 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 644 /// Constant folding fails in cases where only the rank is constant, not the 645 /// shape itself. 646 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 647 /// 648 /// Example: 649 /// 650 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 651 /// %rank = shape.rank %shape 652 /// 653 /// becomes 654 /// 655 /// %rank = shape.const_size 3 656 657 namespace { 658 struct RankShapeOfCanonicalizationPattern 659 : public OpRewritePattern<shape::RankOp> { 660 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 661 662 LogicalResult matchAndRewrite(shape::RankOp op, 663 PatternRewriter &rewriter) const override { 664 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 665 if (!shapeOfOp) 666 return failure(); 667 auto rankedTensorType = 668 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 669 if (!rankedTensorType) 670 return failure(); 671 int64_t rank = rankedTensorType.getRank(); 672 if (op.getType().isa<IndexType>()) { 673 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 674 } else if (op.getType().isa<shape::SizeType>()) { 675 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 676 } else { 677 return failure(); 678 } 679 return success(); 680 } 681 }; 682 } // namespace 683 684 void shape::RankOp::getCanonicalizationPatterns( 685 OwningRewritePatternList &patterns, MLIRContext *context) { 686 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // NumElementsOp 691 //===----------------------------------------------------------------------===// 692 693 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 694 695 // Fold only when argument constant. 696 Attribute shape = operands[0]; 697 if (!shape) 698 return {}; 699 700 APInt product(64, 1); 701 for (auto value : shape.cast<DenseIntElementsAttr>()) 702 product *= value; 703 Builder builder(getContext()); 704 return builder.getIndexAttr(product.getLimitedValue()); 705 } 706 707 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 708 Value shape) { 709 if (shape.getType().isa<ShapedType>()) { 710 auto type = builder.getIndexType(); 711 return build(builder, result, type, shape); 712 } 713 auto type = SizeType::get(builder.getContext()); 714 return build(builder, result, type, shape); 715 } 716 717 //===----------------------------------------------------------------------===// 718 // MulOp 719 //===----------------------------------------------------------------------===// 720 721 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 722 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 723 if (!lhs) 724 return nullptr; 725 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 726 if (!rhs) 727 return nullptr; 728 APInt folded = lhs.getValue() * rhs.getValue(); 729 Type indexTy = IndexType::get(getContext()); 730 return IntegerAttr::get(indexTy, folded); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // ShapeOfOp 735 //===----------------------------------------------------------------------===// 736 737 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 738 auto type = getOperand().getType().dyn_cast<ShapedType>(); 739 if (!type || !type.hasStaticShape()) 740 return nullptr; 741 Builder builder(getContext()); 742 return builder.getIndexTensorAttr(type.getShape()); 743 } 744 745 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 746 Type type = arg.getType().isa<ShapedType>() 747 ? (Type)getExtentTensorType(builder.getContext()) 748 : (Type)builder.getType<ShapeType>(); 749 return ShapeOfOp::build(builder, result, type, arg); 750 } 751 752 namespace { 753 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 754 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 755 756 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 757 PatternRewriter &rewriter) const override { 758 if (!op.arg().getType().isa<ShapedType>()) 759 return failure(); 760 if (op.getType().isa<ShapedType>()) 761 return failure(); 762 763 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 764 return success(); 765 } 766 }; 767 } // namespace 768 769 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 770 MLIRContext *context) { 771 patterns.insert<ShapeOfWithTensor>(context); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // SizeToIndexOp 776 //===----------------------------------------------------------------------===// 777 778 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 779 // Constant values of both types, `shape.size` and `index`, are represented as 780 // `IntegerAttr`s which makes constant folding simple. 781 if (Attribute arg = operands[0]) 782 return arg; 783 return impl::foldCastOp(*this); 784 } 785 786 void SizeToIndexOp::getCanonicalizationPatterns( 787 OwningRewritePatternList &patterns, MLIRContext *context) { 788 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 789 } 790 791 //===----------------------------------------------------------------------===// 792 // YieldOp 793 //===----------------------------------------------------------------------===// 794 795 static LogicalResult verify(YieldOp op) { 796 auto *parentOp = op.getParentOp(); 797 auto results = parentOp->getResults(); 798 auto operands = op.getOperands(); 799 800 if (parentOp->getNumResults() != op.getNumOperands()) 801 return op.emitOpError() << "number of operands does not match number of " 802 "results of its parent"; 803 for (auto e : llvm::zip(results, operands)) 804 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 805 return op.emitOpError() 806 << "types mismatch between yield op and its parent"; 807 808 return success(); 809 } 810 811 //===----------------------------------------------------------------------===// 812 // SplitAtOp 813 //===----------------------------------------------------------------------===// 814 815 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 816 SmallVectorImpl<OpFoldResult> &results) { 817 if (!operands[0] || !operands[1]) 818 return failure(); 819 auto shapeVec = llvm::to_vector<6>( 820 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 821 auto shape = llvm::makeArrayRef(shapeVec); 822 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 823 // Verify that the split point is in the correct range. 824 // TODO: Constant fold to an "error". 825 int64_t rank = shape.size(); 826 if (!(-rank <= splitPoint && splitPoint <= rank)) 827 return failure(); 828 if (splitPoint < 0) 829 splitPoint += shape.size(); 830 Builder builder(operands[0].getContext()); 831 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 832 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 833 return success(); 834 } 835 836 //===----------------------------------------------------------------------===// 837 // ToExtentTensorOp 838 //===----------------------------------------------------------------------===// 839 840 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 841 if (!operands[0]) 842 return impl::foldCastOp(*this); 843 Builder builder(getContext()); 844 auto shape = llvm::to_vector<6>( 845 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 846 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 847 builder.getIndexType()); 848 return DenseIntElementsAttr::get(type, shape); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // ReduceOp 853 //===----------------------------------------------------------------------===// 854 855 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 856 ValueRange initVals) { 857 result.addOperands(shape); 858 result.addOperands(initVals); 859 860 Region *bodyRegion = result.addRegion(); 861 bodyRegion->push_back(new Block); 862 Block &bodyBlock = bodyRegion->front(); 863 bodyBlock.addArgument(builder.getIndexType()); 864 865 Type elementType; 866 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 867 elementType = tensorType.getElementType(); 868 else 869 elementType = SizeType::get(builder.getContext()); 870 bodyBlock.addArgument(elementType); 871 872 for (Type initValType : initVals.getTypes()) { 873 bodyBlock.addArgument(initValType); 874 result.addTypes(initValType); 875 } 876 } 877 878 static LogicalResult verify(ReduceOp op) { 879 // Verify block arg types. 880 Block &block = op.region().front(); 881 882 // The block takes index, extent, and aggregated values as arguments. 883 auto blockArgsCount = op.initVals().size() + 2; 884 if (block.getNumArguments() != blockArgsCount) 885 return op.emitOpError() << "ReduceOp body is expected to have " 886 << blockArgsCount << " arguments"; 887 888 // The first block argument is the index and must always be of type `index`. 889 if (!block.getArgument(0).getType().isa<IndexType>()) 890 return op.emitOpError( 891 "argument 0 of ReduceOp body is expected to be of IndexType"); 892 893 // The second block argument is the extent and must be of type `size` or 894 // `index`, depending on whether the reduce operation is applied to a shape or 895 // to an extent tensor. 896 Type extentTy = block.getArgument(1).getType(); 897 if (op.shape().getType().isa<ShapeType>()) { 898 if (!extentTy.isa<SizeType>()) 899 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 900 "SizeType if the ReduceOp operates on a ShapeType"); 901 } else { 902 if (!extentTy.isa<IndexType>()) 903 return op.emitOpError( 904 "argument 1 of ReduceOp body is expected to be of IndexType if the " 905 "ReduceOp operates on an extent tensor"); 906 } 907 908 for (auto type : llvm::enumerate(op.initVals())) 909 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 910 return op.emitOpError() 911 << "type mismatch between argument " << type.index() + 2 912 << " of ReduceOp body and initial value " << type.index(); 913 return success(); 914 } 915 916 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 917 // Parse operands. 918 SmallVector<OpAsmParser::OperandType, 3> operands; 919 Type shapeOrExtentTensorType; 920 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 921 OpAsmParser::Delimiter::Paren) || 922 parser.parseColonType(shapeOrExtentTensorType) || 923 parser.parseOptionalArrowTypeList(result.types)) 924 return failure(); 925 926 // Resolve operands. 927 auto initVals = llvm::makeArrayRef(operands).drop_front(); 928 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 929 result.operands) || 930 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 931 result.operands)) 932 return failure(); 933 934 // Parse the body. 935 Region *body = result.addRegion(); 936 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 937 return failure(); 938 939 // Parse attributes. 940 if (parser.parseOptionalAttrDict(result.attributes)) 941 return failure(); 942 943 return success(); 944 } 945 946 static void print(OpAsmPrinter &p, ReduceOp op) { 947 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 948 << ") : " << op.shape().getType(); 949 p.printOptionalArrowTypeList(op.getResultTypes()); 950 p.printRegion(op.region()); 951 p.printOptionalAttrDict(op.getAttrs()); 952 } 953 954 namespace mlir { 955 namespace shape { 956 957 #define GET_OP_CLASSES 958 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 959 960 } // namespace shape 961 } // namespace mlir 962