1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SmallString.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 #include "llvm/Support/raw_ostream.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 25 namespace { 26 #include "ShapeCanonicalization.inc" 27 } 28 29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 30 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 31 } 32 33 static bool isErrorPropagationPossible(TypeRange operandTypes) { 34 for (Type ty : operandTypes) 35 if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>()) 36 return true; 37 return false; 38 } 39 40 static LogicalResult verifySizeOrIndexOp(Operation *op) { 41 assert(op != nullptr && op->getNumResults() == 1); 42 Type resultTy = op->getResultTypes().front(); 43 if (isErrorPropagationPossible(op->getOperandTypes())) { 44 if (!resultTy.isa<SizeType>()) 45 return op->emitOpError() 46 << "if at least one of the operands can hold error values then " 47 "the result must be of type `size` to propagate them"; 48 } 49 return success(); 50 } 51 52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 53 assert(op != nullptr && op->getNumResults() == 1); 54 Type resultTy = op->getResultTypes().front(); 55 if (isErrorPropagationPossible(op->getOperandTypes())) { 56 if (!resultTy.isa<ShapeType>()) 57 return op->emitOpError() 58 << "if at least one of the operands can hold error values then " 59 "the result must be of type `shape` to propagate them"; 60 } 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// This class defines the interface for inlining shape dialect ops. 70 struct ShapeInlinerInterface : public DialectInlinerInterface { 71 using DialectInlinerInterface::DialectInlinerInterface; 72 73 // Returns true if the given region 'src' can be inlined into the region 74 // 'dest' that is attached to an operation registered to the current dialect. 75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 76 BlockAndValueMapping &) const final { 77 return true; 78 } 79 80 // Returns true if the given operation 'op', that is registered to this 81 // dialect, can be inlined into the region 'dest' that is attached to an 82 // operation registered to the current dialect. 83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 84 BlockAndValueMapping &) const final { 85 return true; 86 } 87 }; 88 } // namespace 89 90 void ShapeDialect::initialize() { 91 addOperations< 92 #define GET_OP_LIST 93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 94 >(); 95 addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType, 96 WitnessType>(); 97 addInterfaces<ShapeInlinerInterface>(); 98 // Allow unknown operations during prototyping and testing. As the dialect is 99 // still evolving it makes it simple to start with an unregistered ops and 100 // try different variants before actually defining the op. 101 allowUnknownOperations(); 102 } 103 104 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 105 Attribute value, Type type, 106 Location loc) { 107 if (type.isa<ShapeType>() || 108 type == getExtentTensorType(builder.getContext())) 109 return builder.create<ConstShapeOp>(loc, type, 110 value.cast<DenseIntElementsAttr>()); 111 if (type.isa<SizeType>()) 112 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 113 if (type.isa<WitnessType>()) 114 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 115 if (type.isa<IndexType>()) 116 return builder.create<ConstantOp>(loc, type, value); 117 return nullptr; 118 } 119 120 /// Parse a type registered to this dialect. 121 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 122 StringRef keyword; 123 if (parser.parseKeyword(&keyword)) 124 return Type(); 125 126 if (keyword == "component") 127 return ComponentType::get(getContext()); 128 if (keyword == "element") 129 return ElementType::get(getContext()); 130 if (keyword == "shape") 131 return ShapeType::get(getContext()); 132 if (keyword == "size") 133 return SizeType::get(getContext()); 134 if (keyword == "value_shape") 135 return ValueShapeType::get(getContext()); 136 if (keyword == "witness") 137 return WitnessType::get(getContext()); 138 139 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 140 return Type(); 141 } 142 143 /// Print a type registered to this dialect. 144 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 145 TypeSwitch<Type>(type) 146 .Case<ComponentType>([&](Type) { os << "component"; }) 147 .Case<ElementType>([&](Type) { os << "element"; }) 148 .Case<ShapeType>([&](Type) { os << "shape"; }) 149 .Case<SizeType>([&](Type) { os << "size"; }) 150 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 151 .Case<WitnessType>([&](Type) { os << "witness"; }) 152 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 153 } 154 155 //===----------------------------------------------------------------------===// 156 // AnyOp 157 //===----------------------------------------------------------------------===// 158 159 // TODO: Canonicalization should be implemented for shapes that can be 160 // determined through mixtures of the known dimensions of the inputs. 161 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 162 // Only the last operand is checked because AnyOp is commutative. 163 if (operands.back()) 164 return operands.back(); 165 166 return nullptr; 167 } 168 169 //===----------------------------------------------------------------------===// 170 // AssumingOp 171 //===----------------------------------------------------------------------===// 172 173 static ParseResult parseAssumingOp(OpAsmParser &parser, 174 OperationState &result) { 175 result.regions.reserve(1); 176 Region *doRegion = result.addRegion(); 177 178 auto &builder = parser.getBuilder(); 179 OpAsmParser::OperandType cond; 180 if (parser.parseOperand(cond) || 181 parser.resolveOperand(cond, builder.getType<WitnessType>(), 182 result.operands)) 183 return failure(); 184 185 // Parse optional results type list. 186 if (parser.parseOptionalArrowTypeList(result.types)) 187 return failure(); 188 189 // Parse the region and add a terminator if elided. 190 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 191 return failure(); 192 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 193 194 // Parse the optional attribute list. 195 if (parser.parseOptionalAttrDict(result.attributes)) 196 return failure(); 197 return success(); 198 } 199 200 static void print(OpAsmPrinter &p, AssumingOp op) { 201 bool yieldsResults = !op.results().empty(); 202 203 p << AssumingOp::getOperationName() << " " << op.witness(); 204 if (yieldsResults) { 205 p << " -> (" << op.getResultTypes() << ")"; 206 } 207 p.printRegion(op.doRegion(), 208 /*printEntryBlockArgs=*/false, 209 /*printBlockTerminators=*/yieldsResults); 210 p.printOptionalAttrDict(op.getAttrs()); 211 } 212 213 namespace { 214 // Removes AssumingOp with a passing witness and inlines the region. 215 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 216 using OpRewritePattern<AssumingOp>::OpRewritePattern; 217 218 LogicalResult matchAndRewrite(AssumingOp op, 219 PatternRewriter &rewriter) const override { 220 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 221 if (!witness || !witness.passingAttr()) 222 return failure(); 223 224 AssumingOp::inlineRegionIntoParent(op, rewriter); 225 return success(); 226 } 227 }; 228 } // namespace 229 230 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 231 MLIRContext *context) { 232 // If taking a passing witness, inline region. 233 patterns.insert<AssumingWithTrue>(context); 234 } 235 236 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 237 void AssumingOp::getSuccessorRegions( 238 Optional<unsigned> index, ArrayRef<Attribute> operands, 239 SmallVectorImpl<RegionSuccessor> ®ions) { 240 // AssumingOp has unconditional control flow into the region and back to the 241 // parent, so return the correct RegionSuccessor purely based on the index 242 // being None or 0. 243 if (index.hasValue()) { 244 regions.push_back(RegionSuccessor(getResults())); 245 return; 246 } 247 248 regions.push_back(RegionSuccessor(&doRegion())); 249 } 250 251 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 252 PatternRewriter &rewriter) { 253 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 254 auto *assumingBlock = op.getBody(); 255 auto initPosition = rewriter.getInsertionPoint(); 256 auto *blockAfterAssuming = 257 rewriter.splitBlock(blockBeforeAssuming, initPosition); 258 259 // Remove the AssumingOp and AssumingYieldOp. 260 auto &yieldOp = assumingBlock->back(); 261 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 262 rewriter.replaceOp(op, yieldOp.getOperands()); 263 rewriter.eraseOp(&yieldOp); 264 265 // Merge blocks together as there was no branching behavior from the 266 // AssumingOp. 267 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 268 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // AssumingAllOp 273 //===----------------------------------------------------------------------===// 274 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 275 // Iterate in reverse to first handle all constant operands. They are 276 // guaranteed to be the tail of the inputs because this is commutative. 277 for (int idx = operands.size() - 1; idx >= 0; idx--) { 278 Attribute a = operands[idx]; 279 // Cannot fold if any inputs are not constant; 280 if (!a) 281 return nullptr; 282 283 // We do not need to keep statically known values after handling them in 284 // this method. 285 getOperation()->eraseOperand(idx); 286 287 // Always false if any input is statically known false 288 if (!a.cast<BoolAttr>().getValue()) 289 return a; 290 } 291 // If this is reached, all inputs were statically known passing. 292 return BoolAttr::get(true, getContext()); 293 } 294 295 static LogicalResult verify(AssumingAllOp op) { 296 // Ensure that AssumingAllOp contains at least one operand 297 if (op.getNumOperands() == 0) 298 return op.emitOpError("no operands specified"); 299 300 return success(); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // BroadcastOp 305 //===----------------------------------------------------------------------===// 306 307 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 308 if (!operands[1]) 309 return nullptr; 310 311 auto rhsShape = llvm::to_vector<6>( 312 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 313 if (rhsShape.empty()) 314 return lhs(); 315 316 if (!operands[0]) 317 return nullptr; 318 319 auto lhsShape = llvm::to_vector<6>( 320 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 321 if (lhsShape.empty()) 322 return rhs(); 323 324 SmallVector<int64_t, 6> resultShape; 325 // If the shapes are not compatible, we can't fold it. 326 // TODO: Fold to an "error". 327 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 328 return nullptr; 329 Builder builder(getContext()); 330 return builder.getIndexTensorAttr(resultShape); 331 } 332 333 //===----------------------------------------------------------------------===// 334 // ConcatOp 335 //===----------------------------------------------------------------------===// 336 337 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 338 if (!operands[0] || !operands[1]) 339 return nullptr; 340 auto lhsShape = llvm::to_vector<6>( 341 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 342 auto rhsShape = llvm::to_vector<6>( 343 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 344 SmallVector<int64_t, 6> resultShape; 345 resultShape.append(lhsShape.begin(), lhsShape.end()); 346 resultShape.append(rhsShape.begin(), rhsShape.end()); 347 Builder builder(getContext()); 348 return builder.getIndexTensorAttr(resultShape); 349 } 350 351 //===----------------------------------------------------------------------===// 352 // ConstShapeOp 353 //===----------------------------------------------------------------------===// 354 355 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 356 p << "shape.const_shape "; 357 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 358 p << "["; 359 interleaveComma(op.shape().getValues<int64_t>(), p, 360 [&](int64_t i) { p << i; }); 361 p << "] : "; 362 p.printType(op.getType()); 363 } 364 365 static ParseResult parseConstShapeOp(OpAsmParser &parser, 366 OperationState &result) { 367 if (parser.parseOptionalAttrDict(result.attributes)) 368 return failure(); 369 // We piggy-back on ArrayAttr parsing, though we don't internally store the 370 // shape as an ArrayAttr. 371 // TODO: Implement custom parser and maybe make syntax a bit more concise. 372 Attribute extentsRaw; 373 NamedAttrList dummy; 374 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 375 return failure(); 376 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 377 if (!extentsArray) 378 return failure(); 379 SmallVector<int64_t, 6> ints; 380 for (Attribute extent : extentsArray) { 381 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 382 if (!attr) 383 return failure(); 384 ints.push_back(attr.getInt()); 385 } 386 Builder &builder = parser.getBuilder(); 387 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 388 Type resultTy; 389 if (parser.parseColonType(resultTy)) 390 return failure(); 391 result.types.push_back(resultTy); 392 return success(); 393 } 394 395 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 396 397 //===----------------------------------------------------------------------===// 398 // CstrBroadcastableOp 399 //===----------------------------------------------------------------------===// 400 401 namespace { 402 // Given an input shape Value, try to obtain the shape's values. 403 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 404 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 405 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 406 if (!type.hasRank()) 407 return failure(); 408 shapeValues = llvm::to_vector<6>(type.getShape()); 409 return success(); 410 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 411 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 412 return success(); 413 } else { 414 return failure(); 415 } 416 } 417 } // namespace 418 419 void CstrBroadcastableOp::getCanonicalizationPatterns( 420 OwningRewritePatternList &patterns, MLIRContext *context) { 421 // Canonicalization patterns have overlap with the considerations during 422 // folding in case additional shape information is inferred at some point that 423 // does not result in folding. 424 patterns.insert<CstrBroadcastableEqOps>(context); 425 } 426 427 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 428 // Both operands are not needed if one is a scalar. 429 if (operands[0] && 430 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 431 return BoolAttr::get(true, getContext()); 432 if (operands[1] && 433 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 434 return BoolAttr::get(true, getContext()); 435 436 if (operands[0] && operands[1]) { 437 auto lhsShape = llvm::to_vector<6>( 438 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 439 auto rhsShape = llvm::to_vector<6>( 440 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 441 SmallVector<int64_t, 6> resultShape; 442 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 443 return BoolAttr::get(true, getContext()); 444 } 445 446 // Lastly, see if folding can be completed based on what constraints are known 447 // on the input shapes. 448 SmallVector<int64_t, 6> lhsShape, rhsShape; 449 if (failed(getShapeVec(lhs(), lhsShape))) 450 return nullptr; 451 if (failed(getShapeVec(rhs(), rhsShape))) 452 return nullptr; 453 454 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 455 return BoolAttr::get(true, getContext()); 456 457 // Because a failing witness result here represents an eventual assertion 458 // failure, we do not replace it with a constant witness. 459 return nullptr; 460 } 461 462 //===----------------------------------------------------------------------===// 463 // CstrEqOp 464 //===----------------------------------------------------------------------===// 465 466 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 467 MLIRContext *context) { 468 // If inputs are equal, return passing witness 469 patterns.insert<CstrEqEqOps>(context); 470 } 471 472 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 473 if (llvm::all_of(operands, 474 [&](Attribute a) { return a && a == operands[0]; })) 475 return BoolAttr::get(true, getContext()); 476 477 // Because a failing witness result here represents an eventual assertion 478 // failure, we do not try to replace it with a constant witness. Similarly, we 479 // cannot if there are any non-const inputs. 480 return nullptr; 481 } 482 483 //===----------------------------------------------------------------------===// 484 // ConstSizeOp 485 //===----------------------------------------------------------------------===// 486 487 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 488 int64_t value) { 489 build(builder, result, builder.getIndexAttr(value)); 490 } 491 492 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 493 494 void ConstSizeOp::getAsmResultNames( 495 llvm::function_ref<void(Value, StringRef)> setNameFn) { 496 SmallString<4> buffer; 497 llvm::raw_svector_ostream os(buffer); 498 os << "c" << value(); 499 setNameFn(getResult(), os.str()); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // ConstWitnessOp 504 //===----------------------------------------------------------------------===// 505 506 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 507 508 //===----------------------------------------------------------------------===// 509 // CstrRequireOp 510 //===----------------------------------------------------------------------===// 511 512 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 513 return operands[0]; 514 } 515 516 //===----------------------------------------------------------------------===// 517 // ShapeEqOp 518 //===----------------------------------------------------------------------===// 519 520 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 521 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 522 if (lhs == nullptr) 523 return {}; 524 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 525 if (rhs == nullptr) 526 return {}; 527 return BoolAttr::get(lhs == rhs, getContext()); 528 } 529 530 //===----------------------------------------------------------------------===// 531 // IndexToSizeOp 532 //===----------------------------------------------------------------------===// 533 534 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 535 // Constant values of both types, `shape.size` and `index`, are represented as 536 // `IntegerAttr`s which makes constant folding simple. 537 if (Attribute arg = operands[0]) 538 return arg; 539 return {}; 540 } 541 542 void IndexToSizeOp::getCanonicalizationPatterns( 543 OwningRewritePatternList &patterns, MLIRContext *context) { 544 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 545 } 546 547 //===----------------------------------------------------------------------===// 548 // FromExtentsOp 549 //===----------------------------------------------------------------------===// 550 551 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 552 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 553 return nullptr; 554 SmallVector<int64_t, 6> extents; 555 for (auto attr : operands) 556 extents.push_back(attr.cast<IntegerAttr>().getInt()); 557 Builder builder(getContext()); 558 return builder.getIndexTensorAttr(extents); 559 } 560 561 //===----------------------------------------------------------------------===// 562 // FunctionLibraryOp 563 //===----------------------------------------------------------------------===// 564 565 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 566 StringRef name) { 567 ensureTerminator(*result.addRegion(), builder, result.location); 568 result.attributes.push_back(builder.getNamedAttr( 569 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 570 } 571 572 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 573 auto attr = mapping() 574 .get(op->getName().getIdentifier()) 575 .dyn_cast_or_null<FlatSymbolRefAttr>(); 576 if (!attr) 577 return nullptr; 578 return lookupSymbol<FuncOp>(attr); 579 } 580 581 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 582 OperationState &result) { 583 // Parse the op name. 584 StringAttr nameAttr; 585 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 586 result.attributes)) 587 return failure(); 588 589 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 590 return failure(); 591 592 auto *bodyRegion = result.addRegion(); 593 if (parser.parseRegion(*bodyRegion)) 594 return failure(); 595 596 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 597 result.location); 598 if (parser.parseKeyword("mapping")) 599 return failure(); 600 601 DictionaryAttr mappingAttr; 602 if (parser.parseAttribute(mappingAttr, 603 parser.getBuilder().getType<NoneType>(), "mapping", 604 result.attributes)) 605 return failure(); 606 return success(); 607 } 608 609 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 610 p << op.getOperationName() << ' '; 611 p.printSymbolName(op.getName()); 612 p.printOptionalAttrDictWithKeyword( 613 op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 614 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 615 /*printBlockTerminators=*/false); 616 p << " mapping "; 617 p.printAttributeWithoutType(op.mappingAttr()); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // GetExtentOp 622 //===----------------------------------------------------------------------===// 623 624 Optional<int64_t> GetExtentOp::getConstantDim() { 625 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 626 return constSizeOp.value().getLimitedValue(); 627 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 628 return constantOp.value().cast<IntegerAttr>().getInt(); 629 return llvm::None; 630 } 631 632 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 633 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 634 if (!elements) 635 return nullptr; 636 Optional<int64_t> dim = getConstantDim(); 637 if (!dim.hasValue()) 638 return nullptr; 639 if (dim.getValue() >= elements.getNumElements()) 640 return nullptr; 641 return elements.getValue({(uint64_t)dim.getValue()}); 642 } 643 644 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 645 int64_t dim) { 646 auto loc = result.location; 647 auto dimAttr = builder.getIndexAttr(dim); 648 if (shape.getType().isa<ShapeType>()) { 649 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 650 build(builder, result, builder.getType<SizeType>(), shape, dim); 651 } else { 652 Value dim = 653 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 654 build(builder, result, builder.getIndexType(), shape, dim); 655 } 656 } 657 658 //===----------------------------------------------------------------------===// 659 // RankOp 660 //===----------------------------------------------------------------------===// 661 662 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 663 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 664 if (!shape) 665 return {}; 666 int64_t rank = shape.getNumElements(); 667 Builder builder(getContext()); 668 return builder.getIndexAttr(rank); 669 } 670 671 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 672 /// Constant folding fails in cases where only the rank is constant, not the 673 /// shape itself. 674 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 675 /// 676 /// Example: 677 /// 678 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 679 /// %rank = shape.rank %shape 680 /// 681 /// becomes 682 /// 683 /// %rank = shape.const_size 3 684 685 namespace { 686 struct RankShapeOfCanonicalizationPattern 687 : public OpRewritePattern<shape::RankOp> { 688 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 689 690 LogicalResult matchAndRewrite(shape::RankOp op, 691 PatternRewriter &rewriter) const override { 692 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 693 if (!shapeOfOp) 694 return failure(); 695 auto rankedTensorType = 696 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 697 if (!rankedTensorType) 698 return failure(); 699 int64_t rank = rankedTensorType.getRank(); 700 if (op.getType().isa<IndexType>()) { 701 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 702 } else if (op.getType().isa<shape::SizeType>()) { 703 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 704 } else { 705 return failure(); 706 } 707 return success(); 708 } 709 }; 710 } // namespace 711 712 void shape::RankOp::getCanonicalizationPatterns( 713 OwningRewritePatternList &patterns, MLIRContext *context) { 714 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 715 } 716 717 //===----------------------------------------------------------------------===// 718 // NumElementsOp 719 //===----------------------------------------------------------------------===// 720 721 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 722 723 // Fold only when argument constant. 724 Attribute shape = operands[0]; 725 if (!shape) 726 return {}; 727 728 APInt product(64, 1); 729 for (auto value : shape.cast<DenseIntElementsAttr>()) 730 product *= value; 731 Builder builder(getContext()); 732 return builder.getIndexAttr(product.getLimitedValue()); 733 } 734 735 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 736 Value shape) { 737 if (shape.getType().isa<ShapedType>()) { 738 auto type = builder.getIndexType(); 739 return build(builder, result, type, shape); 740 } 741 auto type = SizeType::get(builder.getContext()); 742 return build(builder, result, type, shape); 743 } 744 745 //===----------------------------------------------------------------------===// 746 // MulOp 747 //===----------------------------------------------------------------------===// 748 749 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 750 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 751 if (!lhs) 752 return nullptr; 753 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 754 if (!rhs) 755 return nullptr; 756 APInt folded = lhs.getValue() * rhs.getValue(); 757 Type indexTy = IndexType::get(getContext()); 758 return IntegerAttr::get(indexTy, folded); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // ShapeOfOp 763 //===----------------------------------------------------------------------===// 764 765 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 766 auto type = getOperand().getType().dyn_cast<ShapedType>(); 767 if (!type || !type.hasStaticShape()) 768 return nullptr; 769 Builder builder(getContext()); 770 return builder.getIndexTensorAttr(type.getShape()); 771 } 772 773 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 774 Type type = arg.getType().isa<ShapedType>() 775 ? (Type)getExtentTensorType(builder.getContext()) 776 : (Type)builder.getType<ShapeType>(); 777 return ShapeOfOp::build(builder, result, type, arg); 778 } 779 780 namespace { 781 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 782 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 783 784 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 785 PatternRewriter &rewriter) const override { 786 if (!op.arg().getType().isa<ShapedType>()) 787 return failure(); 788 if (op.getType().isa<ShapedType>()) 789 return failure(); 790 791 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 792 return success(); 793 } 794 }; 795 } // namespace 796 797 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 798 MLIRContext *context) { 799 patterns.insert<ShapeOfWithTensor>(context); 800 } 801 802 //===----------------------------------------------------------------------===// 803 // SizeToIndexOp 804 //===----------------------------------------------------------------------===// 805 806 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 807 // Constant values of both types, `shape.size` and `index`, are represented as 808 // `IntegerAttr`s which makes constant folding simple. 809 if (Attribute arg = operands[0]) 810 return arg; 811 return impl::foldCastOp(*this); 812 } 813 814 void SizeToIndexOp::getCanonicalizationPatterns( 815 OwningRewritePatternList &patterns, MLIRContext *context) { 816 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 817 } 818 819 //===----------------------------------------------------------------------===// 820 // YieldOp 821 //===----------------------------------------------------------------------===// 822 823 static LogicalResult verify(shape::YieldOp op) { 824 auto *parentOp = op.getParentOp(); 825 auto results = parentOp->getResults(); 826 auto operands = op.getOperands(); 827 828 if (parentOp->getNumResults() != op.getNumOperands()) 829 return op.emitOpError() << "number of operands does not match number of " 830 "results of its parent"; 831 for (auto e : llvm::zip(results, operands)) 832 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 833 return op.emitOpError() 834 << "types mismatch between yield op and its parent"; 835 836 return success(); 837 } 838 839 //===----------------------------------------------------------------------===// 840 // SplitAtOp 841 //===----------------------------------------------------------------------===// 842 843 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 844 SmallVectorImpl<OpFoldResult> &results) { 845 if (!operands[0] || !operands[1]) 846 return failure(); 847 auto shapeVec = llvm::to_vector<6>( 848 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 849 auto shape = llvm::makeArrayRef(shapeVec); 850 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 851 // Verify that the split point is in the correct range. 852 // TODO: Constant fold to an "error". 853 int64_t rank = shape.size(); 854 if (!(-rank <= splitPoint && splitPoint <= rank)) 855 return failure(); 856 if (splitPoint < 0) 857 splitPoint += shape.size(); 858 Builder builder(operands[0].getContext()); 859 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 860 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 861 return success(); 862 } 863 864 //===----------------------------------------------------------------------===// 865 // ToExtentTensorOp 866 //===----------------------------------------------------------------------===// 867 868 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 869 if (!operands[0]) 870 return impl::foldCastOp(*this); 871 Builder builder(getContext()); 872 auto shape = llvm::to_vector<6>( 873 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 874 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 875 builder.getIndexType()); 876 return DenseIntElementsAttr::get(type, shape); 877 } 878 879 //===----------------------------------------------------------------------===// 880 // ReduceOp 881 //===----------------------------------------------------------------------===// 882 883 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 884 ValueRange initVals) { 885 result.addOperands(shape); 886 result.addOperands(initVals); 887 888 Region *bodyRegion = result.addRegion(); 889 bodyRegion->push_back(new Block); 890 Block &bodyBlock = bodyRegion->front(); 891 bodyBlock.addArgument(builder.getIndexType()); 892 893 Type elementType; 894 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 895 elementType = tensorType.getElementType(); 896 else 897 elementType = SizeType::get(builder.getContext()); 898 bodyBlock.addArgument(elementType); 899 900 for (Type initValType : initVals.getTypes()) { 901 bodyBlock.addArgument(initValType); 902 result.addTypes(initValType); 903 } 904 } 905 906 static LogicalResult verify(ReduceOp op) { 907 // Verify block arg types. 908 Block &block = op.region().front(); 909 910 // The block takes index, extent, and aggregated values as arguments. 911 auto blockArgsCount = op.initVals().size() + 2; 912 if (block.getNumArguments() != blockArgsCount) 913 return op.emitOpError() << "ReduceOp body is expected to have " 914 << blockArgsCount << " arguments"; 915 916 // The first block argument is the index and must always be of type `index`. 917 if (!block.getArgument(0).getType().isa<IndexType>()) 918 return op.emitOpError( 919 "argument 0 of ReduceOp body is expected to be of IndexType"); 920 921 // The second block argument is the extent and must be of type `size` or 922 // `index`, depending on whether the reduce operation is applied to a shape or 923 // to an extent tensor. 924 Type extentTy = block.getArgument(1).getType(); 925 if (op.shape().getType().isa<ShapeType>()) { 926 if (!extentTy.isa<SizeType>()) 927 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 928 "SizeType if the ReduceOp operates on a ShapeType"); 929 } else { 930 if (!extentTy.isa<IndexType>()) 931 return op.emitOpError( 932 "argument 1 of ReduceOp body is expected to be of IndexType if the " 933 "ReduceOp operates on an extent tensor"); 934 } 935 936 for (auto type : llvm::enumerate(op.initVals())) 937 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 938 return op.emitOpError() 939 << "type mismatch between argument " << type.index() + 2 940 << " of ReduceOp body and initial value " << type.index(); 941 return success(); 942 } 943 944 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 945 // Parse operands. 946 SmallVector<OpAsmParser::OperandType, 3> operands; 947 Type shapeOrExtentTensorType; 948 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 949 OpAsmParser::Delimiter::Paren) || 950 parser.parseColonType(shapeOrExtentTensorType) || 951 parser.parseOptionalArrowTypeList(result.types)) 952 return failure(); 953 954 // Resolve operands. 955 auto initVals = llvm::makeArrayRef(operands).drop_front(); 956 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 957 result.operands) || 958 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 959 result.operands)) 960 return failure(); 961 962 // Parse the body. 963 Region *body = result.addRegion(); 964 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 965 return failure(); 966 967 // Parse attributes. 968 if (parser.parseOptionalAttrDict(result.attributes)) 969 return failure(); 970 971 return success(); 972 } 973 974 static void print(OpAsmPrinter &p, ReduceOp op) { 975 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 976 << ") : " << op.shape().getType(); 977 p.printOptionalArrowTypeList(op.getResultTypes()); 978 p.printRegion(op.region()); 979 p.printOptionalAttrDict(op.getAttrs()); 980 } 981 982 #define GET_OP_CLASSES 983 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 984