1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/Dialect/Traits.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Transforms/InliningUtils.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 using namespace mlir; 25 using namespace mlir::shape; 26 27 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 28 29 namespace { 30 #include "ShapeCanonicalization.inc" 31 } 32 33 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 34 return RankedTensorType::get({rank}, IndexType::get(ctx)); 35 } 36 37 bool shape::isExtentTensorType(Type type) { 38 auto ranked = type.dyn_cast<RankedTensorType>(); 39 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 40 } 41 42 LogicalResult shape::getShapeVec(Value input, 43 SmallVectorImpl<int64_t> &shapeValues) { 44 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 45 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 46 if (!type.hasRank()) 47 return failure(); 48 shapeValues = llvm::to_vector<6>(type.getShape()); 49 return success(); 50 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 51 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 52 return success(); 53 } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) { 54 shapeValues = llvm::to_vector<6>( 55 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 56 return success(); 57 } else { 58 return failure(); 59 } 60 } 61 62 static bool isErrorPropagationPossible(TypeRange operandTypes) { 63 return llvm::any_of(operandTypes, [](Type ty) { 64 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 65 }); 66 } 67 68 static LogicalResult verifySizeOrIndexOp(Operation *op) { 69 assert(op != nullptr && op->getNumResults() == 1); 70 Type resultTy = op->getResultTypes().front(); 71 if (isErrorPropagationPossible(op->getOperandTypes())) { 72 if (!resultTy.isa<SizeType>()) 73 return op->emitOpError() 74 << "if at least one of the operands can hold error values then " 75 "the result must be of type `size` to propagate them"; 76 } 77 return success(); 78 } 79 80 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 81 assert(op != nullptr && op->getNumResults() == 1); 82 Type resultTy = op->getResultTypes().front(); 83 if (isErrorPropagationPossible(op->getOperandTypes())) { 84 if (!resultTy.isa<ShapeType>()) 85 return op->emitOpError() 86 << "if at least one of the operands can hold error values then " 87 "the result must be of type `shape` to propagate them"; 88 } 89 return success(); 90 } 91 92 template <typename... Ty> 93 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 94 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 95 } 96 97 template <typename... Ty, typename... ranges> 98 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 99 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 100 } 101 102 //===----------------------------------------------------------------------===// 103 // InlinerInterface 104 //===----------------------------------------------------------------------===// 105 106 namespace { 107 /// This class defines the interface for inlining shape dialect ops. 108 struct ShapeInlinerInterface : public DialectInlinerInterface { 109 using DialectInlinerInterface::DialectInlinerInterface; 110 111 // Returns true if the given region 'src' can be inlined into the region 112 // 'dest' that is attached to an operation registered to the current dialect. 113 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 114 BlockAndValueMapping &) const final { 115 return true; 116 } 117 118 // Returns true if the given operation 'op', that is registered to this 119 // dialect, can be inlined into the region 'dest' that is attached to an 120 // operation registered to the current dialect. 121 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 122 BlockAndValueMapping &) const final { 123 return true; 124 } 125 }; 126 } // namespace 127 128 void ShapeDialect::initialize() { 129 addOperations< 130 #define GET_OP_LIST 131 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 132 >(); 133 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 134 addInterfaces<ShapeInlinerInterface>(); 135 // Allow unknown operations during prototyping and testing. As the dialect is 136 // still evolving it makes it simple to start with an unregistered ops and 137 // try different variants before actually defining the op. 138 allowUnknownOperations(); 139 } 140 141 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 142 Attribute value, Type type, 143 Location loc) { 144 if (type.isa<ShapeType>() || isExtentTensorType(type)) 145 return builder.create<ConstShapeOp>(loc, type, 146 value.cast<DenseIntElementsAttr>()); 147 if (type.isa<SizeType>()) 148 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 149 if (type.isa<WitnessType>()) 150 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 151 if (ConstantOp::isBuildableWith(value, type)) 152 return builder.create<ConstantOp>(loc, type, value); 153 return nullptr; 154 } 155 156 /// Parse a type registered to this dialect. 157 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 158 StringRef keyword; 159 if (parser.parseKeyword(&keyword)) 160 return Type(); 161 162 if (keyword == "shape") 163 return ShapeType::get(getContext()); 164 if (keyword == "size") 165 return SizeType::get(getContext()); 166 if (keyword == "value_shape") 167 return ValueShapeType::get(getContext()); 168 if (keyword == "witness") 169 return WitnessType::get(getContext()); 170 171 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 172 return Type(); 173 } 174 175 /// Print a type registered to this dialect. 176 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 177 TypeSwitch<Type>(type) 178 .Case<ShapeType>([&](Type) { os << "shape"; }) 179 .Case<SizeType>([&](Type) { os << "size"; }) 180 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 181 .Case<WitnessType>([&](Type) { os << "witness"; }) 182 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 183 } 184 185 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 186 NamedAttribute attribute) { 187 // Verify shape.lib attribute. 188 if (attribute.first == "shape.lib") { 189 if (!op->hasTrait<OpTrait::SymbolTable>()) 190 return op->emitError( 191 "shape.lib attribute may only be on op implementing SymbolTable"); 192 193 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 194 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 195 if (!symbol) 196 return op->emitError("shape function library ") 197 << symbolRef << " not found"; 198 return isa<shape::FunctionLibraryOp>(symbol) 199 ? success() 200 : op->emitError() 201 << symbolRef << " required to be shape function library"; 202 } 203 204 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 205 // Verify all entries are function libraries and mappings in libraries 206 // refer to unique ops. 207 DenseSet<Identifier> key; 208 for (auto it : arr) { 209 if (!it.isa<SymbolRefAttr>()) 210 return op->emitError( 211 "only SymbolRefAttr allowed in shape.lib attribute array"); 212 213 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 214 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 215 if (!shapeFnLib) 216 return op->emitError() 217 << it << " does not refer to FunctionLibraryOp"; 218 for (auto mapping : shapeFnLib.mapping()) { 219 if (!key.insert(mapping.first).second) { 220 return op->emitError("only one op to shape mapping allowed, found " 221 "multiple for `") 222 << mapping.first << "`"; 223 } 224 } 225 } 226 return success(); 227 } 228 229 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 230 "allowed as shape.lib attribute"); 231 } 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // AnyOp 237 //===----------------------------------------------------------------------===// 238 239 // TODO: Canonicalization should be implemented for shapes that can be 240 // determined through mixtures of the known dimensions of the inputs. 241 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 242 // Only the last operand is checked because AnyOp is commutative. 243 if (operands.back()) 244 return operands.back(); 245 246 return nullptr; 247 } 248 249 //===----------------------------------------------------------------------===// 250 // AssumingOp 251 //===----------------------------------------------------------------------===// 252 253 static ParseResult parseAssumingOp(OpAsmParser &parser, 254 OperationState &result) { 255 result.regions.reserve(1); 256 Region *doRegion = result.addRegion(); 257 258 auto &builder = parser.getBuilder(); 259 OpAsmParser::OperandType cond; 260 if (parser.parseOperand(cond) || 261 parser.resolveOperand(cond, builder.getType<WitnessType>(), 262 result.operands)) 263 return failure(); 264 265 // Parse optional results type list. 266 if (parser.parseOptionalArrowTypeList(result.types)) 267 return failure(); 268 269 // Parse the region and add a terminator if elided. 270 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 271 return failure(); 272 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 273 274 // Parse the optional attribute list. 275 if (parser.parseOptionalAttrDict(result.attributes)) 276 return failure(); 277 return success(); 278 } 279 280 static void print(OpAsmPrinter &p, AssumingOp op) { 281 bool yieldsResults = !op.results().empty(); 282 283 p << " " << op.witness(); 284 if (yieldsResults) { 285 p << " -> (" << op.getResultTypes() << ")"; 286 } 287 p.printRegion(op.doRegion(), 288 /*printEntryBlockArgs=*/false, 289 /*printBlockTerminators=*/yieldsResults); 290 p.printOptionalAttrDict(op->getAttrs()); 291 } 292 293 namespace { 294 // Removes AssumingOp with a passing witness and inlines the region. 295 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 296 using OpRewritePattern<AssumingOp>::OpRewritePattern; 297 298 LogicalResult matchAndRewrite(AssumingOp op, 299 PatternRewriter &rewriter) const override { 300 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 301 if (!witness || !witness.passingAttr()) 302 return failure(); 303 304 AssumingOp::inlineRegionIntoParent(op, rewriter); 305 return success(); 306 } 307 }; 308 309 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 310 using OpRewritePattern<AssumingOp>::OpRewritePattern; 311 312 LogicalResult matchAndRewrite(AssumingOp op, 313 PatternRewriter &rewriter) const override { 314 Block *body = op.getBody(); 315 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 316 317 // Find used values. 318 SmallVector<Value, 4> newYieldOperands; 319 Value opResult, yieldOperand; 320 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 321 std::tie(opResult, yieldOperand) = it; 322 if (!opResult.getUses().empty()) { 323 newYieldOperands.push_back(yieldOperand); 324 } 325 } 326 327 // Rewrite only if redundant results exist. 328 if (newYieldOperands.size() == yieldOp->getNumOperands()) 329 return failure(); 330 331 // Replace yield op in the old assuming op's body and move the entire region 332 // to the new assuming op. 333 rewriter.setInsertionPointToEnd(body); 334 auto newYieldOp = 335 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 336 rewriter.setInsertionPoint(op); 337 auto newOp = rewriter.create<AssumingOp>( 338 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 339 newOp.doRegion().takeBody(op.doRegion()); 340 341 // Use the new results to replace the previously used ones. 342 SmallVector<Value, 4> replacementValues; 343 auto src = newOp.getResults().begin(); 344 for (auto it : op.getResults()) { 345 if (it.getUses().empty()) 346 replacementValues.push_back(nullptr); 347 else 348 replacementValues.push_back(*src++); 349 } 350 rewriter.replaceOp(op, replacementValues); 351 return success(); 352 } 353 }; 354 } // namespace 355 356 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 357 MLIRContext *context) { 358 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 359 } 360 361 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 362 void AssumingOp::getSuccessorRegions( 363 Optional<unsigned> index, ArrayRef<Attribute> operands, 364 SmallVectorImpl<RegionSuccessor> ®ions) { 365 // AssumingOp has unconditional control flow into the region and back to the 366 // parent, so return the correct RegionSuccessor purely based on the index 367 // being None or 0. 368 if (index.hasValue()) { 369 regions.push_back(RegionSuccessor(getResults())); 370 return; 371 } 372 373 regions.push_back(RegionSuccessor(&doRegion())); 374 } 375 376 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 377 PatternRewriter &rewriter) { 378 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 379 auto *assumingBlock = op.getBody(); 380 auto initPosition = rewriter.getInsertionPoint(); 381 auto *blockAfterAssuming = 382 rewriter.splitBlock(blockBeforeAssuming, initPosition); 383 384 // Remove the AssumingOp and AssumingYieldOp. 385 auto &yieldOp = assumingBlock->back(); 386 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 387 rewriter.replaceOp(op, yieldOp.getOperands()); 388 rewriter.eraseOp(&yieldOp); 389 390 // Merge blocks together as there was no branching behavior from the 391 // AssumingOp. 392 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 393 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 394 } 395 396 void AssumingOp::build( 397 OpBuilder &builder, OperationState &result, Value witness, 398 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 399 400 result.addOperands(witness); 401 Region *bodyRegion = result.addRegion(); 402 bodyRegion->push_back(new Block); 403 Block &bodyBlock = bodyRegion->front(); 404 405 // Build body. 406 OpBuilder::InsertionGuard guard(builder); 407 builder.setInsertionPointToStart(&bodyBlock); 408 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 409 builder.create<AssumingYieldOp>(result.location, yieldValues); 410 411 SmallVector<Type, 2> assumingTypes; 412 for (Value v : yieldValues) 413 assumingTypes.push_back(v.getType()); 414 result.addTypes(assumingTypes); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // AddOp 419 //===----------------------------------------------------------------------===// 420 421 LogicalResult mlir::shape::AddOp::inferReturnTypes( 422 MLIRContext *context, Optional<Location> location, ValueRange operands, 423 DictionaryAttr attributes, RegionRange regions, 424 SmallVectorImpl<Type> &inferredReturnTypes) { 425 if (operands[0].getType().isa<SizeType>() || 426 operands[1].getType().isa<SizeType>()) 427 inferredReturnTypes.assign({SizeType::get(context)}); 428 else 429 inferredReturnTypes.assign({IndexType::get(context)}); 430 return success(); 431 } 432 433 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 434 // SizeType is compatible with IndexType. 435 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // AssumingAllOp 440 //===----------------------------------------------------------------------===// 441 442 namespace { 443 struct AssumingAllToCstrEqCanonicalization 444 : public OpRewritePattern<AssumingAllOp> { 445 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 446 447 LogicalResult matchAndRewrite(AssumingAllOp op, 448 PatternRewriter &rewriter) const override { 449 SmallVector<Value, 8> shapes; 450 for (Value w : op.inputs()) { 451 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 452 if (!cstrEqOp) 453 return failure(); 454 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 455 return llvm::is_contained(shapes, s); 456 }); 457 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 458 return failure(); 459 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 460 } 461 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 462 return success(); 463 } 464 }; 465 466 template <typename OpTy> 467 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 468 using OpRewritePattern<OpTy>::OpRewritePattern; 469 470 LogicalResult matchAndRewrite(OpTy op, 471 PatternRewriter &rewriter) const override { 472 // Find unique operands. 473 SmallVector<Value, 2> unique; 474 for (Value v : op.getOperands()) { 475 if (!llvm::is_contained(unique, v)) 476 unique.push_back(v); 477 } 478 479 // Reduce op to equivalent with unique operands. 480 if (unique.size() < op.getNumOperands()) { 481 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 482 op->getAttrs()); 483 return success(); 484 } 485 486 return failure(); 487 } 488 }; 489 } // namespace 490 491 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 492 MLIRContext *context) { 493 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 494 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 495 } 496 497 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 498 // Iterate in reverse to first handle all constant operands. They are 499 // guaranteed to be the tail of the inputs because this is commutative. 500 for (int idx = operands.size() - 1; idx >= 0; idx--) { 501 Attribute a = operands[idx]; 502 // Cannot fold if any inputs are not constant; 503 if (!a) 504 return nullptr; 505 506 // We do not need to keep statically known values after handling them in 507 // this method. 508 getOperation()->eraseOperand(idx); 509 510 // Always false if any input is statically known false 511 if (!a.cast<BoolAttr>().getValue()) 512 return a; 513 } 514 // If this is reached, all inputs were statically known passing. 515 return BoolAttr::get(getContext(), true); 516 } 517 518 static LogicalResult verify(AssumingAllOp op) { 519 // Ensure that AssumingAllOp contains at least one operand 520 if (op.getNumOperands() == 0) 521 return op.emitOpError("no operands specified"); 522 523 return success(); 524 } 525 526 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 527 ValueRange inputs) { 528 build(b, state, b.getType<WitnessType>(), inputs); 529 } 530 531 //===----------------------------------------------------------------------===// 532 // BroadcastOp 533 //===----------------------------------------------------------------------===// 534 535 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 536 if (shapes().size() == 1) { 537 // Otherwise, we need a cast which would be a canonicalization, not folding. 538 if (shapes().front().getType() != getType()) 539 return nullptr; 540 return shapes().front(); 541 } 542 543 // TODO: Support folding with more than 2 input shapes 544 if (shapes().size() > 2) 545 return nullptr; 546 547 if (!operands[0] || !operands[1]) 548 return nullptr; 549 auto lhsShape = llvm::to_vector<6>( 550 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 551 auto rhsShape = llvm::to_vector<6>( 552 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 553 SmallVector<int64_t, 6> resultShape; 554 555 // If the shapes are not compatible, we can't fold it. 556 // TODO: Fold to an "error". 557 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 558 return nullptr; 559 560 Builder builder(getContext()); 561 return builder.getIndexTensorAttr(resultShape); 562 } 563 564 static LogicalResult verify(BroadcastOp op) { 565 return verifyShapeOrExtentTensorOp(op); 566 } 567 568 namespace { 569 template <typename OpTy> 570 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 571 using OpRewritePattern<OpTy>::OpRewritePattern; 572 573 LogicalResult matchAndRewrite(OpTy op, 574 PatternRewriter &rewriter) const override { 575 auto isPotentiallyNonEmptyShape = [](Value shape) { 576 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 577 if (extentTensorTy.getDimSize(0) == 0) 578 return false; 579 } 580 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 581 if (constShape.shape().empty()) 582 return false; 583 } 584 return true; 585 }; 586 auto newOperands = llvm::to_vector<8>( 587 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 588 589 // Reduce op to equivalent without empty shape operands. 590 if (newOperands.size() < op.getNumOperands()) { 591 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 592 op->getAttrs()); 593 return success(); 594 } 595 596 return failure(); 597 } 598 }; 599 600 struct BroadcastForwardSingleOperandPattern 601 : public OpRewritePattern<BroadcastOp> { 602 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 603 604 LogicalResult matchAndRewrite(BroadcastOp op, 605 PatternRewriter &rewriter) const override { 606 if (op.getNumOperands() != 1) 607 return failure(); 608 Value replacement = op.shapes().front(); 609 610 // Insert cast if needed. 611 if (replacement.getType() != op.getType()) { 612 auto loc = op.getLoc(); 613 if (op.getType().isa<ShapeType>()) { 614 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 615 } else { 616 assert(!op.getType().isa<ShapeType>() && 617 !replacement.getType().isa<ShapeType>() && 618 "expect extent tensor cast"); 619 replacement = 620 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 621 } 622 } 623 624 rewriter.replaceOp(op, replacement); 625 return success(); 626 } 627 }; 628 629 struct BroadcastFoldConstantOperandsPattern 630 : public OpRewritePattern<BroadcastOp> { 631 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 632 633 LogicalResult matchAndRewrite(BroadcastOp op, 634 PatternRewriter &rewriter) const override { 635 SmallVector<int64_t, 8> foldedConstantShape; 636 SmallVector<Value, 8> newShapeOperands; 637 for (Value shape : op.shapes()) { 638 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 639 SmallVector<int64_t, 8> newFoldedConstantShape; 640 if (OpTrait::util::getBroadcastedShape( 641 foldedConstantShape, 642 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 643 newFoldedConstantShape)) { 644 foldedConstantShape = newFoldedConstantShape; 645 continue; 646 } 647 } 648 newShapeOperands.push_back(shape); 649 } 650 651 // Need at least two constant operands to fold anything. 652 if (op.getNumOperands() - newShapeOperands.size() < 2) 653 return failure(); 654 655 auto foldedConstantOperandsTy = RankedTensorType::get( 656 {static_cast<int64_t>(foldedConstantShape.size())}, 657 rewriter.getIndexType()); 658 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 659 op.getLoc(), foldedConstantOperandsTy, 660 rewriter.getIndexTensorAttr(foldedConstantShape))); 661 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 662 newShapeOperands); 663 return success(); 664 } 665 }; 666 667 template <typename OpTy> 668 struct CanonicalizeCastExtentTensorOperandsPattern 669 : public OpRewritePattern<OpTy> { 670 using OpRewritePattern<OpTy>::OpRewritePattern; 671 672 LogicalResult matchAndRewrite(OpTy op, 673 PatternRewriter &rewriter) const override { 674 // Canonicalize operands. 675 bool anyChange = false; 676 auto canonicalizeOperand = [&](Value operand) { 677 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 678 // Only eliminate the cast if it holds no shape information. 679 bool isInformationLoosingCast = 680 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 681 if (isInformationLoosingCast) { 682 anyChange = true; 683 return castOp.source(); 684 } 685 } 686 return operand; 687 }; 688 auto newOperands = llvm::to_vector<8>( 689 llvm::map_range(op.getOperands(), canonicalizeOperand)); 690 691 // Rewrite op if any change required. 692 if (!anyChange) 693 return failure(); 694 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 695 return success(); 696 } 697 }; 698 699 struct BroadcastConcretizeResultTypePattern 700 : public OpRewritePattern<BroadcastOp> { 701 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 702 703 LogicalResult matchAndRewrite(BroadcastOp op, 704 PatternRewriter &rewriter) const override { 705 // Only concretize dynamic extent tensor result types. 706 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 707 if (!resultTy || !resultTy.isDynamicDim(0)) 708 return failure(); 709 710 // Infer resulting shape rank if possible. 711 int64_t maxRank = 0; 712 for (Value shape : op.shapes()) { 713 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 714 // Cannot infer resulting shape rank if any operand is dynamically 715 // ranked. 716 if (extentTensorTy.isDynamicDim(0)) 717 return failure(); 718 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 719 } 720 } 721 722 auto newOp = rewriter.create<BroadcastOp>( 723 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 724 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 725 return success(); 726 } 727 }; 728 } // namespace 729 730 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 731 MLIRContext *context) { 732 patterns.add<BroadcastConcretizeResultTypePattern, 733 BroadcastFoldConstantOperandsPattern, 734 BroadcastForwardSingleOperandPattern, 735 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 736 RemoveDuplicateOperandsPattern<BroadcastOp>, 737 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 738 } 739 740 //===----------------------------------------------------------------------===// 741 // ConcatOp 742 //===----------------------------------------------------------------------===// 743 744 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 745 if (!operands[0] || !operands[1]) 746 return nullptr; 747 auto lhsShape = llvm::to_vector<6>( 748 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 749 auto rhsShape = llvm::to_vector<6>( 750 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 751 SmallVector<int64_t, 6> resultShape; 752 resultShape.append(lhsShape.begin(), lhsShape.end()); 753 resultShape.append(rhsShape.begin(), rhsShape.end()); 754 Builder builder(getContext()); 755 return builder.getIndexTensorAttr(resultShape); 756 } 757 758 //===----------------------------------------------------------------------===// 759 // ConstShapeOp 760 //===----------------------------------------------------------------------===// 761 762 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 763 p << " "; 764 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 765 p << "["; 766 interleaveComma(op.shape().getValues<int64_t>(), p, 767 [&](int64_t i) { p << i; }); 768 p << "] : "; 769 p.printType(op.getType()); 770 } 771 772 static ParseResult parseConstShapeOp(OpAsmParser &parser, 773 OperationState &result) { 774 if (parser.parseOptionalAttrDict(result.attributes)) 775 return failure(); 776 // We piggy-back on ArrayAttr parsing, though we don't internally store the 777 // shape as an ArrayAttr. 778 // TODO: Implement custom parser and maybe make syntax a bit more concise. 779 Attribute extentsRaw; 780 NamedAttrList dummy; 781 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 782 return failure(); 783 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 784 if (!extentsArray) 785 return failure(); 786 SmallVector<int64_t, 6> ints; 787 for (Attribute extent : extentsArray) { 788 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 789 if (!attr) 790 return failure(); 791 ints.push_back(attr.getInt()); 792 } 793 Builder &builder = parser.getBuilder(); 794 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 795 Type resultTy; 796 if (parser.parseColonType(resultTy)) 797 return failure(); 798 result.types.push_back(resultTy); 799 return success(); 800 } 801 802 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 803 804 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 805 MLIRContext *context) { 806 patterns.add<TensorCastConstShape>(context); 807 } 808 809 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 810 MLIRContext *context, Optional<Location> location, ValueRange operands, 811 DictionaryAttr attributes, RegionRange regions, 812 SmallVectorImpl<Type> &inferredReturnTypes) { 813 Builder b(context); 814 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 815 if (!shape) 816 return emitOptionalError(location, "missing shape attribute"); 817 inferredReturnTypes.assign({RankedTensorType::get( 818 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 819 return success(); 820 } 821 822 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 823 TypeRange r) { 824 if (l.size() != 1 || r.size() != 1) 825 return false; 826 827 Type lhs = l.front(); 828 Type rhs = r.front(); 829 830 if (lhs == rhs) 831 return true; 832 833 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 834 // Shape type is compatible with all other valid return types. 835 return true; 836 837 return succeeded(verifyCompatibleShapes(lhs, rhs)); 838 } 839 840 //===----------------------------------------------------------------------===// 841 // CstrBroadcastableOp 842 //===----------------------------------------------------------------------===// 843 844 void CstrBroadcastableOp::getCanonicalizationPatterns( 845 RewritePatternSet &patterns, MLIRContext *context) { 846 // Canonicalization patterns have overlap with the considerations during 847 // folding in case additional shape information is inferred at some point that 848 // does not result in folding. 849 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 850 CstrBroadcastableEqOps, 851 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 852 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 853 } 854 855 // Return true if there is exactly one attribute not representing a scalar 856 // broadcast. 857 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 858 bool nonScalarSeen = false; 859 for (Attribute a : attributes) { 860 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 861 if (nonScalarSeen) 862 return false; 863 nonScalarSeen = true; 864 } 865 } 866 return true; 867 } 868 869 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 870 // No broadcasting is needed if all operands but one are scalar. 871 if (hasAtMostSingleNonScalar(operands)) 872 return BoolAttr::get(getContext(), true); 873 874 if ([&] { 875 SmallVector<SmallVector<int64_t, 6>, 6> extents; 876 for (const auto &operand : operands) { 877 if (!operand) 878 return false; 879 extents.push_back(llvm::to_vector<6>( 880 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 881 } 882 return OpTrait::util::staticallyKnownBroadcastable(extents); 883 }()) 884 return BoolAttr::get(getContext(), true); 885 886 // Lastly, see if folding can be completed based on what constraints are known 887 // on the input shapes. 888 if ([&] { 889 SmallVector<SmallVector<int64_t, 6>, 6> extents; 890 for (auto shapeValue : shapes()) { 891 extents.emplace_back(); 892 if (failed(getShapeVec(shapeValue, extents.back()))) 893 return false; 894 } 895 return OpTrait::util::staticallyKnownBroadcastable(extents); 896 }()) 897 return BoolAttr::get(getContext(), true); 898 899 // Because a failing witness result here represents an eventual assertion 900 // failure, we do not replace it with a constant witness. 901 return nullptr; 902 } 903 904 static LogicalResult verify(CstrBroadcastableOp op) { 905 // Ensure that AssumingAllOp contains at least one operand 906 if (op.getNumOperands() < 2) 907 return op.emitOpError("required at least 2 input shapes"); 908 return success(); 909 } 910 911 //===----------------------------------------------------------------------===// 912 // CstrEqOp 913 //===----------------------------------------------------------------------===// 914 915 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 916 MLIRContext *context) { 917 // If inputs are equal, return passing witness 918 patterns.add<CstrEqEqOps>(context); 919 } 920 921 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 922 if (llvm::all_of(operands, 923 [&](Attribute a) { return a && a == operands[0]; })) 924 return BoolAttr::get(getContext(), true); 925 926 // Because a failing witness result here represents an eventual assertion 927 // failure, we do not try to replace it with a constant witness. Similarly, we 928 // cannot if there are any non-const inputs. 929 return nullptr; 930 } 931 932 //===----------------------------------------------------------------------===// 933 // ConstSizeOp 934 //===----------------------------------------------------------------------===// 935 936 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 937 int64_t value) { 938 build(builder, result, builder.getIndexAttr(value)); 939 } 940 941 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 942 943 void ConstSizeOp::getAsmResultNames( 944 llvm::function_ref<void(Value, StringRef)> setNameFn) { 945 SmallString<4> buffer; 946 llvm::raw_svector_ostream os(buffer); 947 os << "c" << value(); 948 setNameFn(getResult(), os.str()); 949 } 950 951 //===----------------------------------------------------------------------===// 952 // ConstWitnessOp 953 //===----------------------------------------------------------------------===// 954 955 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 956 957 //===----------------------------------------------------------------------===// 958 // CstrRequireOp 959 //===----------------------------------------------------------------------===// 960 961 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 962 return operands[0]; 963 } 964 965 //===----------------------------------------------------------------------===// 966 // DivOp 967 //===----------------------------------------------------------------------===// 968 969 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 970 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 971 if (!lhs) 972 return nullptr; 973 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 974 if (!rhs) 975 return nullptr; 976 977 // Division in APInt does not follow floor(lhs, rhs) when the result is 978 // negative. Rather, APInt rounds toward zero. 979 APInt quotient, remainder; 980 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 981 if (quotient.isNegative() && !remainder.isNullValue()) { 982 quotient -= 1; 983 } 984 985 Type indexTy = IndexType::get(getContext()); 986 return IntegerAttr::get(indexTy, quotient); 987 } 988 989 LogicalResult mlir::shape::DivOp::inferReturnTypes( 990 MLIRContext *context, Optional<Location> location, ValueRange operands, 991 DictionaryAttr attributes, RegionRange regions, 992 SmallVectorImpl<Type> &inferredReturnTypes) { 993 if (operands[0].getType().isa<SizeType>() || 994 operands[1].getType().isa<SizeType>()) 995 inferredReturnTypes.assign({SizeType::get(context)}); 996 else 997 inferredReturnTypes.assign({IndexType::get(context)}); 998 return success(); 999 } 1000 1001 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1002 // SizeType is compatible with IndexType. 1003 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1004 } 1005 1006 //===----------------------------------------------------------------------===// 1007 // ShapeEqOp 1008 //===----------------------------------------------------------------------===// 1009 1010 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1011 bool allSame = true; 1012 if (!operands.empty() && !operands[0]) 1013 return {}; 1014 for (Attribute operand : operands.drop_front(1)) { 1015 if (!operand) 1016 return {}; 1017 allSame = allSame && operand == operands[0]; 1018 } 1019 return BoolAttr::get(getContext(), allSame); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // IndexToSizeOp 1024 //===----------------------------------------------------------------------===// 1025 1026 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1027 // Constant values of both types, `shape.size` and `index`, are represented as 1028 // `IntegerAttr`s which makes constant folding simple. 1029 if (Attribute arg = operands[0]) 1030 return arg; 1031 return {}; 1032 } 1033 1034 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1035 MLIRContext *context) { 1036 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // FromExtentsOp 1041 //===----------------------------------------------------------------------===// 1042 1043 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1044 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1045 return nullptr; 1046 SmallVector<int64_t, 6> extents; 1047 for (auto attr : operands) 1048 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1049 Builder builder(getContext()); 1050 return builder.getIndexTensorAttr(extents); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // FunctionLibraryOp 1055 //===----------------------------------------------------------------------===// 1056 1057 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1058 StringRef name) { 1059 result.attributes.push_back(builder.getNamedAttr( 1060 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1061 } 1062 1063 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1064 auto attr = mapping() 1065 .get(op->getName().getIdentifier()) 1066 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1067 if (!attr) 1068 return nullptr; 1069 return lookupSymbol<FuncOp>(attr); 1070 } 1071 1072 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1073 OperationState &result) { 1074 // Parse the op name. 1075 StringAttr nameAttr; 1076 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1077 result.attributes)) 1078 return failure(); 1079 1080 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1081 return failure(); 1082 1083 auto *bodyRegion = result.addRegion(); 1084 if (parser.parseRegion(*bodyRegion)) 1085 return failure(); 1086 1087 if (parser.parseKeyword("mapping")) 1088 return failure(); 1089 1090 DictionaryAttr mappingAttr; 1091 if (parser.parseAttribute(mappingAttr, 1092 parser.getBuilder().getType<NoneType>(), "mapping", 1093 result.attributes)) 1094 return failure(); 1095 return success(); 1096 } 1097 1098 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1099 p << ' '; 1100 p.printSymbolName(op.getName()); 1101 p.printOptionalAttrDictWithKeyword( 1102 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1103 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1104 /*printBlockTerminators=*/false); 1105 p << " mapping "; 1106 p.printAttributeWithoutType(op.mappingAttr()); 1107 } 1108 1109 //===----------------------------------------------------------------------===// 1110 // GetExtentOp 1111 //===----------------------------------------------------------------------===// 1112 1113 Optional<int64_t> GetExtentOp::getConstantDim() { 1114 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1115 return constSizeOp.value().getLimitedValue(); 1116 if (auto constantOp = dim().getDefiningOp<ConstantOp>()) 1117 return constantOp.value().cast<IntegerAttr>().getInt(); 1118 return llvm::None; 1119 } 1120 1121 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1122 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1123 if (!elements) 1124 return nullptr; 1125 Optional<int64_t> dim = getConstantDim(); 1126 if (!dim.hasValue()) 1127 return nullptr; 1128 if (dim.getValue() >= elements.getNumElements()) 1129 return nullptr; 1130 return elements.getValue({(uint64_t)dim.getValue()}); 1131 } 1132 1133 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1134 int64_t dim) { 1135 auto loc = result.location; 1136 auto dimAttr = builder.getIndexAttr(dim); 1137 if (shape.getType().isa<ShapeType>()) { 1138 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1139 build(builder, result, builder.getType<SizeType>(), shape, dim); 1140 } else { 1141 Value dim = 1142 builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr); 1143 build(builder, result, builder.getIndexType(), shape, dim); 1144 } 1145 } 1146 1147 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1148 MLIRContext *context, Optional<Location> location, ValueRange operands, 1149 DictionaryAttr attributes, RegionRange regions, 1150 SmallVectorImpl<Type> &inferredReturnTypes) { 1151 inferredReturnTypes.assign({IndexType::get(context)}); 1152 return success(); 1153 } 1154 1155 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1156 TypeRange r) { 1157 // SizeType is compatible with IndexType. 1158 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1159 } 1160 1161 //===----------------------------------------------------------------------===// 1162 // IsBroadcastableOp 1163 //===----------------------------------------------------------------------===// 1164 1165 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1166 MLIRContext *context) { 1167 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1168 } 1169 1170 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1171 // Can always broadcast fewer than two shapes. 1172 if (operands.size() < 2) { 1173 return BoolAttr::get(getContext(), true); 1174 } 1175 1176 return nullptr; 1177 } 1178 1179 //===----------------------------------------------------------------------===// 1180 // JoinOp 1181 //===----------------------------------------------------------------------===// 1182 1183 LogicalResult mlir::shape::JoinOp::inferReturnTypes( 1184 MLIRContext *context, Optional<Location> location, ValueRange operands, 1185 DictionaryAttr attributes, RegionRange regions, 1186 SmallVectorImpl<Type> &inferredReturnTypes) { 1187 inferredReturnTypes.assign({operands[0].getType()}); 1188 return success(); 1189 } 1190 1191 bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1192 if (l.size() != 1 || r.size() != 1) 1193 return false; 1194 if (l == r) 1195 return true; 1196 1197 Type lhs = l.front(); 1198 Type rhs = r.front(); 1199 1200 if (lhs != rhs) 1201 return false; 1202 1203 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1204 return true; 1205 1206 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1207 return true; 1208 return false; 1209 } 1210 1211 //===----------------------------------------------------------------------===// 1212 // RankOp 1213 //===----------------------------------------------------------------------===// 1214 1215 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1216 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1217 if (!shape) 1218 return {}; 1219 int64_t rank = shape.getNumElements(); 1220 Builder builder(getContext()); 1221 return builder.getIndexAttr(rank); 1222 } 1223 1224 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1225 /// Constant folding fails in cases where only the rank is constant, not the 1226 /// shape itself. 1227 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1228 /// 1229 /// Example: 1230 /// 1231 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1232 /// %rank = shape.rank %shape 1233 /// 1234 /// becomes 1235 /// 1236 /// %rank = shape.const_size 3 1237 1238 namespace { 1239 struct RankShapeOfCanonicalizationPattern 1240 : public OpRewritePattern<shape::RankOp> { 1241 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1242 1243 LogicalResult matchAndRewrite(shape::RankOp op, 1244 PatternRewriter &rewriter) const override { 1245 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1246 if (!shapeOfOp) 1247 return failure(); 1248 auto rankedTensorType = 1249 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1250 if (!rankedTensorType) 1251 return failure(); 1252 int64_t rank = rankedTensorType.getRank(); 1253 if (op.getType().isa<IndexType>()) { 1254 rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank); 1255 } else if (op.getType().isa<shape::SizeType>()) { 1256 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1257 } else { 1258 return failure(); 1259 } 1260 return success(); 1261 } 1262 }; 1263 } // namespace 1264 1265 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1266 MLIRContext *context) { 1267 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1268 } 1269 1270 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1271 MLIRContext *context, Optional<Location> location, ValueRange operands, 1272 DictionaryAttr attributes, RegionRange regions, 1273 SmallVectorImpl<Type> &inferredReturnTypes) { 1274 if (operands[0].getType().isa<ShapeType>()) 1275 inferredReturnTypes.assign({SizeType::get(context)}); 1276 else 1277 inferredReturnTypes.assign({IndexType::get(context)}); 1278 return success(); 1279 } 1280 1281 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1282 // SizeType is compatible with IndexType. 1283 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1284 } 1285 1286 //===----------------------------------------------------------------------===// 1287 // NumElementsOp 1288 //===----------------------------------------------------------------------===// 1289 1290 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1291 1292 // Fold only when argument constant. 1293 Attribute shape = operands[0]; 1294 if (!shape) 1295 return {}; 1296 1297 APInt product(64, 1); 1298 for (auto value : shape.cast<DenseIntElementsAttr>()) 1299 product *= value; 1300 Builder builder(getContext()); 1301 return builder.getIndexAttr(product.getLimitedValue()); 1302 } 1303 1304 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1305 MLIRContext *context, Optional<Location> location, ValueRange operands, 1306 DictionaryAttr attributes, RegionRange regions, 1307 SmallVectorImpl<Type> &inferredReturnTypes) { 1308 if (operands[0].getType().isa<ShapeType>()) 1309 inferredReturnTypes.assign({SizeType::get(context)}); 1310 else 1311 inferredReturnTypes.assign({IndexType::get(context)}); 1312 return success(); 1313 } 1314 1315 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1316 TypeRange r) { 1317 // SizeType is compatible with IndexType. 1318 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1319 } 1320 1321 //===----------------------------------------------------------------------===// 1322 // MaxOp 1323 //===----------------------------------------------------------------------===// 1324 1325 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1326 // If operands are equal, just propagate one. 1327 if (lhs() == rhs()) 1328 return lhs(); 1329 return nullptr; 1330 } 1331 1332 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1333 MLIRContext *context, Optional<Location> location, ValueRange operands, 1334 DictionaryAttr attributes, RegionRange regions, 1335 SmallVectorImpl<Type> &inferredReturnTypes) { 1336 if (operands[0].getType() == operands[1].getType()) 1337 inferredReturnTypes.assign({operands[0].getType()}); 1338 else 1339 inferredReturnTypes.assign({SizeType::get(context)}); 1340 return success(); 1341 } 1342 1343 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1344 if (l.size() != 1 || r.size() != 1) 1345 return false; 1346 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1347 return true; 1348 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1349 return true; 1350 return false; 1351 } 1352 1353 //===----------------------------------------------------------------------===// 1354 // MinOp 1355 //===----------------------------------------------------------------------===// 1356 1357 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1358 // If operands are equal, just propagate one. 1359 if (lhs() == rhs()) 1360 return lhs(); 1361 return nullptr; 1362 } 1363 1364 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1365 MLIRContext *context, Optional<Location> location, ValueRange operands, 1366 DictionaryAttr attributes, RegionRange regions, 1367 SmallVectorImpl<Type> &inferredReturnTypes) { 1368 if (operands[0].getType() == operands[1].getType()) 1369 inferredReturnTypes.assign({operands[0].getType()}); 1370 else 1371 inferredReturnTypes.assign({SizeType::get(context)}); 1372 return success(); 1373 } 1374 1375 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1376 if (l.size() != 1 || r.size() != 1) 1377 return false; 1378 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1379 return true; 1380 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1381 return true; 1382 return false; 1383 } 1384 1385 //===----------------------------------------------------------------------===// 1386 // MulOp 1387 //===----------------------------------------------------------------------===// 1388 1389 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1390 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1391 if (!lhs) 1392 return nullptr; 1393 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1394 if (!rhs) 1395 return nullptr; 1396 APInt folded = lhs.getValue() * rhs.getValue(); 1397 Type indexTy = IndexType::get(getContext()); 1398 return IntegerAttr::get(indexTy, folded); 1399 } 1400 1401 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1402 MLIRContext *context, Optional<Location> location, ValueRange operands, 1403 DictionaryAttr attributes, RegionRange regions, 1404 SmallVectorImpl<Type> &inferredReturnTypes) { 1405 if (operands[0].getType().isa<SizeType>() || 1406 operands[1].getType().isa<SizeType>()) 1407 inferredReturnTypes.assign({SizeType::get(context)}); 1408 else 1409 inferredReturnTypes.assign({IndexType::get(context)}); 1410 return success(); 1411 } 1412 1413 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1414 // SizeType is compatible with IndexType. 1415 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1416 } 1417 //===----------------------------------------------------------------------===// 1418 // ShapeOfOp 1419 //===----------------------------------------------------------------------===// 1420 1421 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1422 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1423 if (!type || !type.hasStaticShape()) 1424 return nullptr; 1425 Builder builder(getContext()); 1426 return builder.getIndexTensorAttr(type.getShape()); 1427 } 1428 1429 namespace { 1430 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1431 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1432 1433 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1434 PatternRewriter &rewriter) const override { 1435 if (!op.arg().getType().isa<ShapedType>()) 1436 return failure(); 1437 if (op.getType().isa<ShapedType>()) 1438 return failure(); 1439 1440 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1441 return success(); 1442 } 1443 }; 1444 1445 // Canonicalize 1446 // ``` 1447 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1448 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1449 // ``` 1450 // to 1451 // ``` 1452 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1453 // ``` 1454 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1455 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1456 1457 LogicalResult matchAndRewrite(tensor::CastOp op, 1458 PatternRewriter &rewriter) const override { 1459 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1460 if (!ty || ty.getRank() != 1) 1461 return failure(); 1462 1463 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1464 if (!shapeOfOp) 1465 return failure(); 1466 1467 // Argument type must be ranked and must not conflict. 1468 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1469 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1470 return failure(); 1471 1472 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1473 return success(); 1474 } 1475 }; 1476 } // namespace 1477 1478 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1479 MLIRContext *context) { 1480 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); 1481 } 1482 1483 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1484 MLIRContext *context, Optional<Location> location, ValueRange operands, 1485 DictionaryAttr attributes, RegionRange regions, 1486 SmallVectorImpl<Type> &inferredReturnTypes) { 1487 if (operands[0].getType().isa<ValueShapeType>()) 1488 inferredReturnTypes.assign({ShapeType::get(context)}); 1489 else { 1490 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1491 int64_t rank = 1492 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1493 Type indexTy = IndexType::get(context); 1494 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1495 inferredReturnTypes.assign({extentTensorTy}); 1496 } 1497 return success(); 1498 } 1499 1500 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1501 if (l.size() != 1 || r.size() != 1) 1502 return false; 1503 if (l == r) 1504 return true; 1505 1506 Type lhs = l.front(); 1507 Type rhs = r.front(); 1508 1509 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1510 return false; 1511 1512 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1513 // Shape type is compatible with all other valid return types. 1514 return true; 1515 1516 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1517 return true; 1518 return false; 1519 } 1520 1521 //===----------------------------------------------------------------------===// 1522 // SizeToIndexOp 1523 //===----------------------------------------------------------------------===// 1524 1525 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1526 // Constant values of both types, `shape.size` and `index`, are represented as 1527 // `IntegerAttr`s which makes constant folding simple. 1528 if (Attribute arg = operands[0]) 1529 return arg; 1530 return impl::foldCastOp(*this); 1531 } 1532 1533 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1534 MLIRContext *context) { 1535 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1536 } 1537 1538 //===----------------------------------------------------------------------===// 1539 // YieldOp 1540 //===----------------------------------------------------------------------===// 1541 1542 static LogicalResult verify(shape::YieldOp op) { 1543 auto *parentOp = op->getParentOp(); 1544 auto results = parentOp->getResults(); 1545 auto operands = op.getOperands(); 1546 1547 if (parentOp->getNumResults() != op.getNumOperands()) 1548 return op.emitOpError() << "number of operands does not match number of " 1549 "results of its parent"; 1550 for (auto e : llvm::zip(results, operands)) 1551 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1552 return op.emitOpError() 1553 << "types mismatch between yield op and its parent"; 1554 1555 return success(); 1556 } 1557 1558 //===----------------------------------------------------------------------===// 1559 // SplitAtOp 1560 //===----------------------------------------------------------------------===// 1561 1562 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1563 SmallVectorImpl<OpFoldResult> &results) { 1564 if (!operands[0] || !operands[1]) 1565 return failure(); 1566 auto shapeVec = llvm::to_vector<6>( 1567 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1568 auto shape = llvm::makeArrayRef(shapeVec); 1569 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1570 // Verify that the split point is in the correct range. 1571 // TODO: Constant fold to an "error". 1572 int64_t rank = shape.size(); 1573 if (!(-rank <= splitPoint && splitPoint <= rank)) 1574 return failure(); 1575 if (splitPoint < 0) 1576 splitPoint += shape.size(); 1577 Builder builder(operands[0].getContext()); 1578 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1579 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1580 return success(); 1581 } 1582 1583 //===----------------------------------------------------------------------===// 1584 // ToExtentTensorOp 1585 //===----------------------------------------------------------------------===// 1586 1587 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1588 if (!operands[0]) 1589 return impl::foldCastOp(*this); 1590 Builder builder(getContext()); 1591 auto shape = llvm::to_vector<6>( 1592 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1593 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1594 builder.getIndexType()); 1595 return DenseIntElementsAttr::get(type, shape); 1596 } 1597 1598 //===----------------------------------------------------------------------===// 1599 // ReduceOp 1600 //===----------------------------------------------------------------------===// 1601 1602 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1603 ValueRange initVals) { 1604 result.addOperands(shape); 1605 result.addOperands(initVals); 1606 1607 Region *bodyRegion = result.addRegion(); 1608 bodyRegion->push_back(new Block); 1609 Block &bodyBlock = bodyRegion->front(); 1610 bodyBlock.addArgument(builder.getIndexType()); 1611 1612 Type elementType; 1613 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1614 elementType = tensorType.getElementType(); 1615 else 1616 elementType = SizeType::get(builder.getContext()); 1617 bodyBlock.addArgument(elementType); 1618 1619 for (Type initValType : initVals.getTypes()) { 1620 bodyBlock.addArgument(initValType); 1621 result.addTypes(initValType); 1622 } 1623 } 1624 1625 static LogicalResult verify(ReduceOp op) { 1626 // Verify block arg types. 1627 Block &block = op.region().front(); 1628 1629 // The block takes index, extent, and aggregated values as arguments. 1630 auto blockArgsCount = op.initVals().size() + 2; 1631 if (block.getNumArguments() != blockArgsCount) 1632 return op.emitOpError() << "ReduceOp body is expected to have " 1633 << blockArgsCount << " arguments"; 1634 1635 // The first block argument is the index and must always be of type `index`. 1636 if (!block.getArgument(0).getType().isa<IndexType>()) 1637 return op.emitOpError( 1638 "argument 0 of ReduceOp body is expected to be of IndexType"); 1639 1640 // The second block argument is the extent and must be of type `size` or 1641 // `index`, depending on whether the reduce operation is applied to a shape or 1642 // to an extent tensor. 1643 Type extentTy = block.getArgument(1).getType(); 1644 if (op.shape().getType().isa<ShapeType>()) { 1645 if (!extentTy.isa<SizeType>()) 1646 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1647 "SizeType if the ReduceOp operates on a ShapeType"); 1648 } else { 1649 if (!extentTy.isa<IndexType>()) 1650 return op.emitOpError( 1651 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1652 "ReduceOp operates on an extent tensor"); 1653 } 1654 1655 for (auto type : llvm::enumerate(op.initVals())) 1656 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1657 return op.emitOpError() 1658 << "type mismatch between argument " << type.index() + 2 1659 << " of ReduceOp body and initial value " << type.index(); 1660 return success(); 1661 } 1662 1663 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1664 // Parse operands. 1665 SmallVector<OpAsmParser::OperandType, 3> operands; 1666 Type shapeOrExtentTensorType; 1667 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1668 OpAsmParser::Delimiter::Paren) || 1669 parser.parseColonType(shapeOrExtentTensorType) || 1670 parser.parseOptionalArrowTypeList(result.types)) 1671 return failure(); 1672 1673 // Resolve operands. 1674 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1675 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1676 result.operands) || 1677 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1678 result.operands)) 1679 return failure(); 1680 1681 // Parse the body. 1682 Region *body = result.addRegion(); 1683 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1684 return failure(); 1685 1686 // Parse attributes. 1687 if (parser.parseOptionalAttrDict(result.attributes)) 1688 return failure(); 1689 1690 return success(); 1691 } 1692 1693 static void print(OpAsmPrinter &p, ReduceOp op) { 1694 p << '(' << op.shape() << ", " << op.initVals() 1695 << ") : " << op.shape().getType(); 1696 p.printOptionalArrowTypeList(op.getResultTypes()); 1697 p.printRegion(op.region()); 1698 p.printOptionalAttrDict(op->getAttrs()); 1699 } 1700 1701 #define GET_OP_CLASSES 1702 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1703