1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Traits.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/IR/StandardTypes.h" 16 #include "llvm/ADT/SmallString.h" 17 #include "llvm/Support/raw_ostream.h" 18 19 using namespace mlir; 20 using namespace mlir::shape; 21 22 namespace { 23 #include "ShapeCanonicalization.inc" 24 } 25 26 ShapeDialect::ShapeDialect(MLIRContext *context) 27 : Dialect(getDialectNamespace(), context) { 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 31 >(); 32 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 33 WitnessType>(); 34 // Allow unknown operations during prototyping and testing. As the dialect is 35 // still evolving it makes it simple to start with an unregistered ops and 36 // try different variants before actually defining the op. 37 allowUnknownOperations(); 38 } 39 40 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 41 Attribute value, Type type, 42 Location loc) { 43 if (auto shapeType = type.dyn_cast<ShapeType>()) 44 return builder.create<ConstShapeOp>(loc, type, 45 value.cast<DenseIntElementsAttr>()); 46 if (auto sizeType = type.dyn_cast<SizeType>()) 47 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 48 if (auto witnessType = type.dyn_cast<WitnessType>()) 49 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 50 return nullptr; 51 } 52 53 /// Parse a type registered to this dialect. 54 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 55 StringRef keyword; 56 if (parser.parseKeyword(&keyword)) 57 return Type(); 58 59 if (keyword == "component") 60 return ComponentType::get(getContext()); 61 if (keyword == "element") 62 return ElementType::get(getContext()); 63 if (keyword == "shape") 64 return ShapeType::get(getContext()); 65 if (keyword == "size") 66 return SizeType::get(getContext()); 67 if (keyword == "value_shape") 68 return ValueShapeType::get(getContext()); 69 if (keyword == "witness") 70 return WitnessType::get(getContext()); 71 72 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 73 return Type(); 74 } 75 76 /// Print a type registered to this dialect. 77 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 78 switch (type.getKind()) { 79 case ShapeTypes::Component: 80 os << "component"; 81 return; 82 case ShapeTypes::Element: 83 os << "element"; 84 return; 85 case ShapeTypes::Size: 86 os << "size"; 87 return; 88 case ShapeTypes::Shape: 89 os << "shape"; 90 return; 91 case ShapeTypes::ValueShape: 92 os << "value_shape"; 93 return; 94 case ShapeTypes::Witness: 95 os << "witness"; 96 return; 97 default: 98 llvm_unreachable("unexpected 'shape' type kind"); 99 } 100 } 101 102 //===----------------------------------------------------------------------===// 103 // AnyOp 104 //===----------------------------------------------------------------------===// 105 106 // TODO: Canonicalization should be implemented for shapes that can be 107 // determined through mixtures of the known dimensions of the inputs. 108 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 109 // Only the last operand is checked because AnyOp is commutative. 110 if (operands.back()) 111 return operands.back(); 112 113 return nullptr; 114 } 115 116 //===----------------------------------------------------------------------===// 117 // AssumingOp 118 //===----------------------------------------------------------------------===// 119 120 static ParseResult parseAssumingOp(OpAsmParser &parser, 121 OperationState &result) { 122 result.regions.reserve(1); 123 Region *doRegion = result.addRegion(); 124 125 auto &builder = parser.getBuilder(); 126 OpAsmParser::OperandType cond; 127 if (parser.parseOperand(cond) || 128 parser.resolveOperand(cond, builder.getType<WitnessType>(), 129 result.operands)) 130 return failure(); 131 132 // Parse optional results type list. 133 if (parser.parseOptionalArrowTypeList(result.types)) 134 return failure(); 135 136 // Parse the region and add a terminator if elided. 137 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 138 return failure(); 139 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 140 141 // Parse the optional attribute list. 142 if (parser.parseOptionalAttrDict(result.attributes)) 143 return failure(); 144 return success(); 145 } 146 147 static void print(OpAsmPrinter &p, AssumingOp op) { 148 bool yieldsResults = !op.results().empty(); 149 150 p << AssumingOp::getOperationName() << " " << op.witness(); 151 if (yieldsResults) { 152 p << " -> (" << op.getResultTypes() << ")"; 153 } 154 p.printRegion(op.doRegion(), 155 /*printEntryBlockArgs=*/false, 156 /*printBlockTerminators=*/yieldsResults); 157 p.printOptionalAttrDict(op.getAttrs()); 158 } 159 160 namespace { 161 // Removes AssumingOp with a passing witness and inlines the region. 162 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 163 using OpRewritePattern<AssumingOp>::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(AssumingOp op, 166 PatternRewriter &rewriter) const override { 167 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 168 if (!witness || !witness.passingAttr()) 169 return failure(); 170 171 AssumingOp::inlineRegionIntoParent(op, rewriter); 172 return success(); 173 } 174 }; 175 } // namespace 176 177 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 178 MLIRContext *context) { 179 // If taking a passing witness, inline region. 180 patterns.insert<AssumingWithTrue>(context); 181 } 182 183 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 184 PatternRewriter &rewriter) { 185 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 186 auto *assumingBlock = op.getBody(); 187 auto initPosition = rewriter.getInsertionPoint(); 188 auto *blockAfterAssuming = 189 rewriter.splitBlock(blockBeforeAssuming, initPosition); 190 191 // Remove the AssumingOp and AssumingYieldOp. 192 auto &yieldOp = assumingBlock->back(); 193 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 194 rewriter.replaceOp(op, yieldOp.getOperands()); 195 rewriter.eraseOp(&yieldOp); 196 197 // Merge blocks together as there was no branching behavior from the 198 // AssumingOp. 199 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 200 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // AssumingAllOp 205 //===----------------------------------------------------------------------===// 206 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 207 // Iterate in reverse to first handle all constant operands. They are 208 // guaranteed to be the tail of the inputs because this is commutative. 209 for (int idx = operands.size() - 1; idx >= 0; idx--) { 210 Attribute a = operands[idx]; 211 // Cannot fold if any inputs are not constant; 212 if (!a) 213 return nullptr; 214 215 // We do not need to keep statically known values after handling them in 216 // this method. 217 getOperation()->eraseOperand(idx); 218 219 // Always false if any input is statically known false 220 if (!a.cast<BoolAttr>().getValue()) 221 return a; 222 } 223 // If this is reached, all inputs were statically known passing. 224 return BoolAttr::get(true, getContext()); 225 } 226 227 static LogicalResult verify(AssumingAllOp op) { 228 // Ensure that AssumingAllOp contains at least one operand 229 if (op.getNumOperands() == 0) 230 return op.emitOpError("no operands specified"); 231 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // BroadcastOp 237 //===----------------------------------------------------------------------===// 238 239 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 240 if (!operands[0] || !operands[1]) 241 return nullptr; 242 auto lhsShape = llvm::to_vector<6>( 243 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 244 auto rhsShape = llvm::to_vector<6>( 245 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 246 SmallVector<int64_t, 6> resultShape; 247 // If the shapes are not compatible, we can't fold it. 248 // TODO: Fold to an "error". 249 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 250 return nullptr; 251 Builder builder(getContext()); 252 return builder.getIndexTensorAttr(resultShape); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // ConcatOp 257 //===----------------------------------------------------------------------===// 258 259 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 260 if (!operands[0] || !operands[1]) 261 return nullptr; 262 auto lhsShape = llvm::to_vector<6>( 263 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 264 auto rhsShape = llvm::to_vector<6>( 265 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 266 SmallVector<int64_t, 6> resultShape; 267 resultShape.append(lhsShape.begin(), lhsShape.end()); 268 resultShape.append(rhsShape.begin(), rhsShape.end()); 269 Builder builder(getContext()); 270 return builder.getIndexTensorAttr(resultShape); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // ConstShapeOp 275 //===----------------------------------------------------------------------===// 276 277 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 278 p << "shape.const_shape "; 279 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 280 p << "["; 281 interleaveComma(op.shape().getValues<int64_t>(), p, 282 [&](int64_t i) { p << i; }); 283 p << "]"; 284 } 285 286 static ParseResult parseConstShapeOp(OpAsmParser &parser, 287 OperationState &result) { 288 if (parser.parseOptionalAttrDict(result.attributes)) 289 return failure(); 290 // We piggy-back on ArrayAttr parsing, though we don't internally store the 291 // shape as an ArrayAttr. 292 // TODO: Implement custom parser and maybe make syntax a bit more concise. 293 Attribute extentsRaw; 294 NamedAttrList dummy; 295 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 296 return failure(); 297 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 298 if (!extentsArray) 299 return failure(); 300 SmallVector<int64_t, 6> ints; 301 for (Attribute extent : extentsArray) { 302 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 303 if (!attr) 304 return failure(); 305 ints.push_back(attr.getInt()); 306 } 307 Builder &builder = parser.getBuilder(); 308 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 309 310 result.types.push_back(ShapeType::get(builder.getContext())); 311 return success(); 312 } 313 314 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 315 316 //===----------------------------------------------------------------------===// 317 // CstrBroadcastableOp 318 //===----------------------------------------------------------------------===// 319 320 void CstrBroadcastableOp::getCanonicalizationPatterns( 321 OwningRewritePatternList &patterns, MLIRContext *context) { 322 // If inputs are equal, return passing witness 323 patterns.insert<CstrBroadcastableEqOps>(context); 324 } 325 326 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 327 if (!operands[0] || !operands[1]) 328 return nullptr; 329 auto lhsShape = llvm::to_vector<6>( 330 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 331 auto rhsShape = llvm::to_vector<6>( 332 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 333 SmallVector<int64_t, 6> resultShape; 334 if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 335 return BoolAttr::get(true, getContext()); 336 337 // Because a failing witness result here represents an eventual assertion 338 // failure, we do not replace it with a constant witness. 339 return nullptr; 340 } 341 342 //===----------------------------------------------------------------------===// 343 // CstrEqOp 344 //===----------------------------------------------------------------------===// 345 346 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 347 MLIRContext *context) { 348 // If inputs are equal, return passing witness 349 patterns.insert<CstrEqEqOps>(context); 350 } 351 352 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 353 if (llvm::all_of(operands, 354 [&](Attribute a) { return a && a == operands[0]; })) 355 return BoolAttr::get(true, getContext()); 356 357 // Because a failing witness result here represents an eventual assertion 358 // failure, we do not try to replace it with a constant witness. Similarly, we 359 // cannot if there are any non-const inputs. 360 return nullptr; 361 } 362 363 //===----------------------------------------------------------------------===// 364 // ConstSizeOp 365 //===----------------------------------------------------------------------===// 366 367 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 368 369 void ConstSizeOp::getAsmResultNames( 370 llvm::function_ref<void(Value, StringRef)> setNameFn) { 371 SmallString<4> buffer; 372 llvm::raw_svector_ostream os(buffer); 373 os << "c" << value(); 374 setNameFn(getResult(), os.str()); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // ConstWitnessOp 379 //===----------------------------------------------------------------------===// 380 381 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 382 383 //===----------------------------------------------------------------------===// 384 // IndexToSizeOp 385 //===----------------------------------------------------------------------===// 386 387 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 388 // Constant values of both types, `shape.size` and `index`, are represented as 389 // `IntegerAttr`s which makes constant folding simple. 390 if (Attribute arg = operands[0]) 391 return arg; 392 return {}; 393 } 394 395 //===----------------------------------------------------------------------===// 396 // FromExtentsOp 397 //===----------------------------------------------------------------------===// 398 399 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 400 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 401 return nullptr; 402 SmallVector<int64_t, 6> extents; 403 for (auto attr : operands) 404 extents.push_back(attr.cast<IntegerAttr>().getInt()); 405 Builder builder(getContext()); 406 return builder.getIndexTensorAttr(extents); 407 } 408 409 //===----------------------------------------------------------------------===// 410 // GetExtentOp 411 //===----------------------------------------------------------------------===// 412 413 Optional<int64_t> GetExtentOp::getConstantDim() { 414 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) { 415 return constSizeOp.value().getLimitedValue(); 416 } 417 return llvm::None; 418 } 419 420 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 421 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 422 if (!elements) 423 return nullptr; 424 Optional<int64_t> dim = getConstantDim(); 425 if (!dim.hasValue()) 426 return nullptr; 427 if (dim.getValue() >= elements.getNumElements()) 428 return nullptr; 429 return elements.getValue({(uint64_t)dim.getValue()}); 430 } 431 432 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 433 int64_t dim) { 434 auto loc = result.location; 435 auto dimAttr = builder.getIndexAttr(dim); 436 Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr); 437 build(builder, result, shape, dimValue); 438 } 439 440 //===----------------------------------------------------------------------===// 441 // NumElementsOp 442 //===----------------------------------------------------------------------===// 443 444 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 445 446 // Fold only when argument constant. 447 Attribute shape = operands[0]; 448 if (!shape) 449 return {}; 450 451 APInt product(64, 1); 452 for (auto value : shape.cast<DenseIntElementsAttr>()) 453 product *= value; 454 Builder builder(getContext()); 455 return builder.getIndexAttr(product.getLimitedValue()); 456 } 457 458 //===----------------------------------------------------------------------===// 459 // ShapeOfOp 460 //===----------------------------------------------------------------------===// 461 462 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 463 auto type = getOperand().getType().dyn_cast<ShapedType>(); 464 if (!type || !type.hasStaticShape()) 465 return nullptr; 466 Builder builder(getContext()); 467 return builder.getIndexTensorAttr(type.getShape()); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // SizeToIndexOp 472 //===----------------------------------------------------------------------===// 473 474 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 475 // Constant values of both types, `shape.size` and `index`, are represented as 476 // `IntegerAttr`s which makes constant folding simple. 477 if (Attribute arg = operands[0]) 478 return arg; 479 return {}; 480 } 481 482 //===----------------------------------------------------------------------===// 483 // YieldOp 484 //===----------------------------------------------------------------------===// 485 486 static LogicalResult verify(YieldOp op) { 487 auto *parentOp = op.getParentOp(); 488 auto results = parentOp->getResults(); 489 auto operands = op.getOperands(); 490 491 if (parentOp->getNumResults() != op.getNumOperands()) 492 return op.emitOpError() << "number of operands does not match number of " 493 "results of its parent"; 494 for (auto e : llvm::zip(results, operands)) 495 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 496 return op.emitOpError() 497 << "types mismatch between yield op and its parent"; 498 499 return success(); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // SplitAtOp 504 //===----------------------------------------------------------------------===// 505 506 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 507 SmallVectorImpl<OpFoldResult> &results) { 508 if (!operands[0] || !operands[1]) 509 return failure(); 510 auto shapeVec = llvm::to_vector<6>( 511 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 512 auto shape = llvm::makeArrayRef(shapeVec); 513 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 514 // Verify that the split point is in the correct range. 515 // TODO: Constant fold to an "error". 516 int64_t rank = shape.size(); 517 if (!(-rank <= splitPoint && splitPoint <= rank)) 518 return failure(); 519 if (splitPoint < 0) 520 splitPoint += shape.size(); 521 Builder builder(operands[0].getContext()); 522 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 523 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 524 return success(); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // ToExtentTensorOp 529 //===----------------------------------------------------------------------===// 530 531 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 532 if (!operands[0]) 533 return nullptr; 534 Builder builder(getContext()); 535 auto shape = llvm::to_vector<6>( 536 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 537 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 538 builder.getIndexType()); 539 return DenseIntElementsAttr::get(type, shape); 540 } 541 542 //===----------------------------------------------------------------------===// 543 // ReduceOp 544 //===----------------------------------------------------------------------===// 545 546 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 547 ValueRange initVals) { 548 result.addOperands(shape); 549 result.addOperands(initVals); 550 551 Region *bodyRegion = result.addRegion(); 552 bodyRegion->push_back(new Block); 553 Block &bodyBlock = bodyRegion->front(); 554 bodyBlock.addArgument(builder.getIndexType()); 555 bodyBlock.addArgument(SizeType::get(builder.getContext())); 556 557 for (Type initValType : initVals.getTypes()) { 558 bodyBlock.addArgument(initValType); 559 result.addTypes(initValType); 560 } 561 } 562 563 static LogicalResult verify(ReduceOp op) { 564 // Verify block arg types. 565 Block &block = op.region().front(); 566 567 auto blockArgsCount = op.initVals().size() + 2; 568 if (block.getNumArguments() != blockArgsCount) 569 return op.emitOpError() << "ReduceOp body is expected to have " 570 << blockArgsCount << " arguments"; 571 572 if (block.getArgument(0).getType() != IndexType::get(op.getContext())) 573 return op.emitOpError( 574 "argument 0 of ReduceOp body is expected to be of IndexType"); 575 576 if (block.getArgument(1).getType() != SizeType::get(op.getContext())) 577 return op.emitOpError( 578 "argument 1 of ReduceOp body is expected to be of SizeType"); 579 580 for (auto type : llvm::enumerate(op.initVals())) 581 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 582 return op.emitOpError() 583 << "type mismatch between argument " << type.index() + 2 584 << " of ReduceOp body and initial value " << type.index(); 585 return success(); 586 } 587 588 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 589 auto *ctx = parser.getBuilder().getContext(); 590 // Parse operands. 591 SmallVector<OpAsmParser::OperandType, 3> operands; 592 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 593 OpAsmParser::Delimiter::Paren) || 594 parser.parseOptionalArrowTypeList(result.types)) 595 return failure(); 596 597 // Resolve operands. 598 auto initVals = llvm::makeArrayRef(operands).drop_front(); 599 if (parser.resolveOperand(operands.front(), ShapeType::get(ctx), 600 result.operands) || 601 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 602 result.operands)) 603 return failure(); 604 605 // Parse the body. 606 Region *body = result.addRegion(); 607 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 608 return failure(); 609 610 // Parse attributes. 611 if (parser.parseOptionalAttrDict(result.attributes)) 612 return failure(); 613 614 return success(); 615 } 616 617 static void print(OpAsmPrinter &p, ReduceOp op) { 618 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 619 << ") "; 620 p.printOptionalArrowTypeList(op.getResultTypes()); 621 p.printRegion(op.region()); 622 p.printOptionalAttrDict(op.getAttrs()); 623 } 624 625 namespace mlir { 626 namespace shape { 627 628 #define GET_OP_CLASSES 629 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 630 631 } // namespace shape 632 } // namespace mlir 633