1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Traits.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/SmallString.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/raw_ostream.h" 26 27 using namespace mlir; 28 using namespace mlir::shape; 29 30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 31 32 namespace { 33 #include "ShapeCanonicalization.inc" 34 } 35 36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 37 return RankedTensorType::get({rank}, IndexType::get(ctx)); 38 } 39 40 bool shape::isExtentTensorType(Type type) { 41 auto ranked = type.dyn_cast<RankedTensorType>(); 42 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 43 } 44 45 LogicalResult shape::getShapeVec(Value input, 46 SmallVectorImpl<int64_t> &shapeValues) { 47 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 48 auto type = inputOp.arg().getType().dyn_cast<ShapedType>(); 49 if (!type.hasRank()) 50 return failure(); 51 shapeValues = llvm::to_vector<6>(type.getShape()); 52 return success(); 53 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 54 shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>()); 55 return success(); 56 } else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 57 shapeValues = llvm::to_vector<6>( 58 inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>()); 59 return success(); 60 } else { 61 return failure(); 62 } 63 } 64 65 static bool isErrorPropagationPossible(TypeRange operandTypes) { 66 return llvm::any_of(operandTypes, [](Type ty) { 67 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 68 }); 69 } 70 71 static LogicalResult verifySizeOrIndexOp(Operation *op) { 72 assert(op != nullptr && op->getNumResults() == 1); 73 Type resultTy = op->getResultTypes().front(); 74 if (isErrorPropagationPossible(op->getOperandTypes())) { 75 if (!resultTy.isa<SizeType>()) 76 return op->emitOpError() 77 << "if at least one of the operands can hold error values then " 78 "the result must be of type `size` to propagate them"; 79 } 80 return success(); 81 } 82 83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 84 assert(op != nullptr && op->getNumResults() == 1); 85 Type resultTy = op->getResultTypes().front(); 86 if (isErrorPropagationPossible(op->getOperandTypes())) { 87 if (!resultTy.isa<ShapeType>()) 88 return op->emitOpError() 89 << "if at least one of the operands can hold error values then " 90 "the result must be of type `shape` to propagate them"; 91 } 92 return success(); 93 } 94 95 template <typename... Ty> 96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 98 } 99 100 template <typename... Ty, typename... ranges> 101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // InlinerInterface 107 //===----------------------------------------------------------------------===// 108 109 namespace { 110 /// This class defines the interface for inlining shape dialect ops. 111 struct ShapeInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 // Returns true if the given region 'src' can be inlined into the region 115 // 'dest' that is attached to an operation registered to the current dialect. 116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 // Returns true if the given operation 'op', that is registered to this 122 // dialect, can be inlined into the region 'dest' that is attached to an 123 // operation registered to the current dialect. 124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 125 BlockAndValueMapping &) const final { 126 return true; 127 } 128 }; 129 } // namespace 130 131 void ShapeDialect::initialize() { 132 addOperations< 133 #define GET_OP_LIST 134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 135 >(); 136 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 137 addInterfaces<ShapeInlinerInterface>(); 138 // Allow unknown operations during prototyping and testing. As the dialect is 139 // still evolving it makes it simple to start with an unregistered ops and 140 // try different variants before actually defining the op. 141 allowUnknownOperations(); 142 } 143 144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 145 Attribute value, Type type, 146 Location loc) { 147 if (type.isa<ShapeType>() || isExtentTensorType(type)) 148 return builder.create<ConstShapeOp>(loc, type, 149 value.cast<DenseIntElementsAttr>()); 150 if (type.isa<SizeType>()) 151 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 152 if (type.isa<WitnessType>()) 153 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 154 if (arith::ConstantOp::isBuildableWith(value, type)) 155 return builder.create<arith::ConstantOp>(loc, type, value); 156 return nullptr; 157 } 158 159 /// Parse a type registered to this dialect. 160 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 161 StringRef keyword; 162 if (parser.parseKeyword(&keyword)) 163 return Type(); 164 165 if (keyword == "shape") 166 return ShapeType::get(getContext()); 167 if (keyword == "size") 168 return SizeType::get(getContext()); 169 if (keyword == "value_shape") 170 return ValueShapeType::get(getContext()); 171 if (keyword == "witness") 172 return WitnessType::get(getContext()); 173 174 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 175 return Type(); 176 } 177 178 /// Print a type registered to this dialect. 179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 180 TypeSwitch<Type>(type) 181 .Case<ShapeType>([&](Type) { os << "shape"; }) 182 .Case<SizeType>([&](Type) { os << "size"; }) 183 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 184 .Case<WitnessType>([&](Type) { os << "witness"; }) 185 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 186 } 187 188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 189 NamedAttribute attribute) { 190 // Verify shape.lib attribute. 191 if (attribute.first == "shape.lib") { 192 if (!op->hasTrait<OpTrait::SymbolTable>()) 193 return op->emitError( 194 "shape.lib attribute may only be on op implementing SymbolTable"); 195 196 if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) { 197 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 198 if (!symbol) 199 return op->emitError("shape function library ") 200 << symbolRef << " not found"; 201 return isa<shape::FunctionLibraryOp>(symbol) 202 ? success() 203 : op->emitError() 204 << symbolRef << " required to be shape function library"; 205 } 206 207 if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) { 208 // Verify all entries are function libraries and mappings in libraries 209 // refer to unique ops. 210 DenseSet<Identifier> key; 211 for (auto it : arr) { 212 if (!it.isa<SymbolRefAttr>()) 213 return op->emitError( 214 "only SymbolRefAttr allowed in shape.lib attribute array"); 215 216 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 217 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 218 if (!shapeFnLib) 219 return op->emitError() 220 << it << " does not refer to FunctionLibraryOp"; 221 for (auto mapping : shapeFnLib.mapping()) { 222 if (!key.insert(mapping.first).second) { 223 return op->emitError("only one op to shape mapping allowed, found " 224 "multiple for `") 225 << mapping.first << "`"; 226 } 227 } 228 } 229 return success(); 230 } 231 232 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 233 "allowed as shape.lib attribute"); 234 } 235 return success(); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // AnyOp 240 //===----------------------------------------------------------------------===// 241 242 // TODO: Canonicalization should be implemented for shapes that can be 243 // determined through mixtures of the known dimensions of the inputs. 244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 245 // Only the last operand is checked because AnyOp is commutative. 246 if (operands.back()) 247 return operands.back(); 248 249 return nullptr; 250 } 251 252 //===----------------------------------------------------------------------===// 253 // AssumingOp 254 //===----------------------------------------------------------------------===// 255 256 static ParseResult parseAssumingOp(OpAsmParser &parser, 257 OperationState &result) { 258 result.regions.reserve(1); 259 Region *doRegion = result.addRegion(); 260 261 auto &builder = parser.getBuilder(); 262 OpAsmParser::OperandType cond; 263 if (parser.parseOperand(cond) || 264 parser.resolveOperand(cond, builder.getType<WitnessType>(), 265 result.operands)) 266 return failure(); 267 268 // Parse optional results type list. 269 if (parser.parseOptionalArrowTypeList(result.types)) 270 return failure(); 271 272 // Parse the region and add a terminator if elided. 273 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 274 return failure(); 275 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 276 277 // Parse the optional attribute list. 278 if (parser.parseOptionalAttrDict(result.attributes)) 279 return failure(); 280 return success(); 281 } 282 283 static void print(OpAsmPrinter &p, AssumingOp op) { 284 bool yieldsResults = !op.results().empty(); 285 286 p << " " << op.witness(); 287 if (yieldsResults) { 288 p << " -> (" << op.getResultTypes() << ")"; 289 } 290 p.printRegion(op.doRegion(), 291 /*printEntryBlockArgs=*/false, 292 /*printBlockTerminators=*/yieldsResults); 293 p.printOptionalAttrDict(op->getAttrs()); 294 } 295 296 namespace { 297 // Removes AssumingOp with a passing witness and inlines the region. 298 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 299 using OpRewritePattern<AssumingOp>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(AssumingOp op, 302 PatternRewriter &rewriter) const override { 303 auto witness = op.witness().getDefiningOp<ConstWitnessOp>(); 304 if (!witness || !witness.passingAttr()) 305 return failure(); 306 307 AssumingOp::inlineRegionIntoParent(op, rewriter); 308 return success(); 309 } 310 }; 311 312 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 313 using OpRewritePattern<AssumingOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(AssumingOp op, 316 PatternRewriter &rewriter) const override { 317 Block *body = op.getBody(); 318 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 319 320 // Find used values. 321 SmallVector<Value, 4> newYieldOperands; 322 Value opResult, yieldOperand; 323 for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { 324 std::tie(opResult, yieldOperand) = it; 325 if (!opResult.getUses().empty()) { 326 newYieldOperands.push_back(yieldOperand); 327 } 328 } 329 330 // Rewrite only if redundant results exist. 331 if (newYieldOperands.size() == yieldOp->getNumOperands()) 332 return failure(); 333 334 // Replace yield op in the old assuming op's body and move the entire region 335 // to the new assuming op. 336 rewriter.setInsertionPointToEnd(body); 337 auto newYieldOp = 338 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 339 rewriter.setInsertionPoint(op); 340 auto newOp = rewriter.create<AssumingOp>( 341 op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); 342 newOp.doRegion().takeBody(op.doRegion()); 343 344 // Use the new results to replace the previously used ones. 345 SmallVector<Value, 4> replacementValues; 346 auto src = newOp.getResults().begin(); 347 for (auto it : op.getResults()) { 348 if (it.getUses().empty()) 349 replacementValues.push_back(nullptr); 350 else 351 replacementValues.push_back(*src++); 352 } 353 rewriter.replaceOp(op, replacementValues); 354 return success(); 355 } 356 }; 357 } // namespace 358 359 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 360 MLIRContext *context) { 361 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 362 } 363 364 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 365 void AssumingOp::getSuccessorRegions( 366 Optional<unsigned> index, ArrayRef<Attribute> operands, 367 SmallVectorImpl<RegionSuccessor> ®ions) { 368 // AssumingOp has unconditional control flow into the region and back to the 369 // parent, so return the correct RegionSuccessor purely based on the index 370 // being None or 0. 371 if (index.hasValue()) { 372 regions.push_back(RegionSuccessor(getResults())); 373 return; 374 } 375 376 regions.push_back(RegionSuccessor(&doRegion())); 377 } 378 379 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 380 PatternRewriter &rewriter) { 381 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 382 auto *assumingBlock = op.getBody(); 383 auto initPosition = rewriter.getInsertionPoint(); 384 auto *blockAfterAssuming = 385 rewriter.splitBlock(blockBeforeAssuming, initPosition); 386 387 // Remove the AssumingOp and AssumingYieldOp. 388 auto &yieldOp = assumingBlock->back(); 389 rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming); 390 rewriter.replaceOp(op, yieldOp.getOperands()); 391 rewriter.eraseOp(&yieldOp); 392 393 // Merge blocks together as there was no branching behavior from the 394 // AssumingOp. 395 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 396 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 397 } 398 399 void AssumingOp::build( 400 OpBuilder &builder, OperationState &result, Value witness, 401 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 402 403 result.addOperands(witness); 404 Region *bodyRegion = result.addRegion(); 405 bodyRegion->push_back(new Block); 406 Block &bodyBlock = bodyRegion->front(); 407 408 // Build body. 409 OpBuilder::InsertionGuard guard(builder); 410 builder.setInsertionPointToStart(&bodyBlock); 411 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 412 builder.create<AssumingYieldOp>(result.location, yieldValues); 413 414 SmallVector<Type, 2> assumingTypes; 415 for (Value v : yieldValues) 416 assumingTypes.push_back(v.getType()); 417 result.addTypes(assumingTypes); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // AddOp 422 //===----------------------------------------------------------------------===// 423 424 LogicalResult mlir::shape::AddOp::inferReturnTypes( 425 MLIRContext *context, Optional<Location> location, ValueRange operands, 426 DictionaryAttr attributes, RegionRange regions, 427 SmallVectorImpl<Type> &inferredReturnTypes) { 428 if (operands[0].getType().isa<SizeType>() || 429 operands[1].getType().isa<SizeType>()) 430 inferredReturnTypes.assign({SizeType::get(context)}); 431 else 432 inferredReturnTypes.assign({IndexType::get(context)}); 433 return success(); 434 } 435 436 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 437 // SizeType is compatible with IndexType. 438 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 439 } 440 441 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 442 // add(x, 0) -> x 443 if (matchPattern(rhs(), m_Zero())) 444 return lhs(); 445 446 return constFoldBinaryOp<IntegerAttr>(operands, 447 [](APInt a, APInt b) { return a + b; }); 448 } 449 450 //===----------------------------------------------------------------------===// 451 // AssumingAllOp 452 //===----------------------------------------------------------------------===// 453 454 namespace { 455 struct AssumingAllToCstrEqCanonicalization 456 : public OpRewritePattern<AssumingAllOp> { 457 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 458 459 LogicalResult matchAndRewrite(AssumingAllOp op, 460 PatternRewriter &rewriter) const override { 461 SmallVector<Value, 8> shapes; 462 for (Value w : op.inputs()) { 463 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 464 if (!cstrEqOp) 465 return failure(); 466 bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { 467 return llvm::is_contained(shapes, s); 468 }); 469 if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) 470 return failure(); 471 shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); 472 } 473 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 474 return success(); 475 } 476 }; 477 478 template <typename OpTy> 479 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 480 using OpRewritePattern<OpTy>::OpRewritePattern; 481 482 LogicalResult matchAndRewrite(OpTy op, 483 PatternRewriter &rewriter) const override { 484 // Find unique operands. 485 SmallVector<Value, 2> unique; 486 for (Value v : op.getOperands()) { 487 if (!llvm::is_contained(unique, v)) 488 unique.push_back(v); 489 } 490 491 // Reduce op to equivalent with unique operands. 492 if (unique.size() < op.getNumOperands()) { 493 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 494 op->getAttrs()); 495 return success(); 496 } 497 498 return failure(); 499 } 500 }; 501 } // namespace 502 503 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 504 MLIRContext *context) { 505 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 506 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 507 } 508 509 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 510 // Iterate in reverse to first handle all constant operands. They are 511 // guaranteed to be the tail of the inputs because this is commutative. 512 for (int idx = operands.size() - 1; idx >= 0; idx--) { 513 Attribute a = operands[idx]; 514 // Cannot fold if any inputs are not constant; 515 if (!a) 516 return nullptr; 517 518 // We do not need to keep statically known values after handling them in 519 // this method. 520 getOperation()->eraseOperand(idx); 521 522 // Always false if any input is statically known false 523 if (!a.cast<BoolAttr>().getValue()) 524 return a; 525 } 526 // If this is reached, all inputs were statically known passing. 527 return BoolAttr::get(getContext(), true); 528 } 529 530 static LogicalResult verify(AssumingAllOp op) { 531 // Ensure that AssumingAllOp contains at least one operand 532 if (op.getNumOperands() == 0) 533 return op.emitOpError("no operands specified"); 534 535 return success(); 536 } 537 538 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 539 ValueRange inputs) { 540 build(b, state, b.getType<WitnessType>(), inputs); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // BroadcastOp 545 //===----------------------------------------------------------------------===// 546 547 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 548 if (shapes().size() == 1) { 549 // Otherwise, we need a cast which would be a canonicalization, not folding. 550 if (shapes().front().getType() != getType()) 551 return nullptr; 552 return shapes().front(); 553 } 554 555 // TODO: Support folding with more than 2 input shapes 556 if (shapes().size() > 2) 557 return nullptr; 558 559 if (!operands[0] || !operands[1]) 560 return nullptr; 561 auto lhsShape = llvm::to_vector<6>( 562 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 563 auto rhsShape = llvm::to_vector<6>( 564 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 565 SmallVector<int64_t, 6> resultShape; 566 567 // If the shapes are not compatible, we can't fold it. 568 // TODO: Fold to an "error". 569 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 570 return nullptr; 571 572 Builder builder(getContext()); 573 return builder.getIndexTensorAttr(resultShape); 574 } 575 576 static LogicalResult verify(BroadcastOp op) { 577 return verifyShapeOrExtentTensorOp(op); 578 } 579 580 namespace { 581 template <typename OpTy> 582 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 583 using OpRewritePattern<OpTy>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(OpTy op, 586 PatternRewriter &rewriter) const override { 587 auto isPotentiallyNonEmptyShape = [](Value shape) { 588 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 589 if (extentTensorTy.getDimSize(0) == 0) 590 return false; 591 } 592 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 593 if (constShape.shape().empty()) 594 return false; 595 } 596 return true; 597 }; 598 auto newOperands = llvm::to_vector<8>( 599 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 600 601 // Reduce op to equivalent without empty shape operands. 602 if (newOperands.size() < op.getNumOperands()) { 603 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 604 op->getAttrs()); 605 return success(); 606 } 607 608 return failure(); 609 } 610 }; 611 612 struct BroadcastForwardSingleOperandPattern 613 : public OpRewritePattern<BroadcastOp> { 614 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 615 616 LogicalResult matchAndRewrite(BroadcastOp op, 617 PatternRewriter &rewriter) const override { 618 if (op.getNumOperands() != 1) 619 return failure(); 620 Value replacement = op.shapes().front(); 621 622 // Insert cast if needed. 623 if (replacement.getType() != op.getType()) { 624 auto loc = op.getLoc(); 625 if (op.getType().isa<ShapeType>()) { 626 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 627 } else { 628 assert(!op.getType().isa<ShapeType>() && 629 !replacement.getType().isa<ShapeType>() && 630 "expect extent tensor cast"); 631 replacement = 632 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 633 } 634 } 635 636 rewriter.replaceOp(op, replacement); 637 return success(); 638 } 639 }; 640 641 struct BroadcastFoldConstantOperandsPattern 642 : public OpRewritePattern<BroadcastOp> { 643 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 644 645 LogicalResult matchAndRewrite(BroadcastOp op, 646 PatternRewriter &rewriter) const override { 647 SmallVector<int64_t, 8> foldedConstantShape; 648 SmallVector<Value, 8> newShapeOperands; 649 for (Value shape : op.shapes()) { 650 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 651 SmallVector<int64_t, 8> newFoldedConstantShape; 652 if (OpTrait::util::getBroadcastedShape( 653 foldedConstantShape, 654 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()), 655 newFoldedConstantShape)) { 656 foldedConstantShape = newFoldedConstantShape; 657 continue; 658 } 659 } 660 newShapeOperands.push_back(shape); 661 } 662 663 // Need at least two constant operands to fold anything. 664 if (op.getNumOperands() - newShapeOperands.size() < 2) 665 return failure(); 666 667 auto foldedConstantOperandsTy = RankedTensorType::get( 668 {static_cast<int64_t>(foldedConstantShape.size())}, 669 rewriter.getIndexType()); 670 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 671 op.getLoc(), foldedConstantOperandsTy, 672 rewriter.getIndexTensorAttr(foldedConstantShape))); 673 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 674 newShapeOperands); 675 return success(); 676 } 677 }; 678 679 template <typename OpTy> 680 struct CanonicalizeCastExtentTensorOperandsPattern 681 : public OpRewritePattern<OpTy> { 682 using OpRewritePattern<OpTy>::OpRewritePattern; 683 684 LogicalResult matchAndRewrite(OpTy op, 685 PatternRewriter &rewriter) const override { 686 // Canonicalize operands. 687 bool anyChange = false; 688 auto canonicalizeOperand = [&](Value operand) { 689 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 690 // Only eliminate the cast if it holds no shape information. 691 bool isInformationLoosingCast = 692 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 693 if (isInformationLoosingCast) { 694 anyChange = true; 695 return castOp.source(); 696 } 697 } 698 return operand; 699 }; 700 auto newOperands = llvm::to_vector<8>( 701 llvm::map_range(op.getOperands(), canonicalizeOperand)); 702 703 // Rewrite op if any change required. 704 if (!anyChange) 705 return failure(); 706 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 707 return success(); 708 } 709 }; 710 711 struct BroadcastConcretizeResultTypePattern 712 : public OpRewritePattern<BroadcastOp> { 713 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 714 715 LogicalResult matchAndRewrite(BroadcastOp op, 716 PatternRewriter &rewriter) const override { 717 // Only concretize dynamic extent tensor result types. 718 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 719 if (!resultTy || !resultTy.isDynamicDim(0)) 720 return failure(); 721 722 // Infer resulting shape rank if possible. 723 int64_t maxRank = 0; 724 for (Value shape : op.shapes()) { 725 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 726 // Cannot infer resulting shape rank if any operand is dynamically 727 // ranked. 728 if (extentTensorTy.isDynamicDim(0)) 729 return failure(); 730 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 731 } 732 } 733 734 auto newOp = rewriter.create<BroadcastOp>( 735 op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); 736 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 737 return success(); 738 } 739 }; 740 } // namespace 741 742 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 743 MLIRContext *context) { 744 patterns.add<BroadcastConcretizeResultTypePattern, 745 BroadcastFoldConstantOperandsPattern, 746 BroadcastForwardSingleOperandPattern, 747 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 748 RemoveDuplicateOperandsPattern<BroadcastOp>, 749 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 750 } 751 752 //===----------------------------------------------------------------------===// 753 // ConcatOp 754 //===----------------------------------------------------------------------===// 755 756 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 757 if (!operands[0] || !operands[1]) 758 return nullptr; 759 auto lhsShape = llvm::to_vector<6>( 760 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 761 auto rhsShape = llvm::to_vector<6>( 762 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 763 SmallVector<int64_t, 6> resultShape; 764 resultShape.append(lhsShape.begin(), lhsShape.end()); 765 resultShape.append(rhsShape.begin(), rhsShape.end()); 766 Builder builder(getContext()); 767 return builder.getIndexTensorAttr(resultShape); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // ConstShapeOp 772 //===----------------------------------------------------------------------===// 773 774 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 775 p << " "; 776 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 777 p << "["; 778 interleaveComma(op.shape().getValues<int64_t>(), p, 779 [&](int64_t i) { p << i; }); 780 p << "] : "; 781 p.printType(op.getType()); 782 } 783 784 static ParseResult parseConstShapeOp(OpAsmParser &parser, 785 OperationState &result) { 786 if (parser.parseOptionalAttrDict(result.attributes)) 787 return failure(); 788 // We piggy-back on ArrayAttr parsing, though we don't internally store the 789 // shape as an ArrayAttr. 790 // TODO: Implement custom parser and maybe make syntax a bit more concise. 791 Attribute extentsRaw; 792 NamedAttrList dummy; 793 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 794 return failure(); 795 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 796 if (!extentsArray) 797 return failure(); 798 SmallVector<int64_t, 6> ints; 799 for (Attribute extent : extentsArray) { 800 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 801 if (!attr) 802 return failure(); 803 ints.push_back(attr.getInt()); 804 } 805 Builder &builder = parser.getBuilder(); 806 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 807 Type resultTy; 808 if (parser.parseColonType(resultTy)) 809 return failure(); 810 result.types.push_back(resultTy); 811 return success(); 812 } 813 814 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); } 815 816 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 817 MLIRContext *context) { 818 patterns.add<TensorCastConstShape>(context); 819 } 820 821 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 822 MLIRContext *context, Optional<Location> location, ValueRange operands, 823 DictionaryAttr attributes, RegionRange regions, 824 SmallVectorImpl<Type> &inferredReturnTypes) { 825 Builder b(context); 826 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 827 if (!shape) 828 return emitOptionalError(location, "missing shape attribute"); 829 inferredReturnTypes.assign({RankedTensorType::get( 830 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 831 return success(); 832 } 833 834 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 835 TypeRange r) { 836 if (l.size() != 1 || r.size() != 1) 837 return false; 838 839 Type lhs = l.front(); 840 Type rhs = r.front(); 841 842 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 843 // Shape type is compatible with all other valid return types. 844 return true; 845 return lhs == rhs; 846 } 847 848 //===----------------------------------------------------------------------===// 849 // CstrBroadcastableOp 850 //===----------------------------------------------------------------------===// 851 852 void CstrBroadcastableOp::getCanonicalizationPatterns( 853 RewritePatternSet &patterns, MLIRContext *context) { 854 // Canonicalization patterns have overlap with the considerations during 855 // folding in case additional shape information is inferred at some point that 856 // does not result in folding. 857 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 858 CstrBroadcastableEqOps, 859 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 860 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 861 } 862 863 // Return true if there is exactly one attribute not representing a scalar 864 // broadcast. 865 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 866 bool nonScalarSeen = false; 867 for (Attribute a : attributes) { 868 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 869 if (nonScalarSeen) 870 return false; 871 nonScalarSeen = true; 872 } 873 } 874 return true; 875 } 876 877 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 878 // No broadcasting is needed if all operands but one are scalar. 879 if (hasAtMostSingleNonScalar(operands)) 880 return BoolAttr::get(getContext(), true); 881 882 if ([&] { 883 SmallVector<SmallVector<int64_t, 6>, 6> extents; 884 for (const auto &operand : operands) { 885 if (!operand) 886 return false; 887 extents.push_back(llvm::to_vector<6>( 888 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 889 } 890 return OpTrait::util::staticallyKnownBroadcastable(extents); 891 }()) 892 return BoolAttr::get(getContext(), true); 893 894 // Lastly, see if folding can be completed based on what constraints are known 895 // on the input shapes. 896 if ([&] { 897 SmallVector<SmallVector<int64_t, 6>, 6> extents; 898 for (auto shapeValue : shapes()) { 899 extents.emplace_back(); 900 if (failed(getShapeVec(shapeValue, extents.back()))) 901 return false; 902 } 903 return OpTrait::util::staticallyKnownBroadcastable(extents); 904 }()) 905 return BoolAttr::get(getContext(), true); 906 907 // Because a failing witness result here represents an eventual assertion 908 // failure, we do not replace it with a constant witness. 909 return nullptr; 910 } 911 912 static LogicalResult verify(CstrBroadcastableOp op) { 913 // Ensure that AssumingAllOp contains at least one operand 914 if (op.getNumOperands() < 2) 915 return op.emitOpError("required at least 2 input shapes"); 916 return success(); 917 } 918 919 //===----------------------------------------------------------------------===// 920 // CstrEqOp 921 //===----------------------------------------------------------------------===// 922 923 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 924 MLIRContext *context) { 925 // If inputs are equal, return passing witness 926 patterns.add<CstrEqEqOps>(context); 927 } 928 929 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 930 if (llvm::all_of(operands, 931 [&](Attribute a) { return a && a == operands[0]; })) 932 return BoolAttr::get(getContext(), true); 933 934 // Because a failing witness result here represents an eventual assertion 935 // failure, we do not try to replace it with a constant witness. Similarly, we 936 // cannot if there are any non-const inputs. 937 return nullptr; 938 } 939 940 //===----------------------------------------------------------------------===// 941 // ConstSizeOp 942 //===----------------------------------------------------------------------===// 943 944 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 945 int64_t value) { 946 build(builder, result, builder.getIndexAttr(value)); 947 } 948 949 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); } 950 951 void ConstSizeOp::getAsmResultNames( 952 llvm::function_ref<void(Value, StringRef)> setNameFn) { 953 SmallString<4> buffer; 954 llvm::raw_svector_ostream os(buffer); 955 os << "c" << value(); 956 setNameFn(getResult(), os.str()); 957 } 958 959 //===----------------------------------------------------------------------===// 960 // ConstWitnessOp 961 //===----------------------------------------------------------------------===// 962 963 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); } 964 965 //===----------------------------------------------------------------------===// 966 // CstrRequireOp 967 //===----------------------------------------------------------------------===// 968 969 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 970 return operands[0]; 971 } 972 973 //===----------------------------------------------------------------------===// 974 // DivOp 975 //===----------------------------------------------------------------------===// 976 977 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 978 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 979 if (!lhs) 980 return nullptr; 981 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 982 if (!rhs) 983 return nullptr; 984 985 // Division in APInt does not follow floor(lhs, rhs) when the result is 986 // negative. Rather, APInt rounds toward zero. 987 APInt quotient, remainder; 988 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 989 if (quotient.isNegative() && !remainder.isNullValue()) { 990 quotient -= 1; 991 } 992 993 Type indexTy = IndexType::get(getContext()); 994 return IntegerAttr::get(indexTy, quotient); 995 } 996 997 LogicalResult mlir::shape::DivOp::inferReturnTypes( 998 MLIRContext *context, Optional<Location> location, ValueRange operands, 999 DictionaryAttr attributes, RegionRange regions, 1000 SmallVectorImpl<Type> &inferredReturnTypes) { 1001 if (operands[0].getType().isa<SizeType>() || 1002 operands[1].getType().isa<SizeType>()) 1003 inferredReturnTypes.assign({SizeType::get(context)}); 1004 else 1005 inferredReturnTypes.assign({IndexType::get(context)}); 1006 return success(); 1007 } 1008 1009 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1010 // SizeType is compatible with IndexType. 1011 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1012 } 1013 1014 //===----------------------------------------------------------------------===// 1015 // ShapeEqOp 1016 //===----------------------------------------------------------------------===// 1017 1018 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1019 bool allSame = true; 1020 if (!operands.empty() && !operands[0]) 1021 return {}; 1022 for (Attribute operand : operands.drop_front(1)) { 1023 if (!operand) 1024 return {}; 1025 allSame = allSame && operand == operands[0]; 1026 } 1027 return BoolAttr::get(getContext(), allSame); 1028 } 1029 1030 //===----------------------------------------------------------------------===// 1031 // IndexToSizeOp 1032 //===----------------------------------------------------------------------===// 1033 1034 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1035 // Constant values of both types, `shape.size` and `index`, are represented as 1036 // `IntegerAttr`s which makes constant folding simple. 1037 if (Attribute arg = operands[0]) 1038 return arg; 1039 return {}; 1040 } 1041 1042 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1043 MLIRContext *context) { 1044 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1045 } 1046 1047 //===----------------------------------------------------------------------===// 1048 // FromExtentsOp 1049 //===----------------------------------------------------------------------===// 1050 1051 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1052 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1053 return nullptr; 1054 SmallVector<int64_t, 6> extents; 1055 for (auto attr : operands) 1056 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1057 Builder builder(getContext()); 1058 return builder.getIndexTensorAttr(extents); 1059 } 1060 1061 //===----------------------------------------------------------------------===// 1062 // FunctionLibraryOp 1063 //===----------------------------------------------------------------------===// 1064 1065 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1066 StringRef name) { 1067 result.attributes.push_back(builder.getNamedAttr( 1068 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1069 } 1070 1071 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1072 auto attr = mapping() 1073 .get(op->getName().getIdentifier()) 1074 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1075 if (!attr) 1076 return nullptr; 1077 return lookupSymbol<FuncOp>(attr); 1078 } 1079 1080 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1081 OperationState &result) { 1082 // Parse the op name. 1083 StringAttr nameAttr; 1084 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1085 result.attributes)) 1086 return failure(); 1087 1088 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1089 return failure(); 1090 1091 auto *bodyRegion = result.addRegion(); 1092 if (parser.parseRegion(*bodyRegion)) 1093 return failure(); 1094 1095 if (parser.parseKeyword("mapping")) 1096 return failure(); 1097 1098 DictionaryAttr mappingAttr; 1099 if (parser.parseAttribute(mappingAttr, 1100 parser.getBuilder().getType<NoneType>(), "mapping", 1101 result.attributes)) 1102 return failure(); 1103 return success(); 1104 } 1105 1106 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1107 p << ' '; 1108 p.printSymbolName(op.getName()); 1109 p.printOptionalAttrDictWithKeyword( 1110 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1111 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1112 /*printBlockTerminators=*/false); 1113 p << " mapping "; 1114 p.printAttributeWithoutType(op.mappingAttr()); 1115 } 1116 1117 //===----------------------------------------------------------------------===// 1118 // GetExtentOp 1119 //===----------------------------------------------------------------------===// 1120 1121 Optional<int64_t> GetExtentOp::getConstantDim() { 1122 if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) 1123 return constSizeOp.value().getLimitedValue(); 1124 if (auto constantOp = dim().getDefiningOp<arith::ConstantOp>()) 1125 return constantOp.value().cast<IntegerAttr>().getInt(); 1126 return llvm::None; 1127 } 1128 1129 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1130 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1131 if (!elements) 1132 return nullptr; 1133 Optional<int64_t> dim = getConstantDim(); 1134 if (!dim.hasValue()) 1135 return nullptr; 1136 if (dim.getValue() >= elements.getNumElements()) 1137 return nullptr; 1138 return elements.getValue({(uint64_t)dim.getValue()}); 1139 } 1140 1141 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1142 int64_t dim) { 1143 auto loc = result.location; 1144 auto dimAttr = builder.getIndexAttr(dim); 1145 if (shape.getType().isa<ShapeType>()) { 1146 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1147 build(builder, result, builder.getType<SizeType>(), shape, dim); 1148 } else { 1149 Value dim = 1150 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1151 build(builder, result, builder.getIndexType(), shape, dim); 1152 } 1153 } 1154 1155 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1156 MLIRContext *context, Optional<Location> location, ValueRange operands, 1157 DictionaryAttr attributes, RegionRange regions, 1158 SmallVectorImpl<Type> &inferredReturnTypes) { 1159 inferredReturnTypes.assign({IndexType::get(context)}); 1160 return success(); 1161 } 1162 1163 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1164 TypeRange r) { 1165 // SizeType is compatible with IndexType. 1166 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1167 } 1168 1169 //===----------------------------------------------------------------------===// 1170 // IsBroadcastableOp 1171 //===----------------------------------------------------------------------===// 1172 1173 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1174 MLIRContext *context) { 1175 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1176 } 1177 1178 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1179 // Can always broadcast fewer than two shapes. 1180 if (operands.size() < 2) { 1181 return BoolAttr::get(getContext(), true); 1182 } 1183 1184 return nullptr; 1185 } 1186 1187 //===----------------------------------------------------------------------===// 1188 // MeetOp 1189 //===----------------------------------------------------------------------===// 1190 1191 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1192 MLIRContext *context, Optional<Location> location, ValueRange operands, 1193 DictionaryAttr attributes, RegionRange regions, 1194 SmallVectorImpl<Type> &inferredReturnTypes) { 1195 inferredReturnTypes.assign({operands[0].getType()}); 1196 return success(); 1197 } 1198 1199 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1200 if (l.size() != 1 || r.size() != 1) 1201 return false; 1202 if (l == r) 1203 return true; 1204 1205 Type lhs = l.front(); 1206 Type rhs = r.front(); 1207 1208 if (lhs != rhs) 1209 return false; 1210 1211 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1212 return true; 1213 1214 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1215 return true; 1216 return false; 1217 } 1218 1219 //===----------------------------------------------------------------------===// 1220 // RankOp 1221 //===----------------------------------------------------------------------===// 1222 1223 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1224 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1225 if (!shape) 1226 return {}; 1227 int64_t rank = shape.getNumElements(); 1228 Builder builder(getContext()); 1229 return builder.getIndexAttr(rank); 1230 } 1231 1232 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1233 /// Constant folding fails in cases where only the rank is constant, not the 1234 /// shape itself. 1235 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1236 /// 1237 /// Example: 1238 /// 1239 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1240 /// %rank = shape.rank %shape 1241 /// 1242 /// becomes 1243 /// 1244 /// %rank = shape.const_size 3 1245 1246 namespace { 1247 struct RankShapeOfCanonicalizationPattern 1248 : public OpRewritePattern<shape::RankOp> { 1249 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1250 1251 LogicalResult matchAndRewrite(shape::RankOp op, 1252 PatternRewriter &rewriter) const override { 1253 auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>(); 1254 if (!shapeOfOp) 1255 return failure(); 1256 auto rankedTensorType = 1257 shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1258 if (!rankedTensorType) 1259 return failure(); 1260 int64_t rank = rankedTensorType.getRank(); 1261 if (op.getType().isa<IndexType>()) { 1262 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1263 rank); 1264 } else if (op.getType().isa<shape::SizeType>()) { 1265 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1266 } else { 1267 return failure(); 1268 } 1269 return success(); 1270 } 1271 }; 1272 } // namespace 1273 1274 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1275 MLIRContext *context) { 1276 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1277 } 1278 1279 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1280 MLIRContext *context, Optional<Location> location, ValueRange operands, 1281 DictionaryAttr attributes, RegionRange regions, 1282 SmallVectorImpl<Type> &inferredReturnTypes) { 1283 if (operands[0].getType().isa<ShapeType>()) 1284 inferredReturnTypes.assign({SizeType::get(context)}); 1285 else 1286 inferredReturnTypes.assign({IndexType::get(context)}); 1287 return success(); 1288 } 1289 1290 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1291 // SizeType is compatible with IndexType. 1292 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1293 } 1294 1295 //===----------------------------------------------------------------------===// 1296 // NumElementsOp 1297 //===----------------------------------------------------------------------===// 1298 1299 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1300 1301 // Fold only when argument constant. 1302 Attribute shape = operands[0]; 1303 if (!shape) 1304 return {}; 1305 1306 APInt product(64, 1); 1307 for (auto value : shape.cast<DenseIntElementsAttr>()) 1308 product *= value; 1309 Builder builder(getContext()); 1310 return builder.getIndexAttr(product.getLimitedValue()); 1311 } 1312 1313 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1314 MLIRContext *context, Optional<Location> location, ValueRange operands, 1315 DictionaryAttr attributes, RegionRange regions, 1316 SmallVectorImpl<Type> &inferredReturnTypes) { 1317 if (operands[0].getType().isa<ShapeType>()) 1318 inferredReturnTypes.assign({SizeType::get(context)}); 1319 else 1320 inferredReturnTypes.assign({IndexType::get(context)}); 1321 return success(); 1322 } 1323 1324 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1325 TypeRange r) { 1326 // SizeType is compatible with IndexType. 1327 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1328 } 1329 1330 //===----------------------------------------------------------------------===// 1331 // MaxOp 1332 //===----------------------------------------------------------------------===// 1333 1334 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1335 // If operands are equal, just propagate one. 1336 if (lhs() == rhs()) 1337 return lhs(); 1338 return nullptr; 1339 } 1340 1341 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1342 MLIRContext *context, Optional<Location> location, ValueRange operands, 1343 DictionaryAttr attributes, RegionRange regions, 1344 SmallVectorImpl<Type> &inferredReturnTypes) { 1345 if (operands[0].getType() == operands[1].getType()) 1346 inferredReturnTypes.assign({operands[0].getType()}); 1347 else 1348 inferredReturnTypes.assign({SizeType::get(context)}); 1349 return success(); 1350 } 1351 1352 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1353 if (l.size() != 1 || r.size() != 1) 1354 return false; 1355 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1356 return true; 1357 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1358 return true; 1359 return false; 1360 } 1361 1362 //===----------------------------------------------------------------------===// 1363 // MinOp 1364 //===----------------------------------------------------------------------===// 1365 1366 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1367 // If operands are equal, just propagate one. 1368 if (lhs() == rhs()) 1369 return lhs(); 1370 return nullptr; 1371 } 1372 1373 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1374 MLIRContext *context, Optional<Location> location, ValueRange operands, 1375 DictionaryAttr attributes, RegionRange regions, 1376 SmallVectorImpl<Type> &inferredReturnTypes) { 1377 if (operands[0].getType() == operands[1].getType()) 1378 inferredReturnTypes.assign({operands[0].getType()}); 1379 else 1380 inferredReturnTypes.assign({SizeType::get(context)}); 1381 return success(); 1382 } 1383 1384 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1385 if (l.size() != 1 || r.size() != 1) 1386 return false; 1387 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1388 return true; 1389 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1390 return true; 1391 return false; 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // MulOp 1396 //===----------------------------------------------------------------------===// 1397 1398 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1399 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1400 if (!lhs) 1401 return nullptr; 1402 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1403 if (!rhs) 1404 return nullptr; 1405 APInt folded = lhs.getValue() * rhs.getValue(); 1406 Type indexTy = IndexType::get(getContext()); 1407 return IntegerAttr::get(indexTy, folded); 1408 } 1409 1410 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1411 MLIRContext *context, Optional<Location> location, ValueRange operands, 1412 DictionaryAttr attributes, RegionRange regions, 1413 SmallVectorImpl<Type> &inferredReturnTypes) { 1414 if (operands[0].getType().isa<SizeType>() || 1415 operands[1].getType().isa<SizeType>()) 1416 inferredReturnTypes.assign({SizeType::get(context)}); 1417 else 1418 inferredReturnTypes.assign({IndexType::get(context)}); 1419 return success(); 1420 } 1421 1422 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1423 // SizeType is compatible with IndexType. 1424 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1425 } 1426 //===----------------------------------------------------------------------===// 1427 // ShapeOfOp 1428 //===----------------------------------------------------------------------===// 1429 1430 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1431 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1432 if (!type || !type.hasStaticShape()) 1433 return nullptr; 1434 Builder builder(getContext()); 1435 return builder.getIndexTensorAttr(type.getShape()); 1436 } 1437 1438 namespace { 1439 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1440 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1441 1442 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1443 PatternRewriter &rewriter) const override { 1444 if (!op.arg().getType().isa<ShapedType>()) 1445 return failure(); 1446 if (op.getType().isa<ShapedType>()) 1447 return failure(); 1448 1449 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg()); 1450 return success(); 1451 } 1452 }; 1453 1454 // Canonicalize 1455 // ``` 1456 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1457 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1458 // ``` 1459 // to 1460 // ``` 1461 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1462 // ``` 1463 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1464 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1465 1466 LogicalResult matchAndRewrite(tensor::CastOp op, 1467 PatternRewriter &rewriter) const override { 1468 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1469 if (!ty || ty.getRank() != 1) 1470 return failure(); 1471 1472 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1473 if (!shapeOfOp) 1474 return failure(); 1475 1476 // Argument type must be ranked and must not conflict. 1477 auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>(); 1478 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1479 return failure(); 1480 1481 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg()); 1482 return success(); 1483 } 1484 }; 1485 } // namespace 1486 1487 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1488 MLIRContext *context) { 1489 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1490 ExtractFromShapeOfExtentTensor>(context); 1491 } 1492 1493 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1494 MLIRContext *context, Optional<Location> location, ValueRange operands, 1495 DictionaryAttr attributes, RegionRange regions, 1496 SmallVectorImpl<Type> &inferredReturnTypes) { 1497 if (operands[0].getType().isa<ValueShapeType>()) 1498 inferredReturnTypes.assign({ShapeType::get(context)}); 1499 else { 1500 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1501 int64_t rank = 1502 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1503 Type indexTy = IndexType::get(context); 1504 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1505 inferredReturnTypes.assign({extentTensorTy}); 1506 } 1507 return success(); 1508 } 1509 1510 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1511 if (l.size() != 1 || r.size() != 1) 1512 return false; 1513 if (l == r) 1514 return true; 1515 1516 Type lhs = l.front(); 1517 Type rhs = r.front(); 1518 1519 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1520 return false; 1521 1522 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1523 // Shape type is compatible with all other valid return types. 1524 return true; 1525 1526 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1527 return true; 1528 return false; 1529 } 1530 1531 //===----------------------------------------------------------------------===// 1532 // SizeToIndexOp 1533 //===----------------------------------------------------------------------===// 1534 1535 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1536 // Constant values of both types, `shape.size` and `index`, are represented as 1537 // `IntegerAttr`s which makes constant folding simple. 1538 if (Attribute arg = operands[0]) 1539 return arg; 1540 return impl::foldCastOp(*this); 1541 } 1542 1543 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1544 MLIRContext *context) { 1545 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1546 } 1547 1548 //===----------------------------------------------------------------------===// 1549 // YieldOp 1550 //===----------------------------------------------------------------------===// 1551 1552 static LogicalResult verify(shape::YieldOp op) { 1553 auto *parentOp = op->getParentOp(); 1554 auto results = parentOp->getResults(); 1555 auto operands = op.getOperands(); 1556 1557 if (parentOp->getNumResults() != op.getNumOperands()) 1558 return op.emitOpError() << "number of operands does not match number of " 1559 "results of its parent"; 1560 for (auto e : llvm::zip(results, operands)) 1561 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1562 return op.emitOpError() 1563 << "types mismatch between yield op and its parent"; 1564 1565 return success(); 1566 } 1567 1568 //===----------------------------------------------------------------------===// 1569 // SplitAtOp 1570 //===----------------------------------------------------------------------===// 1571 1572 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1573 SmallVectorImpl<OpFoldResult> &results) { 1574 if (!operands[0] || !operands[1]) 1575 return failure(); 1576 auto shapeVec = llvm::to_vector<6>( 1577 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1578 auto shape = llvm::makeArrayRef(shapeVec); 1579 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1580 // Verify that the split point is in the correct range. 1581 // TODO: Constant fold to an "error". 1582 int64_t rank = shape.size(); 1583 if (!(-rank <= splitPoint && splitPoint <= rank)) 1584 return failure(); 1585 if (splitPoint < 0) 1586 splitPoint += shape.size(); 1587 Builder builder(operands[0].getContext()); 1588 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1589 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1590 return success(); 1591 } 1592 1593 //===----------------------------------------------------------------------===// 1594 // ToExtentTensorOp 1595 //===----------------------------------------------------------------------===// 1596 1597 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1598 if (!operands[0]) 1599 return impl::foldCastOp(*this); 1600 Builder builder(getContext()); 1601 auto shape = llvm::to_vector<6>( 1602 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1603 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1604 builder.getIndexType()); 1605 return DenseIntElementsAttr::get(type, shape); 1606 } 1607 1608 //===----------------------------------------------------------------------===// 1609 // ReduceOp 1610 //===----------------------------------------------------------------------===// 1611 1612 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1613 ValueRange initVals) { 1614 result.addOperands(shape); 1615 result.addOperands(initVals); 1616 1617 Region *bodyRegion = result.addRegion(); 1618 bodyRegion->push_back(new Block); 1619 Block &bodyBlock = bodyRegion->front(); 1620 bodyBlock.addArgument(builder.getIndexType()); 1621 1622 Type elementType; 1623 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1624 elementType = tensorType.getElementType(); 1625 else 1626 elementType = SizeType::get(builder.getContext()); 1627 bodyBlock.addArgument(elementType); 1628 1629 for (Type initValType : initVals.getTypes()) { 1630 bodyBlock.addArgument(initValType); 1631 result.addTypes(initValType); 1632 } 1633 } 1634 1635 static LogicalResult verify(ReduceOp op) { 1636 // Verify block arg types. 1637 Block &block = op.region().front(); 1638 1639 // The block takes index, extent, and aggregated values as arguments. 1640 auto blockArgsCount = op.initVals().size() + 2; 1641 if (block.getNumArguments() != blockArgsCount) 1642 return op.emitOpError() << "ReduceOp body is expected to have " 1643 << blockArgsCount << " arguments"; 1644 1645 // The first block argument is the index and must always be of type `index`. 1646 if (!block.getArgument(0).getType().isa<IndexType>()) 1647 return op.emitOpError( 1648 "argument 0 of ReduceOp body is expected to be of IndexType"); 1649 1650 // The second block argument is the extent and must be of type `size` or 1651 // `index`, depending on whether the reduce operation is applied to a shape or 1652 // to an extent tensor. 1653 Type extentTy = block.getArgument(1).getType(); 1654 if (op.shape().getType().isa<ShapeType>()) { 1655 if (!extentTy.isa<SizeType>()) 1656 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1657 "SizeType if the ReduceOp operates on a ShapeType"); 1658 } else { 1659 if (!extentTy.isa<IndexType>()) 1660 return op.emitOpError( 1661 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1662 "ReduceOp operates on an extent tensor"); 1663 } 1664 1665 for (auto type : llvm::enumerate(op.initVals())) 1666 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1667 return op.emitOpError() 1668 << "type mismatch between argument " << type.index() + 2 1669 << " of ReduceOp body and initial value " << type.index(); 1670 return success(); 1671 } 1672 1673 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1674 // Parse operands. 1675 SmallVector<OpAsmParser::OperandType, 3> operands; 1676 Type shapeOrExtentTensorType; 1677 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1678 OpAsmParser::Delimiter::Paren) || 1679 parser.parseColonType(shapeOrExtentTensorType) || 1680 parser.parseOptionalArrowTypeList(result.types)) 1681 return failure(); 1682 1683 // Resolve operands. 1684 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1685 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1686 result.operands) || 1687 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1688 result.operands)) 1689 return failure(); 1690 1691 // Parse the body. 1692 Region *body = result.addRegion(); 1693 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1694 return failure(); 1695 1696 // Parse attributes. 1697 if (parser.parseOptionalAttrDict(result.attributes)) 1698 return failure(); 1699 1700 return success(); 1701 } 1702 1703 static void print(OpAsmPrinter &p, ReduceOp op) { 1704 p << '(' << op.shape() << ", " << op.initVals() 1705 << ") : " << op.shape().getType(); 1706 p.printOptionalArrowTypeList(op.getResultTypes()); 1707 p.printRegion(op.region()); 1708 p.printOptionalAttrDict(op->getAttrs()); 1709 } 1710 1711 #define GET_OP_CLASSES 1712 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1713