1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SmallString.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 #include "llvm/Support/raw_ostream.h" 21 22 using namespace mlir; 23 using namespace mlir::shape; 24 25 namespace { 26 #include "ShapeCanonicalization.inc" 27 } 28 29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 30 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 31 } 32 33 static bool isErrorPropagationPossible(TypeRange operandTypes) { 34 return llvm::any_of(operandTypes, [](Type ty) { 35 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 36 }); 37 } 38 39 static LogicalResult verifySizeOrIndexOp(Operation *op) { 40 assert(op != nullptr && op->getNumResults() == 1); 41 Type resultTy = op->getResultTypes().front(); 42 if (isErrorPropagationPossible(op->getOperandTypes())) { 43 if (!resultTy.isa<SizeType>()) 44 return op->emitOpError() 45 << "if at least one of the operands can hold error values then " 46 "the result must be of type `size` to propagate them"; 47 } 48 return success(); 49 } 50 51 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 52 assert(op != nullptr && op->getNumResults() == 1); 53 Type resultTy = op->getResultTypes().front(); 54 if (isErrorPropagationPossible(op->getOperandTypes())) { 55 if (!resultTy.isa<ShapeType>()) 56 return op->emitOpError() 57 << "if at least one of the operands can hold error values then " 58 "the result must be of type `shape` to propagate them"; 59 } 60 return success(); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // InlinerInterface 65 //===----------------------------------------------------------------------===// 66 67 namespace { 68 /// This class defines the interface for inlining shape dialect ops. 69 struct ShapeInlinerInterface : public DialectInlinerInterface { 70 using DialectInlinerInterface::DialectInlinerInterface; 71 72 // Returns true if the given region 'src' can be inlined into the region 73 // 'dest' that is attached to an operation registered to the current dialect. 74 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 75 BlockAndValueMapping &) const final { 76 return true; 77 } 78 79 // Returns true if the given operation 'op', that is registered to this 80 // dialect, can be inlined into the region 'dest' that is attached to an 81 // operation registered to the current dialect. 82 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 83 BlockAndValueMapping &) const final { 84 return true; 85 } 86 }; 87 } // namespace 88 89 void ShapeDialect::initialize() { 90 addOperations< 91 #define GET_OP_LIST 92 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 93 >(); 94 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 95 addInterfaces<ShapeInlinerInterface>(); 96 // Allow unknown operations during prototyping and testing. As the dialect is 97 // still evolving it makes it simple to start with an unregistered ops and 98 // try different variants before actually defining the op. 99 allowUnknownOperations(); 100 } 101 102 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 103 Attribute value, Type type, 104 Location loc) { 105 if (type.isa<ShapeType>() || 106 type == getExtentTensorType(builder.getContext())) 107 return builder.create<ConstShapeOp>(loc, type, 108 value.cast<DenseIntElementsAttr>()); 109 if (type.isa<SizeType>()) 110 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 111 if (type.isa<WitnessType>()) 112 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 113 if (ConstantOp::isBuildableWith(value, type)) 114 return builder.create<ConstantOp>(loc, type, value); 115 return nullptr; 116 } 117 118 /// Parse a type registered to this dialect. 119 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 120 StringRef keyword; 121 if (parser.parseKeyword(&keyword)) 122 return Type(); 123 124 if (keyword == "shape") 125 return ShapeType::get(getContext()); 126 if (keyword == "size") 127 return SizeType::get(getContext()); 128 if (keyword == "value_shape") 129 return ValueShapeType::get(getContext()); 130 if (keyword == "witness") 131 return WitnessType::get(getContext()); 132 133 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 134 return Type(); 135 } 136 137 /// Print a type registered to this dialect. 138 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 139 TypeSwitch<Type>(type) 140 .Case<ShapeType>([&](Type) { os << "shape"; }) 141 .Case<SizeType>([&](Type) { os << "size"; }) 142 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 143 .Case<WitnessType>([&](Type) { os << "witness"; }) 144 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // AnyOp 149 //===----------------------------------------------------------------------===// 150 151 // TODO: Canonicalization should be implemented for shapes that can be 152 // determined through mixtures of the known dimensions of the inputs. 153 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 154 // Only the last operand is checked because AnyOp is commutative. 155 if (operands.back()) 156 return operands.back(); 157 158 return nullptr; 159 } 160 161 //===----------------------------------------------------------------------===// 162 // AssumingOp 163 //===----------------------------------------------------------------------===// 164 165 static ParseResult parseAssumingOp(OpAsmParser &parser, 166 OperationState &result) { 167 result.regions.reserve(1); 168 Region *doRegion = result.addRegion(); 169 170 auto &builder = parser.getBuilder(); 171 OpAsmParser::OperandType cond; 172 if (parser.parseOperand(cond) || 173 parser.resolveOperand(cond, builder.getType<WitnessType>(), 174 result.operands)) 175 return failure(); 176 177 // Parse optional results type list. 178 if (parser.parseOptionalArrowTypeList(result.types)) 179 return failure(); 180 181 // Parse the region and add a terminator if elided. 182 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 183 return failure(); 184 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 185 186 // Parse the optional attribute list. 187 if (parser.parseOptionalAttrDict(result.attributes)) 188 return failure(); 189 return success(); 190 } 191 192 static void print(OpAsmPrinter &p, AssumingOp op) { 193 bool yieldsResults = !op.results().empty(); 194 195 p << AssumingOp::getOperationName() << " " << op.witness(); 196 if (yieldsResults) { 197 p << " -> (" << op.getResultTypes() << ")"; 198 } 199 p.printRegion(op.doRegion(), 200 /*printEntryBlockArgs=*/false, 201 /*printBlockTerminators=*/yieldsResults); 202 p.printOptionalAttrDict(op.getAttrs()); 203 } 204 205 namespace { 206 // Removes AssumingOp with a passing witness and inlines the region. 207 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 208 using OpRewritePattern<AssumingOp>::OpRewritePattern; 209 210 LogicalResult matchAndRewrite(AssumingOp op, 211 PatternRewriter &rewriter) const override { 212 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 213 if (!witness || !witness.passingAttr()) 214 return failure(); 215 216 AssumingOp::inlineRegionIntoParent(op, rewriter); 217 return success(); 218 } 219 }; 220 } // namespace 221 222 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 223 MLIRContext *context) { 224 // If taking a passing witness, inline region. 225 patterns.insert<AssumingWithTrue>(context); 226 } 227 228 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 229 void AssumingOp::getSuccessorRegions( 230 Optional<unsigned> index, ArrayRef<Attribute> operands, 231 SmallVectorImpl<RegionSuccessor> ®ions) { 232 // AssumingOp has unconditional control flow into the region and back to the 233 // parent, so return the correct RegionSuccessor purely based on the index 234 // being None or 0. 235 if (index.hasValue()) { 236 regions.push_back(RegionSuccessor(getResults())); 237 return; 238 } 239 240 regions.push_back(RegionSuccessor(&doRegion())); 241 } 242 243 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 244 PatternRewriter &rewriter) { 245 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 246 auto *assumingBlock = op.getBody(); 247 auto initPosition = rewriter.getInsertionPoint(); 248 auto *blockAfterAssuming = 249 rewriter.splitBlock(blockBeforeAssuming, initPosition); 250 251 // Remove the AssumingOp and AssumingYieldOp. 252 auto &yieldOp = assumingBlock->back(); 253 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 254 rewriter.replaceOp(op, yieldOp.getOperands()); 255 rewriter.eraseOp(&yieldOp); 256 257 // Merge blocks together as there was no branching behavior from the 258 // AssumingOp. 259 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 260 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumingAllOp 265 //===----------------------------------------------------------------------===// 266 267 void AssumingAllOp::getCanonicalizationPatterns( 268 OwningRewritePatternList &patterns, MLIRContext *context) { 269 patterns.insert<AssumingAllOneOp>(context); 270 } 271 272 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 273 // Iterate in reverse to first handle all constant operands. They are 274 // guaranteed to be the tail of the inputs because this is commutative. 275 for (int idx = operands.size() - 1; idx >= 0; idx--) { 276 Attribute a = operands[idx]; 277 // Cannot fold if any inputs are not constant; 278 if (!a) 279 return nullptr; 280 281 // We do not need to keep statically known values after handling them in 282 // this method. 283 getOperation()->eraseOperand(idx); 284 285 // Always false if any input is statically known false 286 if (!a.cast<BoolAttr>().getValue()) 287 return a; 288 } 289 // If this is reached, all inputs were statically known passing. 290 return BoolAttr::get(true, getContext()); 291 } 292 293 static LogicalResult verify(AssumingAllOp op) { 294 // Ensure that AssumingAllOp contains at least one operand 295 if (op.getNumOperands() == 0) 296 return op.emitOpError("no operands specified"); 297 298 return success(); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // BroadcastOp 303 //===----------------------------------------------------------------------===// 304 305 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 306 if (!operands[1]) 307 return nullptr; 308 309 auto rhsShape = llvm::to_vector<6>( 310 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 311 if (rhsShape.empty()) 312 return lhs(); 313 314 if (!operands[0]) 315 return nullptr; 316 317 auto lhsShape = llvm::to_vector<6>( 318 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 319 if (lhsShape.empty()) 320 return rhs(); 321 322 SmallVector<int64_t, 6> resultShape; 323 // If the shapes are not compatible, we can't fold it. 324 // TODO: Fold to an "error". 325 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 326 return nullptr; 327 Builder builder(getContext()); 328 return builder.getIndexTensorAttr(resultShape); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // ConcatOp 333 //===----------------------------------------------------------------------===// 334 335 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 336 if (!operands[0] || !operands[1]) 337 return nullptr; 338 auto lhsShape = llvm::to_vector<6>( 339 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 340 auto rhsShape = llvm::to_vector<6>( 341 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 342 SmallVector<int64_t, 6> resultShape; 343 resultShape.append(lhsShape.begin(), lhsShape.end()); 344 resultShape.append(rhsShape.begin(), rhsShape.end()); 345 Builder builder(getContext()); 346 return builder.getIndexTensorAttr(resultShape); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // ConstShapeOp 351 //===----------------------------------------------------------------------===// 352 353 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 354 p << "shape.const_shape "; 355 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 356 p << "["; 357 interleaveComma(op.shape().getValues<int64_t>(), p, 358 [&](int64_t i) { p << i; }); 359 p << "] : "; 360 p.printType(op.getType()); 361 } 362 363 static ParseResult parseConstShapeOp(OpAsmParser &parser, 364 OperationState &result) { 365 if (parser.parseOptionalAttrDict(result.attributes)) 366 return failure(); 367 // We piggy-back on ArrayAttr parsing, though we don't internally store the 368 // shape as an ArrayAttr. 369 // TODO: Implement custom parser and maybe make syntax a bit more concise. 370 Attribute extentsRaw; 371 NamedAttrList dummy; 372 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 373 return failure(); 374 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 375 if (!extentsArray) 376 return failure(); 377 SmallVector<int64_t, 6> ints; 378 for (Attribute extent : extentsArray) { 379 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 380 if (!attr) 381 return failure(); 382 ints.push_back(attr.getInt()); 383 } 384 Builder &builder = parser.getBuilder(); 385 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 386 Type resultTy; 387 if (parser.parseColonType(resultTy)) 388 return failure(); 389 result.types.push_back(resultTy); 390 return success(); 391 } 392 393 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 394 395 void ConstShapeOp::getCanonicalizationPatterns( 396 OwningRewritePatternList &patterns, MLIRContext *context) { 397 patterns.insert<TensorCastConstShape>(context); 398 } 399 400 //===----------------------------------------------------------------------===// 401 // CstrBroadcastableOp 402 //===----------------------------------------------------------------------===// 403 404 namespace { 405 // Given an input shape Value, try to obtain the shape's values. 406 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 407 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 408 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 409 if (!type.hasRank()) 410 return failure(); 411 shapeValues = llvm::to_vector<6>(type.getShape()); 412 return success(); 413 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 414 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 415 return success(); 416 } else { 417 return failure(); 418 } 419 } 420 } // namespace 421 422 void CstrBroadcastableOp::getCanonicalizationPatterns( 423 OwningRewritePatternList &patterns, MLIRContext *context) { 424 // Canonicalization patterns have overlap with the considerations during 425 // folding in case additional shape information is inferred at some point that 426 // does not result in folding. 427 patterns.insert<CstrBroadcastableEqOps>(context); 428 } 429 430 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 431 // Both operands are not needed if one is a scalar. 432 if (operands[0] && 433 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 434 return BoolAttr::get(true, getContext()); 435 if (operands[1] && 436 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 437 return BoolAttr::get(true, getContext()); 438 439 if (operands[0] && operands[1]) { 440 auto lhsShape = llvm::to_vector<6>( 441 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 442 auto rhsShape = llvm::to_vector<6>( 443 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 444 SmallVector<int64_t, 6> resultShape; 445 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 446 return BoolAttr::get(true, getContext()); 447 } 448 449 // Lastly, see if folding can be completed based on what constraints are known 450 // on the input shapes. 451 SmallVector<int64_t, 6> lhsShape, rhsShape; 452 if (failed(getShapeVec(lhs(), lhsShape))) 453 return nullptr; 454 if (failed(getShapeVec(rhs(), rhsShape))) 455 return nullptr; 456 457 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 458 return BoolAttr::get(true, getContext()); 459 460 // Because a failing witness result here represents an eventual assertion 461 // failure, we do not replace it with a constant witness. 462 return nullptr; 463 } 464 465 //===----------------------------------------------------------------------===// 466 // CstrEqOp 467 //===----------------------------------------------------------------------===// 468 469 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 470 MLIRContext *context) { 471 // If inputs are equal, return passing witness 472 patterns.insert<CstrEqEqOps>(context); 473 } 474 475 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 476 if (llvm::all_of(operands, 477 [&](Attribute a) { return a && a == operands[0]; })) 478 return BoolAttr::get(true, getContext()); 479 480 // Because a failing witness result here represents an eventual assertion 481 // failure, we do not try to replace it with a constant witness. Similarly, we 482 // cannot if there are any non-const inputs. 483 return nullptr; 484 } 485 486 //===----------------------------------------------------------------------===// 487 // ConstSizeOp 488 //===----------------------------------------------------------------------===// 489 490 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 491 int64_t value) { 492 build(builder, result, builder.getIndexAttr(value)); 493 } 494 495 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 496 497 void ConstSizeOp::getAsmResultNames( 498 llvm::function_ref<void(Value, StringRef)> setNameFn) { 499 SmallString<4> buffer; 500 llvm::raw_svector_ostream os(buffer); 501 os << "c" << value(); 502 setNameFn(getResult(), os.str()); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // ConstWitnessOp 507 //===----------------------------------------------------------------------===// 508 509 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 510 511 //===----------------------------------------------------------------------===// 512 // CstrRequireOp 513 //===----------------------------------------------------------------------===// 514 515 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 516 return operands[0]; 517 } 518 519 //===----------------------------------------------------------------------===// 520 // ShapeEqOp 521 //===----------------------------------------------------------------------===// 522 523 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 524 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 525 if (lhs == nullptr) 526 return {}; 527 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 528 if (rhs == nullptr) 529 return {}; 530 return BoolAttr::get(lhs == rhs, getContext()); 531 } 532 533 //===----------------------------------------------------------------------===// 534 // IndexToSizeOp 535 //===----------------------------------------------------------------------===// 536 537 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 538 // Constant values of both types, `shape.size` and `index`, are represented as 539 // `IntegerAttr`s which makes constant folding simple. 540 if (Attribute arg = operands[0]) 541 return arg; 542 return {}; 543 } 544 545 void IndexToSizeOp::getCanonicalizationPatterns( 546 OwningRewritePatternList &patterns, MLIRContext *context) { 547 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // FromExtentsOp 552 //===----------------------------------------------------------------------===// 553 554 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 555 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 556 return nullptr; 557 SmallVector<int64_t, 6> extents; 558 for (auto attr : operands) 559 extents.push_back(attr.cast<IntegerAttr>().getInt()); 560 Builder builder(getContext()); 561 return builder.getIndexTensorAttr(extents); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // FunctionLibraryOp 566 //===----------------------------------------------------------------------===// 567 568 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 569 StringRef name) { 570 ensureTerminator(*result.addRegion(), builder, result.location); 571 result.attributes.push_back(builder.getNamedAttr( 572 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 573 } 574 575 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 576 auto attr = mapping() 577 .get(op->getName().getIdentifier()) 578 .dyn_cast_or_null<FlatSymbolRefAttr>(); 579 if (!attr) 580 return nullptr; 581 return lookupSymbol<FuncOp>(attr); 582 } 583 584 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 585 OperationState &result) { 586 // Parse the op name. 587 StringAttr nameAttr; 588 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 589 result.attributes)) 590 return failure(); 591 592 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 593 return failure(); 594 595 auto *bodyRegion = result.addRegion(); 596 if (parser.parseRegion(*bodyRegion)) 597 return failure(); 598 599 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 600 result.location); 601 if (parser.parseKeyword("mapping")) 602 return failure(); 603 604 DictionaryAttr mappingAttr; 605 if (parser.parseAttribute(mappingAttr, 606 parser.getBuilder().getType<NoneType>(), "mapping", 607 result.attributes)) 608 return failure(); 609 return success(); 610 } 611 612 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 613 p << op.getOperationName() << ' '; 614 p.printSymbolName(op.getName()); 615 p.printOptionalAttrDictWithKeyword( 616 op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 617 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 618 /*printBlockTerminators=*/false); 619 p << " mapping "; 620 p.printAttributeWithoutType(op.mappingAttr()); 621 } 622 623 //===----------------------------------------------------------------------===// 624 // GetExtentOp 625 //===----------------------------------------------------------------------===// 626 627 Optional<int64_t> GetExtentOp::getConstantDim() { 628 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 629 return constSizeOp.value().getLimitedValue(); 630 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 631 return constantOp.value().cast<IntegerAttr>().getInt(); 632 return llvm::None; 633 } 634 635 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 636 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 637 if (!elements) 638 return nullptr; 639 Optional<int64_t> dim = getConstantDim(); 640 if (!dim.hasValue()) 641 return nullptr; 642 if (dim.getValue() >= elements.getNumElements()) 643 return nullptr; 644 return elements.getValue({(uint64_t)dim.getValue()}); 645 } 646 647 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 648 int64_t dim) { 649 auto loc = result.location; 650 auto dimAttr = builder.getIndexAttr(dim); 651 if (shape.getType().isa<ShapeType>()) { 652 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 653 build(builder, result, builder.getType<SizeType>(), shape, dim); 654 } else { 655 Value dim = 656 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 657 build(builder, result, builder.getIndexType(), shape, dim); 658 } 659 } 660 661 //===----------------------------------------------------------------------===// 662 // RankOp 663 //===----------------------------------------------------------------------===// 664 665 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 666 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 667 if (!shape) 668 return {}; 669 int64_t rank = shape.getNumElements(); 670 Builder builder(getContext()); 671 return builder.getIndexAttr(rank); 672 } 673 674 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 675 /// Constant folding fails in cases where only the rank is constant, not the 676 /// shape itself. 677 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 678 /// 679 /// Example: 680 /// 681 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 682 /// %rank = shape.rank %shape 683 /// 684 /// becomes 685 /// 686 /// %rank = shape.const_size 3 687 688 namespace { 689 struct RankShapeOfCanonicalizationPattern 690 : public OpRewritePattern<shape::RankOp> { 691 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 692 693 LogicalResult matchAndRewrite(shape::RankOp op, 694 PatternRewriter &rewriter) const override { 695 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 696 if (!shapeOfOp) 697 return failure(); 698 auto rankedTensorType = 699 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 700 if (!rankedTensorType) 701 return failure(); 702 int64_t rank = rankedTensorType.getRank(); 703 if (op.getType().isa<IndexType>()) { 704 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 705 } else if (op.getType().isa<shape::SizeType>()) { 706 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 707 } else { 708 return failure(); 709 } 710 return success(); 711 } 712 }; 713 } // namespace 714 715 void shape::RankOp::getCanonicalizationPatterns( 716 OwningRewritePatternList &patterns, MLIRContext *context) { 717 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 718 } 719 720 //===----------------------------------------------------------------------===// 721 // NumElementsOp 722 //===----------------------------------------------------------------------===// 723 724 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 725 726 // Fold only when argument constant. 727 Attribute shape = operands[0]; 728 if (!shape) 729 return {}; 730 731 APInt product(64, 1); 732 for (auto value : shape.cast<DenseIntElementsAttr>()) 733 product *= value; 734 Builder builder(getContext()); 735 return builder.getIndexAttr(product.getLimitedValue()); 736 } 737 738 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 739 Value shape) { 740 if (shape.getType().isa<ShapedType>()) { 741 auto type = builder.getIndexType(); 742 return build(builder, result, type, shape); 743 } 744 auto type = SizeType::get(builder.getContext()); 745 return build(builder, result, type, shape); 746 } 747 748 //===----------------------------------------------------------------------===// 749 // MulOp 750 //===----------------------------------------------------------------------===// 751 752 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 753 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 754 if (!lhs) 755 return nullptr; 756 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 757 if (!rhs) 758 return nullptr; 759 APInt folded = lhs.getValue() * rhs.getValue(); 760 Type indexTy = IndexType::get(getContext()); 761 return IntegerAttr::get(indexTy, folded); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // ShapeOfOp 766 //===----------------------------------------------------------------------===// 767 768 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 769 auto type = getOperand().getType().dyn_cast<ShapedType>(); 770 if (!type || !type.hasStaticShape()) 771 return nullptr; 772 Builder builder(getContext()); 773 return builder.getIndexTensorAttr(type.getShape()); 774 } 775 776 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 777 Type type = arg.getType().isa<ShapedType>() 778 ? (Type)getExtentTensorType(builder.getContext()) 779 : (Type)builder.getType<ShapeType>(); 780 return ShapeOfOp::build(builder, result, type, arg); 781 } 782 783 namespace { 784 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 785 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 786 787 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 788 PatternRewriter &rewriter) const override { 789 if (!op.arg().getType().isa<ShapedType>()) 790 return failure(); 791 if (op.getType().isa<ShapedType>()) 792 return failure(); 793 794 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 795 return success(); 796 } 797 }; 798 } // namespace 799 800 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 801 MLIRContext *context) { 802 patterns.insert<ShapeOfWithTensor>(context); 803 } 804 805 //===----------------------------------------------------------------------===// 806 // SizeToIndexOp 807 //===----------------------------------------------------------------------===// 808 809 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 810 // Constant values of both types, `shape.size` and `index`, are represented as 811 // `IntegerAttr`s which makes constant folding simple. 812 if (Attribute arg = operands[0]) 813 return arg; 814 return impl::foldCastOp(*this); 815 } 816 817 void SizeToIndexOp::getCanonicalizationPatterns( 818 OwningRewritePatternList &patterns, MLIRContext *context) { 819 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 820 } 821 822 //===----------------------------------------------------------------------===// 823 // YieldOp 824 //===----------------------------------------------------------------------===// 825 826 static LogicalResult verify(shape::YieldOp op) { 827 auto *parentOp = op->getParentOp(); 828 auto results = parentOp->getResults(); 829 auto operands = op.getOperands(); 830 831 if (parentOp->getNumResults() != op.getNumOperands()) 832 return op.emitOpError() << "number of operands does not match number of " 833 "results of its parent"; 834 for (auto e : llvm::zip(results, operands)) 835 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 836 return op.emitOpError() 837 << "types mismatch between yield op and its parent"; 838 839 return success(); 840 } 841 842 //===----------------------------------------------------------------------===// 843 // SplitAtOp 844 //===----------------------------------------------------------------------===// 845 846 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 847 SmallVectorImpl<OpFoldResult> &results) { 848 if (!operands[0] || !operands[1]) 849 return failure(); 850 auto shapeVec = llvm::to_vector<6>( 851 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 852 auto shape = llvm::makeArrayRef(shapeVec); 853 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 854 // Verify that the split point is in the correct range. 855 // TODO: Constant fold to an "error". 856 int64_t rank = shape.size(); 857 if (!(-rank <= splitPoint && splitPoint <= rank)) 858 return failure(); 859 if (splitPoint < 0) 860 splitPoint += shape.size(); 861 Builder builder(operands[0].getContext()); 862 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 863 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 864 return success(); 865 } 866 867 //===----------------------------------------------------------------------===// 868 // ToExtentTensorOp 869 //===----------------------------------------------------------------------===// 870 871 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 872 if (!operands[0]) 873 return impl::foldCastOp(*this); 874 Builder builder(getContext()); 875 auto shape = llvm::to_vector<6>( 876 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 877 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 878 builder.getIndexType()); 879 return DenseIntElementsAttr::get(type, shape); 880 } 881 882 //===----------------------------------------------------------------------===// 883 // ReduceOp 884 //===----------------------------------------------------------------------===// 885 886 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 887 ValueRange initVals) { 888 result.addOperands(shape); 889 result.addOperands(initVals); 890 891 Region *bodyRegion = result.addRegion(); 892 bodyRegion->push_back(new Block); 893 Block &bodyBlock = bodyRegion->front(); 894 bodyBlock.addArgument(builder.getIndexType()); 895 896 Type elementType; 897 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 898 elementType = tensorType.getElementType(); 899 else 900 elementType = SizeType::get(builder.getContext()); 901 bodyBlock.addArgument(elementType); 902 903 for (Type initValType : initVals.getTypes()) { 904 bodyBlock.addArgument(initValType); 905 result.addTypes(initValType); 906 } 907 } 908 909 static LogicalResult verify(ReduceOp op) { 910 // Verify block arg types. 911 Block &block = op.region().front(); 912 913 // The block takes index, extent, and aggregated values as arguments. 914 auto blockArgsCount = op.initVals().size() + 2; 915 if (block.getNumArguments() != blockArgsCount) 916 return op.emitOpError() << "ReduceOp body is expected to have " 917 << blockArgsCount << " arguments"; 918 919 // The first block argument is the index and must always be of type `index`. 920 if (!block.getArgument(0).getType().isa<IndexType>()) 921 return op.emitOpError( 922 "argument 0 of ReduceOp body is expected to be of IndexType"); 923 924 // The second block argument is the extent and must be of type `size` or 925 // `index`, depending on whether the reduce operation is applied to a shape or 926 // to an extent tensor. 927 Type extentTy = block.getArgument(1).getType(); 928 if (op.shape().getType().isa<ShapeType>()) { 929 if (!extentTy.isa<SizeType>()) 930 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 931 "SizeType if the ReduceOp operates on a ShapeType"); 932 } else { 933 if (!extentTy.isa<IndexType>()) 934 return op.emitOpError( 935 "argument 1 of ReduceOp body is expected to be of IndexType if the " 936 "ReduceOp operates on an extent tensor"); 937 } 938 939 for (auto type : llvm::enumerate(op.initVals())) 940 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 941 return op.emitOpError() 942 << "type mismatch between argument " << type.index() + 2 943 << " of ReduceOp body and initial value " << type.index(); 944 return success(); 945 } 946 947 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 948 // Parse operands. 949 SmallVector<OpAsmParser::OperandType, 3> operands; 950 Type shapeOrExtentTensorType; 951 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 952 OpAsmParser::Delimiter::Paren) || 953 parser.parseColonType(shapeOrExtentTensorType) || 954 parser.parseOptionalArrowTypeList(result.types)) 955 return failure(); 956 957 // Resolve operands. 958 auto initVals = llvm::makeArrayRef(operands).drop_front(); 959 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 960 result.operands) || 961 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 962 result.operands)) 963 return failure(); 964 965 // Parse the body. 966 Region *body = result.addRegion(); 967 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 968 return failure(); 969 970 // Parse attributes. 971 if (parser.parseOptionalAttrDict(result.attributes)) 972 return failure(); 973 974 return success(); 975 } 976 977 static void print(OpAsmPrinter &p, ReduceOp op) { 978 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 979 << ") : " << op.shape().getType(); 980 p.printOptionalArrowTypeList(op.getResultTypes()); 981 p.printRegion(op.region()); 982 p.printOptionalAttrDict(op.getAttrs()); 983 } 984 985 #define GET_OP_CLASSES 986 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 987