1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "llvm/ADT/SmallString.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 using namespace mlir::shape; 22 23 namespace { 24 #include "ShapeCanonicalization.inc" 25 } 26 27 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 28 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 29 } 30 31 static bool isErrorPropagationPossible(TypeRange operandTypes) { 32 for (Type ty : operandTypes) 33 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>()) 34 return true; 35 return false; 36 } 37 38 static LogicalResult verifySizeOrIndexOp(Operation *op) { 39 assert(op != nullptr && op->getNumResults() == 1); 40 Type resultTy = op->getResultTypes().front(); 41 if (isErrorPropagationPossible(op->getOperandTypes())) { 42 if (!resultTy.isa<SizeType>()) 43 return op->emitOpError() 44 << "if at least one of the operands can hold error values then " 45 "the result must be of type `size` to propagate them"; 46 } 47 return success(); 48 } 49 50 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 51 assert(op != nullptr && op->getNumResults() == 1); 52 Type resultTy = op->getResultTypes().front(); 53 if (isErrorPropagationPossible(op->getOperandTypes())) { 54 if (!resultTy.isa<ShapeType>()) 55 return op->emitOpError() 56 << "if at least one of the operands can hold error values then " 57 "the result must be of type `shape` to propagate them"; 58 } 59 return success(); 60 } 61 62 void ShapeDialect::initialize() { 63 addOperations< 64 #define GET_OP_LIST 65 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 66 >(); 67 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 68 WitnessType>(); 69 // Allow unknown operations during prototyping and testing. As the dialect is 70 // still evolving it makes it simple to start with an unregistered ops and 71 // try different variants before actually defining the op. 72 allowUnknownOperations(); 73 } 74 75 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 76 Attribute value, Type type, 77 Location loc) { 78 if (type.isa<ShapeType>() || 79 type == getExtentTensorType(builder.getContext())) 80 return builder.create<ConstShapeOp>(loc, type, 81 value.cast<DenseIntElementsAttr>()); 82 if (type.isa<SizeType>()) 83 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 84 if (type.isa<WitnessType>()) 85 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 86 if (type.isa<IndexType>()) 87 return builder.create<ConstantOp>(loc, type, value); 88 return nullptr; 89 } 90 91 /// Parse a type registered to this dialect. 92 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 93 StringRef keyword; 94 if (parser.parseKeyword(&keyword)) 95 return Type(); 96 97 if (keyword == "component") 98 return ComponentType::get(getContext()); 99 if (keyword == "element") 100 return ElementType::get(getContext()); 101 if (keyword == "shape") 102 return ShapeType::get(getContext()); 103 if (keyword == "size") 104 return SizeType::get(getContext()); 105 if (keyword == "value_shape") 106 return ValueShapeType::get(getContext()); 107 if (keyword == "witness") 108 return WitnessType::get(getContext()); 109 110 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 111 return Type(); 112 } 113 114 /// Print a type registered to this dialect. 115 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 116 switch (type.getKind()) { 117 case ShapeTypes::Component: 118 os << "component"; 119 return; 120 case ShapeTypes::Element: 121 os << "element"; 122 return; 123 case ShapeTypes::Size: 124 os << "size"; 125 return; 126 case ShapeTypes::Shape: 127 os << "shape"; 128 return; 129 case ShapeTypes::ValueShape: 130 os << "value_shape"; 131 return; 132 case ShapeTypes::Witness: 133 os << "witness"; 134 return; 135 default: 136 llvm_unreachable("unexpected 'shape' type kind"); 137 } 138 } 139 140 //===----------------------------------------------------------------------===// 141 // AnyOp 142 //===----------------------------------------------------------------------===// 143 144 // TODO: Canonicalization should be implemented for shapes that can be 145 // determined through mixtures of the known dimensions of the inputs. 146 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 147 // Only the last operand is checked because AnyOp is commutative. 148 if (operands.back()) 149 return operands.back(); 150 151 return nullptr; 152 } 153 154 //===----------------------------------------------------------------------===// 155 // AssumingOp 156 //===----------------------------------------------------------------------===// 157 158 static ParseResult parseAssumingOp(OpAsmParser &parser, 159 OperationState &result) { 160 result.regions.reserve(1); 161 Region *doRegion = result.addRegion(); 162 163 auto &builder = parser.getBuilder(); 164 OpAsmParser::OperandType cond; 165 if (parser.parseOperand(cond) || 166 parser.resolveOperand(cond, builder.getType<WitnessType>(), 167 result.operands)) 168 return failure(); 169 170 // Parse optional results type list. 171 if (parser.parseOptionalArrowTypeList(result.types)) 172 return failure(); 173 174 // Parse the region and add a terminator if elided. 175 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 176 return failure(); 177 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 178 179 // Parse the optional attribute list. 180 if (parser.parseOptionalAttrDict(result.attributes)) 181 return failure(); 182 return success(); 183 } 184 185 static void print(OpAsmPrinter &p, AssumingOp op) { 186 bool yieldsResults = !op.results().empty(); 187 188 p << AssumingOp::getOperationName() << " " << op.witness(); 189 if (yieldsResults) { 190 p << " -> (" << op.getResultTypes() << ")"; 191 } 192 p.printRegion(op.doRegion(), 193 /*printEntryBlockArgs=*/false, 194 /*printBlockTerminators=*/yieldsResults); 195 p.printOptionalAttrDict(op.getAttrs()); 196 } 197 198 namespace { 199 // Removes AssumingOp with a passing witness and inlines the region. 200 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 201 using OpRewritePattern<AssumingOp>::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(AssumingOp op, 204 PatternRewriter &rewriter) const override { 205 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 206 if (!witness || !witness.passingAttr()) 207 return failure(); 208 209 AssumingOp::inlineRegionIntoParent(op, rewriter); 210 return success(); 211 } 212 }; 213 } // namespace 214 215 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 216 MLIRContext *context) { 217 // If taking a passing witness, inline region. 218 patterns.insert<AssumingWithTrue>(context); 219 } 220 221 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 222 PatternRewriter &rewriter) { 223 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 224 auto *assumingBlock = op.getBody(); 225 auto initPosition = rewriter.getInsertionPoint(); 226 auto *blockAfterAssuming = 227 rewriter.splitBlock(blockBeforeAssuming, initPosition); 228 229 // Remove the AssumingOp and AssumingYieldOp. 230 auto &yieldOp = assumingBlock->back(); 231 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 232 rewriter.replaceOp(op, yieldOp.getOperands()); 233 rewriter.eraseOp(&yieldOp); 234 235 // Merge blocks together as there was no branching behavior from the 236 // AssumingOp. 237 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 238 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // AssumingAllOp 243 //===----------------------------------------------------------------------===// 244 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 245 // Iterate in reverse to first handle all constant operands. They are 246 // guaranteed to be the tail of the inputs because this is commutative. 247 for (int idx = operands.size() - 1; idx >= 0; idx--) { 248 Attribute a = operands[idx]; 249 // Cannot fold if any inputs are not constant; 250 if (!a) 251 return nullptr; 252 253 // We do not need to keep statically known values after handling them in 254 // this method. 255 getOperation()->eraseOperand(idx); 256 257 // Always false if any input is statically known false 258 if (!a.cast<BoolAttr>().getValue()) 259 return a; 260 } 261 // If this is reached, all inputs were statically known passing. 262 return BoolAttr::get(true, getContext()); 263 } 264 265 static LogicalResult verify(AssumingAllOp op) { 266 // Ensure that AssumingAllOp contains at least one operand 267 if (op.getNumOperands() == 0) 268 return op.emitOpError("no operands specified"); 269 270 return success(); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // BroadcastOp 275 //===----------------------------------------------------------------------===// 276 277 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 278 if (!operands[1]) 279 return nullptr; 280 281 auto rhsShape = llvm::to_vector<6>( 282 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 283 if (rhsShape.empty()) 284 return lhs(); 285 286 if (!operands[0]) 287 return nullptr; 288 289 auto lhsShape = llvm::to_vector<6>( 290 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 291 if (lhsShape.empty()) 292 return rhs(); 293 294 SmallVector<int64_t, 6> resultShape; 295 // If the shapes are not compatible, we can't fold it. 296 // TODO: Fold to an "error". 297 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 298 return nullptr; 299 Builder builder(getContext()); 300 return builder.getIndexTensorAttr(resultShape); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // ConcatOp 305 //===----------------------------------------------------------------------===// 306 307 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 308 if (!operands[0] || !operands[1]) 309 return nullptr; 310 auto lhsShape = llvm::to_vector<6>( 311 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 312 auto rhsShape = llvm::to_vector<6>( 313 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 314 SmallVector<int64_t, 6> resultShape; 315 resultShape.append(lhsShape.begin(), lhsShape.end()); 316 resultShape.append(rhsShape.begin(), rhsShape.end()); 317 Builder builder(getContext()); 318 return builder.getIndexTensorAttr(resultShape); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // ConstShapeOp 323 //===----------------------------------------------------------------------===// 324 325 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 326 p << "shape.const_shape "; 327 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 328 p << "["; 329 interleaveComma(op.shape().getValues<int64_t>(), p, 330 [&](int64_t i) { p << i; }); 331 p << "] : "; 332 p.printType(op.getType()); 333 } 334 335 static ParseResult parseConstShapeOp(OpAsmParser &parser, 336 OperationState &result) { 337 if (parser.parseOptionalAttrDict(result.attributes)) 338 return failure(); 339 // We piggy-back on ArrayAttr parsing, though we don't internally store the 340 // shape as an ArrayAttr. 341 // TODO: Implement custom parser and maybe make syntax a bit more concise. 342 Attribute extentsRaw; 343 NamedAttrList dummy; 344 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 345 return failure(); 346 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 347 if (!extentsArray) 348 return failure(); 349 SmallVector<int64_t, 6> ints; 350 for (Attribute extent : extentsArray) { 351 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 352 if (!attr) 353 return failure(); 354 ints.push_back(attr.getInt()); 355 } 356 Builder &builder = parser.getBuilder(); 357 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 358 Type resultTy; 359 if (parser.parseColonType(resultTy)) 360 return failure(); 361 result.types.push_back(resultTy); 362 return success(); 363 } 364 365 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 366 367 //===----------------------------------------------------------------------===// 368 // CstrBroadcastableOp 369 //===----------------------------------------------------------------------===// 370 371 namespace { 372 // Given an input shape Value, try to obtain the shape's values. 373 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 374 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 375 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 376 if (!type.hasRank()) 377 return failure(); 378 shapeValues = llvm::to_vector<6>(type.getShape()); 379 return success(); 380 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 381 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 382 return success(); 383 } else { 384 return failure(); 385 } 386 } 387 388 // For shapes that were created by some operations, we can obtain partial 389 // information on the shapes and sometimes determine if they will be 390 // broadcastable with that. 391 struct CstrBroadcastablePartialInfo 392 : public OpRewritePattern<CstrBroadcastableOp> { 393 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 394 395 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 396 PatternRewriter &rewriter) const override { 397 SmallVector<int64_t, 6> lhsShape, rhsShape; 398 if (failed(getShapeVec(op.lhs(), lhsShape))) 399 return failure(); 400 if (failed(getShapeVec(op.rhs(), rhsShape))) 401 return failure(); 402 if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 403 return failure(); 404 405 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 406 return success(); 407 } 408 }; 409 410 // Scalars are always broadcastable. 411 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> { 412 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 413 414 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 415 PatternRewriter &rewriter) const override { 416 SmallVector<int64_t, 6> shape; 417 if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0) 418 return failure(); 419 if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0) 420 return failure(); 421 422 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 423 return success(); 424 } 425 }; 426 427 } // namespace 428 429 void CstrBroadcastableOp::getCanonicalizationPatterns( 430 OwningRewritePatternList &patterns, MLIRContext *context) { 431 // Canonicalization patterns have overlap with the considerations during 432 // folding in case additional shape information is inferred at some point that 433 // does not result in folding. 434 patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo, 435 CstrBroadcastableScalar>(context); 436 } 437 438 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 439 // Both operands are not needed if one is a scalar. 440 if (operands[0] && 441 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 442 return BoolAttr::get(true, getContext()); 443 if (operands[1] && 444 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 445 return BoolAttr::get(true, getContext()); 446 447 if (operands[0] && operands[1]) { 448 auto lhsShape = llvm::to_vector<6>( 449 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 450 auto rhsShape = llvm::to_vector<6>( 451 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 452 SmallVector<int64_t, 6> resultShape; 453 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 454 return BoolAttr::get(true, getContext()); 455 } 456 457 // Lastly, see if folding can be completed based on what constraints are known 458 // on the input shapes. 459 SmallVector<int64_t, 6> lhsShape, rhsShape; 460 if (failed(getShapeVec(lhs(), lhsShape))) 461 return nullptr; 462 if (failed(getShapeVec(rhs(), rhsShape))) 463 return nullptr; 464 465 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 466 return BoolAttr::get(true, getContext()); 467 468 // Because a failing witness result here represents an eventual assertion 469 // failure, we do not replace it with a constant witness. 470 return nullptr; 471 } 472 473 //===----------------------------------------------------------------------===// 474 // CstrEqOp 475 //===----------------------------------------------------------------------===// 476 477 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 478 MLIRContext *context) { 479 // If inputs are equal, return passing witness 480 patterns.insert<CstrEqEqOps>(context); 481 } 482 483 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 484 if (llvm::all_of(operands, 485 [&](Attribute a) { return a && a == operands[0]; })) 486 return BoolAttr::get(true, getContext()); 487 488 // Because a failing witness result here represents an eventual assertion 489 // failure, we do not try to replace it with a constant witness. Similarly, we 490 // cannot if there are any non-const inputs. 491 return nullptr; 492 } 493 494 //===----------------------------------------------------------------------===// 495 // ConstSizeOp 496 //===----------------------------------------------------------------------===// 497 498 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 499 int64_t value) { 500 build(builder, result, builder.getIndexAttr(value)); 501 } 502 503 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 504 505 void ConstSizeOp::getAsmResultNames( 506 llvm::function_ref<void(Value, StringRef)> setNameFn) { 507 SmallString<4> buffer; 508 llvm::raw_svector_ostream os(buffer); 509 os << "c" << value(); 510 setNameFn(getResult(), os.str()); 511 } 512 513 //===----------------------------------------------------------------------===// 514 // ConstWitnessOp 515 //===----------------------------------------------------------------------===// 516 517 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 518 519 //===----------------------------------------------------------------------===// 520 // ShapeEqOp 521 //===----------------------------------------------------------------------===// 522 523 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 524 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 525 if (lhs == nullptr) 526 return {}; 527 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 528 if (rhs == nullptr) 529 return {}; 530 return BoolAttr::get(lhs == rhs, getContext()); 531 } 532 533 //===----------------------------------------------------------------------===// 534 // IndexToSizeOp 535 //===----------------------------------------------------------------------===// 536 537 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 538 // Constant values of both types, `shape.size` and `index`, are represented as 539 // `IntegerAttr`s which makes constant folding simple. 540 if (Attribute arg = operands[0]) 541 return arg; 542 return {}; 543 } 544 545 void IndexToSizeOp::getCanonicalizationPatterns( 546 OwningRewritePatternList &patterns, MLIRContext *context) { 547 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // FromExtentsOp 552 //===----------------------------------------------------------------------===// 553 554 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 555 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 556 return nullptr; 557 SmallVector<int64_t, 6> extents; 558 for (auto attr : operands) 559 extents.push_back(attr.cast<IntegerAttr>().getInt()); 560 Builder builder(getContext()); 561 return builder.getIndexTensorAttr(extents); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // GetExtentOp 566 //===----------------------------------------------------------------------===// 567 568 Optional<int64_t> GetExtentOp::getConstantDim() { 569 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 570 return constSizeOp.value().getLimitedValue(); 571 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 572 return constantOp.value().cast<IntegerAttr>().getInt(); 573 return llvm::None; 574 } 575 576 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 577 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 578 if (!elements) 579 return nullptr; 580 Optional<int64_t> dim = getConstantDim(); 581 if (!dim.hasValue()) 582 return nullptr; 583 if (dim.getValue() >= elements.getNumElements()) 584 return nullptr; 585 return elements.getValue({(uint64_t)dim.getValue()}); 586 } 587 588 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 589 int64_t dim) { 590 auto loc = result.location; 591 auto dimAttr = builder.getIndexAttr(dim); 592 if (shape.getType().isa<ShapeType>()) { 593 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 594 build(builder, result, builder.getType<SizeType>(), shape, dim); 595 } else { 596 Value dim = 597 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 598 build(builder, result, builder.getIndexType(), shape, dim); 599 } 600 } 601 602 //===----------------------------------------------------------------------===// 603 // RankOp 604 //===----------------------------------------------------------------------===// 605 606 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 607 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 608 if (!shape) 609 return {}; 610 int64_t rank = shape.getNumElements(); 611 Builder builder(getContext()); 612 return builder.getIndexAttr(rank); 613 } 614 615 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 616 /// Constant folding fails in cases where only the rank is constant, not the 617 /// shape itself. 618 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 619 /// 620 /// Example: 621 /// 622 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 623 /// %rank = shape.rank %shape 624 /// 625 /// becomes 626 /// 627 /// %rank = shape.const_size 3 628 629 namespace { 630 struct RankShapeOfCanonicalizationPattern 631 : public OpRewritePattern<shape::RankOp> { 632 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 633 634 LogicalResult matchAndRewrite(shape::RankOp op, 635 PatternRewriter &rewriter) const override { 636 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 637 if (!shapeOfOp) 638 return failure(); 639 auto rankedTensorType = 640 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 641 if (!rankedTensorType) 642 return failure(); 643 assert(op.getType().isa<IndexType>() && 644 "expected `rank(shape_of( ... )]` based on a shaped argument to " 645 "yield an index type"); 646 int64_t rank = rankedTensorType.getRank(); 647 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 648 return success(); 649 } 650 }; 651 } // namespace 652 653 void shape::RankOp::getCanonicalizationPatterns( 654 OwningRewritePatternList &patterns, MLIRContext *context) { 655 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // NumElementsOp 660 //===----------------------------------------------------------------------===// 661 662 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 663 664 // Fold only when argument constant. 665 Attribute shape = operands[0]; 666 if (!shape) 667 return {}; 668 669 APInt product(64, 1); 670 for (auto value : shape.cast<DenseIntElementsAttr>()) 671 product *= value; 672 Builder builder(getContext()); 673 return builder.getIndexAttr(product.getLimitedValue()); 674 } 675 676 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 677 Value shape) { 678 if (shape.getType().isa<ShapedType>()) { 679 auto type = builder.getIndexType(); 680 return build(builder, result, type, shape); 681 } 682 auto type = SizeType::get(builder.getContext()); 683 return build(builder, result, type, shape); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // MulOp 688 //===----------------------------------------------------------------------===// 689 690 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 691 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 692 if (!lhs) 693 return nullptr; 694 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 695 if (!rhs) 696 return nullptr; 697 APInt folded = lhs.getValue() * rhs.getValue(); 698 Type indexTy = IndexType::get(getContext()); 699 return IntegerAttr::get(indexTy, folded); 700 } 701 702 //===----------------------------------------------------------------------===// 703 // ShapeOfOp 704 //===----------------------------------------------------------------------===// 705 706 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 707 auto type = getOperand().getType().dyn_cast<ShapedType>(); 708 if (!type || !type.hasStaticShape()) 709 return nullptr; 710 Builder builder(getContext()); 711 return builder.getIndexTensorAttr(type.getShape()); 712 } 713 714 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 715 Type type = arg.getType().isa<ShapedType>() 716 ? (Type)getExtentTensorType(builder.getContext()) 717 : (Type)builder.getType<ShapeType>(); 718 return ShapeOfOp::build(builder, result, type, arg); 719 } 720 721 namespace { 722 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 723 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 724 725 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 726 PatternRewriter &rewriter) const override { 727 if (!op.arg().getType().isa<ShapedType>()) 728 return failure(); 729 if (op.getType().isa<ShapedType>()) 730 return failure(); 731 732 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 733 return success(); 734 } 735 }; 736 } // namespace 737 738 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 739 MLIRContext *context) { 740 patterns.insert<ShapeOfWithTensor>(context); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // SizeToIndexOp 745 //===----------------------------------------------------------------------===// 746 747 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 748 // Constant values of both types, `shape.size` and `index`, are represented as 749 // `IntegerAttr`s which makes constant folding simple. 750 if (Attribute arg = operands[0]) 751 return arg; 752 return impl::foldCastOp(*this); 753 } 754 755 void SizeToIndexOp::getCanonicalizationPatterns( 756 OwningRewritePatternList &patterns, MLIRContext *context) { 757 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 758 } 759 760 //===----------------------------------------------------------------------===// 761 // YieldOp 762 //===----------------------------------------------------------------------===// 763 764 static LogicalResult verify(YieldOp op) { 765 auto *parentOp = op.getParentOp(); 766 auto results = parentOp->getResults(); 767 auto operands = op.getOperands(); 768 769 if (parentOp->getNumResults() != op.getNumOperands()) 770 return op.emitOpError() << "number of operands does not match number of " 771 "results of its parent"; 772 for (auto e : llvm::zip(results, operands)) 773 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 774 return op.emitOpError() 775 << "types mismatch between yield op and its parent"; 776 777 return success(); 778 } 779 780 //===----------------------------------------------------------------------===// 781 // SplitAtOp 782 //===----------------------------------------------------------------------===// 783 784 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 785 SmallVectorImpl<OpFoldResult> &results) { 786 if (!operands[0] || !operands[1]) 787 return failure(); 788 auto shapeVec = llvm::to_vector<6>( 789 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 790 auto shape = llvm::makeArrayRef(shapeVec); 791 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 792 // Verify that the split point is in the correct range. 793 // TODO: Constant fold to an "error". 794 int64_t rank = shape.size(); 795 if (!(-rank <= splitPoint && splitPoint <= rank)) 796 return failure(); 797 if (splitPoint < 0) 798 splitPoint += shape.size(); 799 Builder builder(operands[0].getContext()); 800 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 801 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 802 return success(); 803 } 804 805 //===----------------------------------------------------------------------===// 806 // ToExtentTensorOp 807 //===----------------------------------------------------------------------===// 808 809 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 810 if (!operands[0]) 811 return impl::foldCastOp(*this); 812 Builder builder(getContext()); 813 auto shape = llvm::to_vector<6>( 814 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 815 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 816 builder.getIndexType()); 817 return DenseIntElementsAttr::get(type, shape); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // ReduceOp 822 //===----------------------------------------------------------------------===// 823 824 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 825 ValueRange initVals) { 826 result.addOperands(shape); 827 result.addOperands(initVals); 828 829 Region *bodyRegion = result.addRegion(); 830 bodyRegion->push_back(new Block); 831 Block &bodyBlock = bodyRegion->front(); 832 bodyBlock.addArgument(builder.getIndexType()); 833 834 Type elementType; 835 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 836 elementType = tensorType.getElementType(); 837 else 838 elementType = SizeType::get(builder.getContext()); 839 bodyBlock.addArgument(elementType); 840 841 for (Type initValType : initVals.getTypes()) { 842 bodyBlock.addArgument(initValType); 843 result.addTypes(initValType); 844 } 845 } 846 847 static LogicalResult verify(ReduceOp op) { 848 // Verify block arg types. 849 Block &block = op.region().front(); 850 851 // The block takes index, extent, and aggregated values as arguments. 852 auto blockArgsCount = op.initVals().size() + 2; 853 if (block.getNumArguments() != blockArgsCount) 854 return op.emitOpError() << "ReduceOp body is expected to have " 855 << blockArgsCount << " arguments"; 856 857 // The first block argument is the index and must always be of type `index`. 858 if (!block.getArgument(0).getType().isa<IndexType>()) 859 return op.emitOpError( 860 "argument 0 of ReduceOp body is expected to be of IndexType"); 861 862 // The second block argument is the extent and must be of type `size` or 863 // `index`, depending on whether the reduce operation is applied to a shape or 864 // to an extent tensor. 865 Type extentTy = block.getArgument(1).getType(); 866 if (op.shape().getType().isa<ShapeType>()) { 867 if (!extentTy.isa<SizeType>()) 868 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 869 "SizeType if the ReduceOp operates on a ShapeType"); 870 } else { 871 if (!extentTy.isa<IndexType>()) 872 return op.emitOpError( 873 "argument 1 of ReduceOp body is expected to be of IndexType if the " 874 "ReduceOp operates on an extent tensor"); 875 } 876 877 for (auto type : llvm::enumerate(op.initVals())) 878 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 879 return op.emitOpError() 880 << "type mismatch between argument " << type.index() + 2 881 << " of ReduceOp body and initial value " << type.index(); 882 return success(); 883 } 884 885 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 886 // Parse operands. 887 SmallVector<OpAsmParser::OperandType, 3> operands; 888 Type shapeOrExtentTensorType; 889 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 890 OpAsmParser::Delimiter::Paren) || 891 parser.parseColonType(shapeOrExtentTensorType) || 892 parser.parseOptionalArrowTypeList(result.types)) 893 return failure(); 894 895 // Resolve operands. 896 auto initVals = llvm::makeArrayRef(operands).drop_front(); 897 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 898 result.operands) || 899 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 900 result.operands)) 901 return failure(); 902 903 // Parse the body. 904 Region *body = result.addRegion(); 905 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 906 return failure(); 907 908 // Parse attributes. 909 if (parser.parseOptionalAttrDict(result.attributes)) 910 return failure(); 911 912 return success(); 913 } 914 915 static void print(OpAsmPrinter &p, ReduceOp op) { 916 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 917 << ") : " << op.shape().getType(); 918 p.printOptionalArrowTypeList(op.getResultTypes()); 919 p.printRegion(op.region()); 920 p.printOptionalAttrDict(op.getAttrs()); 921 } 922 923 namespace mlir { 924 namespace shape { 925 926 #define GET_OP_CLASSES 927 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 928 929 } // namespace shape 930 } // namespace mlir 931