1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 using namespace mlir::shape; 25 26 namespace { 27 #include "ShapeCanonicalization.inc" 28 } 29 30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 31 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 32 } 33 34 static bool isErrorPropagationPossible(TypeRange operandTypes) { 35 return llvm::any_of(operandTypes, [](Type ty) { 36 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 37 }); 38 } 39 40 static LogicalResult verifySizeOrIndexOp(Operation *op) { 41 assert(op != nullptr && op->getNumResults() == 1); 42 Type resultTy = op->getResultTypes().front(); 43 if (isErrorPropagationPossible(op->getOperandTypes())) { 44 if (!resultTy.isa<SizeType>()) 45 return op->emitOpError() 46 << "if at least one of the operands can hold error values then " 47 "the result must be of type `size` to propagate them"; 48 } 49 return success(); 50 } 51 52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 53 assert(op != nullptr && op->getNumResults() == 1); 54 Type resultTy = op->getResultTypes().front(); 55 if (isErrorPropagationPossible(op->getOperandTypes())) { 56 if (!resultTy.isa<ShapeType>()) 57 return op->emitOpError() 58 << "if at least one of the operands can hold error values then " 59 "the result must be of type `shape` to propagate them"; 60 } 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// This class defines the interface for inlining shape dialect ops. 70 struct ShapeInlinerInterface : public DialectInlinerInterface { 71 using DialectInlinerInterface::DialectInlinerInterface; 72 73 // Returns true if the given region 'src' can be inlined into the region 74 // 'dest' that is attached to an operation registered to the current dialect. 75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 76 BlockAndValueMapping &) const final { 77 return true; 78 } 79 80 // Returns true if the given operation 'op', that is registered to this 81 // dialect, can be inlined into the region 'dest' that is attached to an 82 // operation registered to the current dialect. 83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 84 BlockAndValueMapping &) const final { 85 return true; 86 } 87 }; 88 } // namespace 89 90 void ShapeDialect::initialize() { 91 addOperations< 92 #define GET_OP_LIST 93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 94 >(); 95 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 96 addInterfaces<ShapeInlinerInterface>(); 97 // Allow unknown operations during prototyping and testing. As the dialect is 98 // still evolving it makes it simple to start with an unregistered ops and 99 // try different variants before actually defining the op. 100 allowUnknownOperations(); 101 } 102 103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 104 Attribute value, Type type, 105 Location loc) { 106 if (type.isa<ShapeType>() || 107 type == getExtentTensorType(builder.getContext())) 108 return builder.create<ConstShapeOp>(loc, type, 109 value.cast<DenseIntElementsAttr>()); 110 if (type.isa<SizeType>()) 111 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 112 if (type.isa<WitnessType>()) 113 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 114 if (ConstantOp::isBuildableWith(value, type)) 115 return builder.create<ConstantOp>(loc, type, value); 116 return nullptr; 117 } 118 119 /// Parse a type registered to this dialect. 120 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 121 StringRef keyword; 122 if (parser.parseKeyword(&keyword)) 123 return Type(); 124 125 if (keyword == "shape") 126 return ShapeType::get(getContext()); 127 if (keyword == "size") 128 return SizeType::get(getContext()); 129 if (keyword == "value_shape") 130 return ValueShapeType::get(getContext()); 131 if (keyword == "witness") 132 return WitnessType::get(getContext()); 133 134 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 135 return Type(); 136 } 137 138 /// Print a type registered to this dialect. 139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 140 TypeSwitch<Type>(type) 141 .Case<ShapeType>([&](Type) { os << "shape"; }) 142 .Case<SizeType>([&](Type) { os << "size"; }) 143 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 144 .Case<WitnessType>([&](Type) { os << "witness"; }) 145 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 146 } 147 148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 149 NamedAttribute attribute) { 150 // Verify shape.lib attribute. 151 if (attribute.first == "shape.lib") { 152 if (!op->hasTrait<OpTrait::SymbolTable>()) 153 return op->emitError( 154 "shape.lib attribute may only be on op implementing SymbolTable"); 155 156 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 157 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 158 if (!symbol) 159 return op->emitError("shape function library ") 160 << symbolRef << " not found"; 161 return isa<shape::FunctionLibraryOp>(symbol) 162 ? success() 163 : op->emitError() 164 << symbolRef << " required to be shape function library"; 165 } 166 167 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 168 // Verify all entries are function libraries and mappings in libraries 169 // refer to unique ops. 170 DenseSet<Identifier> key; 171 for (auto it : arr) { 172 if (!it.isa<SymbolRefAttr>()) 173 return op->emitError( 174 "only SymbolRefAttr allowed in shape.lib attribute array"); 175 176 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 177 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 178 if (!shapeFnLib) 179 return op->emitError() 180 << it << " does not refer to FunctionLibraryOp"; 181 for (auto mapping : shapeFnLib.mapping()) { 182 if (!key.insert(mapping.first).second) { 183 return op->emitError("only one op to shape mapping allowed, found " 184 "multiple for `") 185 << mapping.first << "`"; 186 } 187 } 188 } 189 return success(); 190 } 191 192 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 193 "allowed as shape.lib attribute"); 194 } 195 return success(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // AnyOp 200 //===----------------------------------------------------------------------===// 201 202 // TODO: Canonicalization should be implemented for shapes that can be 203 // determined through mixtures of the known dimensions of the inputs. 204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 205 // Only the last operand is checked because AnyOp is commutative. 206 if (operands.back()) 207 return operands.back(); 208 209 return nullptr; 210 } 211 212 //===----------------------------------------------------------------------===// 213 // AssumingOp 214 //===----------------------------------------------------------------------===// 215 216 static ParseResult parseAssumingOp(OpAsmParser &parser, 217 OperationState &result) { 218 result.regions.reserve(1); 219 Region *doRegion = result.addRegion(); 220 221 auto &builder = parser.getBuilder(); 222 OpAsmParser::OperandType cond; 223 if (parser.parseOperand(cond) || 224 parser.resolveOperand(cond, builder.getType<WitnessType>(), 225 result.operands)) 226 return failure(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the region and add a terminator if elided. 233 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 236 237 // Parse the optional attribute list. 238 if (parser.parseOptionalAttrDict(result.attributes)) 239 return failure(); 240 return success(); 241 } 242 243 static void print(OpAsmPrinter &p, AssumingOp op) { 244 bool yieldsResults = !op.results().empty(); 245 246 p << AssumingOp::getOperationName() << " " << op.witness(); 247 if (yieldsResults) { 248 p << " -> (" << op.getResultTypes() << ")"; 249 } 250 p.printRegion(op.doRegion(), 251 /*printEntryBlockArgs=*/false, 252 /*printBlockTerminators=*/yieldsResults); 253 p.printOptionalAttrDict(op->getAttrs()); 254 } 255 256 namespace { 257 // Removes AssumingOp with a passing witness and inlines the region. 258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 259 using OpRewritePattern<AssumingOp>::OpRewritePattern; 260 261 LogicalResult matchAndRewrite(AssumingOp op, 262 PatternRewriter &rewriter) const override { 263 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 264 if (!witness || !witness.passingAttr()) 265 return failure(); 266 267 AssumingOp::inlineRegionIntoParent(op, rewriter); 268 return success(); 269 } 270 }; 271 } // namespace 272 273 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 274 MLIRContext *context) { 275 // If taking a passing witness, inline region. 276 patterns.add<AssumingWithTrue>(context); 277 } 278 279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 280 void AssumingOp::getSuccessorRegions( 281 Optional<unsigned> index, ArrayRef<Attribute> operands, 282 SmallVectorImpl<RegionSuccessor> ®ions) { 283 // AssumingOp has unconditional control flow into the region and back to the 284 // parent, so return the correct RegionSuccessor purely based on the index 285 // being None or 0. 286 if (index.hasValue()) { 287 regions.push_back(RegionSuccessor(getResults())); 288 return; 289 } 290 291 regions.push_back(RegionSuccessor(&doRegion())); 292 } 293 294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 295 PatternRewriter &rewriter) { 296 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 297 auto *assumingBlock = op.getBody(); 298 auto initPosition = rewriter.getInsertionPoint(); 299 auto *blockAfterAssuming = 300 rewriter.splitBlock(blockBeforeAssuming, initPosition); 301 302 // Remove the AssumingOp and AssumingYieldOp. 303 auto &yieldOp = assumingBlock->back(); 304 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 305 rewriter.replaceOp(op, yieldOp.getOperands()); 306 rewriter.eraseOp(&yieldOp); 307 308 // Merge blocks together as there was no branching behavior from the 309 // AssumingOp. 310 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 311 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 312 } 313 314 void AssumingOp::build( 315 OpBuilder &builder, OperationState &result, Value witness, 316 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 317 318 result.addOperands(witness); 319 Region *bodyRegion = result.addRegion(); 320 bodyRegion->push_back(new Block); 321 Block &bodyBlock = bodyRegion->front(); 322 323 // Build body. 324 OpBuilder::InsertionGuard guard(builder); 325 builder.setInsertionPointToStart(&bodyBlock); 326 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 327 builder.create<AssumingYieldOp>(result.location, yieldValues); 328 329 SmallVector<Type, 2> assumingTypes; 330 for (Value v : yieldValues) 331 assumingTypes.push_back(v.getType()); 332 result.addTypes(assumingTypes); 333 } 334 335 //===----------------------------------------------------------------------===// 336 // AssumingAllOp 337 //===----------------------------------------------------------------------===// 338 339 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 340 MLIRContext *context) { 341 patterns.add<AssumingAllOneOp>(context); 342 } 343 344 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 345 // Iterate in reverse to first handle all constant operands. They are 346 // guaranteed to be the tail of the inputs because this is commutative. 347 for (int idx = operands.size() - 1; idx >= 0; idx--) { 348 Attribute a = operands[idx]; 349 // Cannot fold if any inputs are not constant; 350 if (!a) 351 return nullptr; 352 353 // We do not need to keep statically known values after handling them in 354 // this method. 355 getOperation()->eraseOperand(idx); 356 357 // Always false if any input is statically known false 358 if (!a.cast<BoolAttr>().getValue()) 359 return a; 360 } 361 // If this is reached, all inputs were statically known passing. 362 return BoolAttr::get(getContext(), true); 363 } 364 365 static LogicalResult verify(AssumingAllOp op) { 366 // Ensure that AssumingAllOp contains at least one operand 367 if (op.getNumOperands() == 0) 368 return op.emitOpError("no operands specified"); 369 370 return success(); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // BroadcastOp 375 //===----------------------------------------------------------------------===// 376 377 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 378 if (operands.size() == 1) 379 return shapes().front(); 380 381 // TODO: Support folding with more than 2 input shapes 382 if (shapes().size() > 2) 383 return nullptr; 384 385 if (!operands[1]) 386 return nullptr; 387 388 auto rhsShape = llvm::to_vector<6>( 389 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 390 if (rhsShape.empty()) 391 return shapes()[0]; 392 393 if (!operands[0]) 394 return nullptr; 395 396 auto lhsShape = llvm::to_vector<6>( 397 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 398 if (lhsShape.empty()) 399 return shapes()[1]; 400 401 SmallVector<int64_t, 6> resultShape; 402 // If the shapes are not compatible, we can't fold it. 403 // TODO: Fold to an "error". 404 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 405 return nullptr; 406 Builder builder(getContext()); 407 return builder.getIndexTensorAttr(resultShape); 408 } 409 410 static LogicalResult verify(BroadcastOp op) { 411 return verifyShapeOrExtentTensorOp(op); 412 } 413 414 namespace { 415 template <typename OpTy> 416 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 417 using OpRewritePattern<OpTy>::OpRewritePattern; 418 419 LogicalResult matchAndRewrite(OpTy op, 420 PatternRewriter &rewriter) const override { 421 // Find unique operands. 422 SmallVector<Value, 2> unique; 423 for (Value v : op.getOperands()) { 424 if (!llvm::is_contained(unique, v)) 425 unique.push_back(v); 426 } 427 428 // Reduce op to equivalent with unique operands. 429 if (unique.size() < op.getNumOperands()) { 430 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 431 op->getAttrs()); 432 return success(); 433 } 434 435 return failure(); 436 } 437 }; 438 439 struct BroadcastForwardSingleOperandPattern 440 : public OpRewritePattern<BroadcastOp> { 441 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 442 443 LogicalResult matchAndRewrite(BroadcastOp op, 444 PatternRewriter &rewriter) const override { 445 if (op.getNumOperands() == 1) { 446 rewriter.replaceOp(op, op.shapes().front()); 447 return success(); 448 } 449 return failure(); 450 } 451 }; 452 } // namespace 453 454 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 455 MLIRContext *context) { 456 patterns.add<BroadcastForwardSingleOperandPattern, 457 RemoveDuplicateOperandsPattern<BroadcastOp>>(context); 458 } 459 460 //===----------------------------------------------------------------------===// 461 // ConcatOp 462 //===----------------------------------------------------------------------===// 463 464 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 465 if (!operands[0] || !operands[1]) 466 return nullptr; 467 auto lhsShape = llvm::to_vector<6>( 468 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 469 auto rhsShape = llvm::to_vector<6>( 470 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 471 SmallVector<int64_t, 6> resultShape; 472 resultShape.append(lhsShape.begin(), lhsShape.end()); 473 resultShape.append(rhsShape.begin(), rhsShape.end()); 474 Builder builder(getContext()); 475 return builder.getIndexTensorAttr(resultShape); 476 } 477 478 //===----------------------------------------------------------------------===// 479 // ConstShapeOp 480 //===----------------------------------------------------------------------===// 481 482 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 483 p << "shape.const_shape "; 484 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 485 p << "["; 486 interleaveComma(op.shape().getValues<int64_t>(), p, 487 [&](int64_t i) { p << i; }); 488 p << "] : "; 489 p.printType(op.getType()); 490 } 491 492 static ParseResult parseConstShapeOp(OpAsmParser &parser, 493 OperationState &result) { 494 if (parser.parseOptionalAttrDict(result.attributes)) 495 return failure(); 496 // We piggy-back on ArrayAttr parsing, though we don't internally store the 497 // shape as an ArrayAttr. 498 // TODO: Implement custom parser and maybe make syntax a bit more concise. 499 Attribute extentsRaw; 500 NamedAttrList dummy; 501 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 502 return failure(); 503 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 504 if (!extentsArray) 505 return failure(); 506 SmallVector<int64_t, 6> ints; 507 for (Attribute extent : extentsArray) { 508 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 509 if (!attr) 510 return failure(); 511 ints.push_back(attr.getInt()); 512 } 513 Builder &builder = parser.getBuilder(); 514 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 515 Type resultTy; 516 if (parser.parseColonType(resultTy)) 517 return failure(); 518 result.types.push_back(resultTy); 519 return success(); 520 } 521 522 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 523 524 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 525 MLIRContext *context) { 526 patterns.add<TensorCastConstShape>(context); 527 } 528 529 //===----------------------------------------------------------------------===// 530 // CstrBroadcastableOp 531 //===----------------------------------------------------------------------===// 532 533 namespace { 534 // Given an input shape Value, try to obtain the shape's values. 535 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 536 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 537 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 538 if (!type.hasRank()) 539 return failure(); 540 shapeValues = llvm::to_vector<6>(type.getShape()); 541 return success(); 542 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 543 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 544 return success(); 545 } else { 546 return failure(); 547 } 548 } 549 } // namespace 550 551 void CstrBroadcastableOp::getCanonicalizationPatterns( 552 RewritePatternSet &patterns, MLIRContext *context) { 553 // Canonicalization patterns have overlap with the considerations during 554 // folding in case additional shape information is inferred at some point that 555 // does not result in folding. 556 patterns.add<CstrBroadcastableEqOps, 557 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>>(context); 558 } 559 560 // Return true if there is exactly one attribute not representing a scalar 561 // broadcast. 562 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 563 bool nonScalarSeen = false; 564 for (Attribute a : attributes) { 565 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 566 if (nonScalarSeen) 567 return false; 568 nonScalarSeen = true; 569 } 570 } 571 return true; 572 } 573 574 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 575 // No broadcasting is needed if all operands but one are scalar. 576 if (hasAtMostSingleNonScalar(operands)) 577 return BoolAttr::get(getContext(), true); 578 579 if ([&] { 580 SmallVector<SmallVector<int64_t, 6>, 6> extents; 581 for (const auto &operand : operands) { 582 if (!operand) 583 return false; 584 extents.push_back(llvm::to_vector<6>( 585 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 586 } 587 return OpTrait::util::staticallyKnownBroadcastable(extents); 588 }()) 589 return BoolAttr::get(getContext(), true); 590 591 // Lastly, see if folding can be completed based on what constraints are known 592 // on the input shapes. 593 if ([&] { 594 SmallVector<SmallVector<int64_t, 6>, 6> extents; 595 for (auto shapeValue : shapes()) { 596 extents.emplace_back(); 597 if (failed(getShapeVec(shapeValue, extents.back()))) 598 return false; 599 } 600 return OpTrait::util::staticallyKnownBroadcastable(extents); 601 }()) 602 return BoolAttr::get(getContext(), true); 603 604 // Because a failing witness result here represents an eventual assertion 605 // failure, we do not replace it with a constant witness. 606 return nullptr; 607 } 608 609 static LogicalResult verify(CstrBroadcastableOp op) { 610 // Ensure that AssumingAllOp contains at least one operand 611 if (op.getNumOperands() < 2) 612 return op.emitOpError("required at least 2 input shapes"); 613 return success(); 614 } 615 616 //===----------------------------------------------------------------------===// 617 // CstrEqOp 618 //===----------------------------------------------------------------------===// 619 620 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 621 MLIRContext *context) { 622 // If inputs are equal, return passing witness 623 patterns.add<CstrEqEqOps>(context); 624 } 625 626 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 627 if (llvm::all_of(operands, 628 [&](Attribute a) { return a && a == operands[0]; })) 629 return BoolAttr::get(getContext(), true); 630 631 // Because a failing witness result here represents an eventual assertion 632 // failure, we do not try to replace it with a constant witness. Similarly, we 633 // cannot if there are any non-const inputs. 634 return nullptr; 635 } 636 637 //===----------------------------------------------------------------------===// 638 // ConstSizeOp 639 //===----------------------------------------------------------------------===// 640 641 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 642 int64_t value) { 643 build(builder, result, builder.getIndexAttr(value)); 644 } 645 646 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 647 648 void ConstSizeOp::getAsmResultNames( 649 llvm::function_ref<void(Value, StringRef)> setNameFn) { 650 SmallString<4> buffer; 651 llvm::raw_svector_ostream os(buffer); 652 os << "c" << value(); 653 setNameFn(getResult(), os.str()); 654 } 655 656 //===----------------------------------------------------------------------===// 657 // ConstWitnessOp 658 //===----------------------------------------------------------------------===// 659 660 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 661 662 //===----------------------------------------------------------------------===// 663 // CstrRequireOp 664 //===----------------------------------------------------------------------===// 665 666 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 667 return operands[0]; 668 } 669 670 //===----------------------------------------------------------------------===// 671 // DivOp 672 //===----------------------------------------------------------------------===// 673 674 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 675 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 676 if (!lhs) 677 return nullptr; 678 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 679 if (!rhs) 680 return nullptr; 681 682 // Division in APInt does not follow floor(lhs, rhs) when the result is 683 // negative. Rather, APInt rounds toward zero. 684 APInt quotient, remainder; 685 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 686 if (quotient.isNegative() && !remainder.isNullValue()) { 687 quotient -= 1; 688 } 689 690 Type indexTy = IndexType::get(getContext()); 691 return IntegerAttr::get(indexTy, quotient); 692 } 693 694 //===----------------------------------------------------------------------===// 695 // ShapeEqOp 696 //===----------------------------------------------------------------------===// 697 698 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 699 bool allSame = true; 700 if (!operands.empty() && !operands[0]) 701 return {}; 702 for (Attribute operand : operands.drop_front(1)) { 703 if (!operand) 704 return {}; 705 allSame = allSame && operand == operands[0]; 706 } 707 return BoolAttr::get(getContext(), allSame); 708 } 709 710 //===----------------------------------------------------------------------===// 711 // IndexToSizeOp 712 //===----------------------------------------------------------------------===// 713 714 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 715 // Constant values of both types, `shape.size` and `index`, are represented as 716 // `IntegerAttr`s which makes constant folding simple. 717 if (Attribute arg = operands[0]) 718 return arg; 719 return {}; 720 } 721 722 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 723 MLIRContext *context) { 724 patterns.add<SizeToIndexToSizeCanonicalization>(context); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // FromExtentsOp 729 //===----------------------------------------------------------------------===// 730 731 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 732 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 733 return nullptr; 734 SmallVector<int64_t, 6> extents; 735 for (auto attr : operands) 736 extents.push_back(attr.cast<IntegerAttr>().getInt()); 737 Builder builder(getContext()); 738 return builder.getIndexTensorAttr(extents); 739 } 740 741 //===----------------------------------------------------------------------===// 742 // FunctionLibraryOp 743 //===----------------------------------------------------------------------===// 744 745 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 746 StringRef name) { 747 ensureTerminator(*result.addRegion(), builder, result.location); 748 result.attributes.push_back(builder.getNamedAttr( 749 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 750 } 751 752 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 753 auto attr = mapping() 754 .get(op->getName().getIdentifier()) 755 .dyn_cast_or_null<FlatSymbolRefAttr>(); 756 if (!attr) 757 return nullptr; 758 return lookupSymbol<FuncOp>(attr); 759 } 760 761 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 762 OperationState &result) { 763 // Parse the op name. 764 StringAttr nameAttr; 765 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 766 result.attributes)) 767 return failure(); 768 769 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 770 return failure(); 771 772 auto *bodyRegion = result.addRegion(); 773 if (parser.parseRegion(*bodyRegion)) 774 return failure(); 775 776 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 777 result.location); 778 if (parser.parseKeyword("mapping")) 779 return failure(); 780 781 DictionaryAttr mappingAttr; 782 if (parser.parseAttribute(mappingAttr, 783 parser.getBuilder().getType<NoneType>(), "mapping", 784 result.attributes)) 785 return failure(); 786 return success(); 787 } 788 789 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 790 p << op.getOperationName() << ' '; 791 p.printSymbolName(op.getName()); 792 p.printOptionalAttrDictWithKeyword( 793 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 794 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 795 /*printBlockTerminators=*/false); 796 p << " mapping "; 797 p.printAttributeWithoutType(op.mappingAttr()); 798 } 799 800 //===----------------------------------------------------------------------===// 801 // GetExtentOp 802 //===----------------------------------------------------------------------===// 803 804 Optional<int64_t> GetExtentOp::getConstantDim() { 805 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 806 return constSizeOp.value().getLimitedValue(); 807 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 808 return constantOp.value().cast<IntegerAttr>().getInt(); 809 return llvm::None; 810 } 811 812 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 813 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 814 if (!elements) 815 return nullptr; 816 Optional<int64_t> dim = getConstantDim(); 817 if (!dim.hasValue()) 818 return nullptr; 819 if (dim.getValue() >= elements.getNumElements()) 820 return nullptr; 821 return elements.getValue({(uint64_t)dim.getValue()}); 822 } 823 824 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 825 int64_t dim) { 826 auto loc = result.location; 827 auto dimAttr = builder.getIndexAttr(dim); 828 if (shape.getType().isa<ShapeType>()) { 829 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 830 build(builder, result, builder.getType<SizeType>(), shape, dim); 831 } else { 832 Value dim = 833 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 834 build(builder, result, builder.getIndexType(), shape, dim); 835 } 836 } 837 838 //===----------------------------------------------------------------------===// 839 // IsBroadcastableOp 840 //===----------------------------------------------------------------------===// 841 842 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 843 MLIRContext *context) { 844 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 845 } 846 847 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 848 // Can always broadcast fewer than two shapes. 849 if (operands.size() < 2) { 850 return BoolAttr::get(getContext(), true); 851 } 852 853 return nullptr; 854 } 855 856 //===----------------------------------------------------------------------===// 857 // RankOp 858 //===----------------------------------------------------------------------===// 859 860 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 861 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 862 if (!shape) 863 return {}; 864 int64_t rank = shape.getNumElements(); 865 Builder builder(getContext()); 866 return builder.getIndexAttr(rank); 867 } 868 869 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 870 /// Constant folding fails in cases where only the rank is constant, not the 871 /// shape itself. 872 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 873 /// 874 /// Example: 875 /// 876 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 877 /// %rank = shape.rank %shape 878 /// 879 /// becomes 880 /// 881 /// %rank = shape.const_size 3 882 883 namespace { 884 struct RankShapeOfCanonicalizationPattern 885 : public OpRewritePattern<shape::RankOp> { 886 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 887 888 LogicalResult matchAndRewrite(shape::RankOp op, 889 PatternRewriter &rewriter) const override { 890 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 891 if (!shapeOfOp) 892 return failure(); 893 auto rankedTensorType = 894 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 895 if (!rankedTensorType) 896 return failure(); 897 int64_t rank = rankedTensorType.getRank(); 898 if (op.getType().isa<IndexType>()) { 899 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 900 } else if (op.getType().isa<shape::SizeType>()) { 901 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 902 } else { 903 return failure(); 904 } 905 return success(); 906 } 907 }; 908 } // namespace 909 910 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 911 MLIRContext *context) { 912 patterns.add<RankShapeOfCanonicalizationPattern>(context); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // NumElementsOp 917 //===----------------------------------------------------------------------===// 918 919 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 920 921 // Fold only when argument constant. 922 Attribute shape = operands[0]; 923 if (!shape) 924 return {}; 925 926 APInt product(64, 1); 927 for (auto value : shape.cast<DenseIntElementsAttr>()) 928 product *= value; 929 Builder builder(getContext()); 930 return builder.getIndexAttr(product.getLimitedValue()); 931 } 932 933 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 934 Value shape) { 935 if (shape.getType().isa<ShapedType>()) { 936 auto type = builder.getIndexType(); 937 return build(builder, result, type, shape); 938 } 939 auto type = SizeType::get(builder.getContext()); 940 return build(builder, result, type, shape); 941 } 942 943 //===----------------------------------------------------------------------===// 944 // MulOp 945 //===----------------------------------------------------------------------===// 946 947 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 948 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 949 if (!lhs) 950 return nullptr; 951 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 952 if (!rhs) 953 return nullptr; 954 APInt folded = lhs.getValue() * rhs.getValue(); 955 Type indexTy = IndexType::get(getContext()); 956 return IntegerAttr::get(indexTy, folded); 957 } 958 959 //===----------------------------------------------------------------------===// 960 // ShapeOfOp 961 //===----------------------------------------------------------------------===// 962 963 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 964 auto type = getOperand().getType().dyn_cast<ShapedType>(); 965 if (!type || !type.hasStaticShape()) 966 return nullptr; 967 Builder builder(getContext()); 968 return builder.getIndexTensorAttr(type.getShape()); 969 } 970 971 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 972 Type type = arg.getType().isa<ShapedType>() 973 ? (Type)getExtentTensorType(builder.getContext()) 974 : (Type)builder.getType<ShapeType>(); 975 return ShapeOfOp::build(builder, result, type, arg); 976 } 977 978 namespace { 979 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 980 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 981 982 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 983 PatternRewriter &rewriter) const override { 984 if (!op.arg().getType().isa<ShapedType>()) 985 return failure(); 986 if (op.getType().isa<ShapedType>()) 987 return failure(); 988 989 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 990 return success(); 991 } 992 }; 993 } // namespace 994 995 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 996 MLIRContext *context) { 997 patterns.add<ShapeOfWithTensor>(context); 998 } 999 1000 //===----------------------------------------------------------------------===// 1001 // SizeToIndexOp 1002 //===----------------------------------------------------------------------===// 1003 1004 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1005 // Constant values of both types, `shape.size` and `index`, are represented as 1006 // `IntegerAttr`s which makes constant folding simple. 1007 if (Attribute arg = operands[0]) 1008 return arg; 1009 return impl::foldCastOp(*this); 1010 } 1011 1012 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1013 MLIRContext *context) { 1014 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // YieldOp 1019 //===----------------------------------------------------------------------===// 1020 1021 static LogicalResult verify(shape::YieldOp op) { 1022 auto *parentOp = op->getParentOp(); 1023 auto results = parentOp->getResults(); 1024 auto operands = op.getOperands(); 1025 1026 if (parentOp->getNumResults() != op.getNumOperands()) 1027 return op.emitOpError() << "number of operands does not match number of " 1028 "results of its parent"; 1029 for (auto e : llvm::zip(results, operands)) 1030 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1031 return op.emitOpError() 1032 << "types mismatch between yield op and its parent"; 1033 1034 return success(); 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // SplitAtOp 1039 //===----------------------------------------------------------------------===// 1040 1041 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1042 SmallVectorImpl<OpFoldResult> &results) { 1043 if (!operands[0] || !operands[1]) 1044 return failure(); 1045 auto shapeVec = llvm::to_vector<6>( 1046 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1047 auto shape = llvm::makeArrayRef(shapeVec); 1048 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1049 // Verify that the split point is in the correct range. 1050 // TODO: Constant fold to an "error". 1051 int64_t rank = shape.size(); 1052 if (!(-rank <= splitPoint && splitPoint <= rank)) 1053 return failure(); 1054 if (splitPoint < 0) 1055 splitPoint += shape.size(); 1056 Builder builder(operands[0].getContext()); 1057 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1058 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1059 return success(); 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // ToExtentTensorOp 1064 //===----------------------------------------------------------------------===// 1065 1066 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1067 if (!operands[0]) 1068 return impl::foldCastOp(*this); 1069 Builder builder(getContext()); 1070 auto shape = llvm::to_vector<6>( 1071 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1072 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1073 builder.getIndexType()); 1074 return DenseIntElementsAttr::get(type, shape); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // ReduceOp 1079 //===----------------------------------------------------------------------===// 1080 1081 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1082 ValueRange initVals) { 1083 result.addOperands(shape); 1084 result.addOperands(initVals); 1085 1086 Region *bodyRegion = result.addRegion(); 1087 bodyRegion->push_back(new Block); 1088 Block &bodyBlock = bodyRegion->front(); 1089 bodyBlock.addArgument(builder.getIndexType()); 1090 1091 Type elementType; 1092 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1093 elementType = tensorType.getElementType(); 1094 else 1095 elementType = SizeType::get(builder.getContext()); 1096 bodyBlock.addArgument(elementType); 1097 1098 for (Type initValType : initVals.getTypes()) { 1099 bodyBlock.addArgument(initValType); 1100 result.addTypes(initValType); 1101 } 1102 } 1103 1104 static LogicalResult verify(ReduceOp op) { 1105 // Verify block arg types. 1106 Block &block = op.region().front(); 1107 1108 // The block takes index, extent, and aggregated values as arguments. 1109 auto blockArgsCount = op.initVals().size() + 2; 1110 if (block.getNumArguments() != blockArgsCount) 1111 return op.emitOpError() << "ReduceOp body is expected to have " 1112 << blockArgsCount << " arguments"; 1113 1114 // The first block argument is the index and must always be of type `index`. 1115 if (!block.getArgument(0).getType().isa<IndexType>()) 1116 return op.emitOpError( 1117 "argument 0 of ReduceOp body is expected to be of IndexType"); 1118 1119 // The second block argument is the extent and must be of type `size` or 1120 // `index`, depending on whether the reduce operation is applied to a shape or 1121 // to an extent tensor. 1122 Type extentTy = block.getArgument(1).getType(); 1123 if (op.shape().getType().isa<ShapeType>()) { 1124 if (!extentTy.isa<SizeType>()) 1125 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1126 "SizeType if the ReduceOp operates on a ShapeType"); 1127 } else { 1128 if (!extentTy.isa<IndexType>()) 1129 return op.emitOpError( 1130 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1131 "ReduceOp operates on an extent tensor"); 1132 } 1133 1134 for (auto type : llvm::enumerate(op.initVals())) 1135 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1136 return op.emitOpError() 1137 << "type mismatch between argument " << type.index() + 2 1138 << " of ReduceOp body and initial value " << type.index(); 1139 return success(); 1140 } 1141 1142 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1143 // Parse operands. 1144 SmallVector<OpAsmParser::OperandType, 3> operands; 1145 Type shapeOrExtentTensorType; 1146 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1147 OpAsmParser::Delimiter::Paren) || 1148 parser.parseColonType(shapeOrExtentTensorType) || 1149 parser.parseOptionalArrowTypeList(result.types)) 1150 return failure(); 1151 1152 // Resolve operands. 1153 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1154 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1155 result.operands) || 1156 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1157 result.operands)) 1158 return failure(); 1159 1160 // Parse the body. 1161 Region *body = result.addRegion(); 1162 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1163 return failure(); 1164 1165 // Parse attributes. 1166 if (parser.parseOptionalAttrDict(result.attributes)) 1167 return failure(); 1168 1169 return success(); 1170 } 1171 1172 static void print(OpAsmPrinter &p, ReduceOp op) { 1173 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1174 << ") : " << op.shape().getType(); 1175 p.printOptionalArrowTypeList(op.getResultTypes()); 1176 p.printRegion(op.region()); 1177 p.printOptionalAttrDict(op->getAttrs()); 1178 } 1179 1180 #define GET_OP_CLASSES 1181 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1182