1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 using namespace mlir::shape; 25 26 namespace { 27 #include "ShapeCanonicalization.inc" 28 } 29 30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { 31 return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); 32 } 33 34 static bool isErrorPropagationPossible(TypeRange operandTypes) { 35 return llvm::any_of(operandTypes, [](Type ty) { 36 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 37 }); 38 } 39 40 static LogicalResult verifySizeOrIndexOp(Operation *op) { 41 assert(op != nullptr && op->getNumResults() == 1); 42 Type resultTy = op->getResultTypes().front(); 43 if (isErrorPropagationPossible(op->getOperandTypes())) { 44 if (!resultTy.isa<SizeType>()) 45 return op->emitOpError() 46 << "if at least one of the operands can hold error values then " 47 "the result must be of type `size` to propagate them"; 48 } 49 return success(); 50 } 51 52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 53 assert(op != nullptr && op->getNumResults() == 1); 54 Type resultTy = op->getResultTypes().front(); 55 if (isErrorPropagationPossible(op->getOperandTypes())) { 56 if (!resultTy.isa<ShapeType>()) 57 return op->emitOpError() 58 << "if at least one of the operands can hold error values then " 59 "the result must be of type `shape` to propagate them"; 60 } 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // InlinerInterface 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// This class defines the interface for inlining shape dialect ops. 70 struct ShapeInlinerInterface : public DialectInlinerInterface { 71 using DialectInlinerInterface::DialectInlinerInterface; 72 73 // Returns true if the given region 'src' can be inlined into the region 74 // 'dest' that is attached to an operation registered to the current dialect. 75 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 76 BlockAndValueMapping &) const final { 77 return true; 78 } 79 80 // Returns true if the given operation 'op', that is registered to this 81 // dialect, can be inlined into the region 'dest' that is attached to an 82 // operation registered to the current dialect. 83 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 84 BlockAndValueMapping &) const final { 85 return true; 86 } 87 }; 88 } // namespace 89 90 void ShapeDialect::initialize() { 91 addOperations< 92 #define GET_OP_LIST 93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 94 >(); 95 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 96 addInterfaces<ShapeInlinerInterface>(); 97 // Allow unknown operations during prototyping and testing. As the dialect is 98 // still evolving it makes it simple to start with an unregistered ops and 99 // try different variants before actually defining the op. 100 allowUnknownOperations(); 101 } 102 103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 104 Attribute value, Type type, 105 Location loc) { 106 if (type.isa<ShapeType>() || 107 type == getExtentTensorType(builder.getContext())) 108 return builder.create<ConstShapeOp>(loc, type, 109 value.cast<DenseIntElementsAttr>()); 110 if (type.isa<SizeType>()) 111 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 112 if (type.isa<WitnessType>()) 113 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 114 if (ConstantOp::isBuildableWith(value, type)) 115 return builder.create<ConstantOp>(loc, type, value); 116 return nullptr; 117 } 118 119 /// Parse a type registered to this dialect. 120 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 121 StringRef keyword; 122 if (parser.parseKeyword(&keyword)) 123 return Type(); 124 125 if (keyword == "shape") 126 return ShapeType::get(getContext()); 127 if (keyword == "size") 128 return SizeType::get(getContext()); 129 if (keyword == "value_shape") 130 return ValueShapeType::get(getContext()); 131 if (keyword == "witness") 132 return WitnessType::get(getContext()); 133 134 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 135 return Type(); 136 } 137 138 /// Print a type registered to this dialect. 139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 140 TypeSwitch<Type>(type) 141 .Case<ShapeType>([&](Type) { os << "shape"; }) 142 .Case<SizeType>([&](Type) { os << "size"; }) 143 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 144 .Case<WitnessType>([&](Type) { os << "witness"; }) 145 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // AnyOp 150 //===----------------------------------------------------------------------===// 151 152 // TODO: Canonicalization should be implemented for shapes that can be 153 // determined through mixtures of the known dimensions of the inputs. 154 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 155 // Only the last operand is checked because AnyOp is commutative. 156 if (operands.back()) 157 return operands.back(); 158 159 return nullptr; 160 } 161 162 //===----------------------------------------------------------------------===// 163 // AssumingOp 164 //===----------------------------------------------------------------------===// 165 166 static ParseResult parseAssumingOp(OpAsmParser &parser, 167 OperationState &result) { 168 result.regions.reserve(1); 169 Region *doRegion = result.addRegion(); 170 171 auto &builder = parser.getBuilder(); 172 OpAsmParser::OperandType cond; 173 if (parser.parseOperand(cond) || 174 parser.resolveOperand(cond, builder.getType<WitnessType>(), 175 result.operands)) 176 return failure(); 177 178 // Parse optional results type list. 179 if (parser.parseOptionalArrowTypeList(result.types)) 180 return failure(); 181 182 // Parse the region and add a terminator if elided. 183 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 184 return failure(); 185 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 186 187 // Parse the optional attribute list. 188 if (parser.parseOptionalAttrDict(result.attributes)) 189 return failure(); 190 return success(); 191 } 192 193 static void print(OpAsmPrinter &p, AssumingOp op) { 194 bool yieldsResults = !op.results().empty(); 195 196 p << AssumingOp::getOperationName() << " " << op.witness(); 197 if (yieldsResults) { 198 p << " -> (" << op.getResultTypes() << ")"; 199 } 200 p.printRegion(op.doRegion(), 201 /*printEntryBlockArgs=*/false, 202 /*printBlockTerminators=*/yieldsResults); 203 p.printOptionalAttrDict(op.getAttrs()); 204 } 205 206 namespace { 207 // Removes AssumingOp with a passing witness and inlines the region. 208 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 209 using OpRewritePattern<AssumingOp>::OpRewritePattern; 210 211 LogicalResult matchAndRewrite(AssumingOp op, 212 PatternRewriter &rewriter) const override { 213 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 214 if (!witness || !witness.passingAttr()) 215 return failure(); 216 217 AssumingOp::inlineRegionIntoParent(op, rewriter); 218 return success(); 219 } 220 }; 221 } // namespace 222 223 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 224 MLIRContext *context) { 225 // If taking a passing witness, inline region. 226 patterns.insert<AssumingWithTrue>(context); 227 } 228 229 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 230 void AssumingOp::getSuccessorRegions( 231 Optional<unsigned> index, ArrayRef<Attribute> operands, 232 SmallVectorImpl<RegionSuccessor> ®ions) { 233 // AssumingOp has unconditional control flow into the region and back to the 234 // parent, so return the correct RegionSuccessor purely based on the index 235 // being None or 0. 236 if (index.hasValue()) { 237 regions.push_back(RegionSuccessor(getResults())); 238 return; 239 } 240 241 regions.push_back(RegionSuccessor(&doRegion())); 242 } 243 244 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 245 PatternRewriter &rewriter) { 246 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 247 auto *assumingBlock = op.getBody(); 248 auto initPosition = rewriter.getInsertionPoint(); 249 auto *blockAfterAssuming = 250 rewriter.splitBlock(blockBeforeAssuming, initPosition); 251 252 // Remove the AssumingOp and AssumingYieldOp. 253 auto &yieldOp = assumingBlock->back(); 254 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 255 rewriter.replaceOp(op, yieldOp.getOperands()); 256 rewriter.eraseOp(&yieldOp); 257 258 // Merge blocks together as there was no branching behavior from the 259 // AssumingOp. 260 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 261 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // AssumingAllOp 266 //===----------------------------------------------------------------------===// 267 268 void AssumingAllOp::getCanonicalizationPatterns( 269 OwningRewritePatternList &patterns, MLIRContext *context) { 270 patterns.insert<AssumingAllOneOp>(context); 271 } 272 273 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 274 // Iterate in reverse to first handle all constant operands. They are 275 // guaranteed to be the tail of the inputs because this is commutative. 276 for (int idx = operands.size() - 1; idx >= 0; idx--) { 277 Attribute a = operands[idx]; 278 // Cannot fold if any inputs are not constant; 279 if (!a) 280 return nullptr; 281 282 // We do not need to keep statically known values after handling them in 283 // this method. 284 getOperation()->eraseOperand(idx); 285 286 // Always false if any input is statically known false 287 if (!a.cast<BoolAttr>().getValue()) 288 return a; 289 } 290 // If this is reached, all inputs were statically known passing. 291 return BoolAttr::get(true, getContext()); 292 } 293 294 static LogicalResult verify(AssumingAllOp op) { 295 // Ensure that AssumingAllOp contains at least one operand 296 if (op.getNumOperands() == 0) 297 return op.emitOpError("no operands specified"); 298 299 return success(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // BroadcastOp 304 //===----------------------------------------------------------------------===// 305 306 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 307 if (!operands[1]) 308 return nullptr; 309 310 auto rhsShape = llvm::to_vector<6>( 311 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 312 if (rhsShape.empty()) 313 return lhs(); 314 315 if (!operands[0]) 316 return nullptr; 317 318 auto lhsShape = llvm::to_vector<6>( 319 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 320 if (lhsShape.empty()) 321 return rhs(); 322 323 SmallVector<int64_t, 6> resultShape; 324 // If the shapes are not compatible, we can't fold it. 325 // TODO: Fold to an "error". 326 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 327 return nullptr; 328 Builder builder(getContext()); 329 return builder.getIndexTensorAttr(resultShape); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // ConcatOp 334 //===----------------------------------------------------------------------===// 335 336 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 337 if (!operands[0] || !operands[1]) 338 return nullptr; 339 auto lhsShape = llvm::to_vector<6>( 340 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 341 auto rhsShape = llvm::to_vector<6>( 342 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 343 SmallVector<int64_t, 6> resultShape; 344 resultShape.append(lhsShape.begin(), lhsShape.end()); 345 resultShape.append(rhsShape.begin(), rhsShape.end()); 346 Builder builder(getContext()); 347 return builder.getIndexTensorAttr(resultShape); 348 } 349 350 //===----------------------------------------------------------------------===// 351 // ConstShapeOp 352 //===----------------------------------------------------------------------===// 353 354 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 355 p << "shape.const_shape "; 356 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); 357 p << "["; 358 interleaveComma(op.shape().getValues<int64_t>(), p, 359 [&](int64_t i) { p << i; }); 360 p << "] : "; 361 p.printType(op.getType()); 362 } 363 364 static ParseResult parseConstShapeOp(OpAsmParser &parser, 365 OperationState &result) { 366 if (parser.parseOptionalAttrDict(result.attributes)) 367 return failure(); 368 // We piggy-back on ArrayAttr parsing, though we don't internally store the 369 // shape as an ArrayAttr. 370 // TODO: Implement custom parser and maybe make syntax a bit more concise. 371 Attribute extentsRaw; 372 NamedAttrList dummy; 373 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 374 return failure(); 375 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 376 if (!extentsArray) 377 return failure(); 378 SmallVector<int64_t, 6> ints; 379 for (Attribute extent : extentsArray) { 380 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 381 if (!attr) 382 return failure(); 383 ints.push_back(attr.getInt()); 384 } 385 Builder &builder = parser.getBuilder(); 386 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 387 Type resultTy; 388 if (parser.parseColonType(resultTy)) 389 return failure(); 390 result.types.push_back(resultTy); 391 return success(); 392 } 393 394 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 395 396 void ConstShapeOp::getCanonicalizationPatterns( 397 OwningRewritePatternList &patterns, MLIRContext *context) { 398 patterns.insert<TensorCastConstShape>(context); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // CstrBroadcastableOp 403 //===----------------------------------------------------------------------===// 404 405 namespace { 406 // Given an input shape Value, try to obtain the shape's values. 407 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) { 408 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 409 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 410 if (!type.hasRank()) 411 return failure(); 412 shapeValues = llvm::to_vector<6>(type.getShape()); 413 return success(); 414 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 415 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 416 return success(); 417 } else { 418 return failure(); 419 } 420 } 421 } // namespace 422 423 void CstrBroadcastableOp::getCanonicalizationPatterns( 424 OwningRewritePatternList &patterns, MLIRContext *context) { 425 // Canonicalization patterns have overlap with the considerations during 426 // folding in case additional shape information is inferred at some point that 427 // does not result in folding. 428 patterns.insert<CstrBroadcastableEqOps>(context); 429 } 430 431 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 432 // Both operands are not needed if one is a scalar. 433 if (operands[0] && 434 operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0) 435 return BoolAttr::get(true, getContext()); 436 if (operands[1] && 437 operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0) 438 return BoolAttr::get(true, getContext()); 439 440 if (operands[0] && operands[1]) { 441 auto lhsShape = llvm::to_vector<6>( 442 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 443 auto rhsShape = llvm::to_vector<6>( 444 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 445 SmallVector<int64_t, 6> resultShape; 446 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 447 return BoolAttr::get(true, getContext()); 448 } 449 450 // Lastly, see if folding can be completed based on what constraints are known 451 // on the input shapes. 452 SmallVector<int64_t, 6> lhsShape, rhsShape; 453 if (failed(getShapeVec(lhs(), lhsShape))) 454 return nullptr; 455 if (failed(getShapeVec(rhs(), rhsShape))) 456 return nullptr; 457 458 if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) 459 return BoolAttr::get(true, getContext()); 460 461 // Because a failing witness result here represents an eventual assertion 462 // failure, we do not replace it with a constant witness. 463 return nullptr; 464 } 465 466 //===----------------------------------------------------------------------===// 467 // CstrEqOp 468 //===----------------------------------------------------------------------===// 469 470 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 471 MLIRContext *context) { 472 // If inputs are equal, return passing witness 473 patterns.insert<CstrEqEqOps>(context); 474 } 475 476 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 477 if (llvm::all_of(operands, 478 [&](Attribute a) { return a && a == operands[0]; })) 479 return BoolAttr::get(true, getContext()); 480 481 // Because a failing witness result here represents an eventual assertion 482 // failure, we do not try to replace it with a constant witness. Similarly, we 483 // cannot if there are any non-const inputs. 484 return nullptr; 485 } 486 487 //===----------------------------------------------------------------------===// 488 // ConstSizeOp 489 //===----------------------------------------------------------------------===// 490 491 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 492 int64_t value) { 493 build(builder, result, builder.getIndexAttr(value)); 494 } 495 496 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 497 498 void ConstSizeOp::getAsmResultNames( 499 llvm::function_ref<void(Value, StringRef)> setNameFn) { 500 SmallString<4> buffer; 501 llvm::raw_svector_ostream os(buffer); 502 os << "c" << value(); 503 setNameFn(getResult(), os.str()); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // ConstWitnessOp 508 //===----------------------------------------------------------------------===// 509 510 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 511 512 //===----------------------------------------------------------------------===// 513 // CstrRequireOp 514 //===----------------------------------------------------------------------===// 515 516 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 517 return operands[0]; 518 } 519 520 //===----------------------------------------------------------------------===// 521 // ShapeEqOp 522 //===----------------------------------------------------------------------===// 523 524 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 525 auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 526 if (lhs == nullptr) 527 return {}; 528 auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>(); 529 if (rhs == nullptr) 530 return {}; 531 return BoolAttr::get(lhs == rhs, getContext()); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // IndexToSizeOp 536 //===----------------------------------------------------------------------===// 537 538 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 539 // Constant values of both types, `shape.size` and `index`, are represented as 540 // `IntegerAttr`s which makes constant folding simple. 541 if (Attribute arg = operands[0]) 542 return arg; 543 return {}; 544 } 545 546 void IndexToSizeOp::getCanonicalizationPatterns( 547 OwningRewritePatternList &patterns, MLIRContext *context) { 548 patterns.insert<SizeToIndexToSizeCanonicalization>(context); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // FromExtentsOp 553 //===----------------------------------------------------------------------===// 554 555 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 556 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 557 return nullptr; 558 SmallVector<int64_t, 6> extents; 559 for (auto attr : operands) 560 extents.push_back(attr.cast<IntegerAttr>().getInt()); 561 Builder builder(getContext()); 562 return builder.getIndexTensorAttr(extents); 563 } 564 565 //===----------------------------------------------------------------------===// 566 // FunctionLibraryOp 567 //===----------------------------------------------------------------------===// 568 569 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 570 StringRef name) { 571 ensureTerminator(*result.addRegion(), builder, result.location); 572 result.attributes.push_back(builder.getNamedAttr( 573 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 574 } 575 576 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 577 auto attr = mapping() 578 .get(op->getName().getIdentifier()) 579 .dyn_cast_or_null<FlatSymbolRefAttr>(); 580 if (!attr) 581 return nullptr; 582 return lookupSymbol<FuncOp>(attr); 583 } 584 585 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 586 OperationState &result) { 587 // Parse the op name. 588 StringAttr nameAttr; 589 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 590 result.attributes)) 591 return failure(); 592 593 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 594 return failure(); 595 596 auto *bodyRegion = result.addRegion(); 597 if (parser.parseRegion(*bodyRegion)) 598 return failure(); 599 600 FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 601 result.location); 602 if (parser.parseKeyword("mapping")) 603 return failure(); 604 605 DictionaryAttr mappingAttr; 606 if (parser.parseAttribute(mappingAttr, 607 parser.getBuilder().getType<NoneType>(), "mapping", 608 result.attributes)) 609 return failure(); 610 return success(); 611 } 612 613 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 614 p << op.getOperationName() << ' '; 615 p.printSymbolName(op.getName()); 616 p.printOptionalAttrDictWithKeyword( 617 op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 618 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 619 /*printBlockTerminators=*/false); 620 p << " mapping "; 621 p.printAttributeWithoutType(op.mappingAttr()); 622 } 623 624 //===----------------------------------------------------------------------===// 625 // GetExtentOp 626 //===----------------------------------------------------------------------===// 627 628 Optional<int64_t> GetExtentOp::getConstantDim() { 629 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 630 return constSizeOp.value().getLimitedValue(); 631 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 632 return constantOp.value().cast<IntegerAttr>().getInt(); 633 return llvm::None; 634 } 635 636 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 637 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 638 if (!elements) 639 return nullptr; 640 Optional<int64_t> dim = getConstantDim(); 641 if (!dim.hasValue()) 642 return nullptr; 643 if (dim.getValue() >= elements.getNumElements()) 644 return nullptr; 645 return elements.getValue({(uint64_t)dim.getValue()}); 646 } 647 648 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 649 int64_t dim) { 650 auto loc = result.location; 651 auto dimAttr = builder.getIndexAttr(dim); 652 if (shape.getType().isa<ShapeType>()) { 653 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 654 build(builder, result, builder.getType<SizeType>(), shape, dim); 655 } else { 656 Value dim = 657 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 658 build(builder, result, builder.getIndexType(), shape, dim); 659 } 660 } 661 662 //===----------------------------------------------------------------------===// 663 // RankOp 664 //===----------------------------------------------------------------------===// 665 666 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 667 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 668 if (!shape) 669 return {}; 670 int64_t rank = shape.getNumElements(); 671 Builder builder(getContext()); 672 return builder.getIndexAttr(rank); 673 } 674 675 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 676 /// Constant folding fails in cases where only the rank is constant, not the 677 /// shape itself. 678 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 679 /// 680 /// Example: 681 /// 682 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 683 /// %rank = shape.rank %shape 684 /// 685 /// becomes 686 /// 687 /// %rank = shape.const_size 3 688 689 namespace { 690 struct RankShapeOfCanonicalizationPattern 691 : public OpRewritePattern<shape::RankOp> { 692 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 693 694 LogicalResult matchAndRewrite(shape::RankOp op, 695 PatternRewriter &rewriter) const override { 696 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 697 if (!shapeOfOp) 698 return failure(); 699 auto rankedTensorType = 700 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 701 if (!rankedTensorType) 702 return failure(); 703 int64_t rank = rankedTensorType.getRank(); 704 if (op.getType().isa<IndexType>()) { 705 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 706 } else if (op.getType().isa<shape::SizeType>()) { 707 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 708 } else { 709 return failure(); 710 } 711 return success(); 712 } 713 }; 714 } // namespace 715 716 void shape::RankOp::getCanonicalizationPatterns( 717 OwningRewritePatternList &patterns, MLIRContext *context) { 718 patterns.insert<RankShapeOfCanonicalizationPattern>(context); 719 } 720 721 //===----------------------------------------------------------------------===// 722 // NumElementsOp 723 //===----------------------------------------------------------------------===// 724 725 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 726 727 // Fold only when argument constant. 728 Attribute shape = operands[0]; 729 if (!shape) 730 return {}; 731 732 APInt product(64, 1); 733 for (auto value : shape.cast<DenseIntElementsAttr>()) 734 product *= value; 735 Builder builder(getContext()); 736 return builder.getIndexAttr(product.getLimitedValue()); 737 } 738 739 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 740 Value shape) { 741 if (shape.getType().isa<ShapedType>()) { 742 auto type = builder.getIndexType(); 743 return build(builder, result, type, shape); 744 } 745 auto type = SizeType::get(builder.getContext()); 746 return build(builder, result, type, shape); 747 } 748 749 //===----------------------------------------------------------------------===// 750 // MulOp 751 //===----------------------------------------------------------------------===// 752 753 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 754 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 755 if (!lhs) 756 return nullptr; 757 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 758 if (!rhs) 759 return nullptr; 760 APInt folded = lhs.getValue() * rhs.getValue(); 761 Type indexTy = IndexType::get(getContext()); 762 return IntegerAttr::get(indexTy, folded); 763 } 764 765 //===----------------------------------------------------------------------===// 766 // ShapeOfOp 767 //===----------------------------------------------------------------------===// 768 769 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 770 auto type = getOperand().getType().dyn_cast<ShapedType>(); 771 if (!type || !type.hasStaticShape()) 772 return nullptr; 773 Builder builder(getContext()); 774 return builder.getIndexTensorAttr(type.getShape()); 775 } 776 777 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 778 Type type = arg.getType().isa<ShapedType>() 779 ? (Type)getExtentTensorType(builder.getContext()) 780 : (Type)builder.getType<ShapeType>(); 781 return ShapeOfOp::build(builder, result, type, arg); 782 } 783 784 namespace { 785 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 786 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 787 788 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 789 PatternRewriter &rewriter) const override { 790 if (!op.arg().getType().isa<ShapedType>()) 791 return failure(); 792 if (op.getType().isa<ShapedType>()) 793 return failure(); 794 795 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 796 return success(); 797 } 798 }; 799 } // namespace 800 801 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, 802 MLIRContext *context) { 803 patterns.insert<ShapeOfWithTensor>(context); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // SizeToIndexOp 808 //===----------------------------------------------------------------------===// 809 810 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 811 // Constant values of both types, `shape.size` and `index`, are represented as 812 // `IntegerAttr`s which makes constant folding simple. 813 if (Attribute arg = operands[0]) 814 return arg; 815 return impl::foldCastOp(*this); 816 } 817 818 void SizeToIndexOp::getCanonicalizationPatterns( 819 OwningRewritePatternList &patterns, MLIRContext *context) { 820 patterns.insert<IndexToSizeToIndexCanonicalization>(context); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // YieldOp 825 //===----------------------------------------------------------------------===// 826 827 static LogicalResult verify(shape::YieldOp op) { 828 auto *parentOp = op->getParentOp(); 829 auto results = parentOp->getResults(); 830 auto operands = op.getOperands(); 831 832 if (parentOp->getNumResults() != op.getNumOperands()) 833 return op.emitOpError() << "number of operands does not match number of " 834 "results of its parent"; 835 for (auto e : llvm::zip(results, operands)) 836 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 837 return op.emitOpError() 838 << "types mismatch between yield op and its parent"; 839 840 return success(); 841 } 842 843 //===----------------------------------------------------------------------===// 844 // SplitAtOp 845 //===----------------------------------------------------------------------===// 846 847 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 848 SmallVectorImpl<OpFoldResult> &results) { 849 if (!operands[0] || !operands[1]) 850 return failure(); 851 auto shapeVec = llvm::to_vector<6>( 852 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 853 auto shape = llvm::makeArrayRef(shapeVec); 854 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 855 // Verify that the split point is in the correct range. 856 // TODO: Constant fold to an "error". 857 int64_t rank = shape.size(); 858 if (!(-rank <= splitPoint && splitPoint <= rank)) 859 return failure(); 860 if (splitPoint < 0) 861 splitPoint += shape.size(); 862 Builder builder(operands[0].getContext()); 863 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 864 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 865 return success(); 866 } 867 868 //===----------------------------------------------------------------------===// 869 // ToExtentTensorOp 870 //===----------------------------------------------------------------------===// 871 872 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 873 if (!operands[0]) 874 return impl::foldCastOp(*this); 875 Builder builder(getContext()); 876 auto shape = llvm::to_vector<6>( 877 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 878 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 879 builder.getIndexType()); 880 return DenseIntElementsAttr::get(type, shape); 881 } 882 883 //===----------------------------------------------------------------------===// 884 // ReduceOp 885 //===----------------------------------------------------------------------===// 886 887 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 888 ValueRange initVals) { 889 result.addOperands(shape); 890 result.addOperands(initVals); 891 892 Region *bodyRegion = result.addRegion(); 893 bodyRegion->push_back(new Block); 894 Block &bodyBlock = bodyRegion->front(); 895 bodyBlock.addArgument(builder.getIndexType()); 896 897 Type elementType; 898 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 899 elementType = tensorType.getElementType(); 900 else 901 elementType = SizeType::get(builder.getContext()); 902 bodyBlock.addArgument(elementType); 903 904 for (Type initValType : initVals.getTypes()) { 905 bodyBlock.addArgument(initValType); 906 result.addTypes(initValType); 907 } 908 } 909 910 static LogicalResult verify(ReduceOp op) { 911 // Verify block arg types. 912 Block &block = op.region().front(); 913 914 // The block takes index, extent, and aggregated values as arguments. 915 auto blockArgsCount = op.initVals().size() + 2; 916 if (block.getNumArguments() != blockArgsCount) 917 return op.emitOpError() << "ReduceOp body is expected to have " 918 << blockArgsCount << " arguments"; 919 920 // The first block argument is the index and must always be of type `index`. 921 if (!block.getArgument(0).getType().isa<IndexType>()) 922 return op.emitOpError( 923 "argument 0 of ReduceOp body is expected to be of IndexType"); 924 925 // The second block argument is the extent and must be of type `size` or 926 // `index`, depending on whether the reduce operation is applied to a shape or 927 // to an extent tensor. 928 Type extentTy = block.getArgument(1).getType(); 929 if (op.shape().getType().isa<ShapeType>()) { 930 if (!extentTy.isa<SizeType>()) 931 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 932 "SizeType if the ReduceOp operates on a ShapeType"); 933 } else { 934 if (!extentTy.isa<IndexType>()) 935 return op.emitOpError( 936 "argument 1 of ReduceOp body is expected to be of IndexType if the " 937 "ReduceOp operates on an extent tensor"); 938 } 939 940 for (auto type : llvm::enumerate(op.initVals())) 941 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 942 return op.emitOpError() 943 << "type mismatch between argument " << type.index() + 2 944 << " of ReduceOp body and initial value " << type.index(); 945 return success(); 946 } 947 948 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 949 // Parse operands. 950 SmallVector<OpAsmParser::OperandType, 3> operands; 951 Type shapeOrExtentTensorType; 952 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 953 OpAsmParser::Delimiter::Paren) || 954 parser.parseColonType(shapeOrExtentTensorType) || 955 parser.parseOptionalArrowTypeList(result.types)) 956 return failure(); 957 958 // Resolve operands. 959 auto initVals = llvm::makeArrayRef(operands).drop_front(); 960 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 961 result.operands) || 962 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 963 result.operands)) 964 return failure(); 965 966 // Parse the body. 967 Region *body = result.addRegion(); 968 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 969 return failure(); 970 971 // Parse attributes. 972 if (parser.parseOptionalAttrDict(result.attributes)) 973 return failure(); 974 975 return success(); 976 } 977 978 static void print(OpAsmPrinter &p, ReduceOp op) { 979 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 980 << ") : " << op.shape().getType(); 981 p.printOptionalArrowTypeList(op.getResultTypes()); 982 p.printRegion(op.region()); 983 p.printOptionalAttrDict(op.getAttrs()); 984 } 985 986 #define GET_OP_CLASSES 987 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 988