1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/InliningUtils.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 using namespace mlir; 25 using namespace mlir::shape; 26 27 namespace { 28 #include "ShapeCanonicalization.inc" 29 } 30 31 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 32 return RankedTensorType::get({rank}, IndexType::get(ctx)); 33 } 34 35 bool shape::isExtentTensorType(Type type) { 36 auto ranked = type.dyn_cast<RankedTensorType>(); 37 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 38 } 39 40 LogicalResult shape::getShapeVec(Value input, 41 SmallVectorImpl<int64_t> &shapeValues) { 42 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 43 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 44 if (!type.hasRank()) 45 return failure(); 46 shapeValues = llvm::to_vector<6>(type.getShape()); 47 return success(); 48 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 49 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 50 return success(); 51 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) { 52 shapeValues = llvm::to_vector<6>( 53 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 54 return success(); 55 } else { 56 return failure(); 57 } 58 } 59 60 static bool isErrorPropagationPossible(TypeRange operandTypes) { 61 return llvm::any_of(operandTypes, [](Type ty) { 62 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 63 }); 64 } 65 66 static LogicalResult verifySizeOrIndexOp(Operation *op) { 67 assert(op != nullptr && op->getNumResults() == 1); 68 Type resultTy = op->getResultTypes().front(); 69 if (isErrorPropagationPossible(op->getOperandTypes())) { 70 if (!resultTy.isa<SizeType>()) 71 return op->emitOpError() 72 << "if at least one of the operands can hold error values then " 73 "the result must be of type `size` to propagate them"; 74 } 75 return success(); 76 } 77 78 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 79 assert(op != nullptr && op->getNumResults() == 1); 80 Type resultTy = op->getResultTypes().front(); 81 if (isErrorPropagationPossible(op->getOperandTypes())) { 82 if (!resultTy.isa<ShapeType>()) 83 return op->emitOpError() 84 << "if at least one of the operands can hold error values then " 85 "the result must be of type `shape` to propagate them"; 86 } 87 return success(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // InlinerInterface 92 //===----------------------------------------------------------------------===// 93 94 namespace { 95 /// This class defines the interface for inlining shape dialect ops. 96 struct ShapeInlinerInterface : public DialectInlinerInterface { 97 using DialectInlinerInterface::DialectInlinerInterface; 98 99 // Returns true if the given region 'src' can be inlined into the region 100 // 'dest' that is attached to an operation registered to the current dialect. 101 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 102 BlockAndValueMapping &) const final { 103 return true; 104 } 105 106 // Returns true if the given operation 'op', that is registered to this 107 // dialect, can be inlined into the region 'dest' that is attached to an 108 // operation registered to the current dialect. 109 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 110 BlockAndValueMapping &) const final { 111 return true; 112 } 113 }; 114 } // namespace 115 116 void ShapeDialect::initialize() { 117 addOperations< 118 #define GET_OP_LIST 119 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 120 >(); 121 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 122 addInterfaces<ShapeInlinerInterface>(); 123 // Allow unknown operations during prototyping and testing. As the dialect is 124 // still evolving it makes it simple to start with an unregistered ops and 125 // try different variants before actually defining the op. 126 allowUnknownOperations(); 127 } 128 129 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 130 Attribute value, Type type, 131 Location loc) { 132 if (type.isa<ShapeType>() || isExtentTensorType(type)) 133 return builder.create<ConstShapeOp>(loc, type, 134 value.cast<DenseIntElementsAttr>()); 135 if (type.isa<SizeType>()) 136 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 137 if (type.isa<WitnessType>()) 138 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 139 if (ConstantOp::isBuildableWith(value, type)) 140 return builder.create<ConstantOp>(loc, type, value); 141 return nullptr; 142 } 143 144 /// Parse a type registered to this dialect. 145 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 146 StringRef keyword; 147 if (parser.parseKeyword(&keyword)) 148 return Type(); 149 150 if (keyword == "shape") 151 return ShapeType::get(getContext()); 152 if (keyword == "size") 153 return SizeType::get(getContext()); 154 if (keyword == "value_shape") 155 return ValueShapeType::get(getContext()); 156 if (keyword == "witness") 157 return WitnessType::get(getContext()); 158 159 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 160 return Type(); 161 } 162 163 /// Print a type registered to this dialect. 164 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 165 TypeSwitch<Type>(type) 166 .Case<ShapeType>([&](Type) { os << "shape"; }) 167 .Case<SizeType>([&](Type) { os << "size"; }) 168 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 169 .Case<WitnessType>([&](Type) { os << "witness"; }) 170 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 171 } 172 173 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 174 NamedAttribute attribute) { 175 // Verify shape.lib attribute. 176 if (attribute.first == "shape.lib") { 177 if (!op->hasTrait<OpTrait::SymbolTable>()) 178 return op->emitError( 179 "shape.lib attribute may only be on op implementing SymbolTable"); 180 181 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 182 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 183 if (!symbol) 184 return op->emitError("shape function library ") 185 << symbolRef << " not found"; 186 return isa<shape::FunctionLibraryOp>(symbol) 187 ? success() 188 : op->emitError() 189 << symbolRef << " required to be shape function library"; 190 } 191 192 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 193 // Verify all entries are function libraries and mappings in libraries 194 // refer to unique ops. 195 DenseSet<Identifier> key; 196 for (auto it : arr) { 197 if (!it.isa<SymbolRefAttr>()) 198 return op->emitError( 199 "only SymbolRefAttr allowed in shape.lib attribute array"); 200 201 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 202 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 203 if (!shapeFnLib) 204 return op->emitError() 205 << it << " does not refer to FunctionLibraryOp"; 206 for (auto mapping : shapeFnLib.mapping()) { 207 if (!key.insert(mapping.first).second) { 208 return op->emitError("only one op to shape mapping allowed, found " 209 "multiple for `") 210 << mapping.first << "`"; 211 } 212 } 213 } 214 return success(); 215 } 216 217 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 218 "allowed as shape.lib attribute"); 219 } 220 return success(); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // AnyOp 225 //===----------------------------------------------------------------------===// 226 227 // TODO: Canonicalization should be implemented for shapes that can be 228 // determined through mixtures of the known dimensions of the inputs. 229 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 230 // Only the last operand is checked because AnyOp is commutative. 231 if (operands.back()) 232 return operands.back(); 233 234 return nullptr; 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AssumingOp 239 //===----------------------------------------------------------------------===// 240 241 static ParseResult parseAssumingOp(OpAsmParser &parser, 242 OperationState &result) { 243 result.regions.reserve(1); 244 Region *doRegion = result.addRegion(); 245 246 auto &builder = parser.getBuilder(); 247 OpAsmParser::OperandType cond; 248 if (parser.parseOperand(cond) || 249 parser.resolveOperand(cond, builder.getType<WitnessType>(), 250 result.operands)) 251 return failure(); 252 253 // Parse optional results type list. 254 if (parser.parseOptionalArrowTypeList(result.types)) 255 return failure(); 256 257 // Parse the region and add a terminator if elided. 258 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 259 return failure(); 260 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 261 262 // Parse the optional attribute list. 263 if (parser.parseOptionalAttrDict(result.attributes)) 264 return failure(); 265 return success(); 266 } 267 268 static void print(OpAsmPrinter &p, AssumingOp op) { 269 bool yieldsResults = !op.results().empty(); 270 271 p << AssumingOp::getOperationName() << " " << op.witness(); 272 if (yieldsResults) { 273 p << " -> (" << op.getResultTypes() << ")"; 274 } 275 p.printRegion(op.doRegion(), 276 /*printEntryBlockArgs=*/false, 277 /*printBlockTerminators=*/yieldsResults); 278 p.printOptionalAttrDict(op->getAttrs()); 279 } 280 281 namespace { 282 // Removes AssumingOp with a passing witness and inlines the region. 283 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 284 using OpRewritePattern<AssumingOp>::OpRewritePattern; 285 286 LogicalResult matchAndRewrite(AssumingOp op, 287 PatternRewriter &rewriter) const override { 288 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 289 if (!witness || !witness.passingAttr()) 290 return failure(); 291 292 AssumingOp::inlineRegionIntoParent(op, rewriter); 293 return success(); 294 } 295 }; 296 297 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 298 using OpRewritePattern<AssumingOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(AssumingOp op, 301 PatternRewriter &rewriter) const override { 302 Block *body = op.getBody(); 303 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 304 305 // Find used values. 306 SmallVector<Value, 4> newYieldOperands; 307 Value opResult, yieldOperand; 308 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 309 std::tie(opResult, yieldOperand) = it; 310 if (!opResult.getUses().empty()) { 311 newYieldOperands.push_back(yieldOperand); 312 } 313 } 314 315 // Rewrite only if redundant results exist. 316 if (newYieldOperands.size() == yieldOp->getNumOperands()) 317 return failure(); 318 319 // Replace yield op in the old assuming op's body and move the entire region 320 // to the new assuming op. 321 rewriter.setInsertionPointToEnd(body); 322 auto newYieldOp = 323 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 324 rewriter.setInsertionPoint(op); 325 auto newOp = rewriter.create<AssumingOp>( 326 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 327 newOp.doRegion().takeBody(op.doRegion()); 328 329 // Use the new results to replace the previously used ones. 330 SmallVector<Value, 4> replacementValues; 331 auto src = newOp.getResults().begin(); 332 for (auto it : op.getResults()) { 333 if (it.getUses().empty()) 334 replacementValues.push_back(nullptr); 335 else 336 replacementValues.push_back(*src++); 337 } 338 rewriter.replaceOp(op, replacementValues); 339 return success(); 340 } 341 }; 342 } // namespace 343 344 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 345 MLIRContext *context) { 346 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 347 } 348 349 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 350 void AssumingOp::getSuccessorRegions( 351 Optional<unsigned> index, ArrayRef<Attribute> operands, 352 SmallVectorImpl<RegionSuccessor> ®ions) { 353 // AssumingOp has unconditional control flow into the region and back to the 354 // parent, so return the correct RegionSuccessor purely based on the index 355 // being None or 0. 356 if (index.hasValue()) { 357 regions.push_back(RegionSuccessor(getResults())); 358 return; 359 } 360 361 regions.push_back(RegionSuccessor(&doRegion())); 362 } 363 364 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 365 PatternRewriter &rewriter) { 366 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 367 auto *assumingBlock = op.getBody(); 368 auto initPosition = rewriter.getInsertionPoint(); 369 auto *blockAfterAssuming = 370 rewriter.splitBlock(blockBeforeAssuming, initPosition); 371 372 // Remove the AssumingOp and AssumingYieldOp. 373 auto &yieldOp = assumingBlock->back(); 374 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 375 rewriter.replaceOp(op, yieldOp.getOperands()); 376 rewriter.eraseOp(&yieldOp); 377 378 // Merge blocks together as there was no branching behavior from the 379 // AssumingOp. 380 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 381 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 382 } 383 384 void AssumingOp::build( 385 OpBuilder &builder, OperationState &result, Value witness, 386 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 387 388 result.addOperands(witness); 389 Region *bodyRegion = result.addRegion(); 390 bodyRegion->push_back(new Block); 391 Block &bodyBlock = bodyRegion->front(); 392 393 // Build body. 394 OpBuilder::InsertionGuard guard(builder); 395 builder.setInsertionPointToStart(&bodyBlock); 396 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 397 builder.create<AssumingYieldOp>(result.location, yieldValues); 398 399 SmallVector<Type, 2> assumingTypes; 400 for (Value v : yieldValues) 401 assumingTypes.push_back(v.getType()); 402 result.addTypes(assumingTypes); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // AssumingAllOp 407 //===----------------------------------------------------------------------===// 408 409 namespace { 410 struct AssumingAllToCstrEqCanonicalization 411 : public OpRewritePattern<AssumingAllOp> { 412 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 413 414 LogicalResult matchAndRewrite(AssumingAllOp op, 415 PatternRewriter &rewriter) const override { 416 SmallVector<Value, 8> shapes; 417 for (Value w : op.inputs()) { 418 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 419 if (!cstrEqOp) 420 return failure(); 421 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 422 return llvm::is_contained(shapes, s); 423 }); 424 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 425 return failure(); 426 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 427 } 428 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 429 return success(); 430 } 431 }; 432 433 template <typename OpTy> 434 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 435 using OpRewritePattern<OpTy>::OpRewritePattern; 436 437 LogicalResult matchAndRewrite(OpTy op, 438 PatternRewriter &rewriter) const override { 439 // Find unique operands. 440 SmallVector<Value, 2> unique; 441 for (Value v : op.getOperands()) { 442 if (!llvm::is_contained(unique, v)) 443 unique.push_back(v); 444 } 445 446 // Reduce op to equivalent with unique operands. 447 if (unique.size() < op.getNumOperands()) { 448 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 449 op->getAttrs()); 450 return success(); 451 } 452 453 return failure(); 454 } 455 }; 456 } // namespace 457 458 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 459 MLIRContext *context) { 460 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 461 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 462 } 463 464 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 465 // Iterate in reverse to first handle all constant operands. They are 466 // guaranteed to be the tail of the inputs because this is commutative. 467 for (int idx = operands.size() - 1; idx >= 0; idx--) { 468 Attribute a = operands[idx]; 469 // Cannot fold if any inputs are not constant; 470 if (!a) 471 return nullptr; 472 473 // We do not need to keep statically known values after handling them in 474 // this method. 475 getOperation()->eraseOperand(idx); 476 477 // Always false if any input is statically known false 478 if (!a.cast<BoolAttr>().getValue()) 479 return a; 480 } 481 // If this is reached, all inputs were statically known passing. 482 return BoolAttr::get(getContext(), true); 483 } 484 485 static LogicalResult verify(AssumingAllOp op) { 486 // Ensure that AssumingAllOp contains at least one operand 487 if (op.getNumOperands() == 0) 488 return op.emitOpError("no operands specified"); 489 490 return success(); 491 } 492 493 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 494 ValueRange inputs) { 495 build(b, state, b.getType<WitnessType>(), inputs); 496 } 497 498 //===----------------------------------------------------------------------===// 499 // BroadcastOp 500 //===----------------------------------------------------------------------===// 501 502 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 503 if (shapes().size() == 1) { 504 // Otherwise, we need a cast which would be a canonicalization, not folding. 505 if (shapes().front().getType() != getType()) 506 return nullptr; 507 return shapes().front(); 508 } 509 510 // TODO: Support folding with more than 2 input shapes 511 if (shapes().size() > 2) 512 return nullptr; 513 514 if (!operands[0] || !operands[1]) 515 return nullptr; 516 auto lhsShape = llvm::to_vector<6>( 517 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 518 auto rhsShape = llvm::to_vector<6>( 519 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 520 SmallVector<int64_t, 6> resultShape; 521 522 // If the shapes are not compatible, we can't fold it. 523 // TODO: Fold to an "error". 524 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 525 return nullptr; 526 527 Builder builder(getContext()); 528 return builder.getIndexTensorAttr(resultShape); 529 } 530 531 static LogicalResult verify(BroadcastOp op) { 532 return verifyShapeOrExtentTensorOp(op); 533 } 534 535 namespace { 536 template <typename OpTy> 537 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 538 using OpRewritePattern<OpTy>::OpRewritePattern; 539 540 LogicalResult matchAndRewrite(OpTy op, 541 PatternRewriter &rewriter) const override { 542 auto isPotentiallyNonEmptyShape = [](Value shape) { 543 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 544 if (extentTensorTy.getDimSize(0) == 0) 545 return false; 546 } 547 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 548 if (constShape.shape().empty()) 549 return false; 550 } 551 return true; 552 }; 553 auto newOperands = llvm::to_vector<8>( 554 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 555 556 // Reduce op to equivalent without empty shape operands. 557 if (newOperands.size() < op.getNumOperands()) { 558 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 559 op->getAttrs()); 560 return success(); 561 } 562 563 return failure(); 564 } 565 }; 566 567 struct BroadcastForwardSingleOperandPattern 568 : public OpRewritePattern<BroadcastOp> { 569 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 570 571 LogicalResult matchAndRewrite(BroadcastOp op, 572 PatternRewriter &rewriter) const override { 573 if (op.getNumOperands() != 1) 574 return failure(); 575 Value replacement = op.shapes().front(); 576 577 // Insert cast if needed. 578 if (replacement.getType() != op.getType()) { 579 auto loc = op.getLoc(); 580 if (op.getType().isa<ShapeType>()) { 581 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 582 } else { 583 assert(!op.getType().isa<ShapeType>() && 584 !replacement.getType().isa<ShapeType>() && 585 "expect extent tensor cast"); 586 replacement = 587 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 588 } 589 } 590 591 rewriter.replaceOp(op, replacement); 592 return success(); 593 } 594 }; 595 596 struct BroadcastFoldConstantOperandsPattern 597 : public OpRewritePattern<BroadcastOp> { 598 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 599 600 LogicalResult matchAndRewrite(BroadcastOp op, 601 PatternRewriter &rewriter) const override { 602 SmallVector<int64_t, 8> foldedConstantShape; 603 SmallVector<Value, 8> newShapeOperands; 604 for (Value shape : op.shapes()) { 605 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 606 SmallVector<int64_t, 8> newFoldedConstantShape; 607 if (OpTrait::util::getBroadcastedShape( 608 foldedConstantShape, 609 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 610 newFoldedConstantShape)) { 611 foldedConstantShape = newFoldedConstantShape; 612 continue; 613 } 614 } 615 newShapeOperands.push_back(shape); 616 } 617 618 // Need at least two constant operands to fold anything. 619 if (op.getNumOperands() - newShapeOperands.size() < 2) 620 return failure(); 621 622 auto foldedConstantOperandsTy = RankedTensorType::get( 623 {static_cast<int64_t>(foldedConstantShape.size())}, 624 rewriter.getIndexType()); 625 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 626 op.getLoc(), foldedConstantOperandsTy, 627 rewriter.getIndexTensorAttr(foldedConstantShape))); 628 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 629 newShapeOperands); 630 return success(); 631 } 632 }; 633 634 template <typename OpTy> 635 struct CanonicalizeCastExtentTensorOperandsPattern 636 : public OpRewritePattern<OpTy> { 637 using OpRewritePattern<OpTy>::OpRewritePattern; 638 639 LogicalResult matchAndRewrite(OpTy op, 640 PatternRewriter &rewriter) const override { 641 // Canonicalize operands. 642 bool anyChange = false; 643 auto canonicalizeOperand = [&](Value operand) { 644 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 645 // Only eliminate the cast if it holds no shape information. 646 bool isInformationLoosingCast = 647 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 648 if (isInformationLoosingCast) { 649 anyChange = true; 650 return castOp.source(); 651 } 652 } 653 return operand; 654 }; 655 auto newOperands = llvm::to_vector<8>( 656 llvm::map_range(op.getOperands(), canonicalizeOperand)); 657 658 // Rewrite op if any change required. 659 if (!anyChange) 660 return failure(); 661 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 662 return success(); 663 } 664 }; 665 666 struct BroadcastConcretizeResultTypePattern 667 : public OpRewritePattern<BroadcastOp> { 668 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 669 670 LogicalResult matchAndRewrite(BroadcastOp op, 671 PatternRewriter &rewriter) const override { 672 // Only concretize dynamic extent tensor result types. 673 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 674 if (!resultTy || !resultTy.isDynamicDim(0)) 675 return failure(); 676 677 // Infer resulting shape rank if possible. 678 int64_t maxRank = 0; 679 for (Value shape : op.shapes()) { 680 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 681 // Cannot infer resulting shape rank if any operand is dynamically 682 // ranked. 683 if (extentTensorTy.isDynamicDim(0)) 684 return failure(); 685 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 686 } 687 } 688 689 auto newOp = rewriter.create<BroadcastOp>( 690 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 691 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 692 return success(); 693 } 694 }; 695 } // namespace 696 697 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 698 MLIRContext *context) { 699 patterns.add<BroadcastConcretizeResultTypePattern, 700 BroadcastFoldConstantOperandsPattern, 701 BroadcastForwardSingleOperandPattern, 702 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 703 RemoveDuplicateOperandsPattern<BroadcastOp>, 704 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 705 } 706 707 //===----------------------------------------------------------------------===// 708 // ConcatOp 709 //===----------------------------------------------------------------------===// 710 711 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 712 if (!operands[0] || !operands[1]) 713 return nullptr; 714 auto lhsShape = llvm::to_vector<6>( 715 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 716 auto rhsShape = llvm::to_vector<6>( 717 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 718 SmallVector<int64_t, 6> resultShape; 719 resultShape.append(lhsShape.begin(), lhsShape.end()); 720 resultShape.append(rhsShape.begin(), rhsShape.end()); 721 Builder builder(getContext()); 722 return builder.getIndexTensorAttr(resultShape); 723 } 724 725 //===----------------------------------------------------------------------===// 726 // ConstShapeOp 727 //===----------------------------------------------------------------------===// 728 729 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 730 p << "shape.const_shape "; 731 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 732 p << "["; 733 interleaveComma(op.shape().getValues<int64_t>(), p, 734 [&](int64_t i) { p << i; }); 735 p << "] : "; 736 p.printType(op.getType()); 737 } 738 739 static ParseResult parseConstShapeOp(OpAsmParser &parser, 740 OperationState &result) { 741 if (parser.parseOptionalAttrDict(result.attributes)) 742 return failure(); 743 // We piggy-back on ArrayAttr parsing, though we don't internally store the 744 // shape as an ArrayAttr. 745 // TODO: Implement custom parser and maybe make syntax a bit more concise. 746 Attribute extentsRaw; 747 NamedAttrList dummy; 748 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 749 return failure(); 750 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 751 if (!extentsArray) 752 return failure(); 753 SmallVector<int64_t, 6> ints; 754 for (Attribute extent : extentsArray) { 755 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 756 if (!attr) 757 return failure(); 758 ints.push_back(attr.getInt()); 759 } 760 Builder &builder = parser.getBuilder(); 761 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 762 Type resultTy; 763 if (parser.parseColonType(resultTy)) 764 return failure(); 765 result.types.push_back(resultTy); 766 return success(); 767 } 768 769 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 770 771 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 772 MLIRContext *context) { 773 patterns.add<TensorCastConstShape>(context); 774 } 775 776 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 777 MLIRContext *context, Optional<Location> location, ValueRange operands, 778 DictionaryAttr attributes, RegionRange regions, 779 SmallVectorImpl<Type> &inferredReturnTypes) { 780 Builder b(context); 781 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 782 if (!shape) 783 return emitOptionalError(location, "missing shape attribute"); 784 inferredReturnTypes.assign({RankedTensorType::get( 785 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 786 return success(); 787 } 788 789 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 790 TypeRange r) { 791 if (l.size() != 1 || r.size() != 1) 792 return false; 793 794 Type lhs = l.front(); 795 Type rhs = r.front(); 796 797 if (lhs == rhs) 798 return true; 799 800 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 801 // Shape type is compatible with all other valid return types. 802 return true; 803 804 return succeeded(verifyCompatibleShapes(lhs, rhs)); 805 } 806 807 //===----------------------------------------------------------------------===// 808 // CstrBroadcastableOp 809 //===----------------------------------------------------------------------===// 810 811 void CstrBroadcastableOp::getCanonicalizationPatterns( 812 RewritePatternSet &patterns, MLIRContext *context) { 813 // Canonicalization patterns have overlap with the considerations during 814 // folding in case additional shape information is inferred at some point that 815 // does not result in folding. 816 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 817 CstrBroadcastableEqOps, 818 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 819 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 820 } 821 822 // Return true if there is exactly one attribute not representing a scalar 823 // broadcast. 824 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 825 bool nonScalarSeen = false; 826 for (Attribute a : attributes) { 827 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 828 if (nonScalarSeen) 829 return false; 830 nonScalarSeen = true; 831 } 832 } 833 return true; 834 } 835 836 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 837 // No broadcasting is needed if all operands but one are scalar. 838 if (hasAtMostSingleNonScalar(operands)) 839 return BoolAttr::get(getContext(), true); 840 841 if ([&] { 842 SmallVector<SmallVector<int64_t, 6>, 6> extents; 843 for (const auto &operand : operands) { 844 if (!operand) 845 return false; 846 extents.push_back(llvm::to_vector<6>( 847 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 848 } 849 return OpTrait::util::staticallyKnownBroadcastable(extents); 850 }()) 851 return BoolAttr::get(getContext(), true); 852 853 // Lastly, see if folding can be completed based on what constraints are known 854 // on the input shapes. 855 if ([&] { 856 SmallVector<SmallVector<int64_t, 6>, 6> extents; 857 for (auto shapeValue : shapes()) { 858 extents.emplace_back(); 859 if (failed(getShapeVec(shapeValue, extents.back()))) 860 return false; 861 } 862 return OpTrait::util::staticallyKnownBroadcastable(extents); 863 }()) 864 return BoolAttr::get(getContext(), true); 865 866 // Because a failing witness result here represents an eventual assertion 867 // failure, we do not replace it with a constant witness. 868 return nullptr; 869 } 870 871 static LogicalResult verify(CstrBroadcastableOp op) { 872 // Ensure that AssumingAllOp contains at least one operand 873 if (op.getNumOperands() < 2) 874 return op.emitOpError("required at least 2 input shapes"); 875 return success(); 876 } 877 878 //===----------------------------------------------------------------------===// 879 // CstrEqOp 880 //===----------------------------------------------------------------------===// 881 882 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 883 MLIRContext *context) { 884 // If inputs are equal, return passing witness 885 patterns.add<CstrEqEqOps>(context); 886 } 887 888 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 889 if (llvm::all_of(operands, 890 [&](Attribute a) { return a && a == operands[0]; })) 891 return BoolAttr::get(getContext(), true); 892 893 // Because a failing witness result here represents an eventual assertion 894 // failure, we do not try to replace it with a constant witness. Similarly, we 895 // cannot if there are any non-const inputs. 896 return nullptr; 897 } 898 899 //===----------------------------------------------------------------------===// 900 // ConstSizeOp 901 //===----------------------------------------------------------------------===// 902 903 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 904 int64_t value) { 905 build(builder, result, builder.getIndexAttr(value)); 906 } 907 908 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 909 910 void ConstSizeOp::getAsmResultNames( 911 llvm::function_ref<void(Value, StringRef)> setNameFn) { 912 SmallString<4> buffer; 913 llvm::raw_svector_ostream os(buffer); 914 os << "c" << value(); 915 setNameFn(getResult(), os.str()); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // ConstWitnessOp 920 //===----------------------------------------------------------------------===// 921 922 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 923 924 //===----------------------------------------------------------------------===// 925 // CstrRequireOp 926 //===----------------------------------------------------------------------===// 927 928 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 929 return operands[0]; 930 } 931 932 //===----------------------------------------------------------------------===// 933 // DivOp 934 //===----------------------------------------------------------------------===// 935 936 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 937 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 938 if (!lhs) 939 return nullptr; 940 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 941 if (!rhs) 942 return nullptr; 943 944 // Division in APInt does not follow floor(lhs, rhs) when the result is 945 // negative. Rather, APInt rounds toward zero. 946 APInt quotient, remainder; 947 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 948 if (quotient.isNegative() && !remainder.isNullValue()) { 949 quotient -= 1; 950 } 951 952 Type indexTy = IndexType::get(getContext()); 953 return IntegerAttr::get(indexTy, quotient); 954 } 955 956 //===----------------------------------------------------------------------===// 957 // ShapeEqOp 958 //===----------------------------------------------------------------------===// 959 960 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 961 bool allSame = true; 962 if (!operands.empty() && !operands[0]) 963 return {}; 964 for (Attribute operand : operands.drop_front(1)) { 965 if (!operand) 966 return {}; 967 allSame = allSame && operand == operands[0]; 968 } 969 return BoolAttr::get(getContext(), allSame); 970 } 971 972 //===----------------------------------------------------------------------===// 973 // IndexToSizeOp 974 //===----------------------------------------------------------------------===// 975 976 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 977 // Constant values of both types, `shape.size` and `index`, are represented as 978 // `IntegerAttr`s which makes constant folding simple. 979 if (Attribute arg = operands[0]) 980 return arg; 981 return {}; 982 } 983 984 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 985 MLIRContext *context) { 986 patterns.add<SizeToIndexToSizeCanonicalization>(context); 987 } 988 989 //===----------------------------------------------------------------------===// 990 // FromExtentsOp 991 //===----------------------------------------------------------------------===// 992 993 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 994 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 995 return nullptr; 996 SmallVector<int64_t, 6> extents; 997 for (auto attr : operands) 998 extents.push_back(attr.cast<IntegerAttr>().getInt()); 999 Builder builder(getContext()); 1000 return builder.getIndexTensorAttr(extents); 1001 } 1002 1003 //===----------------------------------------------------------------------===// 1004 // FunctionLibraryOp 1005 //===----------------------------------------------------------------------===// 1006 1007 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1008 StringRef name) { 1009 result.attributes.push_back(builder.getNamedAttr( 1010 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1011 } 1012 1013 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1014 auto attr = mapping() 1015 .get(op->getName().getIdentifier()) 1016 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1017 if (!attr) 1018 return nullptr; 1019 return lookupSymbol<FuncOp>(attr); 1020 } 1021 1022 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1023 OperationState &result) { 1024 // Parse the op name. 1025 StringAttr nameAttr; 1026 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1027 result.attributes)) 1028 return failure(); 1029 1030 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1031 return failure(); 1032 1033 auto *bodyRegion = result.addRegion(); 1034 if (parser.parseRegion(*bodyRegion)) 1035 return failure(); 1036 1037 if (parser.parseKeyword("mapping")) 1038 return failure(); 1039 1040 DictionaryAttr mappingAttr; 1041 if (parser.parseAttribute(mappingAttr, 1042 parser.getBuilder().getType<NoneType>(), "mapping", 1043 result.attributes)) 1044 return failure(); 1045 return success(); 1046 } 1047 1048 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1049 p << op.getOperationName() << ' '; 1050 p.printSymbolName(op.getName()); 1051 p.printOptionalAttrDictWithKeyword( 1052 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1053 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1054 /*printBlockTerminators=*/false); 1055 p << " mapping "; 1056 p.printAttributeWithoutType(op.mappingAttr()); 1057 } 1058 1059 //===----------------------------------------------------------------------===// 1060 // GetExtentOp 1061 //===----------------------------------------------------------------------===// 1062 1063 Optional<int64_t> GetExtentOp::getConstantDim() { 1064 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1065 return constSizeOp.value().getLimitedValue(); 1066 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 1067 return constantOp.value().cast<IntegerAttr>().getInt(); 1068 return llvm::None; 1069 } 1070 1071 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1072 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1073 if (!elements) 1074 return nullptr; 1075 Optional<int64_t> dim = getConstantDim(); 1076 if (!dim.hasValue()) 1077 return nullptr; 1078 if (dim.getValue() >= elements.getNumElements()) 1079 return nullptr; 1080 return elements.getValue({(uint64_t)dim.getValue()}); 1081 } 1082 1083 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1084 int64_t dim) { 1085 auto loc = result.location; 1086 auto dimAttr = builder.getIndexAttr(dim); 1087 if (shape.getType().isa<ShapeType>()) { 1088 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1089 build(builder, result, builder.getType<SizeType>(), shape, dim); 1090 } else { 1091 Value dim = 1092 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 1093 build(builder, result, builder.getIndexType(), shape, dim); 1094 } 1095 } 1096 1097 //===----------------------------------------------------------------------===// 1098 // IsBroadcastableOp 1099 //===----------------------------------------------------------------------===// 1100 1101 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1102 MLIRContext *context) { 1103 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1104 } 1105 1106 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1107 // Can always broadcast fewer than two shapes. 1108 if (operands.size() < 2) { 1109 return BoolAttr::get(getContext(), true); 1110 } 1111 1112 return nullptr; 1113 } 1114 1115 //===----------------------------------------------------------------------===// 1116 // RankOp 1117 //===----------------------------------------------------------------------===// 1118 1119 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1120 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1121 if (!shape) 1122 return {}; 1123 int64_t rank = shape.getNumElements(); 1124 Builder builder(getContext()); 1125 return builder.getIndexAttr(rank); 1126 } 1127 1128 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1129 /// Constant folding fails in cases where only the rank is constant, not the 1130 /// shape itself. 1131 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1132 /// 1133 /// Example: 1134 /// 1135 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1136 /// %rank = shape.rank %shape 1137 /// 1138 /// becomes 1139 /// 1140 /// %rank = shape.const_size 3 1141 1142 namespace { 1143 struct RankShapeOfCanonicalizationPattern 1144 : public OpRewritePattern<shape::RankOp> { 1145 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1146 1147 LogicalResult matchAndRewrite(shape::RankOp op, 1148 PatternRewriter &rewriter) const override { 1149 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1150 if (!shapeOfOp) 1151 return failure(); 1152 auto rankedTensorType = 1153 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1154 if (!rankedTensorType) 1155 return failure(); 1156 int64_t rank = rankedTensorType.getRank(); 1157 if (op.getType().isa<IndexType>()) { 1158 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 1159 } else if (op.getType().isa<shape::SizeType>()) { 1160 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1161 } else { 1162 return failure(); 1163 } 1164 return success(); 1165 } 1166 }; 1167 } // namespace 1168 1169 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1170 MLIRContext *context) { 1171 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1172 } 1173 1174 //===----------------------------------------------------------------------===// 1175 // NumElementsOp 1176 //===----------------------------------------------------------------------===// 1177 1178 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1179 1180 // Fold only when argument constant. 1181 Attribute shape = operands[0]; 1182 if (!shape) 1183 return {}; 1184 1185 APInt product(64, 1); 1186 for (auto value : shape.cast<DenseIntElementsAttr>()) 1187 product *= value; 1188 Builder builder(getContext()); 1189 return builder.getIndexAttr(product.getLimitedValue()); 1190 } 1191 1192 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 1193 Value shape) { 1194 if (shape.getType().isa<ShapedType>()) { 1195 auto type = builder.getIndexType(); 1196 return build(builder, result, type, shape); 1197 } 1198 auto type = SizeType::get(builder.getContext()); 1199 return build(builder, result, type, shape); 1200 } 1201 1202 //===----------------------------------------------------------------------===// 1203 // MaxOp 1204 //===----------------------------------------------------------------------===// 1205 1206 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1207 // If operands are equal, just propagate one. 1208 if (lhs() == rhs()) 1209 return lhs(); 1210 return nullptr; 1211 } 1212 1213 //===----------------------------------------------------------------------===// 1214 // MinOp 1215 //===----------------------------------------------------------------------===// 1216 1217 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1218 // If operands are equal, just propagate one. 1219 if (lhs() == rhs()) 1220 return lhs(); 1221 return nullptr; 1222 } 1223 1224 //===----------------------------------------------------------------------===// 1225 // MulOp 1226 //===----------------------------------------------------------------------===// 1227 1228 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1229 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1230 if (!lhs) 1231 return nullptr; 1232 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1233 if (!rhs) 1234 return nullptr; 1235 APInt folded = lhs.getValue() * rhs.getValue(); 1236 Type indexTy = IndexType::get(getContext()); 1237 return IntegerAttr::get(indexTy, folded); 1238 } 1239 1240 //===----------------------------------------------------------------------===// 1241 // ShapeOfOp 1242 //===----------------------------------------------------------------------===// 1243 1244 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1245 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1246 if (!type || !type.hasStaticShape()) 1247 return nullptr; 1248 Builder builder(getContext()); 1249 return builder.getIndexTensorAttr(type.getShape()); 1250 } 1251 1252 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 1253 if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) { 1254 int64_t rank = 1255 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1256 Type indexTy = builder.getIndexType(); 1257 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1258 return ShapeOfOp::build(builder, result, extentTensorTy, arg); 1259 } 1260 Type shapeTy = builder.getType<ShapeType>(); 1261 return ShapeOfOp::build(builder, result, shapeTy, arg); 1262 } 1263 1264 namespace { 1265 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1266 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1267 1268 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1269 PatternRewriter &rewriter) const override { 1270 if (!op.arg().getType().isa<ShapedType>()) 1271 return failure(); 1272 if (op.getType().isa<ShapedType>()) 1273 return failure(); 1274 1275 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1276 return success(); 1277 } 1278 }; 1279 1280 // Canonicalize 1281 // ``` 1282 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1283 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1284 // ``` 1285 // to 1286 // ``` 1287 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1288 // ``` 1289 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1290 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1291 1292 LogicalResult matchAndRewrite(tensor::CastOp op, 1293 PatternRewriter &rewriter) const override { 1294 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1295 if (!ty || ty.getRank() != 1) 1296 return failure(); 1297 1298 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1299 if (!shapeOfOp) 1300 return failure(); 1301 1302 // Argument type must be ranked and must not conflict. 1303 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1304 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1305 return failure(); 1306 1307 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1308 return success(); 1309 } 1310 }; 1311 } // namespace 1312 1313 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1314 MLIRContext *context) { 1315 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); 1316 } 1317 1318 //===----------------------------------------------------------------------===// 1319 // SizeToIndexOp 1320 //===----------------------------------------------------------------------===// 1321 1322 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1323 // Constant values of both types, `shape.size` and `index`, are represented as 1324 // `IntegerAttr`s which makes constant folding simple. 1325 if (Attribute arg = operands[0]) 1326 return arg; 1327 return impl::foldCastOp(*this); 1328 } 1329 1330 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1331 MLIRContext *context) { 1332 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1333 } 1334 1335 //===----------------------------------------------------------------------===// 1336 // YieldOp 1337 //===----------------------------------------------------------------------===// 1338 1339 static LogicalResult verify(shape::YieldOp op) { 1340 auto *parentOp = op->getParentOp(); 1341 auto results = parentOp->getResults(); 1342 auto operands = op.getOperands(); 1343 1344 if (parentOp->getNumResults() != op.getNumOperands()) 1345 return op.emitOpError() << "number of operands does not match number of " 1346 "results of its parent"; 1347 for (auto e : llvm::zip(results, operands)) 1348 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1349 return op.emitOpError() 1350 << "types mismatch between yield op and its parent"; 1351 1352 return success(); 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // SplitAtOp 1357 //===----------------------------------------------------------------------===// 1358 1359 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1360 SmallVectorImpl<OpFoldResult> &results) { 1361 if (!operands[0] || !operands[1]) 1362 return failure(); 1363 auto shapeVec = llvm::to_vector<6>( 1364 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1365 auto shape = llvm::makeArrayRef(shapeVec); 1366 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1367 // Verify that the split point is in the correct range. 1368 // TODO: Constant fold to an "error". 1369 int64_t rank = shape.size(); 1370 if (!(-rank <= splitPoint && splitPoint <= rank)) 1371 return failure(); 1372 if (splitPoint < 0) 1373 splitPoint += shape.size(); 1374 Builder builder(operands[0].getContext()); 1375 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1376 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1377 return success(); 1378 } 1379 1380 //===----------------------------------------------------------------------===// 1381 // ToExtentTensorOp 1382 //===----------------------------------------------------------------------===// 1383 1384 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1385 if (!operands[0]) 1386 return impl::foldCastOp(*this); 1387 Builder builder(getContext()); 1388 auto shape = llvm::to_vector<6>( 1389 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1390 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1391 builder.getIndexType()); 1392 return DenseIntElementsAttr::get(type, shape); 1393 } 1394 1395 //===----------------------------------------------------------------------===// 1396 // ReduceOp 1397 //===----------------------------------------------------------------------===// 1398 1399 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1400 ValueRange initVals) { 1401 result.addOperands(shape); 1402 result.addOperands(initVals); 1403 1404 Region *bodyRegion = result.addRegion(); 1405 bodyRegion->push_back(new Block); 1406 Block &bodyBlock = bodyRegion->front(); 1407 bodyBlock.addArgument(builder.getIndexType()); 1408 1409 Type elementType; 1410 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1411 elementType = tensorType.getElementType(); 1412 else 1413 elementType = SizeType::get(builder.getContext()); 1414 bodyBlock.addArgument(elementType); 1415 1416 for (Type initValType : initVals.getTypes()) { 1417 bodyBlock.addArgument(initValType); 1418 result.addTypes(initValType); 1419 } 1420 } 1421 1422 static LogicalResult verify(ReduceOp op) { 1423 // Verify block arg types. 1424 Block &block = op.region().front(); 1425 1426 // The block takes index, extent, and aggregated values as arguments. 1427 auto blockArgsCount = op.initVals().size() + 2; 1428 if (block.getNumArguments() != blockArgsCount) 1429 return op.emitOpError() << "ReduceOp body is expected to have " 1430 << blockArgsCount << " arguments"; 1431 1432 // The first block argument is the index and must always be of type `index`. 1433 if (!block.getArgument(0).getType().isa<IndexType>()) 1434 return op.emitOpError( 1435 "argument 0 of ReduceOp body is expected to be of IndexType"); 1436 1437 // The second block argument is the extent and must be of type `size` or 1438 // `index`, depending on whether the reduce operation is applied to a shape or 1439 // to an extent tensor. 1440 Type extentTy = block.getArgument(1).getType(); 1441 if (op.shape().getType().isa<ShapeType>()) { 1442 if (!extentTy.isa<SizeType>()) 1443 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1444 "SizeType if the ReduceOp operates on a ShapeType"); 1445 } else { 1446 if (!extentTy.isa<IndexType>()) 1447 return op.emitOpError( 1448 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1449 "ReduceOp operates on an extent tensor"); 1450 } 1451 1452 for (auto type : llvm::enumerate(op.initVals())) 1453 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1454 return op.emitOpError() 1455 << "type mismatch between argument " << type.index() + 2 1456 << " of ReduceOp body and initial value " << type.index(); 1457 return success(); 1458 } 1459 1460 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1461 // Parse operands. 1462 SmallVector<OpAsmParser::OperandType, 3> operands; 1463 Type shapeOrExtentTensorType; 1464 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1465 OpAsmParser::Delimiter::Paren) || 1466 parser.parseColonType(shapeOrExtentTensorType) || 1467 parser.parseOptionalArrowTypeList(result.types)) 1468 return failure(); 1469 1470 // Resolve operands. 1471 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1472 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1473 result.operands) || 1474 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1475 result.operands)) 1476 return failure(); 1477 1478 // Parse the body. 1479 Region *body = result.addRegion(); 1480 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1481 return failure(); 1482 1483 // Parse attributes. 1484 if (parser.parseOptionalAttrDict(result.attributes)) 1485 return failure(); 1486 1487 return success(); 1488 } 1489 1490 static void print(OpAsmPrinter &p, ReduceOp op) { 1491 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1492 << ") : " << op.shape().getType(); 1493 p.printOptionalArrowTypeList(op.getResultTypes()); 1494 p.printRegion(op.region()); 1495 p.printOptionalAttrDict(op->getAttrs()); 1496 } 1497 1498 #define GET_OP_CLASSES 1499 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1500