1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Traits.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::shape; 31 32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 33 34 namespace { 35 #include "ShapeCanonicalization.inc" 36 } // namespace 37 38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 39 return RankedTensorType::get({rank}, IndexType::get(ctx)); 40 } 41 42 bool shape::isExtentTensorType(Type type) { 43 auto ranked = type.dyn_cast<RankedTensorType>(); 44 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 45 } 46 47 LogicalResult shape::getShapeVec(Value input, 48 SmallVectorImpl<int64_t> &shapeValues) { 49 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 50 auto type = inputOp.getArg().getType().dyn_cast<ShapedType>(); 51 if (!type.hasRank()) 52 return failure(); 53 shapeValues = llvm::to_vector<6>(type.getShape()); 54 return success(); 55 } 56 if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 57 shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); 58 return success(); 59 } 60 if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 61 shapeValues = llvm::to_vector<6>( 62 inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>()); 63 return success(); 64 } 65 return failure(); 66 } 67 68 static bool isErrorPropagationPossible(TypeRange operandTypes) { 69 return llvm::any_of(operandTypes, [](Type ty) { 70 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 71 }); 72 } 73 74 static LogicalResult verifySizeOrIndexOp(Operation *op) { 75 assert(op != nullptr && op->getNumResults() == 1); 76 Type resultTy = op->getResultTypes().front(); 77 if (isErrorPropagationPossible(op->getOperandTypes())) { 78 if (!resultTy.isa<SizeType>()) 79 return op->emitOpError() 80 << "if at least one of the operands can hold error values then " 81 "the result must be of type `size` to propagate them"; 82 } 83 return success(); 84 } 85 86 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 87 assert(op != nullptr && op->getNumResults() == 1); 88 Type resultTy = op->getResultTypes().front(); 89 if (isErrorPropagationPossible(op->getOperandTypes())) { 90 if (!resultTy.isa<ShapeType>()) 91 return op->emitOpError() 92 << "if at least one of the operands can hold error values then " 93 "the result must be of type `shape` to propagate them"; 94 } 95 return success(); 96 } 97 98 template <typename... Ty> 99 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 100 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 101 } 102 103 template <typename... Ty, typename... ranges> 104 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 105 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // InlinerInterface 110 //===----------------------------------------------------------------------===// 111 112 namespace { 113 /// This class defines the interface for inlining shape dialect ops. 114 struct ShapeInlinerInterface : public DialectInlinerInterface { 115 using DialectInlinerInterface::DialectInlinerInterface; 116 117 // Returns true if the given region 'src' can be inlined into the region 118 // 'dest' that is attached to an operation registered to the current dialect. 119 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 120 BlockAndValueMapping &) const final { 121 return true; 122 } 123 124 // Returns true if the given operation 'op', that is registered to this 125 // dialect, can be inlined into the region 'dest' that is attached to an 126 // operation registered to the current dialect. 127 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 128 BlockAndValueMapping &) const final { 129 return true; 130 } 131 }; 132 } // namespace 133 134 void ShapeDialect::initialize() { 135 addOperations< 136 #define GET_OP_LIST 137 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 138 >(); 139 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 140 addInterfaces<ShapeInlinerInterface>(); 141 // Allow unknown operations during prototyping and testing. As the dialect is 142 // still evolving it makes it simple to start with an unregistered ops and 143 // try different variants before actually defining the op. 144 allowUnknownOperations(); 145 } 146 147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 148 Attribute value, Type type, 149 Location loc) { 150 if (type.isa<ShapeType>() || isExtentTensorType(type)) 151 return builder.create<ConstShapeOp>(loc, type, 152 value.cast<DenseIntElementsAttr>()); 153 if (type.isa<SizeType>()) 154 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 155 if (type.isa<WitnessType>()) 156 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 157 if (arith::ConstantOp::isBuildableWith(value, type)) 158 return builder.create<arith::ConstantOp>(loc, type, value); 159 return nullptr; 160 } 161 162 /// Parse a type registered to this dialect. 163 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 164 StringRef keyword; 165 if (parser.parseKeyword(&keyword)) 166 return Type(); 167 168 if (keyword == "shape") 169 return ShapeType::get(getContext()); 170 if (keyword == "size") 171 return SizeType::get(getContext()); 172 if (keyword == "value_shape") 173 return ValueShapeType::get(getContext()); 174 if (keyword == "witness") 175 return WitnessType::get(getContext()); 176 177 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 178 return Type(); 179 } 180 181 /// Print a type registered to this dialect. 182 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 183 TypeSwitch<Type>(type) 184 .Case<ShapeType>([&](Type) { os << "shape"; }) 185 .Case<SizeType>([&](Type) { os << "size"; }) 186 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 187 .Case<WitnessType>([&](Type) { os << "witness"; }) 188 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 189 } 190 191 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 192 NamedAttribute attribute) { 193 // Verify shape.lib attribute. 194 if (attribute.getName() == "shape.lib") { 195 if (!op->hasTrait<OpTrait::SymbolTable>()) 196 return op->emitError( 197 "shape.lib attribute may only be on op implementing SymbolTable"); 198 199 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 200 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 201 if (!symbol) 202 return op->emitError("shape function library ") 203 << symbolRef << " not found"; 204 return isa<shape::FunctionLibraryOp>(symbol) 205 ? success() 206 : op->emitError() 207 << symbolRef << " required to be shape function library"; 208 } 209 210 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 211 // Verify all entries are function libraries and mappings in libraries 212 // refer to unique ops. 213 DenseSet<StringAttr> key; 214 for (auto it : arr) { 215 if (!it.isa<SymbolRefAttr>()) 216 return op->emitError( 217 "only SymbolRefAttr allowed in shape.lib attribute array"); 218 219 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 220 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 221 if (!shapeFnLib) 222 return op->emitError() 223 << it << " does not refer to FunctionLibraryOp"; 224 for (auto mapping : shapeFnLib.getMapping()) { 225 if (!key.insert(mapping.getName()).second) { 226 return op->emitError("only one op to shape mapping allowed, found " 227 "multiple for `") 228 << mapping.getName() << "`"; 229 } 230 } 231 } 232 return success(); 233 } 234 235 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 236 "allowed as shape.lib attribute"); 237 } 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // AnyOp 243 //===----------------------------------------------------------------------===// 244 245 // TODO: Canonicalization should be implemented for shapes that can be 246 // determined through mixtures of the known dimensions of the inputs. 247 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 248 // Only the last operand is checked because AnyOp is commutative. 249 if (operands.back()) 250 return operands.back(); 251 252 return nullptr; 253 } 254 255 //===----------------------------------------------------------------------===// 256 // AssumingOp 257 //===----------------------------------------------------------------------===// 258 259 static ParseResult parseAssumingOp(OpAsmParser &parser, 260 OperationState &result) { 261 result.regions.reserve(1); 262 Region *doRegion = result.addRegion(); 263 264 auto &builder = parser.getBuilder(); 265 OpAsmParser::OperandType cond; 266 if (parser.parseOperand(cond) || 267 parser.resolveOperand(cond, builder.getType<WitnessType>(), 268 result.operands)) 269 return failure(); 270 271 // Parse optional results type list. 272 if (parser.parseOptionalArrowTypeList(result.types)) 273 return failure(); 274 275 // Parse the region and add a terminator if elided. 276 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 277 return failure(); 278 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 279 280 // Parse the optional attribute list. 281 if (parser.parseOptionalAttrDict(result.attributes)) 282 return failure(); 283 return success(); 284 } 285 286 static void print(OpAsmPrinter &p, AssumingOp op) { 287 bool yieldsResults = !op.getResults().empty(); 288 289 p << " " << op.getWitness(); 290 if (yieldsResults) 291 p << " -> (" << op.getResultTypes() << ")"; 292 p << ' '; 293 p.printRegion(op.getDoRegion(), 294 /*printEntryBlockArgs=*/false, 295 /*printBlockTerminators=*/yieldsResults); 296 p.printOptionalAttrDict(op->getAttrs()); 297 } 298 299 namespace { 300 // Removes AssumingOp with a passing witness and inlines the region. 301 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 302 using OpRewritePattern<AssumingOp>::OpRewritePattern; 303 304 LogicalResult matchAndRewrite(AssumingOp op, 305 PatternRewriter &rewriter) const override { 306 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 307 if (!witness || !witness.getPassingAttr()) 308 return failure(); 309 310 AssumingOp::inlineRegionIntoParent(op, rewriter); 311 return success(); 312 } 313 }; 314 315 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 316 using OpRewritePattern<AssumingOp>::OpRewritePattern; 317 318 LogicalResult matchAndRewrite(AssumingOp op, 319 PatternRewriter &rewriter) const override { 320 Block *body = op.getBody(); 321 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 322 323 // Find used values. 324 SmallVector<Value, 4> newYieldOperands; 325 Value opResult, yieldOperand; 326 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 327 std::tie(opResult, yieldOperand) = it; 328 if (!opResult.getUses().empty()) { 329 newYieldOperands.push_back(yieldOperand); 330 } 331 } 332 333 // Rewrite only if redundant results exist. 334 if (newYieldOperands.size() == yieldOp->getNumOperands()) 335 return failure(); 336 337 // Replace yield op in the old assuming op's body and move the entire region 338 // to the new assuming op. 339 rewriter.setInsertionPointToEnd(body); 340 auto newYieldOp = 341 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 342 rewriter.setInsertionPoint(op); 343 auto newOp = rewriter.create<AssumingOp>( 344 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 345 newOp.getDoRegion().takeBody(op.getDoRegion()); 346 347 // Use the new results to replace the previously used ones. 348 SmallVector<Value, 4> replacementValues; 349 auto src = newOp.getResults().begin(); 350 for (auto it : op.getResults()) { 351 if (it.getUses().empty()) 352 replacementValues.push_back(nullptr); 353 else 354 replacementValues.push_back(*src++); 355 } 356 rewriter.replaceOp(op, replacementValues); 357 return success(); 358 } 359 }; 360 } // namespace 361 362 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 363 MLIRContext *context) { 364 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 365 } 366 367 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 368 void AssumingOp::getSuccessorRegions( 369 Optional<unsigned> index, ArrayRef<Attribute> operands, 370 SmallVectorImpl<RegionSuccessor> ®ions) { 371 // AssumingOp has unconditional control flow into the region and back to the 372 // parent, so return the correct RegionSuccessor purely based on the index 373 // being None or 0. 374 if (index.hasValue()) { 375 regions.push_back(RegionSuccessor(getResults())); 376 return; 377 } 378 379 regions.push_back(RegionSuccessor(&getDoRegion())); 380 } 381 382 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 383 PatternRewriter &rewriter) { 384 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 385 auto *assumingBlock = op.getBody(); 386 auto initPosition = rewriter.getInsertionPoint(); 387 auto *blockAfterAssuming = 388 rewriter.splitBlock(blockBeforeAssuming, initPosition); 389 390 // Remove the AssumingOp and AssumingYieldOp. 391 auto &yieldOp = assumingBlock->back(); 392 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 393 rewriter.replaceOp(op, yieldOp.getOperands()); 394 rewriter.eraseOp(&yieldOp); 395 396 // Merge blocks together as there was no branching behavior from the 397 // AssumingOp. 398 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 399 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 400 } 401 402 void AssumingOp::build( 403 OpBuilder &builder, OperationState &result, Value witness, 404 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 405 406 result.addOperands(witness); 407 Region *bodyRegion = result.addRegion(); 408 bodyRegion->push_back(new Block); 409 Block &bodyBlock = bodyRegion->front(); 410 411 // Build body. 412 OpBuilder::InsertionGuard guard(builder); 413 builder.setInsertionPointToStart(&bodyBlock); 414 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 415 builder.create<AssumingYieldOp>(result.location, yieldValues); 416 417 SmallVector<Type, 2> assumingTypes; 418 for (Value v : yieldValues) 419 assumingTypes.push_back(v.getType()); 420 result.addTypes(assumingTypes); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // AddOp 425 //===----------------------------------------------------------------------===// 426 427 LogicalResult mlir::shape::AddOp::inferReturnTypes( 428 MLIRContext *context, Optional<Location> location, ValueRange operands, 429 DictionaryAttr attributes, RegionRange regions, 430 SmallVectorImpl<Type> &inferredReturnTypes) { 431 if (operands[0].getType().isa<SizeType>() || 432 operands[1].getType().isa<SizeType>()) 433 inferredReturnTypes.assign({SizeType::get(context)}); 434 else 435 inferredReturnTypes.assign({IndexType::get(context)}); 436 return success(); 437 } 438 439 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 440 // SizeType is compatible with IndexType. 441 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 442 } 443 444 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 445 // add(x, 0) -> x 446 if (matchPattern(getRhs(), m_Zero())) 447 return getLhs(); 448 449 return constFoldBinaryOp<IntegerAttr>( 450 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 451 } 452 453 //===----------------------------------------------------------------------===// 454 // AssumingAllOp 455 //===----------------------------------------------------------------------===// 456 457 namespace { 458 struct AssumingAllToCstrEqCanonicalization 459 : public OpRewritePattern<AssumingAllOp> { 460 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 461 462 LogicalResult matchAndRewrite(AssumingAllOp op, 463 PatternRewriter &rewriter) const override { 464 SmallVector<Value, 8> shapes; 465 for (Value w : op.getInputs()) { 466 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 467 if (!cstrEqOp) 468 return failure(); 469 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 470 return llvm::is_contained(shapes, s); 471 }); 472 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 473 return failure(); 474 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 475 } 476 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 477 return success(); 478 } 479 }; 480 481 template <typename OpTy> 482 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 483 using OpRewritePattern<OpTy>::OpRewritePattern; 484 485 LogicalResult matchAndRewrite(OpTy op, 486 PatternRewriter &rewriter) const override { 487 // Find unique operands. 488 SmallVector<Value, 2> unique; 489 for (Value v : op.getOperands()) { 490 if (!llvm::is_contained(unique, v)) 491 unique.push_back(v); 492 } 493 494 // Reduce op to equivalent with unique operands. 495 if (unique.size() < op.getNumOperands()) { 496 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 497 op->getAttrs()); 498 return success(); 499 } 500 501 return failure(); 502 } 503 }; 504 } // namespace 505 506 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 507 MLIRContext *context) { 508 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 509 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 510 } 511 512 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 513 // Iterate in reverse to first handle all constant operands. They are 514 // guaranteed to be the tail of the inputs because this is commutative. 515 for (int idx = operands.size() - 1; idx >= 0; idx--) { 516 Attribute a = operands[idx]; 517 // Cannot fold if any inputs are not constant; 518 if (!a) 519 return nullptr; 520 521 // We do not need to keep statically known values after handling them in 522 // this method. 523 getOperation()->eraseOperand(idx); 524 525 // Always false if any input is statically known false 526 if (!a.cast<BoolAttr>().getValue()) 527 return a; 528 } 529 // If this is reached, all inputs were statically known passing. 530 return BoolAttr::get(getContext(), true); 531 } 532 533 static LogicalResult verify(AssumingAllOp op) { 534 // Ensure that AssumingAllOp contains at least one operand 535 if (op.getNumOperands() == 0) 536 return op.emitOpError("no operands specified"); 537 538 return success(); 539 } 540 541 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 542 ValueRange inputs) { 543 build(b, state, b.getType<WitnessType>(), inputs); 544 } 545 546 //===----------------------------------------------------------------------===// 547 // BroadcastOp 548 //===----------------------------------------------------------------------===// 549 550 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 551 if (getShapes().size() == 1) { 552 // Otherwise, we need a cast which would be a canonicalization, not folding. 553 if (getShapes().front().getType() != getType()) 554 return nullptr; 555 return getShapes().front(); 556 } 557 558 // TODO: Support folding with more than 2 input shapes 559 if (getShapes().size() > 2) 560 return nullptr; 561 562 if (!operands[0] || !operands[1]) 563 return nullptr; 564 auto lhsShape = llvm::to_vector<6>( 565 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 566 auto rhsShape = llvm::to_vector<6>( 567 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 568 SmallVector<int64_t, 6> resultShape; 569 570 // If the shapes are not compatible, we can't fold it. 571 // TODO: Fold to an "error". 572 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 573 return nullptr; 574 575 Builder builder(getContext()); 576 return builder.getIndexTensorAttr(resultShape); 577 } 578 579 static LogicalResult verify(BroadcastOp op) { 580 return verifyShapeOrExtentTensorOp(op); 581 } 582 583 namespace { 584 template <typename OpTy> 585 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 586 using OpRewritePattern<OpTy>::OpRewritePattern; 587 588 LogicalResult matchAndRewrite(OpTy op, 589 PatternRewriter &rewriter) const override { 590 auto isPotentiallyNonEmptyShape = [](Value shape) { 591 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 592 if (extentTensorTy.getDimSize(0) == 0) 593 return false; 594 } 595 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 596 if (constShape.getShape().empty()) 597 return false; 598 } 599 return true; 600 }; 601 auto newOperands = llvm::to_vector<8>( 602 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 603 604 // Reduce op to equivalent without empty shape operands. 605 if (newOperands.size() < op.getNumOperands()) { 606 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 607 op->getAttrs()); 608 return success(); 609 } 610 611 return failure(); 612 } 613 }; 614 615 struct BroadcastForwardSingleOperandPattern 616 : public OpRewritePattern<BroadcastOp> { 617 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 618 619 LogicalResult matchAndRewrite(BroadcastOp op, 620 PatternRewriter &rewriter) const override { 621 if (op.getNumOperands() != 1) 622 return failure(); 623 Value replacement = op.getShapes().front(); 624 625 // Insert cast if needed. 626 if (replacement.getType() != op.getType()) { 627 auto loc = op.getLoc(); 628 if (op.getType().isa<ShapeType>()) { 629 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 630 } else { 631 assert(!op.getType().isa<ShapeType>() && 632 !replacement.getType().isa<ShapeType>() && 633 "expect extent tensor cast"); 634 replacement = 635 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 636 } 637 } 638 639 rewriter.replaceOp(op, replacement); 640 return success(); 641 } 642 }; 643 644 struct BroadcastFoldConstantOperandsPattern 645 : public OpRewritePattern<BroadcastOp> { 646 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 647 648 LogicalResult matchAndRewrite(BroadcastOp op, 649 PatternRewriter &rewriter) const override { 650 SmallVector<int64_t, 8> foldedConstantShape; 651 SmallVector<Value, 8> newShapeOperands; 652 for (Value shape : op.getShapes()) { 653 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 654 SmallVector<int64_t, 8> newFoldedConstantShape; 655 if (OpTrait::util::getBroadcastedShape( 656 foldedConstantShape, 657 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 658 newFoldedConstantShape)) { 659 foldedConstantShape = newFoldedConstantShape; 660 continue; 661 } 662 } 663 newShapeOperands.push_back(shape); 664 } 665 666 // Need at least two constant operands to fold anything. 667 if (op.getNumOperands() - newShapeOperands.size() < 2) 668 return failure(); 669 670 auto foldedConstantOperandsTy = RankedTensorType::get( 671 {static_cast<int64_t>(foldedConstantShape.size())}, 672 rewriter.getIndexType()); 673 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 674 op.getLoc(), foldedConstantOperandsTy, 675 rewriter.getIndexTensorAttr(foldedConstantShape))); 676 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 677 newShapeOperands); 678 return success(); 679 } 680 }; 681 682 template <typename OpTy> 683 struct CanonicalizeCastExtentTensorOperandsPattern 684 : public OpRewritePattern<OpTy> { 685 using OpRewritePattern<OpTy>::OpRewritePattern; 686 687 LogicalResult matchAndRewrite(OpTy op, 688 PatternRewriter &rewriter) const override { 689 // Canonicalize operands. 690 bool anyChange = false; 691 auto canonicalizeOperand = [&](Value operand) { 692 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 693 // Only eliminate the cast if it holds no shape information. 694 bool isInformationLoosingCast = 695 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 696 if (isInformationLoosingCast) { 697 anyChange = true; 698 return castOp.source(); 699 } 700 } 701 return operand; 702 }; 703 auto newOperands = llvm::to_vector<8>( 704 llvm::map_range(op.getOperands(), canonicalizeOperand)); 705 706 // Rewrite op if any change required. 707 if (!anyChange) 708 return failure(); 709 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 710 return success(); 711 } 712 }; 713 714 struct BroadcastConcretizeResultTypePattern 715 : public OpRewritePattern<BroadcastOp> { 716 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 717 718 LogicalResult matchAndRewrite(BroadcastOp op, 719 PatternRewriter &rewriter) const override { 720 // Only concretize dynamic extent tensor result types. 721 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 722 if (!resultTy || !resultTy.isDynamicDim(0)) 723 return failure(); 724 725 // Infer resulting shape rank if possible. 726 int64_t maxRank = 0; 727 for (Value shape : op.getShapes()) { 728 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 729 // Cannot infer resulting shape rank if any operand is dynamically 730 // ranked. 731 if (extentTensorTy.isDynamicDim(0)) 732 return failure(); 733 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 734 } 735 } 736 737 auto newOp = rewriter.create<BroadcastOp>( 738 op.getLoc(), getExtentTensorType(getContext(), maxRank), 739 op.getShapes()); 740 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 741 return success(); 742 } 743 }; 744 } // namespace 745 746 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 747 MLIRContext *context) { 748 patterns.add<BroadcastConcretizeResultTypePattern, 749 BroadcastFoldConstantOperandsPattern, 750 BroadcastForwardSingleOperandPattern, 751 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 752 RemoveDuplicateOperandsPattern<BroadcastOp>, 753 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 754 } 755 756 //===----------------------------------------------------------------------===// 757 // ConcatOp 758 //===----------------------------------------------------------------------===// 759 760 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 761 if (!operands[0] || !operands[1]) 762 return nullptr; 763 auto lhsShape = llvm::to_vector<6>( 764 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 765 auto rhsShape = llvm::to_vector<6>( 766 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 767 SmallVector<int64_t, 6> resultShape; 768 resultShape.append(lhsShape.begin(), lhsShape.end()); 769 resultShape.append(rhsShape.begin(), rhsShape.end()); 770 Builder builder(getContext()); 771 return builder.getIndexTensorAttr(resultShape); 772 } 773 774 //===----------------------------------------------------------------------===// 775 // ConstShapeOp 776 //===----------------------------------------------------------------------===// 777 778 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 779 p << " "; 780 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 781 p << "["; 782 interleaveComma(op.getShape().getValues<int64_t>(), p, 783 [&](int64_t i) { p << i; }); 784 p << "] : "; 785 p.printType(op.getType()); 786 } 787 788 static ParseResult parseConstShapeOp(OpAsmParser &parser, 789 OperationState &result) { 790 if (parser.parseOptionalAttrDict(result.attributes)) 791 return failure(); 792 // We piggy-back on ArrayAttr parsing, though we don't internally store the 793 // shape as an ArrayAttr. 794 // TODO: Implement custom parser and maybe make syntax a bit more concise. 795 Attribute extentsRaw; 796 NamedAttrList dummy; 797 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 798 return failure(); 799 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 800 if (!extentsArray) 801 return failure(); 802 SmallVector<int64_t, 6> ints; 803 for (Attribute extent : extentsArray) { 804 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 805 if (!attr) 806 return failure(); 807 ints.push_back(attr.getInt()); 808 } 809 Builder &builder = parser.getBuilder(); 810 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 811 Type resultTy; 812 if (parser.parseColonType(resultTy)) 813 return failure(); 814 result.types.push_back(resultTy); 815 return success(); 816 } 817 818 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 819 820 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 821 MLIRContext *context) { 822 patterns.add<TensorCastConstShape>(context); 823 } 824 825 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 826 MLIRContext *context, Optional<Location> location, ValueRange operands, 827 DictionaryAttr attributes, RegionRange regions, 828 SmallVectorImpl<Type> &inferredReturnTypes) { 829 Builder b(context); 830 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 831 if (!shape) 832 return emitOptionalError(location, "missing shape attribute"); 833 inferredReturnTypes.assign({RankedTensorType::get( 834 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 835 return success(); 836 } 837 838 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 839 TypeRange r) { 840 if (l.size() != 1 || r.size() != 1) 841 return false; 842 843 Type lhs = l.front(); 844 Type rhs = r.front(); 845 846 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 847 // Shape type is compatible with all other valid return types. 848 return true; 849 return lhs == rhs; 850 } 851 852 //===----------------------------------------------------------------------===// 853 // CstrBroadcastableOp 854 //===----------------------------------------------------------------------===// 855 856 void CstrBroadcastableOp::getCanonicalizationPatterns( 857 RewritePatternSet &patterns, MLIRContext *context) { 858 // Canonicalization patterns have overlap with the considerations during 859 // folding in case additional shape information is inferred at some point that 860 // does not result in folding. 861 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 862 CstrBroadcastableEqOps, 863 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 864 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 865 } 866 867 // Return true if there is exactly one attribute not representing a scalar 868 // broadcast. 869 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 870 bool nonScalarSeen = false; 871 for (Attribute a : attributes) { 872 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 873 if (nonScalarSeen) 874 return false; 875 nonScalarSeen = true; 876 } 877 } 878 return true; 879 } 880 881 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 882 // No broadcasting is needed if all operands but one are scalar. 883 if (hasAtMostSingleNonScalar(operands)) 884 return BoolAttr::get(getContext(), true); 885 886 if ([&] { 887 SmallVector<SmallVector<int64_t, 6>, 6> extents; 888 for (const auto &operand : operands) { 889 if (!operand) 890 return false; 891 extents.push_back(llvm::to_vector<6>( 892 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 893 } 894 return OpTrait::util::staticallyKnownBroadcastable(extents); 895 }()) 896 return BoolAttr::get(getContext(), true); 897 898 // Lastly, see if folding can be completed based on what constraints are known 899 // on the input shapes. 900 if ([&] { 901 SmallVector<SmallVector<int64_t, 6>, 6> extents; 902 for (auto shapeValue : getShapes()) { 903 extents.emplace_back(); 904 if (failed(getShapeVec(shapeValue, extents.back()))) 905 return false; 906 } 907 return OpTrait::util::staticallyKnownBroadcastable(extents); 908 }()) 909 return BoolAttr::get(getContext(), true); 910 911 // Because a failing witness result here represents an eventual assertion 912 // failure, we do not replace it with a constant witness. 913 return nullptr; 914 } 915 916 static LogicalResult verify(CstrBroadcastableOp op) { 917 // Ensure that AssumingAllOp contains at least one operand 918 if (op.getNumOperands() < 2) 919 return op.emitOpError("required at least 2 input shapes"); 920 return success(); 921 } 922 923 //===----------------------------------------------------------------------===// 924 // CstrEqOp 925 //===----------------------------------------------------------------------===// 926 927 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 928 MLIRContext *context) { 929 // If inputs are equal, return passing witness 930 patterns.add<CstrEqEqOps>(context); 931 } 932 933 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 934 if (llvm::all_of(operands, 935 [&](Attribute a) { return a && a == operands[0]; })) 936 return BoolAttr::get(getContext(), true); 937 938 // Because a failing witness result here represents an eventual assertion 939 // failure, we do not try to replace it with a constant witness. Similarly, we 940 // cannot if there are any non-const inputs. 941 return nullptr; 942 } 943 944 //===----------------------------------------------------------------------===// 945 // ConstSizeOp 946 //===----------------------------------------------------------------------===// 947 948 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 949 int64_t value) { 950 build(builder, result, builder.getIndexAttr(value)); 951 } 952 953 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 954 955 void ConstSizeOp::getAsmResultNames( 956 llvm::function_ref<void(Value, StringRef)> setNameFn) { 957 SmallString<4> buffer; 958 llvm::raw_svector_ostream os(buffer); 959 os << "c" << getValue(); 960 setNameFn(getResult(), os.str()); 961 } 962 963 //===----------------------------------------------------------------------===// 964 // ConstWitnessOp 965 //===----------------------------------------------------------------------===// 966 967 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 968 return getPassingAttr(); 969 } 970 971 //===----------------------------------------------------------------------===// 972 // CstrRequireOp 973 //===----------------------------------------------------------------------===// 974 975 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 976 return operands[0]; 977 } 978 979 //===----------------------------------------------------------------------===// 980 // DivOp 981 //===----------------------------------------------------------------------===// 982 983 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 984 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 985 if (!lhs) 986 return nullptr; 987 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 988 if (!rhs) 989 return nullptr; 990 991 // Division in APInt does not follow floor(lhs, rhs) when the result is 992 // negative. Rather, APInt rounds toward zero. 993 APInt quotient, remainder; 994 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 995 if (quotient.isNegative() && !remainder.isNullValue()) { 996 quotient -= 1; 997 } 998 999 Type indexTy = IndexType::get(getContext()); 1000 return IntegerAttr::get(indexTy, quotient); 1001 } 1002 1003 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1004 MLIRContext *context, Optional<Location> location, ValueRange operands, 1005 DictionaryAttr attributes, RegionRange regions, 1006 SmallVectorImpl<Type> &inferredReturnTypes) { 1007 if (operands[0].getType().isa<SizeType>() || 1008 operands[1].getType().isa<SizeType>()) 1009 inferredReturnTypes.assign({SizeType::get(context)}); 1010 else 1011 inferredReturnTypes.assign({IndexType::get(context)}); 1012 return success(); 1013 } 1014 1015 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1016 // SizeType is compatible with IndexType. 1017 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // ShapeEqOp 1022 //===----------------------------------------------------------------------===// 1023 1024 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1025 bool allSame = true; 1026 if (!operands.empty() && !operands[0]) 1027 return {}; 1028 for (Attribute operand : operands.drop_front(1)) { 1029 if (!operand) 1030 return {}; 1031 allSame = allSame && operand == operands[0]; 1032 } 1033 return BoolAttr::get(getContext(), allSame); 1034 } 1035 1036 //===----------------------------------------------------------------------===// 1037 // IndexToSizeOp 1038 //===----------------------------------------------------------------------===// 1039 1040 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1041 // Constant values of both types, `shape.size` and `index`, are represented as 1042 // `IntegerAttr`s which makes constant folding simple. 1043 if (Attribute arg = operands[0]) 1044 return arg; 1045 return {}; 1046 } 1047 1048 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1049 MLIRContext *context) { 1050 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // FromExtentsOp 1055 //===----------------------------------------------------------------------===// 1056 1057 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1058 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1059 return nullptr; 1060 SmallVector<int64_t, 6> extents; 1061 for (auto attr : operands) 1062 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1063 Builder builder(getContext()); 1064 return builder.getIndexTensorAttr(extents); 1065 } 1066 1067 //===----------------------------------------------------------------------===// 1068 // FunctionLibraryOp 1069 //===----------------------------------------------------------------------===// 1070 1071 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1072 StringRef name) { 1073 result.attributes.push_back(builder.getNamedAttr( 1074 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1075 } 1076 1077 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1078 auto attr = getMapping() 1079 .get(op->getName().getIdentifier()) 1080 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1081 if (!attr) 1082 return nullptr; 1083 return lookupSymbol<FuncOp>(attr); 1084 } 1085 1086 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1087 OperationState &result) { 1088 // Parse the op name. 1089 StringAttr nameAttr; 1090 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1091 result.attributes)) 1092 return failure(); 1093 1094 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1095 return failure(); 1096 1097 auto *bodyRegion = result.addRegion(); 1098 if (parser.parseRegion(*bodyRegion)) 1099 return failure(); 1100 1101 if (parser.parseKeyword("mapping")) 1102 return failure(); 1103 1104 DictionaryAttr mappingAttr; 1105 if (parser.parseAttribute(mappingAttr, 1106 parser.getBuilder().getType<NoneType>(), "mapping", 1107 result.attributes)) 1108 return failure(); 1109 return success(); 1110 } 1111 1112 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1113 p << ' '; 1114 p.printSymbolName(op.getName()); 1115 p.printOptionalAttrDictWithKeyword( 1116 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1117 p << ' '; 1118 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1119 /*printBlockTerminators=*/false); 1120 p << " mapping "; 1121 p.printAttributeWithoutType(op.getMappingAttr()); 1122 } 1123 1124 //===----------------------------------------------------------------------===// 1125 // GetExtentOp 1126 //===----------------------------------------------------------------------===// 1127 1128 Optional<int64_t> GetExtentOp::getConstantDim() { 1129 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1130 return constSizeOp.getValue().getLimitedValue(); 1131 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1132 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1133 return llvm::None; 1134 } 1135 1136 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1137 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1138 if (!elements) 1139 return nullptr; 1140 Optional<int64_t> dim = getConstantDim(); 1141 if (!dim.hasValue()) 1142 return nullptr; 1143 if (dim.getValue() >= elements.getNumElements()) 1144 return nullptr; 1145 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1146 } 1147 1148 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1149 int64_t dim) { 1150 auto loc = result.location; 1151 auto dimAttr = builder.getIndexAttr(dim); 1152 if (shape.getType().isa<ShapeType>()) { 1153 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1154 build(builder, result, builder.getType<SizeType>(), shape, dim); 1155 } else { 1156 Value dim = 1157 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1158 build(builder, result, builder.getIndexType(), shape, dim); 1159 } 1160 } 1161 1162 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1163 MLIRContext *context, Optional<Location> location, ValueRange operands, 1164 DictionaryAttr attributes, RegionRange regions, 1165 SmallVectorImpl<Type> &inferredReturnTypes) { 1166 inferredReturnTypes.assign({IndexType::get(context)}); 1167 return success(); 1168 } 1169 1170 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1171 TypeRange r) { 1172 // SizeType is compatible with IndexType. 1173 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1174 } 1175 1176 //===----------------------------------------------------------------------===// 1177 // IsBroadcastableOp 1178 //===----------------------------------------------------------------------===// 1179 1180 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1181 MLIRContext *context) { 1182 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1183 } 1184 1185 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1186 // Can always broadcast fewer than two shapes. 1187 if (operands.size() < 2) { 1188 return BoolAttr::get(getContext(), true); 1189 } 1190 1191 return nullptr; 1192 } 1193 1194 //===----------------------------------------------------------------------===// 1195 // MeetOp 1196 //===----------------------------------------------------------------------===// 1197 1198 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1199 MLIRContext *context, Optional<Location> location, ValueRange operands, 1200 DictionaryAttr attributes, RegionRange regions, 1201 SmallVectorImpl<Type> &inferredReturnTypes) { 1202 inferredReturnTypes.assign({operands[0].getType()}); 1203 return success(); 1204 } 1205 1206 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1207 if (l.size() != 1 || r.size() != 1) 1208 return false; 1209 if (l == r) 1210 return true; 1211 1212 Type lhs = l.front(); 1213 Type rhs = r.front(); 1214 1215 if (lhs != rhs) 1216 return false; 1217 1218 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1219 return true; 1220 1221 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1222 return true; 1223 return false; 1224 } 1225 1226 //===----------------------------------------------------------------------===// 1227 // RankOp 1228 //===----------------------------------------------------------------------===// 1229 1230 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1231 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1232 if (!shape) 1233 return {}; 1234 int64_t rank = shape.getNumElements(); 1235 Builder builder(getContext()); 1236 return builder.getIndexAttr(rank); 1237 } 1238 1239 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1240 /// Constant folding fails in cases where only the rank is constant, not the 1241 /// shape itself. 1242 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1243 /// 1244 /// Example: 1245 /// 1246 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1247 /// %rank = shape.rank %shape 1248 /// 1249 /// becomes 1250 /// 1251 /// %rank = shape.const_size 3 1252 1253 namespace { 1254 struct RankShapeOfCanonicalizationPattern 1255 : public OpRewritePattern<shape::RankOp> { 1256 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1257 1258 LogicalResult matchAndRewrite(shape::RankOp op, 1259 PatternRewriter &rewriter) const override { 1260 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1261 if (!shapeOfOp) 1262 return failure(); 1263 auto rankedTensorType = 1264 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1265 if (!rankedTensorType) 1266 return failure(); 1267 int64_t rank = rankedTensorType.getRank(); 1268 if (op.getType().isa<IndexType>()) { 1269 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1270 rank); 1271 } else if (op.getType().isa<shape::SizeType>()) { 1272 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1273 } else { 1274 return failure(); 1275 } 1276 return success(); 1277 } 1278 }; 1279 } // namespace 1280 1281 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1282 MLIRContext *context) { 1283 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1284 } 1285 1286 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1287 MLIRContext *context, Optional<Location> location, ValueRange operands, 1288 DictionaryAttr attributes, RegionRange regions, 1289 SmallVectorImpl<Type> &inferredReturnTypes) { 1290 if (operands[0].getType().isa<ShapeType>()) 1291 inferredReturnTypes.assign({SizeType::get(context)}); 1292 else 1293 inferredReturnTypes.assign({IndexType::get(context)}); 1294 return success(); 1295 } 1296 1297 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1298 // SizeType is compatible with IndexType. 1299 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1300 } 1301 1302 //===----------------------------------------------------------------------===// 1303 // NumElementsOp 1304 //===----------------------------------------------------------------------===// 1305 1306 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1307 1308 // Fold only when argument constant. 1309 Attribute shape = operands[0]; 1310 if (!shape) 1311 return {}; 1312 1313 APInt product(64, 1); 1314 for (auto value : shape.cast<DenseIntElementsAttr>()) 1315 product *= value; 1316 Builder builder(getContext()); 1317 return builder.getIndexAttr(product.getLimitedValue()); 1318 } 1319 1320 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1321 MLIRContext *context, Optional<Location> location, ValueRange operands, 1322 DictionaryAttr attributes, RegionRange regions, 1323 SmallVectorImpl<Type> &inferredReturnTypes) { 1324 if (operands[0].getType().isa<ShapeType>()) 1325 inferredReturnTypes.assign({SizeType::get(context)}); 1326 else 1327 inferredReturnTypes.assign({IndexType::get(context)}); 1328 return success(); 1329 } 1330 1331 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1332 TypeRange r) { 1333 // SizeType is compatible with IndexType. 1334 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // MaxOp 1339 //===----------------------------------------------------------------------===// 1340 1341 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1342 // If operands are equal, just propagate one. 1343 if (getLhs() == getRhs()) 1344 return getLhs(); 1345 return nullptr; 1346 } 1347 1348 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1349 MLIRContext *context, Optional<Location> location, ValueRange operands, 1350 DictionaryAttr attributes, RegionRange regions, 1351 SmallVectorImpl<Type> &inferredReturnTypes) { 1352 if (operands[0].getType() == operands[1].getType()) 1353 inferredReturnTypes.assign({operands[0].getType()}); 1354 else 1355 inferredReturnTypes.assign({SizeType::get(context)}); 1356 return success(); 1357 } 1358 1359 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1360 if (l.size() != 1 || r.size() != 1) 1361 return false; 1362 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1363 return true; 1364 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1365 return true; 1366 return false; 1367 } 1368 1369 //===----------------------------------------------------------------------===// 1370 // MinOp 1371 //===----------------------------------------------------------------------===// 1372 1373 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1374 // If operands are equal, just propagate one. 1375 if (getLhs() == getRhs()) 1376 return getLhs(); 1377 return nullptr; 1378 } 1379 1380 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1381 MLIRContext *context, Optional<Location> location, ValueRange operands, 1382 DictionaryAttr attributes, RegionRange regions, 1383 SmallVectorImpl<Type> &inferredReturnTypes) { 1384 if (operands[0].getType() == operands[1].getType()) 1385 inferredReturnTypes.assign({operands[0].getType()}); 1386 else 1387 inferredReturnTypes.assign({SizeType::get(context)}); 1388 return success(); 1389 } 1390 1391 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1392 if (l.size() != 1 || r.size() != 1) 1393 return false; 1394 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1395 return true; 1396 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1397 return true; 1398 return false; 1399 } 1400 1401 //===----------------------------------------------------------------------===// 1402 // MulOp 1403 //===----------------------------------------------------------------------===// 1404 1405 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1406 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1407 if (!lhs) 1408 return nullptr; 1409 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1410 if (!rhs) 1411 return nullptr; 1412 APInt folded = lhs.getValue() * rhs.getValue(); 1413 Type indexTy = IndexType::get(getContext()); 1414 return IntegerAttr::get(indexTy, folded); 1415 } 1416 1417 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1418 MLIRContext *context, Optional<Location> location, ValueRange operands, 1419 DictionaryAttr attributes, RegionRange regions, 1420 SmallVectorImpl<Type> &inferredReturnTypes) { 1421 if (operands[0].getType().isa<SizeType>() || 1422 operands[1].getType().isa<SizeType>()) 1423 inferredReturnTypes.assign({SizeType::get(context)}); 1424 else 1425 inferredReturnTypes.assign({IndexType::get(context)}); 1426 return success(); 1427 } 1428 1429 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1430 // SizeType is compatible with IndexType. 1431 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1432 } 1433 //===----------------------------------------------------------------------===// 1434 // ShapeOfOp 1435 //===----------------------------------------------------------------------===// 1436 1437 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1438 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1439 if (!type || !type.hasStaticShape()) 1440 return nullptr; 1441 Builder builder(getContext()); 1442 return builder.getIndexTensorAttr(type.getShape()); 1443 } 1444 1445 namespace { 1446 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1447 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1448 1449 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1450 PatternRewriter &rewriter) const override { 1451 if (!op.getArg().getType().isa<ShapedType>()) 1452 return failure(); 1453 if (op.getType().isa<ShapedType>()) 1454 return failure(); 1455 1456 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1457 op.getArg()); 1458 return success(); 1459 } 1460 }; 1461 1462 // Canonicalize 1463 // ``` 1464 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1465 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1466 // ``` 1467 // to 1468 // ``` 1469 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1470 // ``` 1471 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1472 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1473 1474 LogicalResult matchAndRewrite(tensor::CastOp op, 1475 PatternRewriter &rewriter) const override { 1476 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1477 if (!ty || ty.getRank() != 1) 1478 return failure(); 1479 1480 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1481 if (!shapeOfOp) 1482 return failure(); 1483 1484 // Argument type must be ranked and must not conflict. 1485 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1486 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1487 return failure(); 1488 1489 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1490 return success(); 1491 } 1492 }; 1493 } // namespace 1494 1495 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1496 MLIRContext *context) { 1497 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1498 ExtractFromShapeOfExtentTensor>(context); 1499 } 1500 1501 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1502 MLIRContext *context, Optional<Location> location, ValueRange operands, 1503 DictionaryAttr attributes, RegionRange regions, 1504 SmallVectorImpl<Type> &inferredReturnTypes) { 1505 if (operands[0].getType().isa<ValueShapeType>()) 1506 inferredReturnTypes.assign({ShapeType::get(context)}); 1507 else { 1508 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1509 int64_t rank = 1510 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1511 Type indexTy = IndexType::get(context); 1512 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1513 inferredReturnTypes.assign({extentTensorTy}); 1514 } 1515 return success(); 1516 } 1517 1518 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1519 if (l.size() != 1 || r.size() != 1) 1520 return false; 1521 if (l == r) 1522 return true; 1523 1524 Type lhs = l.front(); 1525 Type rhs = r.front(); 1526 1527 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1528 return false; 1529 1530 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1531 // Shape type is compatible with all other valid return types. 1532 return true; 1533 1534 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1535 return true; 1536 return false; 1537 } 1538 1539 //===----------------------------------------------------------------------===// 1540 // SizeToIndexOp 1541 //===----------------------------------------------------------------------===// 1542 1543 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1544 // Constant values of both types, `shape.size` and `index`, are represented as 1545 // `IntegerAttr`s which makes constant folding simple. 1546 if (Attribute arg = operands[0]) 1547 return arg; 1548 return impl::foldCastOp(*this); 1549 } 1550 1551 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1552 MLIRContext *context) { 1553 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1554 } 1555 1556 //===----------------------------------------------------------------------===// 1557 // YieldOp 1558 //===----------------------------------------------------------------------===// 1559 1560 static LogicalResult verify(shape::YieldOp op) { 1561 auto *parentOp = op->getParentOp(); 1562 auto results = parentOp->getResults(); 1563 auto operands = op.getOperands(); 1564 1565 if (parentOp->getNumResults() != op.getNumOperands()) 1566 return op.emitOpError() << "number of operands does not match number of " 1567 "results of its parent"; 1568 for (auto e : llvm::zip(results, operands)) 1569 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1570 return op.emitOpError() 1571 << "types mismatch between yield op and its parent"; 1572 1573 return success(); 1574 } 1575 1576 //===----------------------------------------------------------------------===// 1577 // SplitAtOp 1578 //===----------------------------------------------------------------------===// 1579 1580 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1581 SmallVectorImpl<OpFoldResult> &results) { 1582 if (!operands[0] || !operands[1]) 1583 return failure(); 1584 auto shapeVec = llvm::to_vector<6>( 1585 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1586 auto shape = llvm::makeArrayRef(shapeVec); 1587 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1588 // Verify that the split point is in the correct range. 1589 // TODO: Constant fold to an "error". 1590 int64_t rank = shape.size(); 1591 if (!(-rank <= splitPoint && splitPoint <= rank)) 1592 return failure(); 1593 if (splitPoint < 0) 1594 splitPoint += shape.size(); 1595 Builder builder(operands[0].getContext()); 1596 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1597 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1598 return success(); 1599 } 1600 1601 //===----------------------------------------------------------------------===// 1602 // ToExtentTensorOp 1603 //===----------------------------------------------------------------------===// 1604 1605 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1606 if (!operands[0]) 1607 return impl::foldCastOp(*this); 1608 Builder builder(getContext()); 1609 auto shape = llvm::to_vector<6>( 1610 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1611 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1612 builder.getIndexType()); 1613 return DenseIntElementsAttr::get(type, shape); 1614 } 1615 1616 //===----------------------------------------------------------------------===// 1617 // ReduceOp 1618 //===----------------------------------------------------------------------===// 1619 1620 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1621 ValueRange initVals) { 1622 result.addOperands(shape); 1623 result.addOperands(initVals); 1624 1625 Region *bodyRegion = result.addRegion(); 1626 bodyRegion->push_back(new Block); 1627 Block &bodyBlock = bodyRegion->front(); 1628 bodyBlock.addArgument(builder.getIndexType(), result.location); 1629 1630 Type elementType; 1631 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1632 elementType = tensorType.getElementType(); 1633 else 1634 elementType = SizeType::get(builder.getContext()); 1635 bodyBlock.addArgument(elementType, shape.getLoc()); 1636 1637 for (Value initVal : initVals) { 1638 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1639 result.addTypes(initVal.getType()); 1640 } 1641 } 1642 1643 static LogicalResult verify(ReduceOp op) { 1644 // Verify block arg types. 1645 Block &block = op.getRegion().front(); 1646 1647 // The block takes index, extent, and aggregated values as arguments. 1648 auto blockArgsCount = op.getInitVals().size() + 2; 1649 if (block.getNumArguments() != blockArgsCount) 1650 return op.emitOpError() << "ReduceOp body is expected to have " 1651 << blockArgsCount << " arguments"; 1652 1653 // The first block argument is the index and must always be of type `index`. 1654 if (!block.getArgument(0).getType().isa<IndexType>()) 1655 return op.emitOpError( 1656 "argument 0 of ReduceOp body is expected to be of IndexType"); 1657 1658 // The second block argument is the extent and must be of type `size` or 1659 // `index`, depending on whether the reduce operation is applied to a shape or 1660 // to an extent tensor. 1661 Type extentTy = block.getArgument(1).getType(); 1662 if (op.getShape().getType().isa<ShapeType>()) { 1663 if (!extentTy.isa<SizeType>()) 1664 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1665 "SizeType if the ReduceOp operates on a ShapeType"); 1666 } else { 1667 if (!extentTy.isa<IndexType>()) 1668 return op.emitOpError( 1669 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1670 "ReduceOp operates on an extent tensor"); 1671 } 1672 1673 for (const auto &type : llvm::enumerate(op.getInitVals())) 1674 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1675 return op.emitOpError() 1676 << "type mismatch between argument " << type.index() + 2 1677 << " of ReduceOp body and initial value " << type.index(); 1678 return success(); 1679 } 1680 1681 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1682 // Parse operands. 1683 SmallVector<OpAsmParser::OperandType, 3> operands; 1684 Type shapeOrExtentTensorType; 1685 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1686 OpAsmParser::Delimiter::Paren) || 1687 parser.parseColonType(shapeOrExtentTensorType) || 1688 parser.parseOptionalArrowTypeList(result.types)) 1689 return failure(); 1690 1691 // Resolve operands. 1692 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1693 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1694 result.operands) || 1695 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1696 result.operands)) 1697 return failure(); 1698 1699 // Parse the body. 1700 Region *body = result.addRegion(); 1701 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1702 return failure(); 1703 1704 // Parse attributes. 1705 if (parser.parseOptionalAttrDict(result.attributes)) 1706 return failure(); 1707 1708 return success(); 1709 } 1710 1711 static void print(OpAsmPrinter &p, ReduceOp op) { 1712 p << '(' << op.getShape() << ", " << op.getInitVals() 1713 << ") : " << op.getShape().getType(); 1714 p.printOptionalArrowTypeList(op.getResultTypes()); 1715 p << ' '; 1716 p.printRegion(op.getRegion()); 1717 p.printOptionalAttrDict(op->getAttrs()); 1718 } 1719 1720 #define GET_OP_CLASSES 1721 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1722