1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Transforms/InliningUtils.h" 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 using namespace mlir::shape; 25 26 namespace { 27 #include "ShapeCanonicalization.inc" 28 } 29 30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 31 return RankedTensorType::get({rank}, IndexType::get(ctx)); 32 } 33 34 bool shape::isExtentTensorType(Type type) { 35 auto ranked = type.dyn_cast<RankedTensorType>(); 36 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 37 } 38 39 LogicalResult shape::getShapeVec(Value input, 40 SmallVectorImpl<int64_t> &shapeValues) { 41 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 42 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 43 if (!type.hasRank()) 44 return failure(); 45 shapeValues = llvm::to_vector<6>(type.getShape()); 46 return success(); 47 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 48 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 49 return success(); 50 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) { 51 shapeValues = llvm::to_vector<6>( 52 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 53 return success(); 54 } else { 55 return failure(); 56 } 57 } 58 59 static bool isErrorPropagationPossible(TypeRange operandTypes) { 60 return llvm::any_of(operandTypes, [](Type ty) { 61 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 62 }); 63 } 64 65 static LogicalResult verifySizeOrIndexOp(Operation *op) { 66 assert(op != nullptr && op->getNumResults() == 1); 67 Type resultTy = op->getResultTypes().front(); 68 if (isErrorPropagationPossible(op->getOperandTypes())) { 69 if (!resultTy.isa<SizeType>()) 70 return op->emitOpError() 71 << "if at least one of the operands can hold error values then " 72 "the result must be of type `size` to propagate them"; 73 } 74 return success(); 75 } 76 77 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 78 assert(op != nullptr && op->getNumResults() == 1); 79 Type resultTy = op->getResultTypes().front(); 80 if (isErrorPropagationPossible(op->getOperandTypes())) { 81 if (!resultTy.isa<ShapeType>()) 82 return op->emitOpError() 83 << "if at least one of the operands can hold error values then " 84 "the result must be of type `shape` to propagate them"; 85 } 86 return success(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // InlinerInterface 91 //===----------------------------------------------------------------------===// 92 93 namespace { 94 /// This class defines the interface for inlining shape dialect ops. 95 struct ShapeInlinerInterface : public DialectInlinerInterface { 96 using DialectInlinerInterface::DialectInlinerInterface; 97 98 // Returns true if the given region 'src' can be inlined into the region 99 // 'dest' that is attached to an operation registered to the current dialect. 100 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 101 BlockAndValueMapping &) const final { 102 return true; 103 } 104 105 // Returns true if the given operation 'op', that is registered to this 106 // dialect, can be inlined into the region 'dest' that is attached to an 107 // operation registered to the current dialect. 108 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 109 BlockAndValueMapping &) const final { 110 return true; 111 } 112 }; 113 } // namespace 114 115 void ShapeDialect::initialize() { 116 addOperations< 117 #define GET_OP_LIST 118 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 119 >(); 120 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 121 addInterfaces<ShapeInlinerInterface>(); 122 // Allow unknown operations during prototyping and testing. As the dialect is 123 // still evolving it makes it simple to start with an unregistered ops and 124 // try different variants before actually defining the op. 125 allowUnknownOperations(); 126 } 127 128 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 129 Attribute value, Type type, 130 Location loc) { 131 if (type.isa<ShapeType>() || isExtentTensorType(type)) 132 return builder.create<ConstShapeOp>(loc, type, 133 value.cast<DenseIntElementsAttr>()); 134 if (type.isa<SizeType>()) 135 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 136 if (type.isa<WitnessType>()) 137 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 138 if (ConstantOp::isBuildableWith(value, type)) 139 return builder.create<ConstantOp>(loc, type, value); 140 return nullptr; 141 } 142 143 /// Parse a type registered to this dialect. 144 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 145 StringRef keyword; 146 if (parser.parseKeyword(&keyword)) 147 return Type(); 148 149 if (keyword == "shape") 150 return ShapeType::get(getContext()); 151 if (keyword == "size") 152 return SizeType::get(getContext()); 153 if (keyword == "value_shape") 154 return ValueShapeType::get(getContext()); 155 if (keyword == "witness") 156 return WitnessType::get(getContext()); 157 158 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 159 return Type(); 160 } 161 162 /// Print a type registered to this dialect. 163 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 164 TypeSwitch<Type>(type) 165 .Case<ShapeType>([&](Type) { os << "shape"; }) 166 .Case<SizeType>([&](Type) { os << "size"; }) 167 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 168 .Case<WitnessType>([&](Type) { os << "witness"; }) 169 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 170 } 171 172 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 173 NamedAttribute attribute) { 174 // Verify shape.lib attribute. 175 if (attribute.first == "shape.lib") { 176 if (!op->hasTrait<OpTrait::SymbolTable>()) 177 return op->emitError( 178 "shape.lib attribute may only be on op implementing SymbolTable"); 179 180 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 181 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 182 if (!symbol) 183 return op->emitError("shape function library ") 184 << symbolRef << " not found"; 185 return isa<shape::FunctionLibraryOp>(symbol) 186 ? success() 187 : op->emitError() 188 << symbolRef << " required to be shape function library"; 189 } 190 191 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 192 // Verify all entries are function libraries and mappings in libraries 193 // refer to unique ops. 194 DenseSet<Identifier> key; 195 for (auto it : arr) { 196 if (!it.isa<SymbolRefAttr>()) 197 return op->emitError( 198 "only SymbolRefAttr allowed in shape.lib attribute array"); 199 200 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 201 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 202 if (!shapeFnLib) 203 return op->emitError() 204 << it << " does not refer to FunctionLibraryOp"; 205 for (auto mapping : shapeFnLib.mapping()) { 206 if (!key.insert(mapping.first).second) { 207 return op->emitError("only one op to shape mapping allowed, found " 208 "multiple for `") 209 << mapping.first << "`"; 210 } 211 } 212 } 213 return success(); 214 } 215 216 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 217 "allowed as shape.lib attribute"); 218 } 219 return success(); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // AnyOp 224 //===----------------------------------------------------------------------===// 225 226 // TODO: Canonicalization should be implemented for shapes that can be 227 // determined through mixtures of the known dimensions of the inputs. 228 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 229 // Only the last operand is checked because AnyOp is commutative. 230 if (operands.back()) 231 return operands.back(); 232 233 return nullptr; 234 } 235 236 //===----------------------------------------------------------------------===// 237 // AssumingOp 238 //===----------------------------------------------------------------------===// 239 240 static ParseResult parseAssumingOp(OpAsmParser &parser, 241 OperationState &result) { 242 result.regions.reserve(1); 243 Region *doRegion = result.addRegion(); 244 245 auto &builder = parser.getBuilder(); 246 OpAsmParser::OperandType cond; 247 if (parser.parseOperand(cond) || 248 parser.resolveOperand(cond, builder.getType<WitnessType>(), 249 result.operands)) 250 return failure(); 251 252 // Parse optional results type list. 253 if (parser.parseOptionalArrowTypeList(result.types)) 254 return failure(); 255 256 // Parse the region and add a terminator if elided. 257 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 258 return failure(); 259 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 260 261 // Parse the optional attribute list. 262 if (parser.parseOptionalAttrDict(result.attributes)) 263 return failure(); 264 return success(); 265 } 266 267 static void print(OpAsmPrinter &p, AssumingOp op) { 268 bool yieldsResults = !op.results().empty(); 269 270 p << AssumingOp::getOperationName() << " " << op.witness(); 271 if (yieldsResults) { 272 p << " -> (" << op.getResultTypes() << ")"; 273 } 274 p.printRegion(op.doRegion(), 275 /*printEntryBlockArgs=*/false, 276 /*printBlockTerminators=*/yieldsResults); 277 p.printOptionalAttrDict(op->getAttrs()); 278 } 279 280 namespace { 281 // Removes AssumingOp with a passing witness and inlines the region. 282 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 283 using OpRewritePattern<AssumingOp>::OpRewritePattern; 284 285 LogicalResult matchAndRewrite(AssumingOp op, 286 PatternRewriter &rewriter) const override { 287 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 288 if (!witness || !witness.passingAttr()) 289 return failure(); 290 291 AssumingOp::inlineRegionIntoParent(op, rewriter); 292 return success(); 293 } 294 }; 295 296 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 297 using OpRewritePattern<AssumingOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(AssumingOp op, 300 PatternRewriter &rewriter) const override { 301 Block *body = op.getBody(); 302 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 303 304 // Find used values. 305 SmallVector<Value, 4> newYieldOperands; 306 Value opResult, yieldOperand; 307 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 308 std::tie(opResult, yieldOperand) = it; 309 if (!opResult.getUses().empty()) { 310 newYieldOperands.push_back(yieldOperand); 311 } 312 } 313 314 // Rewrite only if redundant results exist. 315 if (newYieldOperands.size() == yieldOp->getNumOperands()) 316 return failure(); 317 318 // Replace yield op in the old assuming op's body and move the entire region 319 // to the new assuming op. 320 rewriter.setInsertionPointToEnd(body); 321 auto newYieldOp = 322 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 323 rewriter.setInsertionPoint(op); 324 auto newOp = rewriter.create<AssumingOp>( 325 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 326 newOp.doRegion().takeBody(op.doRegion()); 327 328 // Use the new results to replace the previously used ones. 329 SmallVector<Value, 4> replacementValues; 330 auto src = newOp.getResults().begin(); 331 for (auto it : op.getResults()) { 332 if (it.getUses().empty()) 333 replacementValues.push_back(nullptr); 334 else 335 replacementValues.push_back(*src++); 336 } 337 rewriter.replaceOp(op, replacementValues); 338 return success(); 339 } 340 }; 341 } // namespace 342 343 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 344 MLIRContext *context) { 345 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 346 } 347 348 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 349 void AssumingOp::getSuccessorRegions( 350 Optional<unsigned> index, ArrayRef<Attribute> operands, 351 SmallVectorImpl<RegionSuccessor> ®ions) { 352 // AssumingOp has unconditional control flow into the region and back to the 353 // parent, so return the correct RegionSuccessor purely based on the index 354 // being None or 0. 355 if (index.hasValue()) { 356 regions.push_back(RegionSuccessor(getResults())); 357 return; 358 } 359 360 regions.push_back(RegionSuccessor(&doRegion())); 361 } 362 363 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 364 PatternRewriter &rewriter) { 365 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 366 auto *assumingBlock = op.getBody(); 367 auto initPosition = rewriter.getInsertionPoint(); 368 auto *blockAfterAssuming = 369 rewriter.splitBlock(blockBeforeAssuming, initPosition); 370 371 // Remove the AssumingOp and AssumingYieldOp. 372 auto &yieldOp = assumingBlock->back(); 373 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 374 rewriter.replaceOp(op, yieldOp.getOperands()); 375 rewriter.eraseOp(&yieldOp); 376 377 // Merge blocks together as there was no branching behavior from the 378 // AssumingOp. 379 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 380 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 381 } 382 383 void AssumingOp::build( 384 OpBuilder &builder, OperationState &result, Value witness, 385 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 386 387 result.addOperands(witness); 388 Region *bodyRegion = result.addRegion(); 389 bodyRegion->push_back(new Block); 390 Block &bodyBlock = bodyRegion->front(); 391 392 // Build body. 393 OpBuilder::InsertionGuard guard(builder); 394 builder.setInsertionPointToStart(&bodyBlock); 395 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 396 builder.create<AssumingYieldOp>(result.location, yieldValues); 397 398 SmallVector<Type, 2> assumingTypes; 399 for (Value v : yieldValues) 400 assumingTypes.push_back(v.getType()); 401 result.addTypes(assumingTypes); 402 } 403 404 //===----------------------------------------------------------------------===// 405 // AssumingAllOp 406 //===----------------------------------------------------------------------===// 407 408 namespace { 409 struct AssumingAllToCstrEqCanonicalization 410 : public OpRewritePattern<AssumingAllOp> { 411 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 412 413 LogicalResult matchAndRewrite(AssumingAllOp op, 414 PatternRewriter &rewriter) const override { 415 SmallVector<Value, 8> shapes; 416 for (Value w : op.inputs()) { 417 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 418 if (!cstrEqOp) 419 return failure(); 420 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 421 return llvm::is_contained(shapes, s); 422 }); 423 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 424 return failure(); 425 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 426 } 427 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 428 return success(); 429 } 430 }; 431 } // namespace 432 433 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 434 MLIRContext *context) { 435 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context); 436 } 437 438 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 439 // Iterate in reverse to first handle all constant operands. They are 440 // guaranteed to be the tail of the inputs because this is commutative. 441 for (int idx = operands.size() - 1; idx >= 0; idx--) { 442 Attribute a = operands[idx]; 443 // Cannot fold if any inputs are not constant; 444 if (!a) 445 return nullptr; 446 447 // We do not need to keep statically known values after handling them in 448 // this method. 449 getOperation()->eraseOperand(idx); 450 451 // Always false if any input is statically known false 452 if (!a.cast<BoolAttr>().getValue()) 453 return a; 454 } 455 // If this is reached, all inputs were statically known passing. 456 return BoolAttr::get(getContext(), true); 457 } 458 459 static LogicalResult verify(AssumingAllOp op) { 460 // Ensure that AssumingAllOp contains at least one operand 461 if (op.getNumOperands() == 0) 462 return op.emitOpError("no operands specified"); 463 464 return success(); 465 } 466 467 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 468 ValueRange inputs) { 469 build(b, state, b.getType<WitnessType>(), inputs); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // BroadcastOp 474 //===----------------------------------------------------------------------===// 475 476 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 477 if (shapes().size() == 1) { 478 // Otherwise, we need a cast which would be a canonicalization, not folding. 479 if (shapes().front().getType() != getType()) 480 return nullptr; 481 return shapes().front(); 482 } 483 484 // TODO: Support folding with more than 2 input shapes 485 if (shapes().size() > 2) 486 return nullptr; 487 488 if (!operands[0] || !operands[1]) 489 return nullptr; 490 auto lhsShape = llvm::to_vector<6>( 491 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 492 auto rhsShape = llvm::to_vector<6>( 493 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 494 SmallVector<int64_t, 6> resultShape; 495 496 // If the shapes are not compatible, we can't fold it. 497 // TODO: Fold to an "error". 498 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 499 return nullptr; 500 501 Builder builder(getContext()); 502 return builder.getIndexTensorAttr(resultShape); 503 } 504 505 static LogicalResult verify(BroadcastOp op) { 506 return verifyShapeOrExtentTensorOp(op); 507 } 508 509 namespace { 510 template <typename OpTy> 511 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 512 using OpRewritePattern<OpTy>::OpRewritePattern; 513 514 LogicalResult matchAndRewrite(OpTy op, 515 PatternRewriter &rewriter) const override { 516 // Find unique operands. 517 SmallVector<Value, 2> unique; 518 for (Value v : op.getOperands()) { 519 if (!llvm::is_contained(unique, v)) 520 unique.push_back(v); 521 } 522 523 // Reduce op to equivalent with unique operands. 524 if (unique.size() < op.getNumOperands()) { 525 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 526 op->getAttrs()); 527 return success(); 528 } 529 530 return failure(); 531 } 532 }; 533 534 template <typename OpTy> 535 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 536 using OpRewritePattern<OpTy>::OpRewritePattern; 537 538 LogicalResult matchAndRewrite(OpTy op, 539 PatternRewriter &rewriter) const override { 540 auto isPotentiallyNonEmptyShape = [](Value shape) { 541 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 542 if (extentTensorTy.getDimSize(0) == 0) 543 return false; 544 } 545 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 546 if (constShape.shape().empty()) 547 return false; 548 } 549 return true; 550 }; 551 auto newOperands = llvm::to_vector<8>( 552 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 553 554 // Reduce op to equivalent without empty shape operands. 555 if (newOperands.size() < op.getNumOperands()) { 556 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 557 op->getAttrs()); 558 return success(); 559 } 560 561 return failure(); 562 } 563 }; 564 565 struct BroadcastForwardSingleOperandPattern 566 : public OpRewritePattern<BroadcastOp> { 567 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 568 569 LogicalResult matchAndRewrite(BroadcastOp op, 570 PatternRewriter &rewriter) const override { 571 if (op.getNumOperands() != 1) 572 return failure(); 573 Value replacement = op.shapes().front(); 574 575 // Insert cast if needed. 576 if (replacement.getType() != op.getType()) { 577 auto loc = op.getLoc(); 578 if (op.getType().isa<ShapeType>()) { 579 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 580 } else { 581 assert(!op.getType().isa<ShapeType>() && 582 !replacement.getType().isa<ShapeType>() && 583 "expect extent tensor cast"); 584 replacement = 585 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 586 } 587 } 588 589 rewriter.replaceOp(op, replacement); 590 return success(); 591 } 592 }; 593 594 struct BroadcastFoldConstantOperandsPattern 595 : public OpRewritePattern<BroadcastOp> { 596 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 597 598 LogicalResult matchAndRewrite(BroadcastOp op, 599 PatternRewriter &rewriter) const override { 600 SmallVector<int64_t, 8> foldedConstantShape; 601 SmallVector<Value, 8> newShapeOperands; 602 for (Value shape : op.shapes()) { 603 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 604 SmallVector<int64_t, 8> newFoldedConstantShape; 605 if (OpTrait::util::getBroadcastedShape( 606 foldedConstantShape, 607 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 608 newFoldedConstantShape)) { 609 foldedConstantShape = newFoldedConstantShape; 610 continue; 611 } 612 } 613 newShapeOperands.push_back(shape); 614 } 615 616 // Need at least two constant operands to fold anything. 617 if (op.getNumOperands() - newShapeOperands.size() < 2) 618 return failure(); 619 620 auto foldedConstantOperandsTy = RankedTensorType::get( 621 {static_cast<int64_t>(foldedConstantShape.size())}, 622 rewriter.getIndexType()); 623 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 624 op.getLoc(), foldedConstantOperandsTy, 625 rewriter.getIndexTensorAttr(foldedConstantShape))); 626 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 627 newShapeOperands); 628 return success(); 629 } 630 }; 631 632 template <typename OpTy> 633 struct CanonicalizeCastExtentTensorOperandsPattern 634 : public OpRewritePattern<OpTy> { 635 using OpRewritePattern<OpTy>::OpRewritePattern; 636 637 LogicalResult matchAndRewrite(OpTy op, 638 PatternRewriter &rewriter) const override { 639 // Canonicalize operands. 640 bool anyChange = false; 641 auto canonicalizeOperand = [&](Value operand) { 642 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 643 // Only eliminate the cast if it holds no shape information. 644 bool isInformationLoosingCast = 645 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 646 if (isInformationLoosingCast) { 647 anyChange = true; 648 return castOp.source(); 649 } 650 } 651 return operand; 652 }; 653 auto newOperands = llvm::to_vector<8>( 654 llvm::map_range(op.getOperands(), canonicalizeOperand)); 655 656 // Rewrite op if any change required. 657 if (!anyChange) 658 return failure(); 659 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 660 return success(); 661 } 662 }; 663 664 struct BroadcastConcretizeResultTypePattern 665 : public OpRewritePattern<BroadcastOp> { 666 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 667 668 LogicalResult matchAndRewrite(BroadcastOp op, 669 PatternRewriter &rewriter) const override { 670 // Only concretize dynamic extent tensor result types. 671 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 672 if (!resultTy || !resultTy.isDynamicDim(0)) 673 return failure(); 674 675 // Infer resulting shape rank if possible. 676 int64_t maxRank = 0; 677 for (Value shape : op.shapes()) { 678 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 679 // Cannot infer resulting shape rank if any operand is dynamically 680 // ranked. 681 if (extentTensorTy.isDynamicDim(0)) 682 return failure(); 683 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 684 } 685 } 686 687 auto newOp = rewriter.create<BroadcastOp>( 688 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 689 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 690 return success(); 691 } 692 }; 693 } // namespace 694 695 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 696 MLIRContext *context) { 697 patterns.add<BroadcastConcretizeResultTypePattern, 698 BroadcastFoldConstantOperandsPattern, 699 BroadcastForwardSingleOperandPattern, 700 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 701 RemoveDuplicateOperandsPattern<BroadcastOp>, 702 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 703 } 704 705 //===----------------------------------------------------------------------===// 706 // ConcatOp 707 //===----------------------------------------------------------------------===// 708 709 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 710 if (!operands[0] || !operands[1]) 711 return nullptr; 712 auto lhsShape = llvm::to_vector<6>( 713 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 714 auto rhsShape = llvm::to_vector<6>( 715 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 716 SmallVector<int64_t, 6> resultShape; 717 resultShape.append(lhsShape.begin(), lhsShape.end()); 718 resultShape.append(rhsShape.begin(), rhsShape.end()); 719 Builder builder(getContext()); 720 return builder.getIndexTensorAttr(resultShape); 721 } 722 723 //===----------------------------------------------------------------------===// 724 // ConstShapeOp 725 //===----------------------------------------------------------------------===// 726 727 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 728 p << "shape.const_shape "; 729 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 730 p << "["; 731 interleaveComma(op.shape().getValues<int64_t>(), p, 732 [&](int64_t i) { p << i; }); 733 p << "] : "; 734 p.printType(op.getType()); 735 } 736 737 static ParseResult parseConstShapeOp(OpAsmParser &parser, 738 OperationState &result) { 739 if (parser.parseOptionalAttrDict(result.attributes)) 740 return failure(); 741 // We piggy-back on ArrayAttr parsing, though we don't internally store the 742 // shape as an ArrayAttr. 743 // TODO: Implement custom parser and maybe make syntax a bit more concise. 744 Attribute extentsRaw; 745 NamedAttrList dummy; 746 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 747 return failure(); 748 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 749 if (!extentsArray) 750 return failure(); 751 SmallVector<int64_t, 6> ints; 752 for (Attribute extent : extentsArray) { 753 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 754 if (!attr) 755 return failure(); 756 ints.push_back(attr.getInt()); 757 } 758 Builder &builder = parser.getBuilder(); 759 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 760 Type resultTy; 761 if (parser.parseColonType(resultTy)) 762 return failure(); 763 result.types.push_back(resultTy); 764 return success(); 765 } 766 767 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 768 769 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 770 MLIRContext *context) { 771 patterns.add<TensorCastConstShape>(context); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // CstrBroadcastableOp 776 //===----------------------------------------------------------------------===// 777 778 void CstrBroadcastableOp::getCanonicalizationPatterns( 779 RewritePatternSet &patterns, MLIRContext *context) { 780 // Canonicalization patterns have overlap with the considerations during 781 // folding in case additional shape information is inferred at some point that 782 // does not result in folding. 783 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 784 CstrBroadcastableEqOps, 785 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 786 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 787 } 788 789 // Return true if there is exactly one attribute not representing a scalar 790 // broadcast. 791 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 792 bool nonScalarSeen = false; 793 for (Attribute a : attributes) { 794 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 795 if (nonScalarSeen) 796 return false; 797 nonScalarSeen = true; 798 } 799 } 800 return true; 801 } 802 803 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 804 // No broadcasting is needed if all operands but one are scalar. 805 if (hasAtMostSingleNonScalar(operands)) 806 return BoolAttr::get(getContext(), true); 807 808 if ([&] { 809 SmallVector<SmallVector<int64_t, 6>, 6> extents; 810 for (const auto &operand : operands) { 811 if (!operand) 812 return false; 813 extents.push_back(llvm::to_vector<6>( 814 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 815 } 816 return OpTrait::util::staticallyKnownBroadcastable(extents); 817 }()) 818 return BoolAttr::get(getContext(), true); 819 820 // Lastly, see if folding can be completed based on what constraints are known 821 // on the input shapes. 822 if ([&] { 823 SmallVector<SmallVector<int64_t, 6>, 6> extents; 824 for (auto shapeValue : shapes()) { 825 extents.emplace_back(); 826 if (failed(getShapeVec(shapeValue, extents.back()))) 827 return false; 828 } 829 return OpTrait::util::staticallyKnownBroadcastable(extents); 830 }()) 831 return BoolAttr::get(getContext(), true); 832 833 // Because a failing witness result here represents an eventual assertion 834 // failure, we do not replace it with a constant witness. 835 return nullptr; 836 } 837 838 static LogicalResult verify(CstrBroadcastableOp op) { 839 // Ensure that AssumingAllOp contains at least one operand 840 if (op.getNumOperands() < 2) 841 return op.emitOpError("required at least 2 input shapes"); 842 return success(); 843 } 844 845 //===----------------------------------------------------------------------===// 846 // CstrEqOp 847 //===----------------------------------------------------------------------===// 848 849 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 850 MLIRContext *context) { 851 // If inputs are equal, return passing witness 852 patterns.add<CstrEqEqOps>(context); 853 } 854 855 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 856 if (llvm::all_of(operands, 857 [&](Attribute a) { return a && a == operands[0]; })) 858 return BoolAttr::get(getContext(), true); 859 860 // Because a failing witness result here represents an eventual assertion 861 // failure, we do not try to replace it with a constant witness. Similarly, we 862 // cannot if there are any non-const inputs. 863 return nullptr; 864 } 865 866 //===----------------------------------------------------------------------===// 867 // ConstSizeOp 868 //===----------------------------------------------------------------------===// 869 870 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 871 int64_t value) { 872 build(builder, result, builder.getIndexAttr(value)); 873 } 874 875 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 876 877 void ConstSizeOp::getAsmResultNames( 878 llvm::function_ref<void(Value, StringRef)> setNameFn) { 879 SmallString<4> buffer; 880 llvm::raw_svector_ostream os(buffer); 881 os << "c" << value(); 882 setNameFn(getResult(), os.str()); 883 } 884 885 //===----------------------------------------------------------------------===// 886 // ConstWitnessOp 887 //===----------------------------------------------------------------------===// 888 889 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 890 891 //===----------------------------------------------------------------------===// 892 // CstrRequireOp 893 //===----------------------------------------------------------------------===// 894 895 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 896 return operands[0]; 897 } 898 899 //===----------------------------------------------------------------------===// 900 // DivOp 901 //===----------------------------------------------------------------------===// 902 903 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 904 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 905 if (!lhs) 906 return nullptr; 907 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 908 if (!rhs) 909 return nullptr; 910 911 // Division in APInt does not follow floor(lhs, rhs) when the result is 912 // negative. Rather, APInt rounds toward zero. 913 APInt quotient, remainder; 914 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 915 if (quotient.isNegative() && !remainder.isNullValue()) { 916 quotient -= 1; 917 } 918 919 Type indexTy = IndexType::get(getContext()); 920 return IntegerAttr::get(indexTy, quotient); 921 } 922 923 //===----------------------------------------------------------------------===// 924 // ShapeEqOp 925 //===----------------------------------------------------------------------===// 926 927 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 928 bool allSame = true; 929 if (!operands.empty() && !operands[0]) 930 return {}; 931 for (Attribute operand : operands.drop_front(1)) { 932 if (!operand) 933 return {}; 934 allSame = allSame && operand == operands[0]; 935 } 936 return BoolAttr::get(getContext(), allSame); 937 } 938 939 //===----------------------------------------------------------------------===// 940 // IndexToSizeOp 941 //===----------------------------------------------------------------------===// 942 943 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 944 // Constant values of both types, `shape.size` and `index`, are represented as 945 // `IntegerAttr`s which makes constant folding simple. 946 if (Attribute arg = operands[0]) 947 return arg; 948 return {}; 949 } 950 951 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 952 MLIRContext *context) { 953 patterns.add<SizeToIndexToSizeCanonicalization>(context); 954 } 955 956 //===----------------------------------------------------------------------===// 957 // FromExtentsOp 958 //===----------------------------------------------------------------------===// 959 960 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 961 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 962 return nullptr; 963 SmallVector<int64_t, 6> extents; 964 for (auto attr : operands) 965 extents.push_back(attr.cast<IntegerAttr>().getInt()); 966 Builder builder(getContext()); 967 return builder.getIndexTensorAttr(extents); 968 } 969 970 //===----------------------------------------------------------------------===// 971 // FunctionLibraryOp 972 //===----------------------------------------------------------------------===// 973 974 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 975 StringRef name) { 976 result.attributes.push_back(builder.getNamedAttr( 977 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 978 } 979 980 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 981 auto attr = mapping() 982 .get(op->getName().getIdentifier()) 983 .dyn_cast_or_null<FlatSymbolRefAttr>(); 984 if (!attr) 985 return nullptr; 986 return lookupSymbol<FuncOp>(attr); 987 } 988 989 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 990 OperationState &result) { 991 // Parse the op name. 992 StringAttr nameAttr; 993 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 994 result.attributes)) 995 return failure(); 996 997 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 998 return failure(); 999 1000 auto *bodyRegion = result.addRegion(); 1001 if (parser.parseRegion(*bodyRegion)) 1002 return failure(); 1003 1004 if (parser.parseKeyword("mapping")) 1005 return failure(); 1006 1007 DictionaryAttr mappingAttr; 1008 if (parser.parseAttribute(mappingAttr, 1009 parser.getBuilder().getType<NoneType>(), "mapping", 1010 result.attributes)) 1011 return failure(); 1012 return success(); 1013 } 1014 1015 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1016 p << op.getOperationName() << ' '; 1017 p.printSymbolName(op.getName()); 1018 p.printOptionalAttrDictWithKeyword( 1019 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1020 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1021 /*printBlockTerminators=*/false); 1022 p << " mapping "; 1023 p.printAttributeWithoutType(op.mappingAttr()); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // GetExtentOp 1028 //===----------------------------------------------------------------------===// 1029 1030 Optional<int64_t> GetExtentOp::getConstantDim() { 1031 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1032 return constSizeOp.value().getLimitedValue(); 1033 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 1034 return constantOp.value().cast<IntegerAttr>().getInt(); 1035 return llvm::None; 1036 } 1037 1038 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1039 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1040 if (!elements) 1041 return nullptr; 1042 Optional<int64_t> dim = getConstantDim(); 1043 if (!dim.hasValue()) 1044 return nullptr; 1045 if (dim.getValue() >= elements.getNumElements()) 1046 return nullptr; 1047 return elements.getValue({(uint64_t)dim.getValue()}); 1048 } 1049 1050 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1051 int64_t dim) { 1052 auto loc = result.location; 1053 auto dimAttr = builder.getIndexAttr(dim); 1054 if (shape.getType().isa<ShapeType>()) { 1055 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1056 build(builder, result, builder.getType<SizeType>(), shape, dim); 1057 } else { 1058 Value dim = 1059 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 1060 build(builder, result, builder.getIndexType(), shape, dim); 1061 } 1062 } 1063 1064 //===----------------------------------------------------------------------===// 1065 // IsBroadcastableOp 1066 //===----------------------------------------------------------------------===// 1067 1068 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1069 MLIRContext *context) { 1070 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1071 } 1072 1073 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1074 // Can always broadcast fewer than two shapes. 1075 if (operands.size() < 2) { 1076 return BoolAttr::get(getContext(), true); 1077 } 1078 1079 return nullptr; 1080 } 1081 1082 //===----------------------------------------------------------------------===// 1083 // RankOp 1084 //===----------------------------------------------------------------------===// 1085 1086 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1087 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1088 if (!shape) 1089 return {}; 1090 int64_t rank = shape.getNumElements(); 1091 Builder builder(getContext()); 1092 return builder.getIndexAttr(rank); 1093 } 1094 1095 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1096 /// Constant folding fails in cases where only the rank is constant, not the 1097 /// shape itself. 1098 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1099 /// 1100 /// Example: 1101 /// 1102 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1103 /// %rank = shape.rank %shape 1104 /// 1105 /// becomes 1106 /// 1107 /// %rank = shape.const_size 3 1108 1109 namespace { 1110 struct RankShapeOfCanonicalizationPattern 1111 : public OpRewritePattern<shape::RankOp> { 1112 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1113 1114 LogicalResult matchAndRewrite(shape::RankOp op, 1115 PatternRewriter &rewriter) const override { 1116 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1117 if (!shapeOfOp) 1118 return failure(); 1119 auto rankedTensorType = 1120 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1121 if (!rankedTensorType) 1122 return failure(); 1123 int64_t rank = rankedTensorType.getRank(); 1124 if (op.getType().isa<IndexType>()) { 1125 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 1126 } else if (op.getType().isa<shape::SizeType>()) { 1127 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1128 } else { 1129 return failure(); 1130 } 1131 return success(); 1132 } 1133 }; 1134 } // namespace 1135 1136 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1137 MLIRContext *context) { 1138 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1139 } 1140 1141 //===----------------------------------------------------------------------===// 1142 // NumElementsOp 1143 //===----------------------------------------------------------------------===// 1144 1145 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1146 1147 // Fold only when argument constant. 1148 Attribute shape = operands[0]; 1149 if (!shape) 1150 return {}; 1151 1152 APInt product(64, 1); 1153 for (auto value : shape.cast<DenseIntElementsAttr>()) 1154 product *= value; 1155 Builder builder(getContext()); 1156 return builder.getIndexAttr(product.getLimitedValue()); 1157 } 1158 1159 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 1160 Value shape) { 1161 if (shape.getType().isa<ShapedType>()) { 1162 auto type = builder.getIndexType(); 1163 return build(builder, result, type, shape); 1164 } 1165 auto type = SizeType::get(builder.getContext()); 1166 return build(builder, result, type, shape); 1167 } 1168 1169 //===----------------------------------------------------------------------===// 1170 // MaxOp 1171 //===----------------------------------------------------------------------===// 1172 1173 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1174 // If operands are equal, just propagate one. 1175 if (lhs() == rhs()) 1176 return lhs(); 1177 return nullptr; 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // MinOp 1182 //===----------------------------------------------------------------------===// 1183 1184 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1185 // If operands are equal, just propagate one. 1186 if (lhs() == rhs()) 1187 return lhs(); 1188 return nullptr; 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // MulOp 1193 //===----------------------------------------------------------------------===// 1194 1195 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1196 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1197 if (!lhs) 1198 return nullptr; 1199 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1200 if (!rhs) 1201 return nullptr; 1202 APInt folded = lhs.getValue() * rhs.getValue(); 1203 Type indexTy = IndexType::get(getContext()); 1204 return IntegerAttr::get(indexTy, folded); 1205 } 1206 1207 //===----------------------------------------------------------------------===// 1208 // ShapeOfOp 1209 //===----------------------------------------------------------------------===// 1210 1211 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1212 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1213 if (!type || !type.hasStaticShape()) 1214 return nullptr; 1215 Builder builder(getContext()); 1216 return builder.getIndexTensorAttr(type.getShape()); 1217 } 1218 1219 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 1220 if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) { 1221 int64_t rank = 1222 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1223 Type indexTy = builder.getIndexType(); 1224 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1225 return ShapeOfOp::build(builder, result, extentTensorTy, arg); 1226 } 1227 Type shapeTy = builder.getType<ShapeType>(); 1228 return ShapeOfOp::build(builder, result, shapeTy, arg); 1229 } 1230 1231 namespace { 1232 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1233 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1234 1235 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1236 PatternRewriter &rewriter) const override { 1237 if (!op.arg().getType().isa<ShapedType>()) 1238 return failure(); 1239 if (op.getType().isa<ShapedType>()) 1240 return failure(); 1241 1242 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1243 return success(); 1244 } 1245 }; 1246 1247 // Canonicalize 1248 // ``` 1249 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1250 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1251 // ``` 1252 // to 1253 // ``` 1254 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1255 // ``` 1256 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1257 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1258 1259 LogicalResult matchAndRewrite(tensor::CastOp op, 1260 PatternRewriter &rewriter) const override { 1261 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1262 if (!ty || ty.getRank() != 1) 1263 return failure(); 1264 1265 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1266 if (!shapeOfOp) 1267 return failure(); 1268 1269 // Argument type must be ranked and must not conflict. 1270 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1271 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1272 return failure(); 1273 1274 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1275 return success(); 1276 } 1277 }; 1278 } // namespace 1279 1280 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1281 MLIRContext *context) { 1282 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); 1283 } 1284 1285 //===----------------------------------------------------------------------===// 1286 // SizeToIndexOp 1287 //===----------------------------------------------------------------------===// 1288 1289 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1290 // Constant values of both types, `shape.size` and `index`, are represented as 1291 // `IntegerAttr`s which makes constant folding simple. 1292 if (Attribute arg = operands[0]) 1293 return arg; 1294 return impl::foldCastOp(*this); 1295 } 1296 1297 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1298 MLIRContext *context) { 1299 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1300 } 1301 1302 //===----------------------------------------------------------------------===// 1303 // YieldOp 1304 //===----------------------------------------------------------------------===// 1305 1306 static LogicalResult verify(shape::YieldOp op) { 1307 auto *parentOp = op->getParentOp(); 1308 auto results = parentOp->getResults(); 1309 auto operands = op.getOperands(); 1310 1311 if (parentOp->getNumResults() != op.getNumOperands()) 1312 return op.emitOpError() << "number of operands does not match number of " 1313 "results of its parent"; 1314 for (auto e : llvm::zip(results, operands)) 1315 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1316 return op.emitOpError() 1317 << "types mismatch between yield op and its parent"; 1318 1319 return success(); 1320 } 1321 1322 //===----------------------------------------------------------------------===// 1323 // SplitAtOp 1324 //===----------------------------------------------------------------------===// 1325 1326 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1327 SmallVectorImpl<OpFoldResult> &results) { 1328 if (!operands[0] || !operands[1]) 1329 return failure(); 1330 auto shapeVec = llvm::to_vector<6>( 1331 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1332 auto shape = llvm::makeArrayRef(shapeVec); 1333 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1334 // Verify that the split point is in the correct range. 1335 // TODO: Constant fold to an "error". 1336 int64_t rank = shape.size(); 1337 if (!(-rank <= splitPoint && splitPoint <= rank)) 1338 return failure(); 1339 if (splitPoint < 0) 1340 splitPoint += shape.size(); 1341 Builder builder(operands[0].getContext()); 1342 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1343 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1344 return success(); 1345 } 1346 1347 //===----------------------------------------------------------------------===// 1348 // ToExtentTensorOp 1349 //===----------------------------------------------------------------------===// 1350 1351 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1352 if (!operands[0]) 1353 return impl::foldCastOp(*this); 1354 Builder builder(getContext()); 1355 auto shape = llvm::to_vector<6>( 1356 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1357 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1358 builder.getIndexType()); 1359 return DenseIntElementsAttr::get(type, shape); 1360 } 1361 1362 //===----------------------------------------------------------------------===// 1363 // ReduceOp 1364 //===----------------------------------------------------------------------===// 1365 1366 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1367 ValueRange initVals) { 1368 result.addOperands(shape); 1369 result.addOperands(initVals); 1370 1371 Region *bodyRegion = result.addRegion(); 1372 bodyRegion->push_back(new Block); 1373 Block &bodyBlock = bodyRegion->front(); 1374 bodyBlock.addArgument(builder.getIndexType()); 1375 1376 Type elementType; 1377 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1378 elementType = tensorType.getElementType(); 1379 else 1380 elementType = SizeType::get(builder.getContext()); 1381 bodyBlock.addArgument(elementType); 1382 1383 for (Type initValType : initVals.getTypes()) { 1384 bodyBlock.addArgument(initValType); 1385 result.addTypes(initValType); 1386 } 1387 } 1388 1389 static LogicalResult verify(ReduceOp op) { 1390 // Verify block arg types. 1391 Block &block = op.region().front(); 1392 1393 // The block takes index, extent, and aggregated values as arguments. 1394 auto blockArgsCount = op.initVals().size() + 2; 1395 if (block.getNumArguments() != blockArgsCount) 1396 return op.emitOpError() << "ReduceOp body is expected to have " 1397 << blockArgsCount << " arguments"; 1398 1399 // The first block argument is the index and must always be of type `index`. 1400 if (!block.getArgument(0).getType().isa<IndexType>()) 1401 return op.emitOpError( 1402 "argument 0 of ReduceOp body is expected to be of IndexType"); 1403 1404 // The second block argument is the extent and must be of type `size` or 1405 // `index`, depending on whether the reduce operation is applied to a shape or 1406 // to an extent tensor. 1407 Type extentTy = block.getArgument(1).getType(); 1408 if (op.shape().getType().isa<ShapeType>()) { 1409 if (!extentTy.isa<SizeType>()) 1410 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1411 "SizeType if the ReduceOp operates on a ShapeType"); 1412 } else { 1413 if (!extentTy.isa<IndexType>()) 1414 return op.emitOpError( 1415 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1416 "ReduceOp operates on an extent tensor"); 1417 } 1418 1419 for (auto type : llvm::enumerate(op.initVals())) 1420 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1421 return op.emitOpError() 1422 << "type mismatch between argument " << type.index() + 2 1423 << " of ReduceOp body and initial value " << type.index(); 1424 return success(); 1425 } 1426 1427 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1428 // Parse operands. 1429 SmallVector<OpAsmParser::OperandType, 3> operands; 1430 Type shapeOrExtentTensorType; 1431 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1432 OpAsmParser::Delimiter::Paren) || 1433 parser.parseColonType(shapeOrExtentTensorType) || 1434 parser.parseOptionalArrowTypeList(result.types)) 1435 return failure(); 1436 1437 // Resolve operands. 1438 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1439 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1440 result.operands) || 1441 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1442 result.operands)) 1443 return failure(); 1444 1445 // Parse the body. 1446 Region *body = result.addRegion(); 1447 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1448 return failure(); 1449 1450 // Parse attributes. 1451 if (parser.parseOptionalAttrDict(result.attributes)) 1452 return failure(); 1453 1454 return success(); 1455 } 1456 1457 static void print(OpAsmPrinter &p, ReduceOp op) { 1458 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1459 << ") : " << op.shape().getType(); 1460 p.printOptionalArrowTypeList(op.getResultTypes()); 1461 p.printRegion(op.region()); 1462 p.printOptionalAttrDict(op->getAttrs()); 1463 } 1464 1465 #define GET_OP_CLASSES 1466 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1467