1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/InliningUtils.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 using namespace mlir; 25 using namespace mlir::shape; 26 27 namespace { 28 #include "ShapeCanonicalization.inc" 29 } 30 31 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 32 return RankedTensorType::get({rank}, IndexType::get(ctx)); 33 } 34 35 bool shape::isExtentTensorType(Type type) { 36 auto ranked = type.dyn_cast<RankedTensorType>(); 37 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 38 } 39 40 LogicalResult shape::getShapeVec(Value input, 41 SmallVectorImpl<int64_t> &shapeValues) { 42 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 43 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 44 if (!type.hasRank()) 45 return failure(); 46 shapeValues = llvm::to_vector<6>(type.getShape()); 47 return success(); 48 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 49 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 50 return success(); 51 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) { 52 shapeValues = llvm::to_vector<6>( 53 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 54 return success(); 55 } else { 56 return failure(); 57 } 58 } 59 60 static bool isErrorPropagationPossible(TypeRange operandTypes) { 61 return llvm::any_of(operandTypes, [](Type ty) { 62 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 63 }); 64 } 65 66 static LogicalResult verifySizeOrIndexOp(Operation *op) { 67 assert(op != nullptr && op->getNumResults() == 1); 68 Type resultTy = op->getResultTypes().front(); 69 if (isErrorPropagationPossible(op->getOperandTypes())) { 70 if (!resultTy.isa<SizeType>()) 71 return op->emitOpError() 72 << "if at least one of the operands can hold error values then " 73 "the result must be of type `size` to propagate them"; 74 } 75 return success(); 76 } 77 78 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 79 assert(op != nullptr && op->getNumResults() == 1); 80 Type resultTy = op->getResultTypes().front(); 81 if (isErrorPropagationPossible(op->getOperandTypes())) { 82 if (!resultTy.isa<ShapeType>()) 83 return op->emitOpError() 84 << "if at least one of the operands can hold error values then " 85 "the result must be of type `shape` to propagate them"; 86 } 87 return success(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // InlinerInterface 92 //===----------------------------------------------------------------------===// 93 94 namespace { 95 /// This class defines the interface for inlining shape dialect ops. 96 struct ShapeInlinerInterface : public DialectInlinerInterface { 97 using DialectInlinerInterface::DialectInlinerInterface; 98 99 // Returns true if the given region 'src' can be inlined into the region 100 // 'dest' that is attached to an operation registered to the current dialect. 101 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 102 BlockAndValueMapping &) const final { 103 return true; 104 } 105 106 // Returns true if the given operation 'op', that is registered to this 107 // dialect, can be inlined into the region 'dest' that is attached to an 108 // operation registered to the current dialect. 109 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 110 BlockAndValueMapping &) const final { 111 return true; 112 } 113 }; 114 } // namespace 115 116 void ShapeDialect::initialize() { 117 addOperations< 118 #define GET_OP_LIST 119 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 120 >(); 121 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 122 addInterfaces<ShapeInlinerInterface>(); 123 // Allow unknown operations during prototyping and testing. As the dialect is 124 // still evolving it makes it simple to start with an unregistered ops and 125 // try different variants before actually defining the op. 126 allowUnknownOperations(); 127 } 128 129 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 130 Attribute value, Type type, 131 Location loc) { 132 if (type.isa<ShapeType>() || isExtentTensorType(type)) 133 return builder.create<ConstShapeOp>(loc, type, 134 value.cast<DenseIntElementsAttr>()); 135 if (type.isa<SizeType>()) 136 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 137 if (type.isa<WitnessType>()) 138 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 139 if (ConstantOp::isBuildableWith(value, type)) 140 return builder.create<ConstantOp>(loc, type, value); 141 return nullptr; 142 } 143 144 /// Parse a type registered to this dialect. 145 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 146 StringRef keyword; 147 if (parser.parseKeyword(&keyword)) 148 return Type(); 149 150 if (keyword == "shape") 151 return ShapeType::get(getContext()); 152 if (keyword == "size") 153 return SizeType::get(getContext()); 154 if (keyword == "value_shape") 155 return ValueShapeType::get(getContext()); 156 if (keyword == "witness") 157 return WitnessType::get(getContext()); 158 159 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 160 return Type(); 161 } 162 163 /// Print a type registered to this dialect. 164 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 165 TypeSwitch<Type>(type) 166 .Case<ShapeType>([&](Type) { os << "shape"; }) 167 .Case<SizeType>([&](Type) { os << "size"; }) 168 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 169 .Case<WitnessType>([&](Type) { os << "witness"; }) 170 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 171 } 172 173 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 174 NamedAttribute attribute) { 175 // Verify shape.lib attribute. 176 if (attribute.first == "shape.lib") { 177 if (!op->hasTrait<OpTrait::SymbolTable>()) 178 return op->emitError( 179 "shape.lib attribute may only be on op implementing SymbolTable"); 180 181 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 182 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 183 if (!symbol) 184 return op->emitError("shape function library ") 185 << symbolRef << " not found"; 186 return isa<shape::FunctionLibraryOp>(symbol) 187 ? success() 188 : op->emitError() 189 << symbolRef << " required to be shape function library"; 190 } 191 192 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 193 // Verify all entries are function libraries and mappings in libraries 194 // refer to unique ops. 195 DenseSet<Identifier> key; 196 for (auto it : arr) { 197 if (!it.isa<SymbolRefAttr>()) 198 return op->emitError( 199 "only SymbolRefAttr allowed in shape.lib attribute array"); 200 201 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 202 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 203 if (!shapeFnLib) 204 return op->emitError() 205 << it << " does not refer to FunctionLibraryOp"; 206 for (auto mapping : shapeFnLib.mapping()) { 207 if (!key.insert(mapping.first).second) { 208 return op->emitError("only one op to shape mapping allowed, found " 209 "multiple for `") 210 << mapping.first << "`"; 211 } 212 } 213 } 214 return success(); 215 } 216 217 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 218 "allowed as shape.lib attribute"); 219 } 220 return success(); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // AnyOp 225 //===----------------------------------------------------------------------===// 226 227 // TODO: Canonicalization should be implemented for shapes that can be 228 // determined through mixtures of the known dimensions of the inputs. 229 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 230 // Only the last operand is checked because AnyOp is commutative. 231 if (operands.back()) 232 return operands.back(); 233 234 return nullptr; 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AssumingOp 239 //===----------------------------------------------------------------------===// 240 241 static ParseResult parseAssumingOp(OpAsmParser &parser, 242 OperationState &result) { 243 result.regions.reserve(1); 244 Region *doRegion = result.addRegion(); 245 246 auto &builder = parser.getBuilder(); 247 OpAsmParser::OperandType cond; 248 if (parser.parseOperand(cond) || 249 parser.resolveOperand(cond, builder.getType<WitnessType>(), 250 result.operands)) 251 return failure(); 252 253 // Parse optional results type list. 254 if (parser.parseOptionalArrowTypeList(result.types)) 255 return failure(); 256 257 // Parse the region and add a terminator if elided. 258 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 259 return failure(); 260 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 261 262 // Parse the optional attribute list. 263 if (parser.parseOptionalAttrDict(result.attributes)) 264 return failure(); 265 return success(); 266 } 267 268 static void print(OpAsmPrinter &p, AssumingOp op) { 269 bool yieldsResults = !op.results().empty(); 270 271 p << AssumingOp::getOperationName() << " " << op.witness(); 272 if (yieldsResults) { 273 p << " -> (" << op.getResultTypes() << ")"; 274 } 275 p.printRegion(op.doRegion(), 276 /*printEntryBlockArgs=*/false, 277 /*printBlockTerminators=*/yieldsResults); 278 p.printOptionalAttrDict(op->getAttrs()); 279 } 280 281 namespace { 282 // Removes AssumingOp with a passing witness and inlines the region. 283 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 284 using OpRewritePattern<AssumingOp>::OpRewritePattern; 285 286 LogicalResult matchAndRewrite(AssumingOp op, 287 PatternRewriter &rewriter) const override { 288 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 289 if (!witness || !witness.passingAttr()) 290 return failure(); 291 292 AssumingOp::inlineRegionIntoParent(op, rewriter); 293 return success(); 294 } 295 }; 296 297 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 298 using OpRewritePattern<AssumingOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(AssumingOp op, 301 PatternRewriter &rewriter) const override { 302 Block *body = op.getBody(); 303 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 304 305 // Find used values. 306 SmallVector<Value, 4> newYieldOperands; 307 Value opResult, yieldOperand; 308 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 309 std::tie(opResult, yieldOperand) = it; 310 if (!opResult.getUses().empty()) { 311 newYieldOperands.push_back(yieldOperand); 312 } 313 } 314 315 // Rewrite only if redundant results exist. 316 if (newYieldOperands.size() == yieldOp->getNumOperands()) 317 return failure(); 318 319 // Replace yield op in the old assuming op's body and move the entire region 320 // to the new assuming op. 321 rewriter.setInsertionPointToEnd(body); 322 auto newYieldOp = 323 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 324 rewriter.setInsertionPoint(op); 325 auto newOp = rewriter.create<AssumingOp>( 326 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 327 newOp.doRegion().takeBody(op.doRegion()); 328 329 // Use the new results to replace the previously used ones. 330 SmallVector<Value, 4> replacementValues; 331 auto src = newOp.getResults().begin(); 332 for (auto it : op.getResults()) { 333 if (it.getUses().empty()) 334 replacementValues.push_back(nullptr); 335 else 336 replacementValues.push_back(*src++); 337 } 338 rewriter.replaceOp(op, replacementValues); 339 return success(); 340 } 341 }; 342 } // namespace 343 344 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 345 MLIRContext *context) { 346 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 347 } 348 349 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 350 void AssumingOp::getSuccessorRegions( 351 Optional<unsigned> index, ArrayRef<Attribute> operands, 352 SmallVectorImpl<RegionSuccessor> ®ions) { 353 // AssumingOp has unconditional control flow into the region and back to the 354 // parent, so return the correct RegionSuccessor purely based on the index 355 // being None or 0. 356 if (index.hasValue()) { 357 regions.push_back(RegionSuccessor(getResults())); 358 return; 359 } 360 361 regions.push_back(RegionSuccessor(&doRegion())); 362 } 363 364 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 365 PatternRewriter &rewriter) { 366 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 367 auto *assumingBlock = op.getBody(); 368 auto initPosition = rewriter.getInsertionPoint(); 369 auto *blockAfterAssuming = 370 rewriter.splitBlock(blockBeforeAssuming, initPosition); 371 372 // Remove the AssumingOp and AssumingYieldOp. 373 auto &yieldOp = assumingBlock->back(); 374 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 375 rewriter.replaceOp(op, yieldOp.getOperands()); 376 rewriter.eraseOp(&yieldOp); 377 378 // Merge blocks together as there was no branching behavior from the 379 // AssumingOp. 380 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 381 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 382 } 383 384 void AssumingOp::build( 385 OpBuilder &builder, OperationState &result, Value witness, 386 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 387 388 result.addOperands(witness); 389 Region *bodyRegion = result.addRegion(); 390 bodyRegion->push_back(new Block); 391 Block &bodyBlock = bodyRegion->front(); 392 393 // Build body. 394 OpBuilder::InsertionGuard guard(builder); 395 builder.setInsertionPointToStart(&bodyBlock); 396 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 397 builder.create<AssumingYieldOp>(result.location, yieldValues); 398 399 SmallVector<Type, 2> assumingTypes; 400 for (Value v : yieldValues) 401 assumingTypes.push_back(v.getType()); 402 result.addTypes(assumingTypes); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // AssumingAllOp 407 //===----------------------------------------------------------------------===// 408 409 namespace { 410 struct AssumingAllToCstrEqCanonicalization 411 : public OpRewritePattern<AssumingAllOp> { 412 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 413 414 LogicalResult matchAndRewrite(AssumingAllOp op, 415 PatternRewriter &rewriter) const override { 416 SmallVector<Value, 8> shapes; 417 for (Value w : op.inputs()) { 418 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 419 if (!cstrEqOp) 420 return failure(); 421 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 422 return llvm::is_contained(shapes, s); 423 }); 424 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 425 return failure(); 426 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 427 } 428 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 429 return success(); 430 } 431 }; 432 } // namespace 433 434 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 435 MLIRContext *context) { 436 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context); 437 } 438 439 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 440 // Iterate in reverse to first handle all constant operands. They are 441 // guaranteed to be the tail of the inputs because this is commutative. 442 for (int idx = operands.size() - 1; idx >= 0; idx--) { 443 Attribute a = operands[idx]; 444 // Cannot fold if any inputs are not constant; 445 if (!a) 446 return nullptr; 447 448 // We do not need to keep statically known values after handling them in 449 // this method. 450 getOperation()->eraseOperand(idx); 451 452 // Always false if any input is statically known false 453 if (!a.cast<BoolAttr>().getValue()) 454 return a; 455 } 456 // If this is reached, all inputs were statically known passing. 457 return BoolAttr::get(getContext(), true); 458 } 459 460 static LogicalResult verify(AssumingAllOp op) { 461 // Ensure that AssumingAllOp contains at least one operand 462 if (op.getNumOperands() == 0) 463 return op.emitOpError("no operands specified"); 464 465 return success(); 466 } 467 468 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 469 ValueRange inputs) { 470 build(b, state, b.getType<WitnessType>(), inputs); 471 } 472 473 //===----------------------------------------------------------------------===// 474 // BroadcastOp 475 //===----------------------------------------------------------------------===// 476 477 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 478 if (shapes().size() == 1) { 479 // Otherwise, we need a cast which would be a canonicalization, not folding. 480 if (shapes().front().getType() != getType()) 481 return nullptr; 482 return shapes().front(); 483 } 484 485 // TODO: Support folding with more than 2 input shapes 486 if (shapes().size() > 2) 487 return nullptr; 488 489 if (!operands[0] || !operands[1]) 490 return nullptr; 491 auto lhsShape = llvm::to_vector<6>( 492 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 493 auto rhsShape = llvm::to_vector<6>( 494 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 495 SmallVector<int64_t, 6> resultShape; 496 497 // If the shapes are not compatible, we can't fold it. 498 // TODO: Fold to an "error". 499 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 500 return nullptr; 501 502 Builder builder(getContext()); 503 return builder.getIndexTensorAttr(resultShape); 504 } 505 506 static LogicalResult verify(BroadcastOp op) { 507 return verifyShapeOrExtentTensorOp(op); 508 } 509 510 namespace { 511 template <typename OpTy> 512 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 513 using OpRewritePattern<OpTy>::OpRewritePattern; 514 515 LogicalResult matchAndRewrite(OpTy op, 516 PatternRewriter &rewriter) const override { 517 // Find unique operands. 518 SmallVector<Value, 2> unique; 519 for (Value v : op.getOperands()) { 520 if (!llvm::is_contained(unique, v)) 521 unique.push_back(v); 522 } 523 524 // Reduce op to equivalent with unique operands. 525 if (unique.size() < op.getNumOperands()) { 526 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 527 op->getAttrs()); 528 return success(); 529 } 530 531 return failure(); 532 } 533 }; 534 535 template <typename OpTy> 536 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 537 using OpRewritePattern<OpTy>::OpRewritePattern; 538 539 LogicalResult matchAndRewrite(OpTy op, 540 PatternRewriter &rewriter) const override { 541 auto isPotentiallyNonEmptyShape = [](Value shape) { 542 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 543 if (extentTensorTy.getDimSize(0) == 0) 544 return false; 545 } 546 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 547 if (constShape.shape().empty()) 548 return false; 549 } 550 return true; 551 }; 552 auto newOperands = llvm::to_vector<8>( 553 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 554 555 // Reduce op to equivalent without empty shape operands. 556 if (newOperands.size() < op.getNumOperands()) { 557 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 558 op->getAttrs()); 559 return success(); 560 } 561 562 return failure(); 563 } 564 }; 565 566 struct BroadcastForwardSingleOperandPattern 567 : public OpRewritePattern<BroadcastOp> { 568 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 569 570 LogicalResult matchAndRewrite(BroadcastOp op, 571 PatternRewriter &rewriter) const override { 572 if (op.getNumOperands() != 1) 573 return failure(); 574 Value replacement = op.shapes().front(); 575 576 // Insert cast if needed. 577 if (replacement.getType() != op.getType()) { 578 auto loc = op.getLoc(); 579 if (op.getType().isa<ShapeType>()) { 580 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 581 } else { 582 assert(!op.getType().isa<ShapeType>() && 583 !replacement.getType().isa<ShapeType>() && 584 "expect extent tensor cast"); 585 replacement = 586 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 587 } 588 } 589 590 rewriter.replaceOp(op, replacement); 591 return success(); 592 } 593 }; 594 595 struct BroadcastFoldConstantOperandsPattern 596 : public OpRewritePattern<BroadcastOp> { 597 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 598 599 LogicalResult matchAndRewrite(BroadcastOp op, 600 PatternRewriter &rewriter) const override { 601 SmallVector<int64_t, 8> foldedConstantShape; 602 SmallVector<Value, 8> newShapeOperands; 603 for (Value shape : op.shapes()) { 604 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 605 SmallVector<int64_t, 8> newFoldedConstantShape; 606 if (OpTrait::util::getBroadcastedShape( 607 foldedConstantShape, 608 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 609 newFoldedConstantShape)) { 610 foldedConstantShape = newFoldedConstantShape; 611 continue; 612 } 613 } 614 newShapeOperands.push_back(shape); 615 } 616 617 // Need at least two constant operands to fold anything. 618 if (op.getNumOperands() - newShapeOperands.size() < 2) 619 return failure(); 620 621 auto foldedConstantOperandsTy = RankedTensorType::get( 622 {static_cast<int64_t>(foldedConstantShape.size())}, 623 rewriter.getIndexType()); 624 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 625 op.getLoc(), foldedConstantOperandsTy, 626 rewriter.getIndexTensorAttr(foldedConstantShape))); 627 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 628 newShapeOperands); 629 return success(); 630 } 631 }; 632 633 template <typename OpTy> 634 struct CanonicalizeCastExtentTensorOperandsPattern 635 : public OpRewritePattern<OpTy> { 636 using OpRewritePattern<OpTy>::OpRewritePattern; 637 638 LogicalResult matchAndRewrite(OpTy op, 639 PatternRewriter &rewriter) const override { 640 // Canonicalize operands. 641 bool anyChange = false; 642 auto canonicalizeOperand = [&](Value operand) { 643 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 644 // Only eliminate the cast if it holds no shape information. 645 bool isInformationLoosingCast = 646 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 647 if (isInformationLoosingCast) { 648 anyChange = true; 649 return castOp.source(); 650 } 651 } 652 return operand; 653 }; 654 auto newOperands = llvm::to_vector<8>( 655 llvm::map_range(op.getOperands(), canonicalizeOperand)); 656 657 // Rewrite op if any change required. 658 if (!anyChange) 659 return failure(); 660 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 661 return success(); 662 } 663 }; 664 665 struct BroadcastConcretizeResultTypePattern 666 : public OpRewritePattern<BroadcastOp> { 667 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 668 669 LogicalResult matchAndRewrite(BroadcastOp op, 670 PatternRewriter &rewriter) const override { 671 // Only concretize dynamic extent tensor result types. 672 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 673 if (!resultTy || !resultTy.isDynamicDim(0)) 674 return failure(); 675 676 // Infer resulting shape rank if possible. 677 int64_t maxRank = 0; 678 for (Value shape : op.shapes()) { 679 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 680 // Cannot infer resulting shape rank if any operand is dynamically 681 // ranked. 682 if (extentTensorTy.isDynamicDim(0)) 683 return failure(); 684 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 685 } 686 } 687 688 auto newOp = rewriter.create<BroadcastOp>( 689 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 690 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 691 return success(); 692 } 693 }; 694 } // namespace 695 696 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 697 MLIRContext *context) { 698 patterns.add<BroadcastConcretizeResultTypePattern, 699 BroadcastFoldConstantOperandsPattern, 700 BroadcastForwardSingleOperandPattern, 701 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 702 RemoveDuplicateOperandsPattern<BroadcastOp>, 703 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // ConcatOp 708 //===----------------------------------------------------------------------===// 709 710 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 711 if (!operands[0] || !operands[1]) 712 return nullptr; 713 auto lhsShape = llvm::to_vector<6>( 714 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 715 auto rhsShape = llvm::to_vector<6>( 716 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 717 SmallVector<int64_t, 6> resultShape; 718 resultShape.append(lhsShape.begin(), lhsShape.end()); 719 resultShape.append(rhsShape.begin(), rhsShape.end()); 720 Builder builder(getContext()); 721 return builder.getIndexTensorAttr(resultShape); 722 } 723 724 //===----------------------------------------------------------------------===// 725 // ConstShapeOp 726 //===----------------------------------------------------------------------===// 727 728 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 729 p << "shape.const_shape "; 730 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 731 p << "["; 732 interleaveComma(op.shape().getValues<int64_t>(), p, 733 [&](int64_t i) { p << i; }); 734 p << "] : "; 735 p.printType(op.getType()); 736 } 737 738 static ParseResult parseConstShapeOp(OpAsmParser &parser, 739 OperationState &result) { 740 if (parser.parseOptionalAttrDict(result.attributes)) 741 return failure(); 742 // We piggy-back on ArrayAttr parsing, though we don't internally store the 743 // shape as an ArrayAttr. 744 // TODO: Implement custom parser and maybe make syntax a bit more concise. 745 Attribute extentsRaw; 746 NamedAttrList dummy; 747 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 748 return failure(); 749 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 750 if (!extentsArray) 751 return failure(); 752 SmallVector<int64_t, 6> ints; 753 for (Attribute extent : extentsArray) { 754 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 755 if (!attr) 756 return failure(); 757 ints.push_back(attr.getInt()); 758 } 759 Builder &builder = parser.getBuilder(); 760 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 761 Type resultTy; 762 if (parser.parseColonType(resultTy)) 763 return failure(); 764 result.types.push_back(resultTy); 765 return success(); 766 } 767 768 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 769 770 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 771 MLIRContext *context) { 772 patterns.add<TensorCastConstShape>(context); 773 } 774 775 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 776 MLIRContext *context, Optional<Location> location, ValueRange operands, 777 DictionaryAttr attributes, RegionRange regions, 778 SmallVectorImpl<Type> &inferredReturnTypes) { 779 Builder b(context); 780 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 781 if (!shape) 782 return emitOptionalError(location, "missing shape attribute"); 783 inferredReturnTypes.assign({RankedTensorType::get( 784 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 785 return success(); 786 } 787 788 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 789 TypeRange r) { 790 if (l.size() != 1 || r.size() != 1) 791 return false; 792 793 Type lhs = l.front(); 794 Type rhs = r.front(); 795 796 if (lhs == rhs) 797 return true; 798 799 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 800 // Shape type is compatible with all other valid return types. 801 return true; 802 803 return succeeded(verifyCompatibleShapes(lhs, rhs)); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // CstrBroadcastableOp 808 //===----------------------------------------------------------------------===// 809 810 void CstrBroadcastableOp::getCanonicalizationPatterns( 811 RewritePatternSet &patterns, MLIRContext *context) { 812 // Canonicalization patterns have overlap with the considerations during 813 // folding in case additional shape information is inferred at some point that 814 // does not result in folding. 815 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 816 CstrBroadcastableEqOps, 817 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 818 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 819 } 820 821 // Return true if there is exactly one attribute not representing a scalar 822 // broadcast. 823 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 824 bool nonScalarSeen = false; 825 for (Attribute a : attributes) { 826 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 827 if (nonScalarSeen) 828 return false; 829 nonScalarSeen = true; 830 } 831 } 832 return true; 833 } 834 835 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 836 // No broadcasting is needed if all operands but one are scalar. 837 if (hasAtMostSingleNonScalar(operands)) 838 return BoolAttr::get(getContext(), true); 839 840 if ([&] { 841 SmallVector<SmallVector<int64_t, 6>, 6> extents; 842 for (const auto &operand : operands) { 843 if (!operand) 844 return false; 845 extents.push_back(llvm::to_vector<6>( 846 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 847 } 848 return OpTrait::util::staticallyKnownBroadcastable(extents); 849 }()) 850 return BoolAttr::get(getContext(), true); 851 852 // Lastly, see if folding can be completed based on what constraints are known 853 // on the input shapes. 854 if ([&] { 855 SmallVector<SmallVector<int64_t, 6>, 6> extents; 856 for (auto shapeValue : shapes()) { 857 extents.emplace_back(); 858 if (failed(getShapeVec(shapeValue, extents.back()))) 859 return false; 860 } 861 return OpTrait::util::staticallyKnownBroadcastable(extents); 862 }()) 863 return BoolAttr::get(getContext(), true); 864 865 // Because a failing witness result here represents an eventual assertion 866 // failure, we do not replace it with a constant witness. 867 return nullptr; 868 } 869 870 static LogicalResult verify(CstrBroadcastableOp op) { 871 // Ensure that AssumingAllOp contains at least one operand 872 if (op.getNumOperands() < 2) 873 return op.emitOpError("required at least 2 input shapes"); 874 return success(); 875 } 876 877 //===----------------------------------------------------------------------===// 878 // CstrEqOp 879 //===----------------------------------------------------------------------===// 880 881 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 882 MLIRContext *context) { 883 // If inputs are equal, return passing witness 884 patterns.add<CstrEqEqOps>(context); 885 } 886 887 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 888 if (llvm::all_of(operands, 889 [&](Attribute a) { return a && a == operands[0]; })) 890 return BoolAttr::get(getContext(), true); 891 892 // Because a failing witness result here represents an eventual assertion 893 // failure, we do not try to replace it with a constant witness. Similarly, we 894 // cannot if there are any non-const inputs. 895 return nullptr; 896 } 897 898 //===----------------------------------------------------------------------===// 899 // ConstSizeOp 900 //===----------------------------------------------------------------------===// 901 902 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 903 int64_t value) { 904 build(builder, result, builder.getIndexAttr(value)); 905 } 906 907 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 908 909 void ConstSizeOp::getAsmResultNames( 910 llvm::function_ref<void(Value, StringRef)> setNameFn) { 911 SmallString<4> buffer; 912 llvm::raw_svector_ostream os(buffer); 913 os << "c" << value(); 914 setNameFn(getResult(), os.str()); 915 } 916 917 //===----------------------------------------------------------------------===// 918 // ConstWitnessOp 919 //===----------------------------------------------------------------------===// 920 921 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 922 923 //===----------------------------------------------------------------------===// 924 // CstrRequireOp 925 //===----------------------------------------------------------------------===// 926 927 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 928 return operands[0]; 929 } 930 931 //===----------------------------------------------------------------------===// 932 // DivOp 933 //===----------------------------------------------------------------------===// 934 935 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 936 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 937 if (!lhs) 938 return nullptr; 939 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 940 if (!rhs) 941 return nullptr; 942 943 // Division in APInt does not follow floor(lhs, rhs) when the result is 944 // negative. Rather, APInt rounds toward zero. 945 APInt quotient, remainder; 946 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 947 if (quotient.isNegative() && !remainder.isNullValue()) { 948 quotient -= 1; 949 } 950 951 Type indexTy = IndexType::get(getContext()); 952 return IntegerAttr::get(indexTy, quotient); 953 } 954 955 //===----------------------------------------------------------------------===// 956 // ShapeEqOp 957 //===----------------------------------------------------------------------===// 958 959 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 960 bool allSame = true; 961 if (!operands.empty() && !operands[0]) 962 return {}; 963 for (Attribute operand : operands.drop_front(1)) { 964 if (!operand) 965 return {}; 966 allSame = allSame && operand == operands[0]; 967 } 968 return BoolAttr::get(getContext(), allSame); 969 } 970 971 //===----------------------------------------------------------------------===// 972 // IndexToSizeOp 973 //===----------------------------------------------------------------------===// 974 975 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 976 // Constant values of both types, `shape.size` and `index`, are represented as 977 // `IntegerAttr`s which makes constant folding simple. 978 if (Attribute arg = operands[0]) 979 return arg; 980 return {}; 981 } 982 983 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 984 MLIRContext *context) { 985 patterns.add<SizeToIndexToSizeCanonicalization>(context); 986 } 987 988 //===----------------------------------------------------------------------===// 989 // FromExtentsOp 990 //===----------------------------------------------------------------------===// 991 992 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 993 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 994 return nullptr; 995 SmallVector<int64_t, 6> extents; 996 for (auto attr : operands) 997 extents.push_back(attr.cast<IntegerAttr>().getInt()); 998 Builder builder(getContext()); 999 return builder.getIndexTensorAttr(extents); 1000 } 1001 1002 //===----------------------------------------------------------------------===// 1003 // FunctionLibraryOp 1004 //===----------------------------------------------------------------------===// 1005 1006 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1007 StringRef name) { 1008 result.attributes.push_back(builder.getNamedAttr( 1009 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1010 } 1011 1012 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1013 auto attr = mapping() 1014 .get(op->getName().getIdentifier()) 1015 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1016 if (!attr) 1017 return nullptr; 1018 return lookupSymbol<FuncOp>(attr); 1019 } 1020 1021 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1022 OperationState &result) { 1023 // Parse the op name. 1024 StringAttr nameAttr; 1025 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1026 result.attributes)) 1027 return failure(); 1028 1029 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1030 return failure(); 1031 1032 auto *bodyRegion = result.addRegion(); 1033 if (parser.parseRegion(*bodyRegion)) 1034 return failure(); 1035 1036 if (parser.parseKeyword("mapping")) 1037 return failure(); 1038 1039 DictionaryAttr mappingAttr; 1040 if (parser.parseAttribute(mappingAttr, 1041 parser.getBuilder().getType<NoneType>(), "mapping", 1042 result.attributes)) 1043 return failure(); 1044 return success(); 1045 } 1046 1047 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1048 p << op.getOperationName() << ' '; 1049 p.printSymbolName(op.getName()); 1050 p.printOptionalAttrDictWithKeyword( 1051 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1052 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1053 /*printBlockTerminators=*/false); 1054 p << " mapping "; 1055 p.printAttributeWithoutType(op.mappingAttr()); 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // GetExtentOp 1060 //===----------------------------------------------------------------------===// 1061 1062 Optional<int64_t> GetExtentOp::getConstantDim() { 1063 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1064 return constSizeOp.value().getLimitedValue(); 1065 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 1066 return constantOp.value().cast<IntegerAttr>().getInt(); 1067 return llvm::None; 1068 } 1069 1070 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1071 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1072 if (!elements) 1073 return nullptr; 1074 Optional<int64_t> dim = getConstantDim(); 1075 if (!dim.hasValue()) 1076 return nullptr; 1077 if (dim.getValue() >= elements.getNumElements()) 1078 return nullptr; 1079 return elements.getValue({(uint64_t)dim.getValue()}); 1080 } 1081 1082 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1083 int64_t dim) { 1084 auto loc = result.location; 1085 auto dimAttr = builder.getIndexAttr(dim); 1086 if (shape.getType().isa<ShapeType>()) { 1087 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1088 build(builder, result, builder.getType<SizeType>(), shape, dim); 1089 } else { 1090 Value dim = 1091 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 1092 build(builder, result, builder.getIndexType(), shape, dim); 1093 } 1094 } 1095 1096 //===----------------------------------------------------------------------===// 1097 // IsBroadcastableOp 1098 //===----------------------------------------------------------------------===// 1099 1100 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1101 MLIRContext *context) { 1102 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1103 } 1104 1105 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1106 // Can always broadcast fewer than two shapes. 1107 if (operands.size() < 2) { 1108 return BoolAttr::get(getContext(), true); 1109 } 1110 1111 return nullptr; 1112 } 1113 1114 //===----------------------------------------------------------------------===// 1115 // RankOp 1116 //===----------------------------------------------------------------------===// 1117 1118 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1119 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1120 if (!shape) 1121 return {}; 1122 int64_t rank = shape.getNumElements(); 1123 Builder builder(getContext()); 1124 return builder.getIndexAttr(rank); 1125 } 1126 1127 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1128 /// Constant folding fails in cases where only the rank is constant, not the 1129 /// shape itself. 1130 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1131 /// 1132 /// Example: 1133 /// 1134 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1135 /// %rank = shape.rank %shape 1136 /// 1137 /// becomes 1138 /// 1139 /// %rank = shape.const_size 3 1140 1141 namespace { 1142 struct RankShapeOfCanonicalizationPattern 1143 : public OpRewritePattern<shape::RankOp> { 1144 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1145 1146 LogicalResult matchAndRewrite(shape::RankOp op, 1147 PatternRewriter &rewriter) const override { 1148 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1149 if (!shapeOfOp) 1150 return failure(); 1151 auto rankedTensorType = 1152 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1153 if (!rankedTensorType) 1154 return failure(); 1155 int64_t rank = rankedTensorType.getRank(); 1156 if (op.getType().isa<IndexType>()) { 1157 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 1158 } else if (op.getType().isa<shape::SizeType>()) { 1159 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1160 } else { 1161 return failure(); 1162 } 1163 return success(); 1164 } 1165 }; 1166 } // namespace 1167 1168 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1169 MLIRContext *context) { 1170 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1171 } 1172 1173 //===----------------------------------------------------------------------===// 1174 // NumElementsOp 1175 //===----------------------------------------------------------------------===// 1176 1177 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1178 1179 // Fold only when argument constant. 1180 Attribute shape = operands[0]; 1181 if (!shape) 1182 return {}; 1183 1184 APInt product(64, 1); 1185 for (auto value : shape.cast<DenseIntElementsAttr>()) 1186 product *= value; 1187 Builder builder(getContext()); 1188 return builder.getIndexAttr(product.getLimitedValue()); 1189 } 1190 1191 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 1192 Value shape) { 1193 if (shape.getType().isa<ShapedType>()) { 1194 auto type = builder.getIndexType(); 1195 return build(builder, result, type, shape); 1196 } 1197 auto type = SizeType::get(builder.getContext()); 1198 return build(builder, result, type, shape); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // MaxOp 1203 //===----------------------------------------------------------------------===// 1204 1205 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1206 // If operands are equal, just propagate one. 1207 if (lhs() == rhs()) 1208 return lhs(); 1209 return nullptr; 1210 } 1211 1212 //===----------------------------------------------------------------------===// 1213 // MinOp 1214 //===----------------------------------------------------------------------===// 1215 1216 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1217 // If operands are equal, just propagate one. 1218 if (lhs() == rhs()) 1219 return lhs(); 1220 return nullptr; 1221 } 1222 1223 //===----------------------------------------------------------------------===// 1224 // MulOp 1225 //===----------------------------------------------------------------------===// 1226 1227 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1228 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1229 if (!lhs) 1230 return nullptr; 1231 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1232 if (!rhs) 1233 return nullptr; 1234 APInt folded = lhs.getValue() * rhs.getValue(); 1235 Type indexTy = IndexType::get(getContext()); 1236 return IntegerAttr::get(indexTy, folded); 1237 } 1238 1239 //===----------------------------------------------------------------------===// 1240 // ShapeOfOp 1241 //===----------------------------------------------------------------------===// 1242 1243 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1244 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1245 if (!type || !type.hasStaticShape()) 1246 return nullptr; 1247 Builder builder(getContext()); 1248 return builder.getIndexTensorAttr(type.getShape()); 1249 } 1250 1251 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 1252 if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) { 1253 int64_t rank = 1254 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1255 Type indexTy = builder.getIndexType(); 1256 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1257 return ShapeOfOp::build(builder, result, extentTensorTy, arg); 1258 } 1259 Type shapeTy = builder.getType<ShapeType>(); 1260 return ShapeOfOp::build(builder, result, shapeTy, arg); 1261 } 1262 1263 namespace { 1264 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1265 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1266 1267 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1268 PatternRewriter &rewriter) const override { 1269 if (!op.arg().getType().isa<ShapedType>()) 1270 return failure(); 1271 if (op.getType().isa<ShapedType>()) 1272 return failure(); 1273 1274 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1275 return success(); 1276 } 1277 }; 1278 1279 // Canonicalize 1280 // ``` 1281 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1282 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1283 // ``` 1284 // to 1285 // ``` 1286 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1287 // ``` 1288 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1289 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1290 1291 LogicalResult matchAndRewrite(tensor::CastOp op, 1292 PatternRewriter &rewriter) const override { 1293 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1294 if (!ty || ty.getRank() != 1) 1295 return failure(); 1296 1297 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1298 if (!shapeOfOp) 1299 return failure(); 1300 1301 // Argument type must be ranked and must not conflict. 1302 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1303 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1304 return failure(); 1305 1306 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1307 return success(); 1308 } 1309 }; 1310 } // namespace 1311 1312 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1313 MLIRContext *context) { 1314 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); 1315 } 1316 1317 //===----------------------------------------------------------------------===// 1318 // SizeToIndexOp 1319 //===----------------------------------------------------------------------===// 1320 1321 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1322 // Constant values of both types, `shape.size` and `index`, are represented as 1323 // `IntegerAttr`s which makes constant folding simple. 1324 if (Attribute arg = operands[0]) 1325 return arg; 1326 return impl::foldCastOp(*this); 1327 } 1328 1329 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1330 MLIRContext *context) { 1331 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1332 } 1333 1334 //===----------------------------------------------------------------------===// 1335 // YieldOp 1336 //===----------------------------------------------------------------------===// 1337 1338 static LogicalResult verify(shape::YieldOp op) { 1339 auto *parentOp = op->getParentOp(); 1340 auto results = parentOp->getResults(); 1341 auto operands = op.getOperands(); 1342 1343 if (parentOp->getNumResults() != op.getNumOperands()) 1344 return op.emitOpError() << "number of operands does not match number of " 1345 "results of its parent"; 1346 for (auto e : llvm::zip(results, operands)) 1347 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1348 return op.emitOpError() 1349 << "types mismatch between yield op and its parent"; 1350 1351 return success(); 1352 } 1353 1354 //===----------------------------------------------------------------------===// 1355 // SplitAtOp 1356 //===----------------------------------------------------------------------===// 1357 1358 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1359 SmallVectorImpl<OpFoldResult> &results) { 1360 if (!operands[0] || !operands[1]) 1361 return failure(); 1362 auto shapeVec = llvm::to_vector<6>( 1363 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1364 auto shape = llvm::makeArrayRef(shapeVec); 1365 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1366 // Verify that the split point is in the correct range. 1367 // TODO: Constant fold to an "error". 1368 int64_t rank = shape.size(); 1369 if (!(-rank <= splitPoint && splitPoint <= rank)) 1370 return failure(); 1371 if (splitPoint < 0) 1372 splitPoint += shape.size(); 1373 Builder builder(operands[0].getContext()); 1374 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1375 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1376 return success(); 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // ToExtentTensorOp 1381 //===----------------------------------------------------------------------===// 1382 1383 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1384 if (!operands[0]) 1385 return impl::foldCastOp(*this); 1386 Builder builder(getContext()); 1387 auto shape = llvm::to_vector<6>( 1388 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1389 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1390 builder.getIndexType()); 1391 return DenseIntElementsAttr::get(type, shape); 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // ReduceOp 1396 //===----------------------------------------------------------------------===// 1397 1398 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1399 ValueRange initVals) { 1400 result.addOperands(shape); 1401 result.addOperands(initVals); 1402 1403 Region *bodyRegion = result.addRegion(); 1404 bodyRegion->push_back(new Block); 1405 Block &bodyBlock = bodyRegion->front(); 1406 bodyBlock.addArgument(builder.getIndexType()); 1407 1408 Type elementType; 1409 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1410 elementType = tensorType.getElementType(); 1411 else 1412 elementType = SizeType::get(builder.getContext()); 1413 bodyBlock.addArgument(elementType); 1414 1415 for (Type initValType : initVals.getTypes()) { 1416 bodyBlock.addArgument(initValType); 1417 result.addTypes(initValType); 1418 } 1419 } 1420 1421 static LogicalResult verify(ReduceOp op) { 1422 // Verify block arg types. 1423 Block &block = op.region().front(); 1424 1425 // The block takes index, extent, and aggregated values as arguments. 1426 auto blockArgsCount = op.initVals().size() + 2; 1427 if (block.getNumArguments() != blockArgsCount) 1428 return op.emitOpError() << "ReduceOp body is expected to have " 1429 << blockArgsCount << " arguments"; 1430 1431 // The first block argument is the index and must always be of type `index`. 1432 if (!block.getArgument(0).getType().isa<IndexType>()) 1433 return op.emitOpError( 1434 "argument 0 of ReduceOp body is expected to be of IndexType"); 1435 1436 // The second block argument is the extent and must be of type `size` or 1437 // `index`, depending on whether the reduce operation is applied to a shape or 1438 // to an extent tensor. 1439 Type extentTy = block.getArgument(1).getType(); 1440 if (op.shape().getType().isa<ShapeType>()) { 1441 if (!extentTy.isa<SizeType>()) 1442 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1443 "SizeType if the ReduceOp operates on a ShapeType"); 1444 } else { 1445 if (!extentTy.isa<IndexType>()) 1446 return op.emitOpError( 1447 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1448 "ReduceOp operates on an extent tensor"); 1449 } 1450 1451 for (auto type : llvm::enumerate(op.initVals())) 1452 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1453 return op.emitOpError() 1454 << "type mismatch between argument " << type.index() + 2 1455 << " of ReduceOp body and initial value " << type.index(); 1456 return success(); 1457 } 1458 1459 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1460 // Parse operands. 1461 SmallVector<OpAsmParser::OperandType, 3> operands; 1462 Type shapeOrExtentTensorType; 1463 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1464 OpAsmParser::Delimiter::Paren) || 1465 parser.parseColonType(shapeOrExtentTensorType) || 1466 parser.parseOptionalArrowTypeList(result.types)) 1467 return failure(); 1468 1469 // Resolve operands. 1470 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1471 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1472 result.operands) || 1473 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1474 result.operands)) 1475 return failure(); 1476 1477 // Parse the body. 1478 Region *body = result.addRegion(); 1479 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1480 return failure(); 1481 1482 // Parse attributes. 1483 if (parser.parseOptionalAttrDict(result.attributes)) 1484 return failure(); 1485 1486 return success(); 1487 } 1488 1489 static void print(OpAsmPrinter &p, ReduceOp op) { 1490 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1491 << ") : " << op.shape().getType(); 1492 p.printOptionalArrowTypeList(op.getResultTypes()); 1493 p.printRegion(op.region()); 1494 p.printOptionalAttrDict(op->getAttrs()); 1495 } 1496 1497 #define GET_OP_CLASSES 1498 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1499