1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SmallString.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 #include "llvm/Support/raw_ostream.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 25 namespace { 26 #include "ShapeCanonicalization.inc" 27 } 28 29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 30 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 31 } 32 33 static bool isErrorPropagationPossible(TypeRange operandTypes) { 34 for (Type ty : operandTypes) 35 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>()) 36 return true; 37 return false; 38 } 39 40 static LogicalResult verifySizeOrIndexOp(Operation *op) { 41 assert(op != nullptr && op->getNumResults() == 1); 42 Type resultTy = op->getResultTypes().front(); 43 if (isErrorPropagationPossible(op->getOperandTypes())) { 44 if (!resultTy.isa<SizeType>()) 45 return op->emitOpError() 46 << "if at least one of the operands can hold error values then " 47 "the result must be of type `size` to propagate them"; 48 } 49 return success(); 50 } 51 52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 53 assert(op != nullptr && op->getNumResults() == 1); 54 Type resultTy = op->getResultTypes().front(); 55 if (isErrorPropagationPossible(op->getOperandTypes())) { 56 if (!resultTy.isa<ShapeType>()) 57 return op->emitOpError() 58 << "if at least one of the operands can hold error values then " 59 "the result must be of type `shape` to propagate them"; 60 } 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// This class defines the interface for inlining shape dialect ops. 70 struct ShapeInlinerInterface : public DialectInlinerInterface { 71 using DialectInlinerInterface::DialectInlinerInterface; 72 73 // Returns true if the given region 'src' can be inlined into the region 74 // 'dest' that is attached to an operation registered to the current dialect. 75 bool isLegalToInline(Region *dest, Region *src, 76 BlockAndValueMapping &) const final { 77 return true; 78 } 79 80 // Returns true if the given operation 'op', that is registered to this 81 // dialect, can be inlined into the region 'dest' that is attached to an 82 // operation registered to the current dialect. 83 bool isLegalToInline(Operation *op, Region *dest, 84 BlockAndValueMapping &) const final { 85 return true; 86 } 87 }; 88 } // namespace 89 90 void ShapeDialect::initialize() { 91 addOperations< 92 #define GET_OP_LIST 93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 94 >(); 95 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 96 WitnessType>(); 97 addInterfaces<ShapeInlinerInterface>(); 98 // Allow unknown operations during prototyping and testing. As the dialect is 99 // still evolving it makes it simple to start with an unregistered ops and 100 // try different variants before actually defining the op. 101 allowUnknownOperations(); 102 } 103 104 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 105 Attribute value, Type type, 106 Location loc) { 107 if (type.isa<ShapeType>() || 108 type == getExtentTensorType(builder.getContext())) 109 return builder.create<ConstShapeOp>(loc, type, 110 value.cast<DenseIntElementsAttr>()); 111 if (type.isa<SizeType>()) 112 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 113 if (type.isa<WitnessType>()) 114 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 115 if (type.isa<IndexType>()) 116 return builder.create<ConstantOp>(loc, type, value); 117 return nullptr; 118 } 119 120 /// Parse a type registered to this dialect. 121 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 122 StringRef keyword; 123 if (parser.parseKeyword(&keyword)) 124 return Type(); 125 126 if (keyword == "component") 127 return ComponentType::get(getContext()); 128 if (keyword == "element") 129 return ElementType::get(getContext()); 130 if (keyword == "shape") 131 return ShapeType::get(getContext()); 132 if (keyword == "size") 133 return SizeType::get(getContext()); 134 if (keyword == "value_shape") 135 return ValueShapeType::get(getContext()); 136 if (keyword == "witness") 137 return WitnessType::get(getContext()); 138 139 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 140 return Type(); 141 } 142 143 /// Print a type registered to this dialect. 144 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 145 TypeSwitch<Type>(type) 146 .Case<ComponentType>([&](Type) { os << "component"; }) 147 .Case<ElementType>([&](Type) { os << "element"; }) 148 .Case<ShapeType>([&](Type) { os << "shape"; }) 149 .Case<SizeType>([&](Type) { os << "size"; }) 150 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 151 .Case<WitnessType>([&](Type) { os << "witness"; }) 152 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 153 } 154 155 //===----------------------------------------------------------------------===// 156 // AnyOp 157 //===----------------------------------------------------------------------===// 158 159 // TODO: Canonicalization should be implemented for shapes that can be 160 // determined through mixtures of the known dimensions of the inputs. 161 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 162 // Only the last operand is checked because AnyOp is commutative. 163 if (operands.back()) 164 return operands.back(); 165 166 return nullptr; 167 } 168 169 //===----------------------------------------------------------------------===// 170 // AssumingOp 171 //===----------------------------------------------------------------------===// 172 173 static ParseResult parseAssumingOp(OpAsmParser &parser, 174 OperationState &result) { 175 result.regions.reserve(1); 176 Region *doRegion = result.addRegion(); 177 178 auto &builder = parser.getBuilder(); 179 OpAsmParser::OperandType cond; 180 if (parser.parseOperand(cond) || 181 parser.resolveOperand(cond, builder.getType<WitnessType>(), 182 result.operands)) 183 return failure(); 184 185 // Parse optional results type list. 186 if (parser.parseOptionalArrowTypeList(result.types)) 187 return failure(); 188 189 // Parse the region and add a terminator if elided. 190 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 191 return failure(); 192 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 193 194 // Parse the optional attribute list. 195 if (parser.parseOptionalAttrDict(result.attributes)) 196 return failure(); 197 return success(); 198 } 199 200 static void print(OpAsmPrinter &p, AssumingOp op) { 201 bool yieldsResults = !op.results().empty(); 202 203 p << AssumingOp::getOperationName() << " " << op.witness(); 204 if (yieldsResults) { 205 p << " -> (" << op.getResultTypes() << ")"; 206 } 207 p.printRegion(op.doRegion(), 208 /*printEntryBlockArgs=*/false, 209 /*printBlockTerminators=*/yieldsResults); 210 p.printOptionalAttrDict(op.getAttrs()); 211 } 212 213 namespace { 214 // Removes AssumingOp with a passing witness and inlines the region. 215 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 216 using OpRewritePattern<AssumingOp>::OpRewritePattern; 217 218 LogicalResult matchAndRewrite(AssumingOp op, 219 PatternRewriter &rewriter) const override { 220 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 221 if (!witness || !witness.passingAttr()) 222 return failure(); 223 224 AssumingOp::inlineRegionIntoParent(op, rewriter); 225 return success(); 226 } 227 }; 228 } // namespace 229 230 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 231 MLIRContext *context) { 232 // If taking a passing witness, inline region. 233 patterns.insert<AssumingWithTrue>(context); 234 } 235 236 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 237 PatternRewriter &rewriter) { 238 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 239 auto *assumingBlock = op.getBody(); 240 auto initPosition = rewriter.getInsertionPoint(); 241 auto *blockAfterAssuming = 242 rewriter.splitBlock(blockBeforeAssuming, initPosition); 243 244 // Remove the AssumingOp and AssumingYieldOp. 245 auto &yieldOp = assumingBlock->back(); 246 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 247 rewriter.replaceOp(op, yieldOp.getOperands()); 248 rewriter.eraseOp(&yieldOp); 249 250 // Merge blocks together as there was no branching behavior from the 251 // AssumingOp. 252 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 253 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 254 } 255 256 //===----------------------------------------------------------------------===// 257 // AssumingAllOp 258 //===----------------------------------------------------------------------===// 259 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 260 // Iterate in reverse to first handle all constant operands. They are 261 // guaranteed to be the tail of the inputs because this is commutative. 262 for (int idx = operands.size() - 1; idx >= 0; idx--) { 263 Attribute a = operands[idx]; 264 // Cannot fold if any inputs are not constant; 265 if (!a) 266 return nullptr; 267 268 // We do not need to keep statically known values after handling them in 269 // this method. 270 getOperation()->eraseOperand(idx); 271 272 // Always false if any input is statically known false 273 if (!a.cast<BoolAttr>().getValue()) 274 return a; 275 } 276 // If this is reached, all inputs were statically known passing. 277 return BoolAttr::get(true, getContext()); 278 } 279 280 static LogicalResult verify(AssumingAllOp op) { 281 // Ensure that AssumingAllOp contains at least one operand 282 if (op.getNumOperands() == 0) 283 return op.emitOpError("no operands specified"); 284 285 return success(); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // BroadcastOp 290 //===----------------------------------------------------------------------===// 291 292 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 293 if (!operands[1]) 294 return nullptr; 295 296 auto rhsShape = llvm::to_vector<6>( 297 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 298 if (rhsShape.empty()) 299 return lhs(); 300 301 if (!operands[0]) 302 return nullptr; 303 304 auto lhsShape = llvm::to_vector<6>( 305 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 306 if (lhsShape.empty()) 307 return rhs(); 308 309 SmallVector<int64_t, 6> resultShape; 310 // If the shapes are not compatible, we can't fold it. 311 // TODO: Fold to an "error". 312 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 313 return nullptr; 314 Builder builder(getContext()); 315 return builder.getIndexTensorAttr(resultShape); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // ConcatOp 320 //===----------------------------------------------------------------------===// 321 322 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 323 if (!operands[0] || !operands[1]) 324 return nullptr; 325 auto lhsShape = llvm::to_vector<6>( 326 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 327 auto rhsShape = llvm::to_vector<6>( 328 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 329 SmallVector<int64_t, 6> resultShape; 330 resultShape.append(lhsShape.begin(), lhsShape.end()); 331 resultShape.append(rhsShape.begin(), rhsShape.end()); 332 Builder builder(getContext()); 333 return builder.getIndexTensorAttr(resultShape); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // ConstShapeOp 338 //===----------------------------------------------------------------------===// 339 340 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 341 p << "shape.const_shape "; 342 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 343 p << "["; 344 interleaveComma(op.shape().getValues<int64_t>(), p, 345 [&](int64_t i) { p << i; }); 346 p << "] : "; 347 p.printType(op.getType()); 348 } 349 350 static ParseResult parseConstShapeOp(OpAsmParser &parser, 351 OperationState &result) { 352 if (parser.parseOptionalAttrDict(result.attributes)) 353 return failure(); 354 // We piggy-back on ArrayAttr parsing, though we don't internally store the 355 // shape as an ArrayAttr. 356 // TODO: Implement custom parser and maybe make syntax a bit more concise. 357 Attribute extentsRaw; 358 NamedAttrList dummy; 359 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 360 return failure(); 361 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 362 if (!extentsArray) 363 return failure(); 364 SmallVector<int64_t, 6> ints; 365 for (Attribute extent : extentsArray) { 366 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 367 if (!attr) 368 return failure(); 369 ints.push_back(attr.getInt()); 370 } 371 Builder &builder = parser.getBuilder(); 372 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 373 Type resultTy; 374 if (parser.parseColonType(resultTy)) 375 return failure(); 376 result.types.push_back(resultTy); 377 return success(); 378 } 379 380 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 381 382 //===----------------------------------------------------------------------===// 383 // CstrBroadcastableOp 384 //===----------------------------------------------------------------------===// 385 386 namespace { 387 // Given an input shape Value, try to obtain the shape's values. 388 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 389 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 390 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 391 if (!type.hasRank()) 392 return failure(); 393 shapeValues = llvm::to_vector<6>(type.getShape()); 394 return success(); 395 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 396 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 397 return success(); 398 } else { 399 return failure(); 400 } 401 } 402 403 // For shapes that were created by some operations, we can obtain partial 404 // information on the shapes and sometimes determine if they will be 405 // broadcastable with that. 406 struct CstrBroadcastablePartialInfo 407 : public OpRewritePattern<CstrBroadcastableOp> { 408 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 409 410 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 411 PatternRewriter &rewriter) const override { 412 SmallVector<int64_t, 6> lhsShape, rhsShape; 413 if (failed(getShapeVec(op.lhs(), lhsShape))) 414 return failure(); 415 if (failed(getShapeVec(op.rhs(), rhsShape))) 416 return failure(); 417 if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 418 return failure(); 419 420 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 421 return success(); 422 } 423 }; 424 425 // Scalars are always broadcastable. 426 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> { 427 using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern; 428 429 LogicalResult matchAndRewrite(CstrBroadcastableOp op, 430 PatternRewriter &rewriter) const override { 431 SmallVector<int64_t, 6> shape; 432 if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0) 433 return failure(); 434 if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0) 435 return failure(); 436 437 rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true); 438 return success(); 439 } 440 }; 441 442 } // namespace 443 444 void CstrBroadcastableOp::getCanonicalizationPatterns( 445 OwningRewritePatternList &patterns, MLIRContext *context) { 446 // Canonicalization patterns have overlap with the considerations during 447 // folding in case additional shape information is inferred at some point that 448 // does not result in folding. 449 patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo, 450 CstrBroadcastableScalar>(context); 451 } 452 453 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 454 // Both operands are not needed if one is a scalar. 455 if (operands[0] && 456 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 457 return BoolAttr::get(true, getContext()); 458 if (operands[1] && 459 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 460 return BoolAttr::get(true, getContext()); 461 462 if (operands[0] && operands[1]) { 463 auto lhsShape = llvm::to_vector<6>( 464 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 465 auto rhsShape = llvm::to_vector<6>( 466 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 467 SmallVector<int64_t, 6> resultShape; 468 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 469 return BoolAttr::get(true, getContext()); 470 } 471 472 // Lastly, see if folding can be completed based on what constraints are known 473 // on the input shapes. 474 SmallVector<int64_t, 6> lhsShape, rhsShape; 475 if (failed(getShapeVec(lhs(), lhsShape))) 476 return nullptr; 477 if (failed(getShapeVec(rhs(), rhsShape))) 478 return nullptr; 479 480 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 481 return BoolAttr::get(true, getContext()); 482 483 // Because a failing witness result here represents an eventual assertion 484 // failure, we do not replace it with a constant witness. 485 return nullptr; 486 } 487 488 //===----------------------------------------------------------------------===// 489 // CstrEqOp 490 //===----------------------------------------------------------------------===// 491 492 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 493 MLIRContext *context) { 494 // If inputs are equal, return passing witness 495 patterns.insert<CstrEqEqOps>(context); 496 } 497 498 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 499 if (llvm::all_of(operands, 500 [&](Attribute a) { return a && a == operands[0]; })) 501 return BoolAttr::get(true, getContext()); 502 503 // Because a failing witness result here represents an eventual assertion 504 // failure, we do not try to replace it with a constant witness. Similarly, we 505 // cannot if there are any non-const inputs. 506 return nullptr; 507 } 508 509 //===----------------------------------------------------------------------===// 510 // ConstSizeOp 511 //===----------------------------------------------------------------------===// 512 513 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 514 int64_t value) { 515 build(builder, result, builder.getIndexAttr(value)); 516 } 517 518 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 519 520 void ConstSizeOp::getAsmResultNames( 521 llvm::function_ref<void(Value, StringRef)> setNameFn) { 522 SmallString<4> buffer; 523 llvm::raw_svector_ostream os(buffer); 524 os << "c" << value(); 525 setNameFn(getResult(), os.str()); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // ConstWitnessOp 530 //===----------------------------------------------------------------------===// 531 532 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 533 534 //===----------------------------------------------------------------------===// 535 // ShapeEqOp 536 //===----------------------------------------------------------------------===// 537 538 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 539 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 540 if (lhs == nullptr) 541 return {}; 542 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 543 if (rhs == nullptr) 544 return {}; 545 return BoolAttr::get(lhs == rhs, getContext()); 546 } 547 548 //===----------------------------------------------------------------------===// 549 // IndexToSizeOp 550 //===----------------------------------------------------------------------===// 551 552 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 553 // Constant values of both types, `shape.size` and `index`, are represented as 554 // `IntegerAttr`s which makes constant folding simple. 555 if (Attribute arg = operands[0]) 556 return arg; 557 return {}; 558 } 559 560 void IndexToSizeOp::getCanonicalizationPatterns( 561 OwningRewritePatternList &patterns, MLIRContext *context) { 562 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 563 } 564 565 //===----------------------------------------------------------------------===// 566 // FromExtentsOp 567 //===----------------------------------------------------------------------===// 568 569 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 570 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 571 return nullptr; 572 SmallVector<int64_t, 6> extents; 573 for (auto attr : operands) 574 extents.push_back(attr.cast<IntegerAttr>().getInt()); 575 Builder builder(getContext()); 576 return builder.getIndexTensorAttr(extents); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // GetExtentOp 581 //===----------------------------------------------------------------------===// 582 583 Optional<int64_t> GetExtentOp::getConstantDim() { 584 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 585 return constSizeOp.value().getLimitedValue(); 586 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 587 return constantOp.value().cast<IntegerAttr>().getInt(); 588 return llvm::None; 589 } 590 591 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 592 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 593 if (!elements) 594 return nullptr; 595 Optional<int64_t> dim = getConstantDim(); 596 if (!dim.hasValue()) 597 return nullptr; 598 if (dim.getValue() >= elements.getNumElements()) 599 return nullptr; 600 return elements.getValue({(uint64_t)dim.getValue()}); 601 } 602 603 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 604 int64_t dim) { 605 auto loc = result.location; 606 auto dimAttr = builder.getIndexAttr(dim); 607 if (shape.getType().isa<ShapeType>()) { 608 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 609 build(builder, result, builder.getType<SizeType>(), shape, dim); 610 } else { 611 Value dim = 612 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 613 build(builder, result, builder.getIndexType(), shape, dim); 614 } 615 } 616 617 //===----------------------------------------------------------------------===// 618 // RankOp 619 //===----------------------------------------------------------------------===// 620 621 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 622 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 623 if (!shape) 624 return {}; 625 int64_t rank = shape.getNumElements(); 626 Builder builder(getContext()); 627 return builder.getIndexAttr(rank); 628 } 629 630 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 631 /// Constant folding fails in cases where only the rank is constant, not the 632 /// shape itself. 633 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 634 /// 635 /// Example: 636 /// 637 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 638 /// %rank = shape.rank %shape 639 /// 640 /// becomes 641 /// 642 /// %rank = shape.const_size 3 643 644 namespace { 645 struct RankShapeOfCanonicalizationPattern 646 : public OpRewritePattern<shape::RankOp> { 647 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 648 649 LogicalResult matchAndRewrite(shape::RankOp op, 650 PatternRewriter &rewriter) const override { 651 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 652 if (!shapeOfOp) 653 return failure(); 654 auto rankedTensorType = 655 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 656 if (!rankedTensorType) 657 return failure(); 658 int64_t rank = rankedTensorType.getRank(); 659 if (op.getType().isa<IndexType>()) { 660 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 661 } else if (op.getType().isa<shape::SizeType>()) { 662 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 663 } else { 664 return failure(); 665 } 666 return success(); 667 } 668 }; 669 } // namespace 670 671 void shape::RankOp::getCanonicalizationPatterns( 672 OwningRewritePatternList &patterns, MLIRContext *context) { 673 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // NumElementsOp 678 //===----------------------------------------------------------------------===// 679 680 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 681 682 // Fold only when argument constant. 683 Attribute shape = operands[0]; 684 if (!shape) 685 return {}; 686 687 APInt product(64, 1); 688 for (auto value : shape.cast<DenseIntElementsAttr>()) 689 product *= value; 690 Builder builder(getContext()); 691 return builder.getIndexAttr(product.getLimitedValue()); 692 } 693 694 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 695 Value shape) { 696 if (shape.getType().isa<ShapedType>()) { 697 auto type = builder.getIndexType(); 698 return build(builder, result, type, shape); 699 } 700 auto type = SizeType::get(builder.getContext()); 701 return build(builder, result, type, shape); 702 } 703 704 //===----------------------------------------------------------------------===// 705 // MulOp 706 //===----------------------------------------------------------------------===// 707 708 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 709 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 710 if (!lhs) 711 return nullptr; 712 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 713 if (!rhs) 714 return nullptr; 715 APInt folded = lhs.getValue() * rhs.getValue(); 716 Type indexTy = IndexType::get(getContext()); 717 return IntegerAttr::get(indexTy, folded); 718 } 719 720 //===----------------------------------------------------------------------===// 721 // ShapeOfOp 722 //===----------------------------------------------------------------------===// 723 724 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 725 auto type = getOperand().getType().dyn_cast<ShapedType>(); 726 if (!type || !type.hasStaticShape()) 727 return nullptr; 728 Builder builder(getContext()); 729 return builder.getIndexTensorAttr(type.getShape()); 730 } 731 732 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 733 Type type = arg.getType().isa<ShapedType>() 734 ? (Type)getExtentTensorType(builder.getContext()) 735 : (Type)builder.getType<ShapeType>(); 736 return ShapeOfOp::build(builder, result, type, arg); 737 } 738 739 namespace { 740 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 741 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 742 743 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 744 PatternRewriter &rewriter) const override { 745 if (!op.arg().getType().isa<ShapedType>()) 746 return failure(); 747 if (op.getType().isa<ShapedType>()) 748 return failure(); 749 750 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 751 return success(); 752 } 753 }; 754 } // namespace 755 756 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 757 MLIRContext *context) { 758 patterns.insert<ShapeOfWithTensor>(context); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // SizeToIndexOp 763 //===----------------------------------------------------------------------===// 764 765 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 766 // Constant values of both types, `shape.size` and `index`, are represented as 767 // `IntegerAttr`s which makes constant folding simple. 768 if (Attribute arg = operands[0]) 769 return arg; 770 return impl::foldCastOp(*this); 771 } 772 773 void SizeToIndexOp::getCanonicalizationPatterns( 774 OwningRewritePatternList &patterns, MLIRContext *context) { 775 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 776 } 777 778 //===----------------------------------------------------------------------===// 779 // YieldOp 780 //===----------------------------------------------------------------------===// 781 782 static LogicalResult verify(shape::YieldOp op) { 783 auto *parentOp = op.getParentOp(); 784 auto results = parentOp->getResults(); 785 auto operands = op.getOperands(); 786 787 if (parentOp->getNumResults() != op.getNumOperands()) 788 return op.emitOpError() << "number of operands does not match number of " 789 "results of its parent"; 790 for (auto e : llvm::zip(results, operands)) 791 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 792 return op.emitOpError() 793 << "types mismatch between yield op and its parent"; 794 795 return success(); 796 } 797 798 //===----------------------------------------------------------------------===// 799 // SplitAtOp 800 //===----------------------------------------------------------------------===// 801 802 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 803 SmallVectorImpl<OpFoldResult> &results) { 804 if (!operands[0] || !operands[1]) 805 return failure(); 806 auto shapeVec = llvm::to_vector<6>( 807 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 808 auto shape = llvm::makeArrayRef(shapeVec); 809 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 810 // Verify that the split point is in the correct range. 811 // TODO: Constant fold to an "error". 812 int64_t rank = shape.size(); 813 if (!(-rank <= splitPoint && splitPoint <= rank)) 814 return failure(); 815 if (splitPoint < 0) 816 splitPoint += shape.size(); 817 Builder builder(operands[0].getContext()); 818 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 819 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 820 return success(); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // ToExtentTensorOp 825 //===----------------------------------------------------------------------===// 826 827 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 828 if (!operands[0]) 829 return impl::foldCastOp(*this); 830 Builder builder(getContext()); 831 auto shape = llvm::to_vector<6>( 832 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 833 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 834 builder.getIndexType()); 835 return DenseIntElementsAttr::get(type, shape); 836 } 837 838 //===----------------------------------------------------------------------===// 839 // ReduceOp 840 //===----------------------------------------------------------------------===// 841 842 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 843 ValueRange initVals) { 844 result.addOperands(shape); 845 result.addOperands(initVals); 846 847 Region *bodyRegion = result.addRegion(); 848 bodyRegion->push_back(new Block); 849 Block &bodyBlock = bodyRegion->front(); 850 bodyBlock.addArgument(builder.getIndexType()); 851 852 Type elementType; 853 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 854 elementType = tensorType.getElementType(); 855 else 856 elementType = SizeType::get(builder.getContext()); 857 bodyBlock.addArgument(elementType); 858 859 for (Type initValType : initVals.getTypes()) { 860 bodyBlock.addArgument(initValType); 861 result.addTypes(initValType); 862 } 863 } 864 865 static LogicalResult verify(ReduceOp op) { 866 // Verify block arg types. 867 Block &block = op.region().front(); 868 869 // The block takes index, extent, and aggregated values as arguments. 870 auto blockArgsCount = op.initVals().size() + 2; 871 if (block.getNumArguments() != blockArgsCount) 872 return op.emitOpError() << "ReduceOp body is expected to have " 873 << blockArgsCount << " arguments"; 874 875 // The first block argument is the index and must always be of type `index`. 876 if (!block.getArgument(0).getType().isa<IndexType>()) 877 return op.emitOpError( 878 "argument 0 of ReduceOp body is expected to be of IndexType"); 879 880 // The second block argument is the extent and must be of type `size` or 881 // `index`, depending on whether the reduce operation is applied to a shape or 882 // to an extent tensor. 883 Type extentTy = block.getArgument(1).getType(); 884 if (op.shape().getType().isa<ShapeType>()) { 885 if (!extentTy.isa<SizeType>()) 886 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 887 "SizeType if the ReduceOp operates on a ShapeType"); 888 } else { 889 if (!extentTy.isa<IndexType>()) 890 return op.emitOpError( 891 "argument 1 of ReduceOp body is expected to be of IndexType if the " 892 "ReduceOp operates on an extent tensor"); 893 } 894 895 for (auto type : llvm::enumerate(op.initVals())) 896 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 897 return op.emitOpError() 898 << "type mismatch between argument " << type.index() + 2 899 << " of ReduceOp body and initial value " << type.index(); 900 return success(); 901 } 902 903 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 904 // Parse operands. 905 SmallVector<OpAsmParser::OperandType, 3> operands; 906 Type shapeOrExtentTensorType; 907 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 908 OpAsmParser::Delimiter::Paren) || 909 parser.parseColonType(shapeOrExtentTensorType) || 910 parser.parseOptionalArrowTypeList(result.types)) 911 return failure(); 912 913 // Resolve operands. 914 auto initVals = llvm::makeArrayRef(operands).drop_front(); 915 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 916 result.operands) || 917 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 918 result.operands)) 919 return failure(); 920 921 // Parse the body. 922 Region *body = result.addRegion(); 923 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 924 return failure(); 925 926 // Parse attributes. 927 if (parser.parseOptionalAttrDict(result.attributes)) 928 return failure(); 929 930 return success(); 931 } 932 933 static void print(OpAsmPrinter &p, ReduceOp op) { 934 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 935 << ") : " << op.shape().getType(); 936 p.printOptionalArrowTypeList(op.getResultTypes()); 937 p.printRegion(op.region()); 938 p.printOptionalAttrDict(op.getAttrs()); 939 } 940 941 namespace mlir { 942 namespace shape { 943 944 #define GET_OP_CLASSES 945 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 946 947 } // namespace shape 948 } // namespace mlir 949