1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/InliningUtils.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 using namespace mlir; 25 using namespace mlir::shape; 26 27 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 28 29 namespace { 30 #include "ShapeCanonicalization.inc" 31 } 32 33 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 34 return RankedTensorType::get({rank}, IndexType::get(ctx)); 35 } 36 37 bool shape::isExtentTensorType(Type type) { 38 auto ranked = type.dyn_cast<RankedTensorType>(); 39 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 40 } 41 42 LogicalResult shape::getShapeVec(Value input, 43 SmallVectorImpl<int64_t> &shapeValues) { 44 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 45 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 46 if (!type.hasRank()) 47 return failure(); 48 shapeValues = llvm::to_vector<6>(type.getShape()); 49 return success(); 50 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 51 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 52 return success(); 53 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) { 54 shapeValues = llvm::to_vector<6>( 55 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 56 return success(); 57 } else { 58 return failure(); 59 } 60 } 61 62 static bool isErrorPropagationPossible(TypeRange operandTypes) { 63 return llvm::any_of(operandTypes, [](Type ty) { 64 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 65 }); 66 } 67 68 static LogicalResult verifySizeOrIndexOp(Operation *op) { 69 assert(op != nullptr && op->getNumResults() == 1); 70 Type resultTy = op->getResultTypes().front(); 71 if (isErrorPropagationPossible(op->getOperandTypes())) { 72 if (!resultTy.isa<SizeType>()) 73 return op->emitOpError() 74 << "if at least one of the operands can hold error values then " 75 "the result must be of type `size` to propagate them"; 76 } 77 return success(); 78 } 79 80 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 81 assert(op != nullptr && op->getNumResults() == 1); 82 Type resultTy = op->getResultTypes().front(); 83 if (isErrorPropagationPossible(op->getOperandTypes())) { 84 if (!resultTy.isa<ShapeType>()) 85 return op->emitOpError() 86 << "if at least one of the operands can hold error values then " 87 "the result must be of type `shape` to propagate them"; 88 } 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // InlinerInterface 94 //===----------------------------------------------------------------------===// 95 96 namespace { 97 /// This class defines the interface for inlining shape dialect ops. 98 struct ShapeInlinerInterface : public DialectInlinerInterface { 99 using DialectInlinerInterface::DialectInlinerInterface; 100 101 // Returns true if the given region 'src' can be inlined into the region 102 // 'dest' that is attached to an operation registered to the current dialect. 103 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 104 BlockAndValueMapping &) const final { 105 return true; 106 } 107 108 // Returns true if the given operation 'op', that is registered to this 109 // dialect, can be inlined into the region 'dest' that is attached to an 110 // operation registered to the current dialect. 111 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 112 BlockAndValueMapping &) const final { 113 return true; 114 } 115 }; 116 } // namespace 117 118 void ShapeDialect::initialize() { 119 addOperations< 120 #define GET_OP_LIST 121 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 122 >(); 123 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 124 addInterfaces<ShapeInlinerInterface>(); 125 // Allow unknown operations during prototyping and testing. As the dialect is 126 // still evolving it makes it simple to start with an unregistered ops and 127 // try different variants before actually defining the op. 128 allowUnknownOperations(); 129 } 130 131 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 132 Attribute value, Type type, 133 Location loc) { 134 if (type.isa<ShapeType>() || isExtentTensorType(type)) 135 return builder.create<ConstShapeOp>(loc, type, 136 value.cast<DenseIntElementsAttr>()); 137 if (type.isa<SizeType>()) 138 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 139 if (type.isa<WitnessType>()) 140 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 141 if (ConstantOp::isBuildableWith(value, type)) 142 return builder.create<ConstantOp>(loc, type, value); 143 return nullptr; 144 } 145 146 /// Parse a type registered to this dialect. 147 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 148 StringRef keyword; 149 if (parser.parseKeyword(&keyword)) 150 return Type(); 151 152 if (keyword == "shape") 153 return ShapeType::get(getContext()); 154 if (keyword == "size") 155 return SizeType::get(getContext()); 156 if (keyword == "value_shape") 157 return ValueShapeType::get(getContext()); 158 if (keyword == "witness") 159 return WitnessType::get(getContext()); 160 161 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 162 return Type(); 163 } 164 165 /// Print a type registered to this dialect. 166 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 167 TypeSwitch<Type>(type) 168 .Case<ShapeType>([&](Type) { os << "shape"; }) 169 .Case<SizeType>([&](Type) { os << "size"; }) 170 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 171 .Case<WitnessType>([&](Type) { os << "witness"; }) 172 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 173 } 174 175 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 176 NamedAttribute attribute) { 177 // Verify shape.lib attribute. 178 if (attribute.first == "shape.lib") { 179 if (!op->hasTrait<OpTrait::SymbolTable>()) 180 return op->emitError( 181 "shape.lib attribute may only be on op implementing SymbolTable"); 182 183 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 184 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 185 if (!symbol) 186 return op->emitError("shape function library ") 187 << symbolRef << " not found"; 188 return isa<shape::FunctionLibraryOp>(symbol) 189 ? success() 190 : op->emitError() 191 << symbolRef << " required to be shape function library"; 192 } 193 194 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 195 // Verify all entries are function libraries and mappings in libraries 196 // refer to unique ops. 197 DenseSet<Identifier> key; 198 for (auto it : arr) { 199 if (!it.isa<SymbolRefAttr>()) 200 return op->emitError( 201 "only SymbolRefAttr allowed in shape.lib attribute array"); 202 203 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 204 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 205 if (!shapeFnLib) 206 return op->emitError() 207 << it << " does not refer to FunctionLibraryOp"; 208 for (auto mapping : shapeFnLib.mapping()) { 209 if (!key.insert(mapping.first).second) { 210 return op->emitError("only one op to shape mapping allowed, found " 211 "multiple for `") 212 << mapping.first << "`"; 213 } 214 } 215 } 216 return success(); 217 } 218 219 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 220 "allowed as shape.lib attribute"); 221 } 222 return success(); 223 } 224 225 //===----------------------------------------------------------------------===// 226 // AnyOp 227 //===----------------------------------------------------------------------===// 228 229 // TODO: Canonicalization should be implemented for shapes that can be 230 // determined through mixtures of the known dimensions of the inputs. 231 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 232 // Only the last operand is checked because AnyOp is commutative. 233 if (operands.back()) 234 return operands.back(); 235 236 return nullptr; 237 } 238 239 //===----------------------------------------------------------------------===// 240 // AssumingOp 241 //===----------------------------------------------------------------------===// 242 243 static ParseResult parseAssumingOp(OpAsmParser &parser, 244 OperationState &result) { 245 result.regions.reserve(1); 246 Region *doRegion = result.addRegion(); 247 248 auto &builder = parser.getBuilder(); 249 OpAsmParser::OperandType cond; 250 if (parser.parseOperand(cond) || 251 parser.resolveOperand(cond, builder.getType<WitnessType>(), 252 result.operands)) 253 return failure(); 254 255 // Parse optional results type list. 256 if (parser.parseOptionalArrowTypeList(result.types)) 257 return failure(); 258 259 // Parse the region and add a terminator if elided. 260 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 261 return failure(); 262 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 263 264 // Parse the optional attribute list. 265 if (parser.parseOptionalAttrDict(result.attributes)) 266 return failure(); 267 return success(); 268 } 269 270 static void print(OpAsmPrinter &p, AssumingOp op) { 271 bool yieldsResults = !op.results().empty(); 272 273 p << AssumingOp::getOperationName() << " " << op.witness(); 274 if (yieldsResults) { 275 p << " -> (" << op.getResultTypes() << ")"; 276 } 277 p.printRegion(op.doRegion(), 278 /*printEntryBlockArgs=*/false, 279 /*printBlockTerminators=*/yieldsResults); 280 p.printOptionalAttrDict(op->getAttrs()); 281 } 282 283 namespace { 284 // Removes AssumingOp with a passing witness and inlines the region. 285 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 286 using OpRewritePattern<AssumingOp>::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(AssumingOp op, 289 PatternRewriter &rewriter) const override { 290 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 291 if (!witness || !witness.passingAttr()) 292 return failure(); 293 294 AssumingOp::inlineRegionIntoParent(op, rewriter); 295 return success(); 296 } 297 }; 298 299 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 300 using OpRewritePattern<AssumingOp>::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(AssumingOp op, 303 PatternRewriter &rewriter) const override { 304 Block *body = op.getBody(); 305 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 306 307 // Find used values. 308 SmallVector<Value, 4> newYieldOperands; 309 Value opResult, yieldOperand; 310 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 311 std::tie(opResult, yieldOperand) = it; 312 if (!opResult.getUses().empty()) { 313 newYieldOperands.push_back(yieldOperand); 314 } 315 } 316 317 // Rewrite only if redundant results exist. 318 if (newYieldOperands.size() == yieldOp->getNumOperands()) 319 return failure(); 320 321 // Replace yield op in the old assuming op's body and move the entire region 322 // to the new assuming op. 323 rewriter.setInsertionPointToEnd(body); 324 auto newYieldOp = 325 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 326 rewriter.setInsertionPoint(op); 327 auto newOp = rewriter.create<AssumingOp>( 328 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 329 newOp.doRegion().takeBody(op.doRegion()); 330 331 // Use the new results to replace the previously used ones. 332 SmallVector<Value, 4> replacementValues; 333 auto src = newOp.getResults().begin(); 334 for (auto it : op.getResults()) { 335 if (it.getUses().empty()) 336 replacementValues.push_back(nullptr); 337 else 338 replacementValues.push_back(*src++); 339 } 340 rewriter.replaceOp(op, replacementValues); 341 return success(); 342 } 343 }; 344 } // namespace 345 346 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 347 MLIRContext *context) { 348 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 349 } 350 351 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 352 void AssumingOp::getSuccessorRegions( 353 Optional<unsigned> index, ArrayRef<Attribute> operands, 354 SmallVectorImpl<RegionSuccessor> ®ions) { 355 // AssumingOp has unconditional control flow into the region and back to the 356 // parent, so return the correct RegionSuccessor purely based on the index 357 // being None or 0. 358 if (index.hasValue()) { 359 regions.push_back(RegionSuccessor(getResults())); 360 return; 361 } 362 363 regions.push_back(RegionSuccessor(&doRegion())); 364 } 365 366 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 367 PatternRewriter &rewriter) { 368 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 369 auto *assumingBlock = op.getBody(); 370 auto initPosition = rewriter.getInsertionPoint(); 371 auto *blockAfterAssuming = 372 rewriter.splitBlock(blockBeforeAssuming, initPosition); 373 374 // Remove the AssumingOp and AssumingYieldOp. 375 auto &yieldOp = assumingBlock->back(); 376 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 377 rewriter.replaceOp(op, yieldOp.getOperands()); 378 rewriter.eraseOp(&yieldOp); 379 380 // Merge blocks together as there was no branching behavior from the 381 // AssumingOp. 382 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 383 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 384 } 385 386 void AssumingOp::build( 387 OpBuilder &builder, OperationState &result, Value witness, 388 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 389 390 result.addOperands(witness); 391 Region *bodyRegion = result.addRegion(); 392 bodyRegion->push_back(new Block); 393 Block &bodyBlock = bodyRegion->front(); 394 395 // Build body. 396 OpBuilder::InsertionGuard guard(builder); 397 builder.setInsertionPointToStart(&bodyBlock); 398 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 399 builder.create<AssumingYieldOp>(result.location, yieldValues); 400 401 SmallVector<Type, 2> assumingTypes; 402 for (Value v : yieldValues) 403 assumingTypes.push_back(v.getType()); 404 result.addTypes(assumingTypes); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // AssumingAllOp 409 //===----------------------------------------------------------------------===// 410 411 namespace { 412 struct AssumingAllToCstrEqCanonicalization 413 : public OpRewritePattern<AssumingAllOp> { 414 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 415 416 LogicalResult matchAndRewrite(AssumingAllOp op, 417 PatternRewriter &rewriter) const override { 418 SmallVector<Value, 8> shapes; 419 for (Value w : op.inputs()) { 420 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 421 if (!cstrEqOp) 422 return failure(); 423 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 424 return llvm::is_contained(shapes, s); 425 }); 426 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 427 return failure(); 428 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 429 } 430 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 431 return success(); 432 } 433 }; 434 435 template <typename OpTy> 436 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 437 using OpRewritePattern<OpTy>::OpRewritePattern; 438 439 LogicalResult matchAndRewrite(OpTy op, 440 PatternRewriter &rewriter) const override { 441 // Find unique operands. 442 SmallVector<Value, 2> unique; 443 for (Value v : op.getOperands()) { 444 if (!llvm::is_contained(unique, v)) 445 unique.push_back(v); 446 } 447 448 // Reduce op to equivalent with unique operands. 449 if (unique.size() < op.getNumOperands()) { 450 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 451 op->getAttrs()); 452 return success(); 453 } 454 455 return failure(); 456 } 457 }; 458 } // namespace 459 460 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 461 MLIRContext *context) { 462 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 463 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 464 } 465 466 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 467 // Iterate in reverse to first handle all constant operands. They are 468 // guaranteed to be the tail of the inputs because this is commutative. 469 for (int idx = operands.size() - 1; idx >= 0; idx--) { 470 Attribute a = operands[idx]; 471 // Cannot fold if any inputs are not constant; 472 if (!a) 473 return nullptr; 474 475 // We do not need to keep statically known values after handling them in 476 // this method. 477 getOperation()->eraseOperand(idx); 478 479 // Always false if any input is statically known false 480 if (!a.cast<BoolAttr>().getValue()) 481 return a; 482 } 483 // If this is reached, all inputs were statically known passing. 484 return BoolAttr::get(getContext(), true); 485 } 486 487 static LogicalResult verify(AssumingAllOp op) { 488 // Ensure that AssumingAllOp contains at least one operand 489 if (op.getNumOperands() == 0) 490 return op.emitOpError("no operands specified"); 491 492 return success(); 493 } 494 495 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 496 ValueRange inputs) { 497 build(b, state, b.getType<WitnessType>(), inputs); 498 } 499 500 //===----------------------------------------------------------------------===// 501 // BroadcastOp 502 //===----------------------------------------------------------------------===// 503 504 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 505 if (shapes().size() == 1) { 506 // Otherwise, we need a cast which would be a canonicalization, not folding. 507 if (shapes().front().getType() != getType()) 508 return nullptr; 509 return shapes().front(); 510 } 511 512 // TODO: Support folding with more than 2 input shapes 513 if (shapes().size() > 2) 514 return nullptr; 515 516 if (!operands[0] || !operands[1]) 517 return nullptr; 518 auto lhsShape = llvm::to_vector<6>( 519 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 520 auto rhsShape = llvm::to_vector<6>( 521 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 522 SmallVector<int64_t, 6> resultShape; 523 524 // If the shapes are not compatible, we can't fold it. 525 // TODO: Fold to an "error". 526 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 527 return nullptr; 528 529 Builder builder(getContext()); 530 return builder.getIndexTensorAttr(resultShape); 531 } 532 533 static LogicalResult verify(BroadcastOp op) { 534 return verifyShapeOrExtentTensorOp(op); 535 } 536 537 namespace { 538 template <typename OpTy> 539 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 540 using OpRewritePattern<OpTy>::OpRewritePattern; 541 542 LogicalResult matchAndRewrite(OpTy op, 543 PatternRewriter &rewriter) const override { 544 auto isPotentiallyNonEmptyShape = [](Value shape) { 545 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 546 if (extentTensorTy.getDimSize(0) == 0) 547 return false; 548 } 549 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 550 if (constShape.shape().empty()) 551 return false; 552 } 553 return true; 554 }; 555 auto newOperands = llvm::to_vector<8>( 556 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 557 558 // Reduce op to equivalent without empty shape operands. 559 if (newOperands.size() < op.getNumOperands()) { 560 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 561 op->getAttrs()); 562 return success(); 563 } 564 565 return failure(); 566 } 567 }; 568 569 struct BroadcastForwardSingleOperandPattern 570 : public OpRewritePattern<BroadcastOp> { 571 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 572 573 LogicalResult matchAndRewrite(BroadcastOp op, 574 PatternRewriter &rewriter) const override { 575 if (op.getNumOperands() != 1) 576 return failure(); 577 Value replacement = op.shapes().front(); 578 579 // Insert cast if needed. 580 if (replacement.getType() != op.getType()) { 581 auto loc = op.getLoc(); 582 if (op.getType().isa<ShapeType>()) { 583 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 584 } else { 585 assert(!op.getType().isa<ShapeType>() && 586 !replacement.getType().isa<ShapeType>() && 587 "expect extent tensor cast"); 588 replacement = 589 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 590 } 591 } 592 593 rewriter.replaceOp(op, replacement); 594 return success(); 595 } 596 }; 597 598 struct BroadcastFoldConstantOperandsPattern 599 : public OpRewritePattern<BroadcastOp> { 600 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 601 602 LogicalResult matchAndRewrite(BroadcastOp op, 603 PatternRewriter &rewriter) const override { 604 SmallVector<int64_t, 8> foldedConstantShape; 605 SmallVector<Value, 8> newShapeOperands; 606 for (Value shape : op.shapes()) { 607 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 608 SmallVector<int64_t, 8> newFoldedConstantShape; 609 if (OpTrait::util::getBroadcastedShape( 610 foldedConstantShape, 611 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 612 newFoldedConstantShape)) { 613 foldedConstantShape = newFoldedConstantShape; 614 continue; 615 } 616 } 617 newShapeOperands.push_back(shape); 618 } 619 620 // Need at least two constant operands to fold anything. 621 if (op.getNumOperands() - newShapeOperands.size() < 2) 622 return failure(); 623 624 auto foldedConstantOperandsTy = RankedTensorType::get( 625 {static_cast<int64_t>(foldedConstantShape.size())}, 626 rewriter.getIndexType()); 627 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 628 op.getLoc(), foldedConstantOperandsTy, 629 rewriter.getIndexTensorAttr(foldedConstantShape))); 630 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 631 newShapeOperands); 632 return success(); 633 } 634 }; 635 636 template <typename OpTy> 637 struct CanonicalizeCastExtentTensorOperandsPattern 638 : public OpRewritePattern<OpTy> { 639 using OpRewritePattern<OpTy>::OpRewritePattern; 640 641 LogicalResult matchAndRewrite(OpTy op, 642 PatternRewriter &rewriter) const override { 643 // Canonicalize operands. 644 bool anyChange = false; 645 auto canonicalizeOperand = [&](Value operand) { 646 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 647 // Only eliminate the cast if it holds no shape information. 648 bool isInformationLoosingCast = 649 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 650 if (isInformationLoosingCast) { 651 anyChange = true; 652 return castOp.source(); 653 } 654 } 655 return operand; 656 }; 657 auto newOperands = llvm::to_vector<8>( 658 llvm::map_range(op.getOperands(), canonicalizeOperand)); 659 660 // Rewrite op if any change required. 661 if (!anyChange) 662 return failure(); 663 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 664 return success(); 665 } 666 }; 667 668 struct BroadcastConcretizeResultTypePattern 669 : public OpRewritePattern<BroadcastOp> { 670 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 671 672 LogicalResult matchAndRewrite(BroadcastOp op, 673 PatternRewriter &rewriter) const override { 674 // Only concretize dynamic extent tensor result types. 675 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 676 if (!resultTy || !resultTy.isDynamicDim(0)) 677 return failure(); 678 679 // Infer resulting shape rank if possible. 680 int64_t maxRank = 0; 681 for (Value shape : op.shapes()) { 682 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 683 // Cannot infer resulting shape rank if any operand is dynamically 684 // ranked. 685 if (extentTensorTy.isDynamicDim(0)) 686 return failure(); 687 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 688 } 689 } 690 691 auto newOp = rewriter.create<BroadcastOp>( 692 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 693 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 694 return success(); 695 } 696 }; 697 } // namespace 698 699 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 700 MLIRContext *context) { 701 patterns.add<BroadcastConcretizeResultTypePattern, 702 BroadcastFoldConstantOperandsPattern, 703 BroadcastForwardSingleOperandPattern, 704 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 705 RemoveDuplicateOperandsPattern<BroadcastOp>, 706 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 707 } 708 709 //===----------------------------------------------------------------------===// 710 // ConcatOp 711 //===----------------------------------------------------------------------===// 712 713 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 714 if (!operands[0] || !operands[1]) 715 return nullptr; 716 auto lhsShape = llvm::to_vector<6>( 717 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 718 auto rhsShape = llvm::to_vector<6>( 719 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 720 SmallVector<int64_t, 6> resultShape; 721 resultShape.append(lhsShape.begin(), lhsShape.end()); 722 resultShape.append(rhsShape.begin(), rhsShape.end()); 723 Builder builder(getContext()); 724 return builder.getIndexTensorAttr(resultShape); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // ConstShapeOp 729 //===----------------------------------------------------------------------===// 730 731 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 732 p << "shape.const_shape "; 733 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 734 p << "["; 735 interleaveComma(op.shape().getValues<int64_t>(), p, 736 [&](int64_t i) { p << i; }); 737 p << "] : "; 738 p.printType(op.getType()); 739 } 740 741 static ParseResult parseConstShapeOp(OpAsmParser &parser, 742 OperationState &result) { 743 if (parser.parseOptionalAttrDict(result.attributes)) 744 return failure(); 745 // We piggy-back on ArrayAttr parsing, though we don't internally store the 746 // shape as an ArrayAttr. 747 // TODO: Implement custom parser and maybe make syntax a bit more concise. 748 Attribute extentsRaw; 749 NamedAttrList dummy; 750 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 751 return failure(); 752 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 753 if (!extentsArray) 754 return failure(); 755 SmallVector<int64_t, 6> ints; 756 for (Attribute extent : extentsArray) { 757 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 758 if (!attr) 759 return failure(); 760 ints.push_back(attr.getInt()); 761 } 762 Builder &builder = parser.getBuilder(); 763 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 764 Type resultTy; 765 if (parser.parseColonType(resultTy)) 766 return failure(); 767 result.types.push_back(resultTy); 768 return success(); 769 } 770 771 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 772 773 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 774 MLIRContext *context) { 775 patterns.add<TensorCastConstShape>(context); 776 } 777 778 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 779 MLIRContext *context, Optional<Location> location, ValueRange operands, 780 DictionaryAttr attributes, RegionRange regions, 781 SmallVectorImpl<Type> &inferredReturnTypes) { 782 Builder b(context); 783 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 784 if (!shape) 785 return emitOptionalError(location, "missing shape attribute"); 786 inferredReturnTypes.assign({RankedTensorType::get( 787 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 788 return success(); 789 } 790 791 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 792 TypeRange r) { 793 if (l.size() != 1 || r.size() != 1) 794 return false; 795 796 Type lhs = l.front(); 797 Type rhs = r.front(); 798 799 if (lhs == rhs) 800 return true; 801 802 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 803 // Shape type is compatible with all other valid return types. 804 return true; 805 806 return succeeded(verifyCompatibleShapes(lhs, rhs)); 807 } 808 809 //===----------------------------------------------------------------------===// 810 // CstrBroadcastableOp 811 //===----------------------------------------------------------------------===// 812 813 void CstrBroadcastableOp::getCanonicalizationPatterns( 814 RewritePatternSet &patterns, MLIRContext *context) { 815 // Canonicalization patterns have overlap with the considerations during 816 // folding in case additional shape information is inferred at some point that 817 // does not result in folding. 818 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 819 CstrBroadcastableEqOps, 820 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 821 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 822 } 823 824 // Return true if there is exactly one attribute not representing a scalar 825 // broadcast. 826 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 827 bool nonScalarSeen = false; 828 for (Attribute a : attributes) { 829 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 830 if (nonScalarSeen) 831 return false; 832 nonScalarSeen = true; 833 } 834 } 835 return true; 836 } 837 838 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 839 // No broadcasting is needed if all operands but one are scalar. 840 if (hasAtMostSingleNonScalar(operands)) 841 return BoolAttr::get(getContext(), true); 842 843 if ([&] { 844 SmallVector<SmallVector<int64_t, 6>, 6> extents; 845 for (const auto &operand : operands) { 846 if (!operand) 847 return false; 848 extents.push_back(llvm::to_vector<6>( 849 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 850 } 851 return OpTrait::util::staticallyKnownBroadcastable(extents); 852 }()) 853 return BoolAttr::get(getContext(), true); 854 855 // Lastly, see if folding can be completed based on what constraints are known 856 // on the input shapes. 857 if ([&] { 858 SmallVector<SmallVector<int64_t, 6>, 6> extents; 859 for (auto shapeValue : shapes()) { 860 extents.emplace_back(); 861 if (failed(getShapeVec(shapeValue, extents.back()))) 862 return false; 863 } 864 return OpTrait::util::staticallyKnownBroadcastable(extents); 865 }()) 866 return BoolAttr::get(getContext(), true); 867 868 // Because a failing witness result here represents an eventual assertion 869 // failure, we do not replace it with a constant witness. 870 return nullptr; 871 } 872 873 static LogicalResult verify(CstrBroadcastableOp op) { 874 // Ensure that AssumingAllOp contains at least one operand 875 if (op.getNumOperands() < 2) 876 return op.emitOpError("required at least 2 input shapes"); 877 return success(); 878 } 879 880 //===----------------------------------------------------------------------===// 881 // CstrEqOp 882 //===----------------------------------------------------------------------===// 883 884 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 885 MLIRContext *context) { 886 // If inputs are equal, return passing witness 887 patterns.add<CstrEqEqOps>(context); 888 } 889 890 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 891 if (llvm::all_of(operands, 892 [&](Attribute a) { return a && a == operands[0]; })) 893 return BoolAttr::get(getContext(), true); 894 895 // Because a failing witness result here represents an eventual assertion 896 // failure, we do not try to replace it with a constant witness. Similarly, we 897 // cannot if there are any non-const inputs. 898 return nullptr; 899 } 900 901 //===----------------------------------------------------------------------===// 902 // ConstSizeOp 903 //===----------------------------------------------------------------------===// 904 905 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 906 int64_t value) { 907 build(builder, result, builder.getIndexAttr(value)); 908 } 909 910 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 911 912 void ConstSizeOp::getAsmResultNames( 913 llvm::function_ref<void(Value, StringRef)> setNameFn) { 914 SmallString<4> buffer; 915 llvm::raw_svector_ostream os(buffer); 916 os << "c" << value(); 917 setNameFn(getResult(), os.str()); 918 } 919 920 //===----------------------------------------------------------------------===// 921 // ConstWitnessOp 922 //===----------------------------------------------------------------------===// 923 924 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 925 926 //===----------------------------------------------------------------------===// 927 // CstrRequireOp 928 //===----------------------------------------------------------------------===// 929 930 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 931 return operands[0]; 932 } 933 934 //===----------------------------------------------------------------------===// 935 // DivOp 936 //===----------------------------------------------------------------------===// 937 938 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 939 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 940 if (!lhs) 941 return nullptr; 942 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 943 if (!rhs) 944 return nullptr; 945 946 // Division in APInt does not follow floor(lhs, rhs) when the result is 947 // negative. Rather, APInt rounds toward zero. 948 APInt quotient, remainder; 949 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 950 if (quotient.isNegative() && !remainder.isNullValue()) { 951 quotient -= 1; 952 } 953 954 Type indexTy = IndexType::get(getContext()); 955 return IntegerAttr::get(indexTy, quotient); 956 } 957 958 //===----------------------------------------------------------------------===// 959 // ShapeEqOp 960 //===----------------------------------------------------------------------===// 961 962 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 963 bool allSame = true; 964 if (!operands.empty() && !operands[0]) 965 return {}; 966 for (Attribute operand : operands.drop_front(1)) { 967 if (!operand) 968 return {}; 969 allSame = allSame && operand == operands[0]; 970 } 971 return BoolAttr::get(getContext(), allSame); 972 } 973 974 //===----------------------------------------------------------------------===// 975 // IndexToSizeOp 976 //===----------------------------------------------------------------------===// 977 978 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 979 // Constant values of both types, `shape.size` and `index`, are represented as 980 // `IntegerAttr`s which makes constant folding simple. 981 if (Attribute arg = operands[0]) 982 return arg; 983 return {}; 984 } 985 986 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 987 MLIRContext *context) { 988 patterns.add<SizeToIndexToSizeCanonicalization>(context); 989 } 990 991 //===----------------------------------------------------------------------===// 992 // FromExtentsOp 993 //===----------------------------------------------------------------------===// 994 995 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 996 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 997 return nullptr; 998 SmallVector<int64_t, 6> extents; 999 for (auto attr : operands) 1000 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1001 Builder builder(getContext()); 1002 return builder.getIndexTensorAttr(extents); 1003 } 1004 1005 //===----------------------------------------------------------------------===// 1006 // FunctionLibraryOp 1007 //===----------------------------------------------------------------------===// 1008 1009 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1010 StringRef name) { 1011 result.attributes.push_back(builder.getNamedAttr( 1012 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1013 } 1014 1015 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1016 auto attr = mapping() 1017 .get(op->getName().getIdentifier()) 1018 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1019 if (!attr) 1020 return nullptr; 1021 return lookupSymbol<FuncOp>(attr); 1022 } 1023 1024 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1025 OperationState &result) { 1026 // Parse the op name. 1027 StringAttr nameAttr; 1028 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1029 result.attributes)) 1030 return failure(); 1031 1032 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1033 return failure(); 1034 1035 auto *bodyRegion = result.addRegion(); 1036 if (parser.parseRegion(*bodyRegion)) 1037 return failure(); 1038 1039 if (parser.parseKeyword("mapping")) 1040 return failure(); 1041 1042 DictionaryAttr mappingAttr; 1043 if (parser.parseAttribute(mappingAttr, 1044 parser.getBuilder().getType<NoneType>(), "mapping", 1045 result.attributes)) 1046 return failure(); 1047 return success(); 1048 } 1049 1050 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1051 p << op.getOperationName() << ' '; 1052 p.printSymbolName(op.getName()); 1053 p.printOptionalAttrDictWithKeyword( 1054 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1055 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1056 /*printBlockTerminators=*/false); 1057 p << " mapping "; 1058 p.printAttributeWithoutType(op.mappingAttr()); 1059 } 1060 1061 //===----------------------------------------------------------------------===// 1062 // GetExtentOp 1063 //===----------------------------------------------------------------------===// 1064 1065 Optional<int64_t> GetExtentOp::getConstantDim() { 1066 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1067 return constSizeOp.value().getLimitedValue(); 1068 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 1069 return constantOp.value().cast<IntegerAttr>().getInt(); 1070 return llvm::None; 1071 } 1072 1073 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1074 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1075 if (!elements) 1076 return nullptr; 1077 Optional<int64_t> dim = getConstantDim(); 1078 if (!dim.hasValue()) 1079 return nullptr; 1080 if (dim.getValue() >= elements.getNumElements()) 1081 return nullptr; 1082 return elements.getValue({(uint64_t)dim.getValue()}); 1083 } 1084 1085 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1086 int64_t dim) { 1087 auto loc = result.location; 1088 auto dimAttr = builder.getIndexAttr(dim); 1089 if (shape.getType().isa<ShapeType>()) { 1090 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1091 build(builder, result, builder.getType<SizeType>(), shape, dim); 1092 } else { 1093 Value dim = 1094 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 1095 build(builder, result, builder.getIndexType(), shape, dim); 1096 } 1097 } 1098 1099 //===----------------------------------------------------------------------===// 1100 // IsBroadcastableOp 1101 //===----------------------------------------------------------------------===// 1102 1103 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1104 MLIRContext *context) { 1105 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1106 } 1107 1108 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1109 // Can always broadcast fewer than two shapes. 1110 if (operands.size() < 2) { 1111 return BoolAttr::get(getContext(), true); 1112 } 1113 1114 return nullptr; 1115 } 1116 1117 //===----------------------------------------------------------------------===// 1118 // RankOp 1119 //===----------------------------------------------------------------------===// 1120 1121 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1122 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1123 if (!shape) 1124 return {}; 1125 int64_t rank = shape.getNumElements(); 1126 Builder builder(getContext()); 1127 return builder.getIndexAttr(rank); 1128 } 1129 1130 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1131 /// Constant folding fails in cases where only the rank is constant, not the 1132 /// shape itself. 1133 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1134 /// 1135 /// Example: 1136 /// 1137 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1138 /// %rank = shape.rank %shape 1139 /// 1140 /// becomes 1141 /// 1142 /// %rank = shape.const_size 3 1143 1144 namespace { 1145 struct RankShapeOfCanonicalizationPattern 1146 : public OpRewritePattern<shape::RankOp> { 1147 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1148 1149 LogicalResult matchAndRewrite(shape::RankOp op, 1150 PatternRewriter &rewriter) const override { 1151 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1152 if (!shapeOfOp) 1153 return failure(); 1154 auto rankedTensorType = 1155 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1156 if (!rankedTensorType) 1157 return failure(); 1158 int64_t rank = rankedTensorType.getRank(); 1159 if (op.getType().isa<IndexType>()) { 1160 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 1161 } else if (op.getType().isa<shape::SizeType>()) { 1162 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1163 } else { 1164 return failure(); 1165 } 1166 return success(); 1167 } 1168 }; 1169 } // namespace 1170 1171 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1172 MLIRContext *context) { 1173 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1174 } 1175 1176 //===----------------------------------------------------------------------===// 1177 // NumElementsOp 1178 //===----------------------------------------------------------------------===// 1179 1180 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1181 1182 // Fold only when argument constant. 1183 Attribute shape = operands[0]; 1184 if (!shape) 1185 return {}; 1186 1187 APInt product(64, 1); 1188 for (auto value : shape.cast<DenseIntElementsAttr>()) 1189 product *= value; 1190 Builder builder(getContext()); 1191 return builder.getIndexAttr(product.getLimitedValue()); 1192 } 1193 1194 void NumElementsOp::build(OpBuilder &builder, OperationState &result, 1195 Value shape) { 1196 if (shape.getType().isa<ShapedType>()) { 1197 auto type = builder.getIndexType(); 1198 return build(builder, result, type, shape); 1199 } 1200 auto type = SizeType::get(builder.getContext()); 1201 return build(builder, result, type, shape); 1202 } 1203 1204 //===----------------------------------------------------------------------===// 1205 // MaxOp 1206 //===----------------------------------------------------------------------===// 1207 1208 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1209 // If operands are equal, just propagate one. 1210 if (lhs() == rhs()) 1211 return lhs(); 1212 return nullptr; 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // MinOp 1217 //===----------------------------------------------------------------------===// 1218 1219 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1220 // If operands are equal, just propagate one. 1221 if (lhs() == rhs()) 1222 return lhs(); 1223 return nullptr; 1224 } 1225 1226 //===----------------------------------------------------------------------===// 1227 // MulOp 1228 //===----------------------------------------------------------------------===// 1229 1230 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1231 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1232 if (!lhs) 1233 return nullptr; 1234 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1235 if (!rhs) 1236 return nullptr; 1237 APInt folded = lhs.getValue() * rhs.getValue(); 1238 Type indexTy = IndexType::get(getContext()); 1239 return IntegerAttr::get(indexTy, folded); 1240 } 1241 1242 //===----------------------------------------------------------------------===// 1243 // ShapeOfOp 1244 //===----------------------------------------------------------------------===// 1245 1246 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1247 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1248 if (!type || !type.hasStaticShape()) 1249 return nullptr; 1250 Builder builder(getContext()); 1251 return builder.getIndexTensorAttr(type.getShape()); 1252 } 1253 1254 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { 1255 if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) { 1256 int64_t rank = 1257 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1258 Type indexTy = builder.getIndexType(); 1259 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1260 return ShapeOfOp::build(builder, result, extentTensorTy, arg); 1261 } 1262 Type shapeTy = builder.getType<ShapeType>(); 1263 return ShapeOfOp::build(builder, result, shapeTy, arg); 1264 } 1265 1266 namespace { 1267 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1268 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1269 1270 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1271 PatternRewriter &rewriter) const override { 1272 if (!op.arg().getType().isa<ShapedType>()) 1273 return failure(); 1274 if (op.getType().isa<ShapedType>()) 1275 return failure(); 1276 1277 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1278 return success(); 1279 } 1280 }; 1281 1282 // Canonicalize 1283 // ``` 1284 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1285 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1286 // ``` 1287 // to 1288 // ``` 1289 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1290 // ``` 1291 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1292 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1293 1294 LogicalResult matchAndRewrite(tensor::CastOp op, 1295 PatternRewriter &rewriter) const override { 1296 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1297 if (!ty || ty.getRank() != 1) 1298 return failure(); 1299 1300 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1301 if (!shapeOfOp) 1302 return failure(); 1303 1304 // Argument type must be ranked and must not conflict. 1305 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1306 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1307 return failure(); 1308 1309 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1310 return success(); 1311 } 1312 }; 1313 } // namespace 1314 1315 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1316 MLIRContext *context) { 1317 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); 1318 } 1319 1320 //===----------------------------------------------------------------------===// 1321 // SizeToIndexOp 1322 //===----------------------------------------------------------------------===// 1323 1324 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1325 // Constant values of both types, `shape.size` and `index`, are represented as 1326 // `IntegerAttr`s which makes constant folding simple. 1327 if (Attribute arg = operands[0]) 1328 return arg; 1329 return impl::foldCastOp(*this); 1330 } 1331 1332 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1333 MLIRContext *context) { 1334 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // YieldOp 1339 //===----------------------------------------------------------------------===// 1340 1341 static LogicalResult verify(shape::YieldOp op) { 1342 auto *parentOp = op->getParentOp(); 1343 auto results = parentOp->getResults(); 1344 auto operands = op.getOperands(); 1345 1346 if (parentOp->getNumResults() != op.getNumOperands()) 1347 return op.emitOpError() << "number of operands does not match number of " 1348 "results of its parent"; 1349 for (auto e : llvm::zip(results, operands)) 1350 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1351 return op.emitOpError() 1352 << "types mismatch between yield op and its parent"; 1353 1354 return success(); 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // SplitAtOp 1359 //===----------------------------------------------------------------------===// 1360 1361 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1362 SmallVectorImpl<OpFoldResult> &results) { 1363 if (!operands[0] || !operands[1]) 1364 return failure(); 1365 auto shapeVec = llvm::to_vector<6>( 1366 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1367 auto shape = llvm::makeArrayRef(shapeVec); 1368 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1369 // Verify that the split point is in the correct range. 1370 // TODO: Constant fold to an "error". 1371 int64_t rank = shape.size(); 1372 if (!(-rank <= splitPoint && splitPoint <= rank)) 1373 return failure(); 1374 if (splitPoint < 0) 1375 splitPoint += shape.size(); 1376 Builder builder(operands[0].getContext()); 1377 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1378 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1379 return success(); 1380 } 1381 1382 //===----------------------------------------------------------------------===// 1383 // ToExtentTensorOp 1384 //===----------------------------------------------------------------------===// 1385 1386 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1387 if (!operands[0]) 1388 return impl::foldCastOp(*this); 1389 Builder builder(getContext()); 1390 auto shape = llvm::to_vector<6>( 1391 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1392 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1393 builder.getIndexType()); 1394 return DenseIntElementsAttr::get(type, shape); 1395 } 1396 1397 //===----------------------------------------------------------------------===// 1398 // ReduceOp 1399 //===----------------------------------------------------------------------===// 1400 1401 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1402 ValueRange initVals) { 1403 result.addOperands(shape); 1404 result.addOperands(initVals); 1405 1406 Region *bodyRegion = result.addRegion(); 1407 bodyRegion->push_back(new Block); 1408 Block &bodyBlock = bodyRegion->front(); 1409 bodyBlock.addArgument(builder.getIndexType()); 1410 1411 Type elementType; 1412 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1413 elementType = tensorType.getElementType(); 1414 else 1415 elementType = SizeType::get(builder.getContext()); 1416 bodyBlock.addArgument(elementType); 1417 1418 for (Type initValType : initVals.getTypes()) { 1419 bodyBlock.addArgument(initValType); 1420 result.addTypes(initValType); 1421 } 1422 } 1423 1424 static LogicalResult verify(ReduceOp op) { 1425 // Verify block arg types. 1426 Block &block = op.region().front(); 1427 1428 // The block takes index, extent, and aggregated values as arguments. 1429 auto blockArgsCount = op.initVals().size() + 2; 1430 if (block.getNumArguments() != blockArgsCount) 1431 return op.emitOpError() << "ReduceOp body is expected to have " 1432 << blockArgsCount << " arguments"; 1433 1434 // The first block argument is the index and must always be of type `index`. 1435 if (!block.getArgument(0).getType().isa<IndexType>()) 1436 return op.emitOpError( 1437 "argument 0 of ReduceOp body is expected to be of IndexType"); 1438 1439 // The second block argument is the extent and must be of type `size` or 1440 // `index`, depending on whether the reduce operation is applied to a shape or 1441 // to an extent tensor. 1442 Type extentTy = block.getArgument(1).getType(); 1443 if (op.shape().getType().isa<ShapeType>()) { 1444 if (!extentTy.isa<SizeType>()) 1445 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1446 "SizeType if the ReduceOp operates on a ShapeType"); 1447 } else { 1448 if (!extentTy.isa<IndexType>()) 1449 return op.emitOpError( 1450 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1451 "ReduceOp operates on an extent tensor"); 1452 } 1453 1454 for (auto type : llvm::enumerate(op.initVals())) 1455 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1456 return op.emitOpError() 1457 << "type mismatch between argument " << type.index() + 2 1458 << " of ReduceOp body and initial value " << type.index(); 1459 return success(); 1460 } 1461 1462 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1463 // Parse operands. 1464 SmallVector<OpAsmParser::OperandType, 3> operands; 1465 Type shapeOrExtentTensorType; 1466 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1467 OpAsmParser::Delimiter::Paren) || 1468 parser.parseColonType(shapeOrExtentTensorType) || 1469 parser.parseOptionalArrowTypeList(result.types)) 1470 return failure(); 1471 1472 // Resolve operands. 1473 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1474 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1475 result.operands) || 1476 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1477 result.operands)) 1478 return failure(); 1479 1480 // Parse the body. 1481 Region *body = result.addRegion(); 1482 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1483 return failure(); 1484 1485 // Parse attributes. 1486 if (parser.parseOptionalAttrDict(result.attributes)) 1487 return failure(); 1488 1489 return success(); 1490 } 1491 1492 static void print(OpAsmPrinter &p, ReduceOp op) { 1493 p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() 1494 << ") : " << op.shape().getType(); 1495 p.printOptionalArrowTypeList(op.getResultTypes()); 1496 p.printRegion(op.region()); 1497 p.printOptionalAttrDict(op->getAttrs()); 1498 } 1499 1500 #define GET_OP_CLASSES 1501 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1502