1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/SmallString.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 using namespace mlir; 29 using namespace mlir::shape; 30 31 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 32 33 namespace { 34 #include "ShapeCanonicalization.inc" 35 } // namespace 36 37 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 38 return RankedTensorType::get({rank}, IndexType::get(ctx)); 39 } 40 41 bool shape::isExtentTensorType(Type type) { 42 auto ranked = type.dyn_cast<RankedTensorType>(); 43 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 44 } 45 46 LogicalResult shape::getShapeVec(Value input, 47 SmallVectorImpl<int64_t> &shapeValues) { 48 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 49 auto type = inputOp.getArg().getType().dyn_cast<ShapedType>(); 50 if (!type.hasRank()) 51 return failure(); 52 shapeValues = llvm::to_vector<6>(type.getShape()); 53 return success(); 54 } 55 if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 56 shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); 57 return success(); 58 } 59 if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 60 shapeValues = llvm::to_vector<6>( 61 inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>()); 62 return success(); 63 } 64 return failure(); 65 } 66 67 static bool isErrorPropagationPossible(TypeRange operandTypes) { 68 return llvm::any_of(operandTypes, [](Type ty) { 69 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 70 }); 71 } 72 73 static LogicalResult verifySizeOrIndexOp(Operation *op) { 74 assert(op != nullptr && op->getNumResults() == 1); 75 Type resultTy = op->getResultTypes().front(); 76 if (isErrorPropagationPossible(op->getOperandTypes())) { 77 if (!resultTy.isa<SizeType>()) 78 return op->emitOpError() 79 << "if at least one of the operands can hold error values then " 80 "the result must be of type `size` to propagate them"; 81 } 82 return success(); 83 } 84 85 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 86 assert(op != nullptr && op->getNumResults() == 1); 87 Type resultTy = op->getResultTypes().front(); 88 if (isErrorPropagationPossible(op->getOperandTypes())) { 89 if (!resultTy.isa<ShapeType>()) 90 return op->emitOpError() 91 << "if at least one of the operands can hold error values then " 92 "the result must be of type `shape` to propagate them"; 93 } 94 return success(); 95 } 96 97 template <typename... Ty> 98 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 99 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 100 } 101 102 template <typename... Ty, typename... ranges> 103 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 104 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // InlinerInterface 109 //===----------------------------------------------------------------------===// 110 111 namespace { 112 /// This class defines the interface for inlining shape dialect ops. 113 struct ShapeInlinerInterface : public DialectInlinerInterface { 114 using DialectInlinerInterface::DialectInlinerInterface; 115 116 // Returns true if the given region 'src' can be inlined into the region 117 // 'dest' that is attached to an operation registered to the current dialect. 118 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 119 BlockAndValueMapping &) const final { 120 return true; 121 } 122 123 // Returns true if the given operation 'op', that is registered to this 124 // dialect, can be inlined into the region 'dest' that is attached to an 125 // operation registered to the current dialect. 126 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 127 BlockAndValueMapping &) const final { 128 return true; 129 } 130 }; 131 } // namespace 132 133 void ShapeDialect::initialize() { 134 addOperations< 135 #define GET_OP_LIST 136 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 137 >(); 138 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 139 addInterfaces<ShapeInlinerInterface>(); 140 // Allow unknown operations during prototyping and testing. As the dialect is 141 // still evolving it makes it simple to start with an unregistered ops and 142 // try different variants before actually defining the op. 143 allowUnknownOperations(); 144 } 145 146 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 147 Attribute value, Type type, 148 Location loc) { 149 if (type.isa<ShapeType>() || isExtentTensorType(type)) 150 return builder.create<ConstShapeOp>(loc, type, 151 value.cast<DenseIntElementsAttr>()); 152 if (type.isa<SizeType>()) 153 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 154 if (type.isa<WitnessType>()) 155 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 156 if (arith::ConstantOp::isBuildableWith(value, type)) 157 return builder.create<arith::ConstantOp>(loc, type, value); 158 return nullptr; 159 } 160 161 /// Parse a type registered to this dialect. 162 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 163 StringRef keyword; 164 if (parser.parseKeyword(&keyword)) 165 return Type(); 166 167 if (keyword == "shape") 168 return ShapeType::get(getContext()); 169 if (keyword == "size") 170 return SizeType::get(getContext()); 171 if (keyword == "value_shape") 172 return ValueShapeType::get(getContext()); 173 if (keyword == "witness") 174 return WitnessType::get(getContext()); 175 176 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 177 return Type(); 178 } 179 180 /// Print a type registered to this dialect. 181 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 182 TypeSwitch<Type>(type) 183 .Case<ShapeType>([&](Type) { os << "shape"; }) 184 .Case<SizeType>([&](Type) { os << "size"; }) 185 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 186 .Case<WitnessType>([&](Type) { os << "witness"; }) 187 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 188 } 189 190 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 191 NamedAttribute attribute) { 192 // Verify shape.lib attribute. 193 if (attribute.getName() == "shape.lib") { 194 if (!op->hasTrait<OpTrait::SymbolTable>()) 195 return op->emitError( 196 "shape.lib attribute may only be on op implementing SymbolTable"); 197 198 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 199 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 200 if (!symbol) 201 return op->emitError("shape function library ") 202 << symbolRef << " not found"; 203 return isa<shape::FunctionLibraryOp>(symbol) 204 ? success() 205 : op->emitError() 206 << symbolRef << " required to be shape function library"; 207 } 208 209 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 210 // Verify all entries are function libraries and mappings in libraries 211 // refer to unique ops. 212 DenseSet<StringAttr> key; 213 for (auto it : arr) { 214 if (!it.isa<SymbolRefAttr>()) 215 return op->emitError( 216 "only SymbolRefAttr allowed in shape.lib attribute array"); 217 218 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 219 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 220 if (!shapeFnLib) 221 return op->emitError() 222 << it << " does not refer to FunctionLibraryOp"; 223 for (auto mapping : shapeFnLib.getMapping()) { 224 if (!key.insert(mapping.getName()).second) { 225 return op->emitError("only one op to shape mapping allowed, found " 226 "multiple for `") 227 << mapping.getName() << "`"; 228 } 229 } 230 } 231 return success(); 232 } 233 234 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 235 "allowed as shape.lib attribute"); 236 } 237 return success(); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // AnyOp 242 //===----------------------------------------------------------------------===// 243 244 // TODO: Canonicalization should be implemented for shapes that can be 245 // determined through mixtures of the known dimensions of the inputs. 246 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 247 // Only the last operand is checked because AnyOp is commutative. 248 if (operands.back()) 249 return operands.back(); 250 251 return nullptr; 252 } 253 254 //===----------------------------------------------------------------------===// 255 // AssumingOp 256 //===----------------------------------------------------------------------===// 257 258 static ParseResult parseAssumingOp(OpAsmParser &parser, 259 OperationState &result) { 260 result.regions.reserve(1); 261 Region *doRegion = result.addRegion(); 262 263 auto &builder = parser.getBuilder(); 264 OpAsmParser::OperandType cond; 265 if (parser.parseOperand(cond) || 266 parser.resolveOperand(cond, builder.getType<WitnessType>(), 267 result.operands)) 268 return failure(); 269 270 // Parse optional results type list. 271 if (parser.parseOptionalArrowTypeList(result.types)) 272 return failure(); 273 274 // Parse the region and add a terminator if elided. 275 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 276 return failure(); 277 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 278 279 // Parse the optional attribute list. 280 if (parser.parseOptionalAttrDict(result.attributes)) 281 return failure(); 282 return success(); 283 } 284 285 static void print(OpAsmPrinter &p, AssumingOp op) { 286 bool yieldsResults = !op.getResults().empty(); 287 288 p << " " << op.getWitness(); 289 if (yieldsResults) 290 p << " -> (" << op.getResultTypes() << ")"; 291 p << ' '; 292 p.printRegion(op.getDoRegion(), 293 /*printEntryBlockArgs=*/false, 294 /*printBlockTerminators=*/yieldsResults); 295 p.printOptionalAttrDict(op->getAttrs()); 296 } 297 298 namespace { 299 // Removes AssumingOp with a passing witness and inlines the region. 300 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 301 using OpRewritePattern<AssumingOp>::OpRewritePattern; 302 303 LogicalResult matchAndRewrite(AssumingOp op, 304 PatternRewriter &rewriter) const override { 305 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 306 if (!witness || !witness.getPassingAttr()) 307 return failure(); 308 309 AssumingOp::inlineRegionIntoParent(op, rewriter); 310 return success(); 311 } 312 }; 313 314 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 315 using OpRewritePattern<AssumingOp>::OpRewritePattern; 316 317 LogicalResult matchAndRewrite(AssumingOp op, 318 PatternRewriter &rewriter) const override { 319 Block *body = op.getBody(); 320 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 321 322 // Find used values. 323 SmallVector<Value, 4> newYieldOperands; 324 Value opResult, yieldOperand; 325 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 326 std::tie(opResult, yieldOperand) = it; 327 if (!opResult.getUses().empty()) { 328 newYieldOperands.push_back(yieldOperand); 329 } 330 } 331 332 // Rewrite only if redundant results exist. 333 if (newYieldOperands.size() == yieldOp->getNumOperands()) 334 return failure(); 335 336 // Replace yield op in the old assuming op's body and move the entire region 337 // to the new assuming op. 338 rewriter.setInsertionPointToEnd(body); 339 auto newYieldOp = 340 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 341 rewriter.setInsertionPoint(op); 342 auto newOp = rewriter.create<AssumingOp>( 343 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 344 newOp.getDoRegion().takeBody(op.getDoRegion()); 345 346 // Use the new results to replace the previously used ones. 347 SmallVector<Value, 4> replacementValues; 348 auto src = newOp.getResults().begin(); 349 for (auto it : op.getResults()) { 350 if (it.getUses().empty()) 351 replacementValues.push_back(nullptr); 352 else 353 replacementValues.push_back(*src++); 354 } 355 rewriter.replaceOp(op, replacementValues); 356 return success(); 357 } 358 }; 359 } // namespace 360 361 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 362 MLIRContext *context) { 363 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 364 } 365 366 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 367 void AssumingOp::getSuccessorRegions( 368 Optional<unsigned> index, ArrayRef<Attribute> operands, 369 SmallVectorImpl<RegionSuccessor> ®ions) { 370 // AssumingOp has unconditional control flow into the region and back to the 371 // parent, so return the correct RegionSuccessor purely based on the index 372 // being None or 0. 373 if (index.hasValue()) { 374 regions.push_back(RegionSuccessor(getResults())); 375 return; 376 } 377 378 regions.push_back(RegionSuccessor(&getDoRegion())); 379 } 380 381 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 382 PatternRewriter &rewriter) { 383 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 384 auto *assumingBlock = op.getBody(); 385 auto initPosition = rewriter.getInsertionPoint(); 386 auto *blockAfterAssuming = 387 rewriter.splitBlock(blockBeforeAssuming, initPosition); 388 389 // Remove the AssumingOp and AssumingYieldOp. 390 auto &yieldOp = assumingBlock->back(); 391 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 392 rewriter.replaceOp(op, yieldOp.getOperands()); 393 rewriter.eraseOp(&yieldOp); 394 395 // Merge blocks together as there was no branching behavior from the 396 // AssumingOp. 397 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 398 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 399 } 400 401 void AssumingOp::build( 402 OpBuilder &builder, OperationState &result, Value witness, 403 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 404 405 result.addOperands(witness); 406 Region *bodyRegion = result.addRegion(); 407 bodyRegion->push_back(new Block); 408 Block &bodyBlock = bodyRegion->front(); 409 410 // Build body. 411 OpBuilder::InsertionGuard guard(builder); 412 builder.setInsertionPointToStart(&bodyBlock); 413 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 414 builder.create<AssumingYieldOp>(result.location, yieldValues); 415 416 SmallVector<Type, 2> assumingTypes; 417 for (Value v : yieldValues) 418 assumingTypes.push_back(v.getType()); 419 result.addTypes(assumingTypes); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // AddOp 424 //===----------------------------------------------------------------------===// 425 426 LogicalResult mlir::shape::AddOp::inferReturnTypes( 427 MLIRContext *context, Optional<Location> location, ValueRange operands, 428 DictionaryAttr attributes, RegionRange regions, 429 SmallVectorImpl<Type> &inferredReturnTypes) { 430 if (operands[0].getType().isa<SizeType>() || 431 operands[1].getType().isa<SizeType>()) 432 inferredReturnTypes.assign({SizeType::get(context)}); 433 else 434 inferredReturnTypes.assign({IndexType::get(context)}); 435 return success(); 436 } 437 438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 439 // SizeType is compatible with IndexType. 440 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 441 } 442 443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 444 // add(x, 0) -> x 445 if (matchPattern(getRhs(), m_Zero())) 446 return getLhs(); 447 448 return constFoldBinaryOp<IntegerAttr>( 449 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 450 } 451 452 //===----------------------------------------------------------------------===// 453 // AssumingAllOp 454 //===----------------------------------------------------------------------===// 455 456 namespace { 457 struct AssumingAllToCstrEqCanonicalization 458 : public OpRewritePattern<AssumingAllOp> { 459 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 460 461 LogicalResult matchAndRewrite(AssumingAllOp op, 462 PatternRewriter &rewriter) const override { 463 SmallVector<Value, 8> shapes; 464 for (Value w : op.getInputs()) { 465 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 466 if (!cstrEqOp) 467 return failure(); 468 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 469 return llvm::is_contained(shapes, s); 470 }); 471 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 472 return failure(); 473 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 474 } 475 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 476 return success(); 477 } 478 }; 479 480 template <typename OpTy> 481 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 482 using OpRewritePattern<OpTy>::OpRewritePattern; 483 484 LogicalResult matchAndRewrite(OpTy op, 485 PatternRewriter &rewriter) const override { 486 // Find unique operands. 487 SmallVector<Value, 2> unique; 488 for (Value v : op.getOperands()) { 489 if (!llvm::is_contained(unique, v)) 490 unique.push_back(v); 491 } 492 493 // Reduce op to equivalent with unique operands. 494 if (unique.size() < op.getNumOperands()) { 495 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 496 op->getAttrs()); 497 return success(); 498 } 499 500 return failure(); 501 } 502 }; 503 } // namespace 504 505 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 506 MLIRContext *context) { 507 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 508 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 509 } 510 511 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 512 // Iterate in reverse to first handle all constant operands. They are 513 // guaranteed to be the tail of the inputs because this is commutative. 514 for (int idx = operands.size() - 1; idx >= 0; idx--) { 515 Attribute a = operands[idx]; 516 // Cannot fold if any inputs are not constant; 517 if (!a) 518 return nullptr; 519 520 // We do not need to keep statically known values after handling them in 521 // this method. 522 getOperation()->eraseOperand(idx); 523 524 // Always false if any input is statically known false 525 if (!a.cast<BoolAttr>().getValue()) 526 return a; 527 } 528 // If this is reached, all inputs were statically known passing. 529 return BoolAttr::get(getContext(), true); 530 } 531 532 static LogicalResult verify(AssumingAllOp op) { 533 // Ensure that AssumingAllOp contains at least one operand 534 if (op.getNumOperands() == 0) 535 return op.emitOpError("no operands specified"); 536 537 return success(); 538 } 539 540 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 541 ValueRange inputs) { 542 build(b, state, b.getType<WitnessType>(), inputs); 543 } 544 545 //===----------------------------------------------------------------------===// 546 // BroadcastOp 547 //===----------------------------------------------------------------------===// 548 549 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 550 if (getShapes().size() == 1) { 551 // Otherwise, we need a cast which would be a canonicalization, not folding. 552 if (getShapes().front().getType() != getType()) 553 return nullptr; 554 return getShapes().front(); 555 } 556 557 // TODO: Support folding with more than 2 input shapes 558 if (getShapes().size() > 2) 559 return nullptr; 560 561 if (!operands[0] || !operands[1]) 562 return nullptr; 563 auto lhsShape = llvm::to_vector<6>( 564 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 565 auto rhsShape = llvm::to_vector<6>( 566 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 567 SmallVector<int64_t, 6> resultShape; 568 569 // If the shapes are not compatible, we can't fold it. 570 // TODO: Fold to an "error". 571 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 572 return nullptr; 573 574 Builder builder(getContext()); 575 return builder.getIndexTensorAttr(resultShape); 576 } 577 578 static LogicalResult verify(BroadcastOp op) { 579 return verifyShapeOrExtentTensorOp(op); 580 } 581 582 namespace { 583 template <typename OpTy> 584 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 585 using OpRewritePattern<OpTy>::OpRewritePattern; 586 587 LogicalResult matchAndRewrite(OpTy op, 588 PatternRewriter &rewriter) const override { 589 auto isPotentiallyNonEmptyShape = [](Value shape) { 590 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 591 if (extentTensorTy.getDimSize(0) == 0) 592 return false; 593 } 594 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 595 if (constShape.getShape().empty()) 596 return false; 597 } 598 return true; 599 }; 600 auto newOperands = llvm::to_vector<8>( 601 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 602 603 // Reduce op to equivalent without empty shape operands. 604 if (newOperands.size() < op.getNumOperands()) { 605 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 606 op->getAttrs()); 607 return success(); 608 } 609 610 return failure(); 611 } 612 }; 613 614 struct BroadcastForwardSingleOperandPattern 615 : public OpRewritePattern<BroadcastOp> { 616 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 617 618 LogicalResult matchAndRewrite(BroadcastOp op, 619 PatternRewriter &rewriter) const override { 620 if (op.getNumOperands() != 1) 621 return failure(); 622 Value replacement = op.getShapes().front(); 623 624 // Insert cast if needed. 625 if (replacement.getType() != op.getType()) { 626 auto loc = op.getLoc(); 627 if (op.getType().isa<ShapeType>()) { 628 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 629 } else { 630 assert(!op.getType().isa<ShapeType>() && 631 !replacement.getType().isa<ShapeType>() && 632 "expect extent tensor cast"); 633 replacement = 634 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 635 } 636 } 637 638 rewriter.replaceOp(op, replacement); 639 return success(); 640 } 641 }; 642 643 struct BroadcastFoldConstantOperandsPattern 644 : public OpRewritePattern<BroadcastOp> { 645 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 646 647 LogicalResult matchAndRewrite(BroadcastOp op, 648 PatternRewriter &rewriter) const override { 649 SmallVector<int64_t, 8> foldedConstantShape; 650 SmallVector<Value, 8> newShapeOperands; 651 for (Value shape : op.getShapes()) { 652 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 653 SmallVector<int64_t, 8> newFoldedConstantShape; 654 if (OpTrait::util::getBroadcastedShape( 655 foldedConstantShape, 656 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 657 newFoldedConstantShape)) { 658 foldedConstantShape = newFoldedConstantShape; 659 continue; 660 } 661 } 662 newShapeOperands.push_back(shape); 663 } 664 665 // Need at least two constant operands to fold anything. 666 if (op.getNumOperands() - newShapeOperands.size() < 2) 667 return failure(); 668 669 auto foldedConstantOperandsTy = RankedTensorType::get( 670 {static_cast<int64_t>(foldedConstantShape.size())}, 671 rewriter.getIndexType()); 672 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 673 op.getLoc(), foldedConstantOperandsTy, 674 rewriter.getIndexTensorAttr(foldedConstantShape))); 675 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 676 newShapeOperands); 677 return success(); 678 } 679 }; 680 681 template <typename OpTy> 682 struct CanonicalizeCastExtentTensorOperandsPattern 683 : public OpRewritePattern<OpTy> { 684 using OpRewritePattern<OpTy>::OpRewritePattern; 685 686 LogicalResult matchAndRewrite(OpTy op, 687 PatternRewriter &rewriter) const override { 688 // Canonicalize operands. 689 bool anyChange = false; 690 auto canonicalizeOperand = [&](Value operand) { 691 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 692 // Only eliminate the cast if it holds no shape information. 693 bool isInformationLoosingCast = 694 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 695 if (isInformationLoosingCast) { 696 anyChange = true; 697 return castOp.source(); 698 } 699 } 700 return operand; 701 }; 702 auto newOperands = llvm::to_vector<8>( 703 llvm::map_range(op.getOperands(), canonicalizeOperand)); 704 705 // Rewrite op if any change required. 706 if (!anyChange) 707 return failure(); 708 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 709 return success(); 710 } 711 }; 712 713 struct BroadcastConcretizeResultTypePattern 714 : public OpRewritePattern<BroadcastOp> { 715 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 716 717 LogicalResult matchAndRewrite(BroadcastOp op, 718 PatternRewriter &rewriter) const override { 719 // Only concretize dynamic extent tensor result types. 720 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 721 if (!resultTy || !resultTy.isDynamicDim(0)) 722 return failure(); 723 724 // Infer resulting shape rank if possible. 725 int64_t maxRank = 0; 726 for (Value shape : op.getShapes()) { 727 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 728 // Cannot infer resulting shape rank if any operand is dynamically 729 // ranked. 730 if (extentTensorTy.isDynamicDim(0)) 731 return failure(); 732 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 733 } 734 } 735 736 auto newOp = rewriter.create<BroadcastOp>( 737 op.getLoc(), getExtentTensorType(getContext(), maxRank), 738 op.getShapes()); 739 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 740 return success(); 741 } 742 }; 743 } // namespace 744 745 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 746 MLIRContext *context) { 747 patterns.add<BroadcastConcretizeResultTypePattern, 748 BroadcastFoldConstantOperandsPattern, 749 BroadcastForwardSingleOperandPattern, 750 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 751 RemoveDuplicateOperandsPattern<BroadcastOp>, 752 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 753 } 754 755 //===----------------------------------------------------------------------===// 756 // ConcatOp 757 //===----------------------------------------------------------------------===// 758 759 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 760 if (!operands[0] || !operands[1]) 761 return nullptr; 762 auto lhsShape = llvm::to_vector<6>( 763 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 764 auto rhsShape = llvm::to_vector<6>( 765 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 766 SmallVector<int64_t, 6> resultShape; 767 resultShape.append(lhsShape.begin(), lhsShape.end()); 768 resultShape.append(rhsShape.begin(), rhsShape.end()); 769 Builder builder(getContext()); 770 return builder.getIndexTensorAttr(resultShape); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // ConstShapeOp 775 //===----------------------------------------------------------------------===// 776 777 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 778 p << " "; 779 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 780 p << "["; 781 interleaveComma(op.getShape().getValues<int64_t>(), p, 782 [&](int64_t i) { p << i; }); 783 p << "] : "; 784 p.printType(op.getType()); 785 } 786 787 static ParseResult parseConstShapeOp(OpAsmParser &parser, 788 OperationState &result) { 789 if (parser.parseOptionalAttrDict(result.attributes)) 790 return failure(); 791 // We piggy-back on ArrayAttr parsing, though we don't internally store the 792 // shape as an ArrayAttr. 793 // TODO: Implement custom parser and maybe make syntax a bit more concise. 794 Attribute extentsRaw; 795 NamedAttrList dummy; 796 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 797 return failure(); 798 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 799 if (!extentsArray) 800 return failure(); 801 SmallVector<int64_t, 6> ints; 802 for (Attribute extent : extentsArray) { 803 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 804 if (!attr) 805 return failure(); 806 ints.push_back(attr.getInt()); 807 } 808 Builder &builder = parser.getBuilder(); 809 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 810 Type resultTy; 811 if (parser.parseColonType(resultTy)) 812 return failure(); 813 result.types.push_back(resultTy); 814 return success(); 815 } 816 817 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 818 819 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 820 MLIRContext *context) { 821 patterns.add<TensorCastConstShape>(context); 822 } 823 824 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 825 MLIRContext *context, Optional<Location> location, ValueRange operands, 826 DictionaryAttr attributes, RegionRange regions, 827 SmallVectorImpl<Type> &inferredReturnTypes) { 828 Builder b(context); 829 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 830 if (!shape) 831 return emitOptionalError(location, "missing shape attribute"); 832 inferredReturnTypes.assign({RankedTensorType::get( 833 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 834 return success(); 835 } 836 837 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 838 TypeRange r) { 839 if (l.size() != 1 || r.size() != 1) 840 return false; 841 842 Type lhs = l.front(); 843 Type rhs = r.front(); 844 845 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 846 // Shape type is compatible with all other valid return types. 847 return true; 848 return lhs == rhs; 849 } 850 851 //===----------------------------------------------------------------------===// 852 // CstrBroadcastableOp 853 //===----------------------------------------------------------------------===// 854 855 void CstrBroadcastableOp::getCanonicalizationPatterns( 856 RewritePatternSet &patterns, MLIRContext *context) { 857 // Canonicalization patterns have overlap with the considerations during 858 // folding in case additional shape information is inferred at some point that 859 // does not result in folding. 860 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 861 CstrBroadcastableEqOps, 862 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 863 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 864 } 865 866 // Return true if there is exactly one attribute not representing a scalar 867 // broadcast. 868 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 869 bool nonScalarSeen = false; 870 for (Attribute a : attributes) { 871 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 872 if (nonScalarSeen) 873 return false; 874 nonScalarSeen = true; 875 } 876 } 877 return true; 878 } 879 880 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 881 // No broadcasting is needed if all operands but one are scalar. 882 if (hasAtMostSingleNonScalar(operands)) 883 return BoolAttr::get(getContext(), true); 884 885 if ([&] { 886 SmallVector<SmallVector<int64_t, 6>, 6> extents; 887 for (const auto &operand : operands) { 888 if (!operand) 889 return false; 890 extents.push_back(llvm::to_vector<6>( 891 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 892 } 893 return OpTrait::util::staticallyKnownBroadcastable(extents); 894 }()) 895 return BoolAttr::get(getContext(), true); 896 897 // Lastly, see if folding can be completed based on what constraints are known 898 // on the input shapes. 899 if ([&] { 900 SmallVector<SmallVector<int64_t, 6>, 6> extents; 901 for (auto shapeValue : getShapes()) { 902 extents.emplace_back(); 903 if (failed(getShapeVec(shapeValue, extents.back()))) 904 return false; 905 } 906 return OpTrait::util::staticallyKnownBroadcastable(extents); 907 }()) 908 return BoolAttr::get(getContext(), true); 909 910 // Because a failing witness result here represents an eventual assertion 911 // failure, we do not replace it with a constant witness. 912 return nullptr; 913 } 914 915 static LogicalResult verify(CstrBroadcastableOp op) { 916 // Ensure that AssumingAllOp contains at least one operand 917 if (op.getNumOperands() < 2) 918 return op.emitOpError("required at least 2 input shapes"); 919 return success(); 920 } 921 922 //===----------------------------------------------------------------------===// 923 // CstrEqOp 924 //===----------------------------------------------------------------------===// 925 926 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 927 MLIRContext *context) { 928 // If inputs are equal, return passing witness 929 patterns.add<CstrEqEqOps>(context); 930 } 931 932 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 933 if (llvm::all_of(operands, 934 [&](Attribute a) { return a && a == operands[0]; })) 935 return BoolAttr::get(getContext(), true); 936 937 // Because a failing witness result here represents an eventual assertion 938 // failure, we do not try to replace it with a constant witness. Similarly, we 939 // cannot if there are any non-const inputs. 940 return nullptr; 941 } 942 943 //===----------------------------------------------------------------------===// 944 // ConstSizeOp 945 //===----------------------------------------------------------------------===// 946 947 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 948 int64_t value) { 949 build(builder, result, builder.getIndexAttr(value)); 950 } 951 952 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 953 954 void ConstSizeOp::getAsmResultNames( 955 llvm::function_ref<void(Value, StringRef)> setNameFn) { 956 SmallString<4> buffer; 957 llvm::raw_svector_ostream os(buffer); 958 os << "c" << getValue(); 959 setNameFn(getResult(), os.str()); 960 } 961 962 //===----------------------------------------------------------------------===// 963 // ConstWitnessOp 964 //===----------------------------------------------------------------------===// 965 966 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 967 return getPassingAttr(); 968 } 969 970 //===----------------------------------------------------------------------===// 971 // CstrRequireOp 972 //===----------------------------------------------------------------------===// 973 974 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 975 return operands[0]; 976 } 977 978 //===----------------------------------------------------------------------===// 979 // DivOp 980 //===----------------------------------------------------------------------===// 981 982 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 983 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 984 if (!lhs) 985 return nullptr; 986 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 987 if (!rhs) 988 return nullptr; 989 990 // Division in APInt does not follow floor(lhs, rhs) when the result is 991 // negative. Rather, APInt rounds toward zero. 992 APInt quotient, remainder; 993 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 994 if (quotient.isNegative() && !remainder.isNullValue()) { 995 quotient -= 1; 996 } 997 998 Type indexTy = IndexType::get(getContext()); 999 return IntegerAttr::get(indexTy, quotient); 1000 } 1001 1002 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1003 MLIRContext *context, Optional<Location> location, ValueRange operands, 1004 DictionaryAttr attributes, RegionRange regions, 1005 SmallVectorImpl<Type> &inferredReturnTypes) { 1006 if (operands[0].getType().isa<SizeType>() || 1007 operands[1].getType().isa<SizeType>()) 1008 inferredReturnTypes.assign({SizeType::get(context)}); 1009 else 1010 inferredReturnTypes.assign({IndexType::get(context)}); 1011 return success(); 1012 } 1013 1014 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1015 // SizeType is compatible with IndexType. 1016 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1017 } 1018 1019 //===----------------------------------------------------------------------===// 1020 // ShapeEqOp 1021 //===----------------------------------------------------------------------===// 1022 1023 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1024 bool allSame = true; 1025 if (!operands.empty() && !operands[0]) 1026 return {}; 1027 for (Attribute operand : operands.drop_front(1)) { 1028 if (!operand) 1029 return {}; 1030 allSame = allSame && operand == operands[0]; 1031 } 1032 return BoolAttr::get(getContext(), allSame); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // IndexToSizeOp 1037 //===----------------------------------------------------------------------===// 1038 1039 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1040 // Constant values of both types, `shape.size` and `index`, are represented as 1041 // `IntegerAttr`s which makes constant folding simple. 1042 if (Attribute arg = operands[0]) 1043 return arg; 1044 return {}; 1045 } 1046 1047 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1048 MLIRContext *context) { 1049 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // FromExtentsOp 1054 //===----------------------------------------------------------------------===// 1055 1056 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1057 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1058 return nullptr; 1059 SmallVector<int64_t, 6> extents; 1060 for (auto attr : operands) 1061 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1062 Builder builder(getContext()); 1063 return builder.getIndexTensorAttr(extents); 1064 } 1065 1066 //===----------------------------------------------------------------------===// 1067 // FunctionLibraryOp 1068 //===----------------------------------------------------------------------===// 1069 1070 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1071 StringRef name) { 1072 result.attributes.push_back(builder.getNamedAttr( 1073 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1074 } 1075 1076 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1077 auto attr = getMapping() 1078 .get(op->getName().getIdentifier()) 1079 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1080 if (!attr) 1081 return nullptr; 1082 return lookupSymbol<FuncOp>(attr); 1083 } 1084 1085 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1086 OperationState &result) { 1087 // Parse the op name. 1088 StringAttr nameAttr; 1089 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1090 result.attributes)) 1091 return failure(); 1092 1093 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1094 return failure(); 1095 1096 auto *bodyRegion = result.addRegion(); 1097 if (parser.parseRegion(*bodyRegion)) 1098 return failure(); 1099 1100 if (parser.parseKeyword("mapping")) 1101 return failure(); 1102 1103 DictionaryAttr mappingAttr; 1104 if (parser.parseAttribute(mappingAttr, 1105 parser.getBuilder().getType<NoneType>(), "mapping", 1106 result.attributes)) 1107 return failure(); 1108 return success(); 1109 } 1110 1111 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1112 p << ' '; 1113 p.printSymbolName(op.getName()); 1114 p.printOptionalAttrDictWithKeyword( 1115 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1116 p << ' '; 1117 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1118 /*printBlockTerminators=*/false); 1119 p << " mapping "; 1120 p.printAttributeWithoutType(op.getMappingAttr()); 1121 } 1122 1123 //===----------------------------------------------------------------------===// 1124 // GetExtentOp 1125 //===----------------------------------------------------------------------===// 1126 1127 Optional<int64_t> GetExtentOp::getConstantDim() { 1128 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1129 return constSizeOp.getValue().getLimitedValue(); 1130 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1131 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1132 return llvm::None; 1133 } 1134 1135 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1136 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1137 if (!elements) 1138 return nullptr; 1139 Optional<int64_t> dim = getConstantDim(); 1140 if (!dim.hasValue()) 1141 return nullptr; 1142 if (dim.getValue() >= elements.getNumElements()) 1143 return nullptr; 1144 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1145 } 1146 1147 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1148 int64_t dim) { 1149 auto loc = result.location; 1150 auto dimAttr = builder.getIndexAttr(dim); 1151 if (shape.getType().isa<ShapeType>()) { 1152 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1153 build(builder, result, builder.getType<SizeType>(), shape, dim); 1154 } else { 1155 Value dim = 1156 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1157 build(builder, result, builder.getIndexType(), shape, dim); 1158 } 1159 } 1160 1161 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1162 MLIRContext *context, Optional<Location> location, ValueRange operands, 1163 DictionaryAttr attributes, RegionRange regions, 1164 SmallVectorImpl<Type> &inferredReturnTypes) { 1165 inferredReturnTypes.assign({IndexType::get(context)}); 1166 return success(); 1167 } 1168 1169 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1170 TypeRange r) { 1171 // SizeType is compatible with IndexType. 1172 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1173 } 1174 1175 //===----------------------------------------------------------------------===// 1176 // IsBroadcastableOp 1177 //===----------------------------------------------------------------------===// 1178 1179 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1180 MLIRContext *context) { 1181 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1182 } 1183 1184 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1185 // Can always broadcast fewer than two shapes. 1186 if (operands.size() < 2) { 1187 return BoolAttr::get(getContext(), true); 1188 } 1189 1190 return nullptr; 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // MeetOp 1195 //===----------------------------------------------------------------------===// 1196 1197 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1198 MLIRContext *context, Optional<Location> location, ValueRange operands, 1199 DictionaryAttr attributes, RegionRange regions, 1200 SmallVectorImpl<Type> &inferredReturnTypes) { 1201 inferredReturnTypes.assign({operands[0].getType()}); 1202 return success(); 1203 } 1204 1205 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1206 if (l.size() != 1 || r.size() != 1) 1207 return false; 1208 if (l == r) 1209 return true; 1210 1211 Type lhs = l.front(); 1212 Type rhs = r.front(); 1213 1214 if (lhs != rhs) 1215 return false; 1216 1217 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1218 return true; 1219 1220 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1221 return true; 1222 return false; 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // RankOp 1227 //===----------------------------------------------------------------------===// 1228 1229 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1230 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1231 if (!shape) 1232 return {}; 1233 int64_t rank = shape.getNumElements(); 1234 Builder builder(getContext()); 1235 return builder.getIndexAttr(rank); 1236 } 1237 1238 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1239 /// Constant folding fails in cases where only the rank is constant, not the 1240 /// shape itself. 1241 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1242 /// 1243 /// Example: 1244 /// 1245 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1246 /// %rank = shape.rank %shape 1247 /// 1248 /// becomes 1249 /// 1250 /// %rank = shape.const_size 3 1251 1252 namespace { 1253 struct RankShapeOfCanonicalizationPattern 1254 : public OpRewritePattern<shape::RankOp> { 1255 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1256 1257 LogicalResult matchAndRewrite(shape::RankOp op, 1258 PatternRewriter &rewriter) const override { 1259 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1260 if (!shapeOfOp) 1261 return failure(); 1262 auto rankedTensorType = 1263 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1264 if (!rankedTensorType) 1265 return failure(); 1266 int64_t rank = rankedTensorType.getRank(); 1267 if (op.getType().isa<IndexType>()) { 1268 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1269 rank); 1270 } else if (op.getType().isa<shape::SizeType>()) { 1271 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1272 } else { 1273 return failure(); 1274 } 1275 return success(); 1276 } 1277 }; 1278 } // namespace 1279 1280 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1281 MLIRContext *context) { 1282 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1283 } 1284 1285 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1286 MLIRContext *context, Optional<Location> location, ValueRange operands, 1287 DictionaryAttr attributes, RegionRange regions, 1288 SmallVectorImpl<Type> &inferredReturnTypes) { 1289 if (operands[0].getType().isa<ShapeType>()) 1290 inferredReturnTypes.assign({SizeType::get(context)}); 1291 else 1292 inferredReturnTypes.assign({IndexType::get(context)}); 1293 return success(); 1294 } 1295 1296 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1297 // SizeType is compatible with IndexType. 1298 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1299 } 1300 1301 //===----------------------------------------------------------------------===// 1302 // NumElementsOp 1303 //===----------------------------------------------------------------------===// 1304 1305 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1306 1307 // Fold only when argument constant. 1308 Attribute shape = operands[0]; 1309 if (!shape) 1310 return {}; 1311 1312 APInt product(64, 1); 1313 for (auto value : shape.cast<DenseIntElementsAttr>()) 1314 product *= value; 1315 Builder builder(getContext()); 1316 return builder.getIndexAttr(product.getLimitedValue()); 1317 } 1318 1319 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1320 MLIRContext *context, Optional<Location> location, ValueRange operands, 1321 DictionaryAttr attributes, RegionRange regions, 1322 SmallVectorImpl<Type> &inferredReturnTypes) { 1323 if (operands[0].getType().isa<ShapeType>()) 1324 inferredReturnTypes.assign({SizeType::get(context)}); 1325 else 1326 inferredReturnTypes.assign({IndexType::get(context)}); 1327 return success(); 1328 } 1329 1330 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1331 TypeRange r) { 1332 // SizeType is compatible with IndexType. 1333 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1334 } 1335 1336 //===----------------------------------------------------------------------===// 1337 // MaxOp 1338 //===----------------------------------------------------------------------===// 1339 1340 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1341 // If operands are equal, just propagate one. 1342 if (getLhs() == getRhs()) 1343 return getLhs(); 1344 return nullptr; 1345 } 1346 1347 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1348 MLIRContext *context, Optional<Location> location, ValueRange operands, 1349 DictionaryAttr attributes, RegionRange regions, 1350 SmallVectorImpl<Type> &inferredReturnTypes) { 1351 if (operands[0].getType() == operands[1].getType()) 1352 inferredReturnTypes.assign({operands[0].getType()}); 1353 else 1354 inferredReturnTypes.assign({SizeType::get(context)}); 1355 return success(); 1356 } 1357 1358 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1359 if (l.size() != 1 || r.size() != 1) 1360 return false; 1361 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1362 return true; 1363 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1364 return true; 1365 return false; 1366 } 1367 1368 //===----------------------------------------------------------------------===// 1369 // MinOp 1370 //===----------------------------------------------------------------------===// 1371 1372 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1373 // If operands are equal, just propagate one. 1374 if (getLhs() == getRhs()) 1375 return getLhs(); 1376 return nullptr; 1377 } 1378 1379 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1380 MLIRContext *context, Optional<Location> location, ValueRange operands, 1381 DictionaryAttr attributes, RegionRange regions, 1382 SmallVectorImpl<Type> &inferredReturnTypes) { 1383 if (operands[0].getType() == operands[1].getType()) 1384 inferredReturnTypes.assign({operands[0].getType()}); 1385 else 1386 inferredReturnTypes.assign({SizeType::get(context)}); 1387 return success(); 1388 } 1389 1390 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1391 if (l.size() != 1 || r.size() != 1) 1392 return false; 1393 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1394 return true; 1395 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1396 return true; 1397 return false; 1398 } 1399 1400 //===----------------------------------------------------------------------===// 1401 // MulOp 1402 //===----------------------------------------------------------------------===// 1403 1404 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1405 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1406 if (!lhs) 1407 return nullptr; 1408 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1409 if (!rhs) 1410 return nullptr; 1411 APInt folded = lhs.getValue() * rhs.getValue(); 1412 Type indexTy = IndexType::get(getContext()); 1413 return IntegerAttr::get(indexTy, folded); 1414 } 1415 1416 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1417 MLIRContext *context, Optional<Location> location, ValueRange operands, 1418 DictionaryAttr attributes, RegionRange regions, 1419 SmallVectorImpl<Type> &inferredReturnTypes) { 1420 if (operands[0].getType().isa<SizeType>() || 1421 operands[1].getType().isa<SizeType>()) 1422 inferredReturnTypes.assign({SizeType::get(context)}); 1423 else 1424 inferredReturnTypes.assign({IndexType::get(context)}); 1425 return success(); 1426 } 1427 1428 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1429 // SizeType is compatible with IndexType. 1430 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1431 } 1432 //===----------------------------------------------------------------------===// 1433 // ShapeOfOp 1434 //===----------------------------------------------------------------------===// 1435 1436 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1437 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1438 if (!type || !type.hasStaticShape()) 1439 return nullptr; 1440 Builder builder(getContext()); 1441 return builder.getIndexTensorAttr(type.getShape()); 1442 } 1443 1444 namespace { 1445 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1446 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1447 1448 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1449 PatternRewriter &rewriter) const override { 1450 if (!op.getArg().getType().isa<ShapedType>()) 1451 return failure(); 1452 if (op.getType().isa<ShapedType>()) 1453 return failure(); 1454 1455 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1456 op.getArg()); 1457 return success(); 1458 } 1459 }; 1460 1461 // Canonicalize 1462 // ``` 1463 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1464 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1465 // ``` 1466 // to 1467 // ``` 1468 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1469 // ``` 1470 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1471 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1472 1473 LogicalResult matchAndRewrite(tensor::CastOp op, 1474 PatternRewriter &rewriter) const override { 1475 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1476 if (!ty || ty.getRank() != 1) 1477 return failure(); 1478 1479 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1480 if (!shapeOfOp) 1481 return failure(); 1482 1483 // Argument type must be ranked and must not conflict. 1484 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1485 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1486 return failure(); 1487 1488 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1489 return success(); 1490 } 1491 }; 1492 } // namespace 1493 1494 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1495 MLIRContext *context) { 1496 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1497 ExtractFromShapeOfExtentTensor>(context); 1498 } 1499 1500 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1501 MLIRContext *context, Optional<Location> location, ValueRange operands, 1502 DictionaryAttr attributes, RegionRange regions, 1503 SmallVectorImpl<Type> &inferredReturnTypes) { 1504 if (operands[0].getType().isa<ValueShapeType>()) 1505 inferredReturnTypes.assign({ShapeType::get(context)}); 1506 else { 1507 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1508 int64_t rank = 1509 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1510 Type indexTy = IndexType::get(context); 1511 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1512 inferredReturnTypes.assign({extentTensorTy}); 1513 } 1514 return success(); 1515 } 1516 1517 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1518 if (l.size() != 1 || r.size() != 1) 1519 return false; 1520 if (l == r) 1521 return true; 1522 1523 Type lhs = l.front(); 1524 Type rhs = r.front(); 1525 1526 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1527 return false; 1528 1529 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1530 // Shape type is compatible with all other valid return types. 1531 return true; 1532 1533 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1534 return true; 1535 return false; 1536 } 1537 1538 //===----------------------------------------------------------------------===// 1539 // SizeToIndexOp 1540 //===----------------------------------------------------------------------===// 1541 1542 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1543 // Constant values of both types, `shape.size` and `index`, are represented as 1544 // `IntegerAttr`s which makes constant folding simple. 1545 if (Attribute arg = operands[0]) 1546 return arg; 1547 return impl::foldCastOp(*this); 1548 } 1549 1550 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1551 MLIRContext *context) { 1552 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1553 } 1554 1555 //===----------------------------------------------------------------------===// 1556 // YieldOp 1557 //===----------------------------------------------------------------------===// 1558 1559 static LogicalResult verify(shape::YieldOp op) { 1560 auto *parentOp = op->getParentOp(); 1561 auto results = parentOp->getResults(); 1562 auto operands = op.getOperands(); 1563 1564 if (parentOp->getNumResults() != op.getNumOperands()) 1565 return op.emitOpError() << "number of operands does not match number of " 1566 "results of its parent"; 1567 for (auto e : llvm::zip(results, operands)) 1568 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1569 return op.emitOpError() 1570 << "types mismatch between yield op and its parent"; 1571 1572 return success(); 1573 } 1574 1575 //===----------------------------------------------------------------------===// 1576 // SplitAtOp 1577 //===----------------------------------------------------------------------===// 1578 1579 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1580 SmallVectorImpl<OpFoldResult> &results) { 1581 if (!operands[0] || !operands[1]) 1582 return failure(); 1583 auto shapeVec = llvm::to_vector<6>( 1584 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1585 auto shape = llvm::makeArrayRef(shapeVec); 1586 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1587 // Verify that the split point is in the correct range. 1588 // TODO: Constant fold to an "error". 1589 int64_t rank = shape.size(); 1590 if (!(-rank <= splitPoint && splitPoint <= rank)) 1591 return failure(); 1592 if (splitPoint < 0) 1593 splitPoint += shape.size(); 1594 Builder builder(operands[0].getContext()); 1595 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1596 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1597 return success(); 1598 } 1599 1600 //===----------------------------------------------------------------------===// 1601 // ToExtentTensorOp 1602 //===----------------------------------------------------------------------===// 1603 1604 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1605 if (!operands[0]) 1606 return impl::foldCastOp(*this); 1607 Builder builder(getContext()); 1608 auto shape = llvm::to_vector<6>( 1609 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1610 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1611 builder.getIndexType()); 1612 return DenseIntElementsAttr::get(type, shape); 1613 } 1614 1615 //===----------------------------------------------------------------------===// 1616 // ReduceOp 1617 //===----------------------------------------------------------------------===// 1618 1619 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1620 ValueRange initVals) { 1621 result.addOperands(shape); 1622 result.addOperands(initVals); 1623 1624 Region *bodyRegion = result.addRegion(); 1625 bodyRegion->push_back(new Block); 1626 Block &bodyBlock = bodyRegion->front(); 1627 bodyBlock.addArgument(builder.getIndexType(), result.location); 1628 1629 Type elementType; 1630 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1631 elementType = tensorType.getElementType(); 1632 else 1633 elementType = SizeType::get(builder.getContext()); 1634 bodyBlock.addArgument(elementType, shape.getLoc()); 1635 1636 for (Value initVal : initVals) { 1637 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1638 result.addTypes(initVal.getType()); 1639 } 1640 } 1641 1642 static LogicalResult verify(ReduceOp op) { 1643 // Verify block arg types. 1644 Block &block = op.getRegion().front(); 1645 1646 // The block takes index, extent, and aggregated values as arguments. 1647 auto blockArgsCount = op.getInitVals().size() + 2; 1648 if (block.getNumArguments() != blockArgsCount) 1649 return op.emitOpError() << "ReduceOp body is expected to have " 1650 << blockArgsCount << " arguments"; 1651 1652 // The first block argument is the index and must always be of type `index`. 1653 if (!block.getArgument(0).getType().isa<IndexType>()) 1654 return op.emitOpError( 1655 "argument 0 of ReduceOp body is expected to be of IndexType"); 1656 1657 // The second block argument is the extent and must be of type `size` or 1658 // `index`, depending on whether the reduce operation is applied to a shape or 1659 // to an extent tensor. 1660 Type extentTy = block.getArgument(1).getType(); 1661 if (op.getShape().getType().isa<ShapeType>()) { 1662 if (!extentTy.isa<SizeType>()) 1663 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1664 "SizeType if the ReduceOp operates on a ShapeType"); 1665 } else { 1666 if (!extentTy.isa<IndexType>()) 1667 return op.emitOpError( 1668 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1669 "ReduceOp operates on an extent tensor"); 1670 } 1671 1672 for (const auto &type : llvm::enumerate(op.getInitVals())) 1673 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1674 return op.emitOpError() 1675 << "type mismatch between argument " << type.index() + 2 1676 << " of ReduceOp body and initial value " << type.index(); 1677 return success(); 1678 } 1679 1680 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1681 // Parse operands. 1682 SmallVector<OpAsmParser::OperandType, 3> operands; 1683 Type shapeOrExtentTensorType; 1684 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1685 OpAsmParser::Delimiter::Paren) || 1686 parser.parseColonType(shapeOrExtentTensorType) || 1687 parser.parseOptionalArrowTypeList(result.types)) 1688 return failure(); 1689 1690 // Resolve operands. 1691 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1692 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1693 result.operands) || 1694 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1695 result.operands)) 1696 return failure(); 1697 1698 // Parse the body. 1699 Region *body = result.addRegion(); 1700 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1701 return failure(); 1702 1703 // Parse attributes. 1704 if (parser.parseOptionalAttrDict(result.attributes)) 1705 return failure(); 1706 1707 return success(); 1708 } 1709 1710 static void print(OpAsmPrinter &p, ReduceOp op) { 1711 p << '(' << op.getShape() << ", " << op.getInitVals() 1712 << ") : " << op.getShape().getType(); 1713 p.printOptionalArrowTypeList(op.getResultTypes()); 1714 p << ' '; 1715 p.printRegion(op.getRegion()); 1716 p.printOptionalAttrDict(op->getAttrs()); 1717 } 1718 1719 #define GET_OP_CLASSES 1720 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1721