1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 using namespace mlir::shape; 25 26 namespace { 27 #include "ShapeCanonicalization.inc" 28 } 29 30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 31 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 32 } 33 34 static bool isErrorPropagationPossible(TypeRange operandTypes) { 35 return llvm::any_of(operandTypes, [](Type ty) { 36 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 37 }); 38 } 39 40 static LogicalResult verifySizeOrIndexOp(Operation *op) { 41 assert(op != nullptr && op->getNumResults() == 1); 42 Type resultTy = op->getResultTypes().front(); 43 if (isErrorPropagationPossible(op->getOperandTypes())) { 44 if (!resultTy.isa<SizeType>()) 45 return op->emitOpError() 46 << "if at least one of the operands can hold error values then " 47 "the result must be of type `size` to propagate them"; 48 } 49 return success(); 50 } 51 52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 53 assert(op != nullptr && op->getNumResults() == 1); 54 Type resultTy = op->getResultTypes().front(); 55 if (isErrorPropagationPossible(op->getOperandTypes())) { 56 if (!resultTy.isa<ShapeType>()) 57 return op->emitOpError() 58 << "if at least one of the operands can hold error values then " 59 "the result must be of type `shape` to propagate them"; 60 } 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// This class defines the interface for inlining shape dialect ops. 70 struct ShapeInlinerInterface : public DialectInlinerInterface { 71 using DialectInlinerInterface::DialectInlinerInterface; 72 73 // Returns true if the given region 'src' can be inlined into the region 74 // 'dest' that is attached to an operation registered to the current dialect. 75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 76 BlockAndValueMapping &) const final { 77 return true; 78 } 79 80 // Returns true if the given operation 'op', that is registered to this 81 // dialect, can be inlined into the region 'dest' that is attached to an 82 // operation registered to the current dialect. 83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 84 BlockAndValueMapping &) const final { 85 return true; 86 } 87 }; 88 } // namespace 89 90 void ShapeDialect::initialize() { 91 addOperations< 92 #define GET_OP_LIST 93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 94 >(); 95 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 96 addInterfaces<ShapeInlinerInterface>(); 97 // Allow unknown operations during prototyping and testing. As the dialect is 98 // still evolving it makes it simple to start with an unregistered ops and 99 // try different variants before actually defining the op. 100 allowUnknownOperations(); 101 } 102 103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 104 Attribute value, Type type, 105 Location loc) { 106 if (type.isa<ShapeType>() || 107 type == getExtentTensorType(builder.getContext())) 108 return builder.create<ConstShapeOp>(loc, type, 109 value.cast<DenseIntElementsAttr>()); 110 if (type.isa<SizeType>()) 111 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 112 if (type.isa<WitnessType>()) 113 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 114 if (ConstantOp::isBuildableWith(value, type)) 115 return builder.create<ConstantOp>(loc, type, value); 116 return nullptr; 117 } 118 119 /// Parse a type registered to this dialect. 120 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 121 StringRef keyword; 122 if (parser.parseKeyword(&keyword)) 123 return Type(); 124 125 if (keyword == "shape") 126 return ShapeType::get(getContext()); 127 if (keyword == "size") 128 return SizeType::get(getContext()); 129 if (keyword == "value_shape") 130 return ValueShapeType::get(getContext()); 131 if (keyword == "witness") 132 return WitnessType::get(getContext()); 133 134 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 135 return Type(); 136 } 137 138 /// Print a type registered to this dialect. 139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 140 TypeSwitch<Type>(type) 141 .Case<ShapeType>([&](Type) { os << "shape"; }) 142 .Case<SizeType>([&](Type) { os << "size"; }) 143 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 144 .Case<WitnessType>([&](Type) { os << "witness"; }) 145 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 146 } 147 148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 149 NamedAttribute attribute) { 150 // Verify shape.lib attribute. 151 if (attribute.first == "shape.lib") { 152 if (!op->hasTrait<OpTrait::SymbolTable>()) 153 return op->emitError( 154 "shape.lib attribute may only be on op implementing SymbolTable"); 155 156 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 157 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 158 if (!symbol) 159 return op->emitError("shape function library ") 160 << symbolRef << " not found"; 161 return isa<shape::FunctionLibraryOp>(symbol) 162 ? success() 163 : op->emitError() 164 << symbolRef << " required to be shape function library"; 165 } 166 167 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 168 // Verify all entries are function libraries and mappings in libraries 169 // refer to unique ops. 170 DenseSet<Identifier> key; 171 for (auto it : arr) { 172 if (!it.isa<SymbolRefAttr>()) 173 return op->emitError( 174 "only SymbolRefAttr allowed in shape.lib attribute array"); 175 176 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 177 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 178 if (!shapeFnLib) 179 return op->emitError() 180 << it << " does not refer to FunctionLibraryOp"; 181 for (auto mapping : shapeFnLib.mapping()) { 182 if (!key.insert(mapping.first).second) { 183 return op->emitError("only one op to shape mapping allowed, found " 184 "multiple for `") 185 << mapping.first << "`"; 186 } 187 } 188 } 189 return success(); 190 } 191 192 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 193 "allowed as shape.lib attribute"); 194 } 195 return success(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // AnyOp 200 //===----------------------------------------------------------------------===// 201 202 // TODO: Canonicalization should be implemented for shapes that can be 203 // determined through mixtures of the known dimensions of the inputs. 204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 205 // Only the last operand is checked because AnyOp is commutative. 206 if (operands.back()) 207 return operands.back(); 208 209 return nullptr; 210 } 211 212 //===----------------------------------------------------------------------===// 213 // AssumingOp 214 //===----------------------------------------------------------------------===// 215 216 static ParseResult parseAssumingOp(OpAsmParser &parser, 217 OperationState &result) { 218 result.regions.reserve(1); 219 Region *doRegion = result.addRegion(); 220 221 auto &builder = parser.getBuilder(); 222 OpAsmParser::OperandType cond; 223 if (parser.parseOperand(cond) || 224 parser.resolveOperand(cond, builder.getType<WitnessType>(), 225 result.operands)) 226 return failure(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the region and add a terminator if elided. 233 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 236 237 // Parse the optional attribute list. 238 if (parser.parseOptionalAttrDict(result.attributes)) 239 return failure(); 240 return success(); 241 } 242 243 static void print(OpAsmPrinter &p, AssumingOp op) { 244 bool yieldsResults = !op.results().empty(); 245 246 p << AssumingOp::getOperationName() << " " << op.witness(); 247 if (yieldsResults) { 248 p << " -> (" << op.getResultTypes() << ")"; 249 } 250 p.printRegion(op.doRegion(), 251 /*printEntryBlockArgs=*/false, 252 /*printBlockTerminators=*/yieldsResults); 253 p.printOptionalAttrDict(op.getAttrs()); 254 } 255 256 namespace { 257 // Removes AssumingOp with a passing witness and inlines the region. 258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 259 using OpRewritePattern<AssumingOp>::OpRewritePattern; 260 261 LogicalResult matchAndRewrite(AssumingOp op, 262 PatternRewriter &rewriter) const override { 263 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 264 if (!witness || !witness.passingAttr()) 265 return failure(); 266 267 AssumingOp::inlineRegionIntoParent(op, rewriter); 268 return success(); 269 } 270 }; 271 } // namespace 272 273 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 274 MLIRContext *context) { 275 // If taking a passing witness, inline region. 276 patterns.insert<AssumingWithTrue>(context); 277 } 278 279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 280 void AssumingOp::getSuccessorRegions( 281 Optional<unsigned> index, ArrayRef<Attribute> operands, 282 SmallVectorImpl<RegionSuccessor> ®ions) { 283 // AssumingOp has unconditional control flow into the region and back to the 284 // parent, so return the correct RegionSuccessor purely based on the index 285 // being None or 0. 286 if (index.hasValue()) { 287 regions.push_back(RegionSuccessor(getResults())); 288 return; 289 } 290 291 regions.push_back(RegionSuccessor(&doRegion())); 292 } 293 294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 295 PatternRewriter &rewriter) { 296 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 297 auto *assumingBlock = op.getBody(); 298 auto initPosition = rewriter.getInsertionPoint(); 299 auto *blockAfterAssuming = 300 rewriter.splitBlock(blockBeforeAssuming, initPosition); 301 302 // Remove the AssumingOp and AssumingYieldOp. 303 auto &yieldOp = assumingBlock->back(); 304 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 305 rewriter.replaceOp(op, yieldOp.getOperands()); 306 rewriter.eraseOp(&yieldOp); 307 308 // Merge blocks together as there was no branching behavior from the 309 // AssumingOp. 310 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 311 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // AssumingAllOp 316 //===----------------------------------------------------------------------===// 317 318 void AssumingAllOp::getCanonicalizationPatterns( 319 OwningRewritePatternList &patterns, MLIRContext *context) { 320 patterns.insert<AssumingAllOneOp>(context); 321 } 322 323 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 324 // Iterate in reverse to first handle all constant operands. They are 325 // guaranteed to be the tail of the inputs because this is commutative. 326 for (int idx = operands.size() - 1; idx >= 0; idx--) { 327 Attribute a = operands[idx]; 328 // Cannot fold if any inputs are not constant; 329 if (!a) 330 return nullptr; 331 332 // We do not need to keep statically known values after handling them in 333 // this method. 334 getOperation()->eraseOperand(idx); 335 336 // Always false if any input is statically known false 337 if (!a.cast<BoolAttr>().getValue()) 338 return a; 339 } 340 // If this is reached, all inputs were statically known passing. 341 return BoolAttr::get(getContext(), true); 342 } 343 344 static LogicalResult verify(AssumingAllOp op) { 345 // Ensure that AssumingAllOp contains at least one operand 346 if (op.getNumOperands() == 0) 347 return op.emitOpError("no operands specified"); 348 349 return success(); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // BroadcastOp 354 //===----------------------------------------------------------------------===// 355 356 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 357 if (!operands[1]) 358 return nullptr; 359 360 // TODO: Support folding with more than 2 input shapes 361 if (operands.size() > 2 && !operands[2].isa<StringAttr>()) 362 return nullptr; 363 364 auto rhsShape = llvm::to_vector<6>( 365 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 366 if (rhsShape.empty()) 367 return shapes()[0]; 368 369 if (!operands[0]) 370 return nullptr; 371 372 auto lhsShape = llvm::to_vector<6>( 373 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 374 if (lhsShape.empty()) 375 return shapes()[1]; 376 377 SmallVector<int64_t, 6> resultShape; 378 // If the shapes are not compatible, we can't fold it. 379 // TODO: Fold to an "error". 380 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 381 return nullptr; 382 Builder builder(getContext()); 383 return builder.getIndexTensorAttr(resultShape); 384 } 385 386 static LogicalResult verify(BroadcastOp op) { 387 // Ensure that AssumingAllOp contains at least one operand 388 if (op.getNumOperands() < 2) 389 return op.emitOpError("required at least 2 input shapes"); 390 391 return verifyShapeOrExtentTensorOp(op); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // ConcatOp 396 //===----------------------------------------------------------------------===// 397 398 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 399 if (!operands[0] || !operands[1]) 400 return nullptr; 401 auto lhsShape = llvm::to_vector<6>( 402 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 403 auto rhsShape = llvm::to_vector<6>( 404 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 405 SmallVector<int64_t, 6> resultShape; 406 resultShape.append(lhsShape.begin(), lhsShape.end()); 407 resultShape.append(rhsShape.begin(), rhsShape.end()); 408 Builder builder(getContext()); 409 return builder.getIndexTensorAttr(resultShape); 410 } 411 412 //===----------------------------------------------------------------------===// 413 // ConstShapeOp 414 //===----------------------------------------------------------------------===// 415 416 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 417 p << "shape.const_shape "; 418 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 419 p << "["; 420 interleaveComma(op.shape().getValues<int64_t>(), p, 421 [&](int64_t i) { p << i; }); 422 p << "] : "; 423 p.printType(op.getType()); 424 } 425 426 static ParseResult parseConstShapeOp(OpAsmParser &parser, 427 OperationState &result) { 428 if (parser.parseOptionalAttrDict(result.attributes)) 429 return failure(); 430 // We piggy-back on ArrayAttr parsing, though we don't internally store the 431 // shape as an ArrayAttr. 432 // TODO: Implement custom parser and maybe make syntax a bit more concise. 433 Attribute extentsRaw; 434 NamedAttrList dummy; 435 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 436 return failure(); 437 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 438 if (!extentsArray) 439 return failure(); 440 SmallVector<int64_t, 6> ints; 441 for (Attribute extent : extentsArray) { 442 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 443 if (!attr) 444 return failure(); 445 ints.push_back(attr.getInt()); 446 } 447 Builder &builder = parser.getBuilder(); 448 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 449 Type resultTy; 450 if (parser.parseColonType(resultTy)) 451 return failure(); 452 result.types.push_back(resultTy); 453 return success(); 454 } 455 456 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 457 458 void ConstShapeOp::getCanonicalizationPatterns( 459 OwningRewritePatternList &patterns, MLIRContext *context) { 460 patterns.insert<TensorCastConstShape>(context); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // CstrBroadcastableOp 465 //===----------------------------------------------------------------------===// 466 467 namespace { 468 // Given an input shape Value, try to obtain the shape's values. 469 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 470 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 471 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 472 if (!type.hasRank()) 473 return failure(); 474 shapeValues = llvm::to_vector<6>(type.getShape()); 475 return success(); 476 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 477 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 478 return success(); 479 } else { 480 return failure(); 481 } 482 } 483 } // namespace 484 485 void CstrBroadcastableOp::getCanonicalizationPatterns( 486 OwningRewritePatternList &patterns, MLIRContext *context) { 487 // Canonicalization patterns have overlap with the considerations during 488 // folding in case additional shape information is inferred at some point that 489 // does not result in folding. 490 patterns.insert<CstrBroadcastableEqOps>(context); 491 } 492 493 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 494 // TODO: Add folding for the nary case 495 if (operands.size() != 2) 496 return nullptr; 497 498 // Both operands are not needed if one is a scalar. 499 if (operands[0] && 500 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 501 return BoolAttr::get(getContext(), true); 502 if (operands[1] && 503 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 504 return BoolAttr::get(getContext(), true); 505 506 if (operands[0] && operands[1]) { 507 auto lhsShape = llvm::to_vector<6>( 508 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 509 auto rhsShape = llvm::to_vector<6>( 510 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 511 SmallVector<int64_t, 6> resultShape; 512 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 513 return BoolAttr::get(getContext(), true); 514 } 515 516 // Lastly, see if folding can be completed based on what constraints are known 517 // on the input shapes. 518 SmallVector<int64_t, 6> lhsShape, rhsShape; 519 if (failed(getShapeVec(shapes()[0], lhsShape))) 520 return nullptr; 521 if (failed(getShapeVec(shapes()[1], rhsShape))) 522 return nullptr; 523 524 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 525 return BoolAttr::get(getContext(), true); 526 527 // Because a failing witness result here represents an eventual assertion 528 // failure, we do not replace it with a constant witness. 529 return nullptr; 530 } 531 532 static LogicalResult verify(CstrBroadcastableOp op) { 533 // Ensure that AssumingAllOp contains at least one operand 534 if (op.getNumOperands() < 2) 535 return op.emitOpError("required at least 2 input shapes"); 536 return success(); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // CstrEqOp 541 //===----------------------------------------------------------------------===// 542 543 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 544 MLIRContext *context) { 545 // If inputs are equal, return passing witness 546 patterns.insert<CstrEqEqOps>(context); 547 } 548 549 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 550 if (llvm::all_of(operands, 551 [&](Attribute a) { return a && a == operands[0]; })) 552 return BoolAttr::get(getContext(), true); 553 554 // Because a failing witness result here represents an eventual assertion 555 // failure, we do not try to replace it with a constant witness. Similarly, we 556 // cannot if there are any non-const inputs. 557 return nullptr; 558 } 559 560 //===----------------------------------------------------------------------===// 561 // ConstSizeOp 562 //===----------------------------------------------------------------------===// 563 564 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 565 int64_t value) { 566 build(builder, result, builder.getIndexAttr(value)); 567 } 568 569 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 570 571 void ConstSizeOp::getAsmResultNames( 572 llvm::function_ref<void(Value, StringRef)> setNameFn) { 573 SmallString<4> buffer; 574 llvm::raw_svector_ostream os(buffer); 575 os << "c" << value(); 576 setNameFn(getResult(), os.str()); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // ConstWitnessOp 581 //===----------------------------------------------------------------------===// 582 583 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 584 585 //===----------------------------------------------------------------------===// 586 // CstrRequireOp 587 //===----------------------------------------------------------------------===// 588 589 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 590 return operands[0]; 591 } 592 593 //===----------------------------------------------------------------------===// 594 // ShapeEqOp 595 //===----------------------------------------------------------------------===// 596 597 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 598 if (lhs() == rhs()) 599 return BoolAttr::get(getContext(), true); 600 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 601 if (lhs == nullptr) 602 return {}; 603 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 604 if (rhs == nullptr) 605 return {}; 606 return BoolAttr::get(getContext(), lhs == rhs); 607 } 608 609 //===----------------------------------------------------------------------===// 610 // IndexToSizeOp 611 //===----------------------------------------------------------------------===// 612 613 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 614 // Constant values of both types, `shape.size` and `index`, are represented as 615 // `IntegerAttr`s which makes constant folding simple. 616 if (Attribute arg = operands[0]) 617 return arg; 618 return {}; 619 } 620 621 void IndexToSizeOp::getCanonicalizationPatterns( 622 OwningRewritePatternList &patterns, MLIRContext *context) { 623 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 624 } 625 626 //===----------------------------------------------------------------------===// 627 // FromExtentsOp 628 //===----------------------------------------------------------------------===// 629 630 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 631 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 632 return nullptr; 633 SmallVector<int64_t, 6> extents; 634 for (auto attr : operands) 635 extents.push_back(attr.cast<IntegerAttr>().getInt()); 636 Builder builder(getContext()); 637 return builder.getIndexTensorAttr(extents); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // FunctionLibraryOp 642 //===----------------------------------------------------------------------===// 643 644 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 645 StringRef name) { 646 ensureTerminator(*result.addRegion(), builder, result.location); 647 result.attributes.push_back(builder.getNamedAttr( 648 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 649 } 650 651 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 652 auto attr = mapping() 653 .get(op->getName().getIdentifier()) 654 .dyn_cast_or_null<FlatSymbolRefAttr>(); 655 if (!attr) 656 return nullptr; 657 return lookupSymbol<FuncOp>(attr); 658 } 659 660 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 661 OperationState &result) { 662 // Parse the op name. 663 StringAttr nameAttr; 664 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 665 result.attributes)) 666 return failure(); 667 668 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 669 return failure(); 670 671 auto *bodyRegion = result.addRegion(); 672 if (parser.parseRegion(*bodyRegion)) 673 return failure(); 674 675 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 676 result.location); 677 if (parser.parseKeyword("mapping")) 678 return failure(); 679 680 DictionaryAttr mappingAttr; 681 if (parser.parseAttribute(mappingAttr, 682 parser.getBuilder().getType<NoneType>(), "mapping", 683 result.attributes)) 684 return failure(); 685 return success(); 686 } 687 688 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 689 p << op.getOperationName() << ' '; 690 p.printSymbolName(op.getName()); 691 p.printOptionalAttrDictWithKeyword( 692 op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 693 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 694 /*printBlockTerminators=*/false); 695 p << " mapping "; 696 p.printAttributeWithoutType(op.mappingAttr()); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // GetExtentOp 701 //===----------------------------------------------------------------------===// 702 703 Optional<int64_t> GetExtentOp::getConstantDim() { 704 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 705 return constSizeOp.value().getLimitedValue(); 706 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 707 return constantOp.value().cast<IntegerAttr>().getInt(); 708 return llvm::None; 709 } 710 711 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 712 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 713 if (!elements) 714 return nullptr; 715 Optional<int64_t> dim = getConstantDim(); 716 if (!dim.hasValue()) 717 return nullptr; 718 if (dim.getValue() >= elements.getNumElements()) 719 return nullptr; 720 return elements.getValue({(uint64_t)dim.getValue()}); 721 } 722 723 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 724 int64_t dim) { 725 auto loc = result.location; 726 auto dimAttr = builder.getIndexAttr(dim); 727 if (shape.getType().isa<ShapeType>()) { 728 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 729 build(builder, result, builder.getType<SizeType>(), shape, dim); 730 } else { 731 Value dim = 732 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 733 build(builder, result, builder.getIndexType(), shape, dim); 734 } 735 } 736 737 //===----------------------------------------------------------------------===// 738 // IsBroadcastableOp 739 //===----------------------------------------------------------------------===// 740 741 static LogicalResult verify(IsBroadcastableOp op) { 742 // Ensure that AssumingAllOp contains at least one operand 743 if (op.getNumOperands() < 2) 744 return op.emitOpError("required at least 2 input shapes"); 745 return success(); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // RankOp 750 //===----------------------------------------------------------------------===// 751 752 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 753 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 754 if (!shape) 755 return {}; 756 int64_t rank = shape.getNumElements(); 757 Builder builder(getContext()); 758 return builder.getIndexAttr(rank); 759 } 760 761 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 762 /// Constant folding fails in cases where only the rank is constant, not the 763 /// shape itself. 764 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 765 /// 766 /// Example: 767 /// 768 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 769 /// %rank = shape.rank %shape 770 /// 771 /// becomes 772 /// 773 /// %rank = shape.const_size 3 774 775 namespace { 776 struct RankShapeOfCanonicalizationPattern 777 : public OpRewritePattern<shape::RankOp> { 778 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 779 780 LogicalResult matchAndRewrite(shape::RankOp op, 781 PatternRewriter &rewriter) const override { 782 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 783 if (!shapeOfOp) 784 return failure(); 785 auto rankedTensorType = 786 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 787 if (!rankedTensorType) 788 return failure(); 789 int64_t rank = rankedTensorType.getRank(); 790 if (op.getType().isa<IndexType>()) { 791 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 792 } else if (op.getType().isa<shape::SizeType>()) { 793 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 794 } else { 795 return failure(); 796 } 797 return success(); 798 } 799 }; 800 } // namespace 801 802 void shape::RankOp::getCanonicalizationPatterns( 803 OwningRewritePatternList &patterns, MLIRContext *context) { 804 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 805 } 806 807 //===----------------------------------------------------------------------===// 808 // NumElementsOp 809 //===----------------------------------------------------------------------===// 810 811 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 812 813 // Fold only when argument constant. 814 Attribute shape = operands[0]; 815 if (!shape) 816 return {}; 817 818 APInt product(64, 1); 819 for (auto value : shape.cast<DenseIntElementsAttr>()) 820 product *= value; 821 Builder builder(getContext()); 822 return builder.getIndexAttr(product.getLimitedValue()); 823 } 824 825 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 826 Value shape) { 827 if (shape.getType().isa<ShapedType>()) { 828 auto type = builder.getIndexType(); 829 return build(builder, result, type, shape); 830 } 831 auto type = SizeType::get(builder.getContext()); 832 return build(builder, result, type, shape); 833 } 834 835 //===----------------------------------------------------------------------===// 836 // MulOp 837 //===----------------------------------------------------------------------===// 838 839 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 840 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 841 if (!lhs) 842 return nullptr; 843 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 844 if (!rhs) 845 return nullptr; 846 APInt folded = lhs.getValue() * rhs.getValue(); 847 Type indexTy = IndexType::get(getContext()); 848 return IntegerAttr::get(indexTy, folded); 849 } 850 851 //===----------------------------------------------------------------------===// 852 // ShapeOfOp 853 //===----------------------------------------------------------------------===// 854 855 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 856 auto type = getOperand().getType().dyn_cast<ShapedType>(); 857 if (!type || !type.hasStaticShape()) 858 return nullptr; 859 Builder builder(getContext()); 860 return builder.getIndexTensorAttr(type.getShape()); 861 } 862 863 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 864 Type type = arg.getType().isa<ShapedType>() 865 ? (Type)getExtentTensorType(builder.getContext()) 866 : (Type)builder.getType<ShapeType>(); 867 return ShapeOfOp::build(builder, result, type, arg); 868 } 869 870 namespace { 871 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 872 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 873 874 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 875 PatternRewriter &rewriter) const override { 876 if (!op.arg().getType().isa<ShapedType>()) 877 return failure(); 878 if (op.getType().isa<ShapedType>()) 879 return failure(); 880 881 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 882 return success(); 883 } 884 }; 885 } // namespace 886 887 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 888 MLIRContext *context) { 889 patterns.insert<ShapeOfWithTensor>(context); 890 } 891 892 //===----------------------------------------------------------------------===// 893 // SizeToIndexOp 894 //===----------------------------------------------------------------------===// 895 896 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 897 // Constant values of both types, `shape.size` and `index`, are represented as 898 // `IntegerAttr`s which makes constant folding simple. 899 if (Attribute arg = operands[0]) 900 return arg; 901 return impl::foldCastOp(*this); 902 } 903 904 void SizeToIndexOp::getCanonicalizationPatterns( 905 OwningRewritePatternList &patterns, MLIRContext *context) { 906 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // YieldOp 911 //===----------------------------------------------------------------------===// 912 913 static LogicalResult verify(shape::YieldOp op) { 914 auto *parentOp = op->getParentOp(); 915 auto results = parentOp->getResults(); 916 auto operands = op.getOperands(); 917 918 if (parentOp->getNumResults() != op.getNumOperands()) 919 return op.emitOpError() << "number of operands does not match number of " 920 "results of its parent"; 921 for (auto e : llvm::zip(results, operands)) 922 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 923 return op.emitOpError() 924 << "types mismatch between yield op and its parent"; 925 926 return success(); 927 } 928 929 //===----------------------------------------------------------------------===// 930 // SplitAtOp 931 //===----------------------------------------------------------------------===// 932 933 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 934 SmallVectorImpl<OpFoldResult> &results) { 935 if (!operands[0] || !operands[1]) 936 return failure(); 937 auto shapeVec = llvm::to_vector<6>( 938 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 939 auto shape = llvm::makeArrayRef(shapeVec); 940 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 941 // Verify that the split point is in the correct range. 942 // TODO: Constant fold to an "error". 943 int64_t rank = shape.size(); 944 if (!(-rank <= splitPoint && splitPoint <= rank)) 945 return failure(); 946 if (splitPoint < 0) 947 splitPoint += shape.size(); 948 Builder builder(operands[0].getContext()); 949 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 950 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 951 return success(); 952 } 953 954 //===----------------------------------------------------------------------===// 955 // ToExtentTensorOp 956 //===----------------------------------------------------------------------===// 957 958 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 959 if (!operands[0]) 960 return impl::foldCastOp(*this); 961 Builder builder(getContext()); 962 auto shape = llvm::to_vector<6>( 963 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 964 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 965 builder.getIndexType()); 966 return DenseIntElementsAttr::get(type, shape); 967 } 968 969 //===----------------------------------------------------------------------===// 970 // ReduceOp 971 //===----------------------------------------------------------------------===// 972 973 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 974 ValueRange initVals) { 975 result.addOperands(shape); 976 result.addOperands(initVals); 977 978 Region *bodyRegion = result.addRegion(); 979 bodyRegion->push_back(new Block); 980 Block &bodyBlock = bodyRegion->front(); 981 bodyBlock.addArgument(builder.getIndexType()); 982 983 Type elementType; 984 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 985 elementType = tensorType.getElementType(); 986 else 987 elementType = SizeType::get(builder.getContext()); 988 bodyBlock.addArgument(elementType); 989 990 for (Type initValType : initVals.getTypes()) { 991 bodyBlock.addArgument(initValType); 992 result.addTypes(initValType); 993 } 994 } 995 996 static LogicalResult verify(ReduceOp op) { 997 // Verify block arg types. 998 Block &block = op.region().front(); 999 1000 // The block takes index, extent, and aggregated values as arguments. 1001 auto blockArgsCount = op.initVals().size() + 2; 1002 if (block.getNumArguments() != blockArgsCount) 1003 return op.emitOpError() << "ReduceOp body is expected to have " 1004 << blockArgsCount << " arguments"; 1005 1006 // The first block argument is the index and must always be of type `index`. 1007 if (!block.getArgument(0).getType().isa<IndexType>()) 1008 return op.emitOpError( 1009 "argument 0 of ReduceOp body is expected to be of IndexType"); 1010 1011 // The second block argument is the extent and must be of type `size` or 1012 // `index`, depending on whether the reduce operation is applied to a shape or 1013 // to an extent tensor. 1014 Type extentTy = block.getArgument(1).getType(); 1015 if (op.shape().getType().isa<ShapeType>()) { 1016 if (!extentTy.isa<SizeType>()) 1017 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1018 "SizeType if the ReduceOp operates on a ShapeType"); 1019 } else { 1020 if (!extentTy.isa<IndexType>()) 1021 return op.emitOpError( 1022 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1023 "ReduceOp operates on an extent tensor"); 1024 } 1025 1026 for (auto type : llvm::enumerate(op.initVals())) 1027 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1028 return op.emitOpError() 1029 << "type mismatch between argument " << type.index() + 2 1030 << " of ReduceOp body and initial value " << type.index(); 1031 return success(); 1032 } 1033 1034 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1035 // Parse operands. 1036 SmallVector<OpAsmParser::OperandType, 3> operands; 1037 Type shapeOrExtentTensorType; 1038 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1039 OpAsmParser::Delimiter::Paren) || 1040 parser.parseColonType(shapeOrExtentTensorType) || 1041 parser.parseOptionalArrowTypeList(result.types)) 1042 return failure(); 1043 1044 // Resolve operands. 1045 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1046 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1047 result.operands) || 1048 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1049 result.operands)) 1050 return failure(); 1051 1052 // Parse the body. 1053 Region *body = result.addRegion(); 1054 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1055 return failure(); 1056 1057 // Parse attributes. 1058 if (parser.parseOptionalAttrDict(result.attributes)) 1059 return failure(); 1060 1061 return success(); 1062 } 1063 1064 static void print(OpAsmPrinter &p, ReduceOp op) { 1065 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1066 << ") : " << op.shape().getType(); 1067 p.printOptionalArrowTypeList(op.getResultTypes()); 1068 p.printRegion(op.region()); 1069 p.printOptionalAttrDict(op.getAttrs()); 1070 } 1071 1072 #define GET_OP_CLASSES 1073 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1074