1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Traits.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/SmallString.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/raw_ostream.h" 26 27 using namespace mlir; 28 using namespace mlir::shape; 29 30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 31 32 namespace { 33 #include "ShapeCanonicalization.inc" 34 } // namespace 35 36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 37 return RankedTensorType::get({rank}, IndexType::get(ctx)); 38 } 39 40 bool shape::isExtentTensorType(Type type) { 41 auto ranked = type.dyn_cast<RankedTensorType>(); 42 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 43 } 44 45 LogicalResult shape::getShapeVec(Value input, 46 SmallVectorImpl<int64_t> &shapeValues) { 47 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 48 auto type = inputOp.getArg().getType().dyn_cast<ShapedType>(); 49 if (!type.hasRank()) 50 return failure(); 51 shapeValues = llvm::to_vector<6>(type.getShape()); 52 return success(); 53 } 54 if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 55 shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); 56 return success(); 57 } 58 if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 59 shapeValues = llvm::to_vector<6>( 60 inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>()); 61 return success(); 62 } 63 return failure(); 64 } 65 66 static bool isErrorPropagationPossible(TypeRange operandTypes) { 67 return llvm::any_of(operandTypes, [](Type ty) { 68 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 69 }); 70 } 71 72 static LogicalResult verifySizeOrIndexOp(Operation *op) { 73 assert(op != nullptr && op->getNumResults() == 1); 74 Type resultTy = op->getResultTypes().front(); 75 if (isErrorPropagationPossible(op->getOperandTypes())) { 76 if (!resultTy.isa<SizeType>()) 77 return op->emitOpError() 78 << "if at least one of the operands can hold error values then " 79 "the result must be of type `size` to propagate them"; 80 } 81 return success(); 82 } 83 84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 85 assert(op != nullptr && op->getNumResults() == 1); 86 Type resultTy = op->getResultTypes().front(); 87 if (isErrorPropagationPossible(op->getOperandTypes())) { 88 if (!resultTy.isa<ShapeType>()) 89 return op->emitOpError() 90 << "if at least one of the operands can hold error values then " 91 "the result must be of type `shape` to propagate them"; 92 } 93 return success(); 94 } 95 96 template <typename... Ty> 97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 98 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 99 } 100 101 template <typename... Ty, typename... ranges> 102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 103 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // InlinerInterface 108 //===----------------------------------------------------------------------===// 109 110 namespace { 111 /// This class defines the interface for inlining shape dialect ops. 112 struct ShapeInlinerInterface : public DialectInlinerInterface { 113 using DialectInlinerInterface::DialectInlinerInterface; 114 115 // Returns true if the given region 'src' can be inlined into the region 116 // 'dest' that is attached to an operation registered to the current dialect. 117 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 118 BlockAndValueMapping &) const final { 119 return true; 120 } 121 122 // Returns true if the given operation 'op', that is registered to this 123 // dialect, can be inlined into the region 'dest' that is attached to an 124 // operation registered to the current dialect. 125 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 126 BlockAndValueMapping &) const final { 127 return true; 128 } 129 }; 130 } // namespace 131 132 void ShapeDialect::initialize() { 133 addOperations< 134 #define GET_OP_LIST 135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 136 >(); 137 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 138 addInterfaces<ShapeInlinerInterface>(); 139 // Allow unknown operations during prototyping and testing. As the dialect is 140 // still evolving it makes it simple to start with an unregistered ops and 141 // try different variants before actually defining the op. 142 allowUnknownOperations(); 143 } 144 145 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 146 Attribute value, Type type, 147 Location loc) { 148 if (type.isa<ShapeType>() || isExtentTensorType(type)) 149 return builder.create<ConstShapeOp>(loc, type, 150 value.cast<DenseIntElementsAttr>()); 151 if (type.isa<SizeType>()) 152 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 153 if (type.isa<WitnessType>()) 154 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 155 if (arith::ConstantOp::isBuildableWith(value, type)) 156 return builder.create<arith::ConstantOp>(loc, type, value); 157 return nullptr; 158 } 159 160 /// Parse a type registered to this dialect. 161 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 162 StringRef keyword; 163 if (parser.parseKeyword(&keyword)) 164 return Type(); 165 166 if (keyword == "shape") 167 return ShapeType::get(getContext()); 168 if (keyword == "size") 169 return SizeType::get(getContext()); 170 if (keyword == "value_shape") 171 return ValueShapeType::get(getContext()); 172 if (keyword == "witness") 173 return WitnessType::get(getContext()); 174 175 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 176 return Type(); 177 } 178 179 /// Print a type registered to this dialect. 180 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 181 TypeSwitch<Type>(type) 182 .Case<ShapeType>([&](Type) { os << "shape"; }) 183 .Case<SizeType>([&](Type) { os << "size"; }) 184 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 185 .Case<WitnessType>([&](Type) { os << "witness"; }) 186 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 187 } 188 189 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 190 NamedAttribute attribute) { 191 // Verify shape.lib attribute. 192 if (attribute.getName() == "shape.lib") { 193 if (!op->hasTrait<OpTrait::SymbolTable>()) 194 return op->emitError( 195 "shape.lib attribute may only be on op implementing SymbolTable"); 196 197 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 198 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 199 if (!symbol) 200 return op->emitError("shape function library ") 201 << symbolRef << " not found"; 202 return isa<shape::FunctionLibraryOp>(symbol) 203 ? success() 204 : op->emitError() 205 << symbolRef << " required to be shape function library"; 206 } 207 208 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 209 // Verify all entries are function libraries and mappings in libraries 210 // refer to unique ops. 211 DenseSet<StringAttr> key; 212 for (auto it : arr) { 213 if (!it.isa<SymbolRefAttr>()) 214 return op->emitError( 215 "only SymbolRefAttr allowed in shape.lib attribute array"); 216 217 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 218 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 219 if (!shapeFnLib) 220 return op->emitError() 221 << it << " does not refer to FunctionLibraryOp"; 222 for (auto mapping : shapeFnLib.getMapping()) { 223 if (!key.insert(mapping.getName()).second) { 224 return op->emitError("only one op to shape mapping allowed, found " 225 "multiple for `") 226 << mapping.getName() << "`"; 227 } 228 } 229 } 230 return success(); 231 } 232 233 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 234 "allowed as shape.lib attribute"); 235 } 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // AnyOp 241 //===----------------------------------------------------------------------===// 242 243 // TODO: Canonicalization should be implemented for shapes that can be 244 // determined through mixtures of the known dimensions of the inputs. 245 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 246 // Only the last operand is checked because AnyOp is commutative. 247 if (operands.back()) 248 return operands.back(); 249 250 return nullptr; 251 } 252 253 //===----------------------------------------------------------------------===// 254 // AssumingOp 255 //===----------------------------------------------------------------------===// 256 257 static ParseResult parseAssumingOp(OpAsmParser &parser, 258 OperationState &result) { 259 result.regions.reserve(1); 260 Region *doRegion = result.addRegion(); 261 262 auto &builder = parser.getBuilder(); 263 OpAsmParser::OperandType cond; 264 if (parser.parseOperand(cond) || 265 parser.resolveOperand(cond, builder.getType<WitnessType>(), 266 result.operands)) 267 return failure(); 268 269 // Parse optional results type list. 270 if (parser.parseOptionalArrowTypeList(result.types)) 271 return failure(); 272 273 // Parse the region and add a terminator if elided. 274 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 275 return failure(); 276 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 277 278 // Parse the optional attribute list. 279 if (parser.parseOptionalAttrDict(result.attributes)) 280 return failure(); 281 return success(); 282 } 283 284 static void print(OpAsmPrinter &p, AssumingOp op) { 285 bool yieldsResults = !op.getResults().empty(); 286 287 p << " " << op.getWitness(); 288 if (yieldsResults) { 289 p << " -> (" << op.getResultTypes() << ")"; 290 } 291 p.printRegion(op.getDoRegion(), 292 /*printEntryBlockArgs=*/false, 293 /*printBlockTerminators=*/yieldsResults); 294 p.printOptionalAttrDict(op->getAttrs()); 295 } 296 297 namespace { 298 // Removes AssumingOp with a passing witness and inlines the region. 299 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 300 using OpRewritePattern<AssumingOp>::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(AssumingOp op, 303 PatternRewriter &rewriter) const override { 304 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 305 if (!witness || !witness.getPassingAttr()) 306 return failure(); 307 308 AssumingOp::inlineRegionIntoParent(op, rewriter); 309 return success(); 310 } 311 }; 312 313 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 314 using OpRewritePattern<AssumingOp>::OpRewritePattern; 315 316 LogicalResult matchAndRewrite(AssumingOp op, 317 PatternRewriter &rewriter) const override { 318 Block *body = op.getBody(); 319 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 320 321 // Find used values. 322 SmallVector<Value, 4> newYieldOperands; 323 Value opResult, yieldOperand; 324 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 325 std::tie(opResult, yieldOperand) = it; 326 if (!opResult.getUses().empty()) { 327 newYieldOperands.push_back(yieldOperand); 328 } 329 } 330 331 // Rewrite only if redundant results exist. 332 if (newYieldOperands.size() == yieldOp->getNumOperands()) 333 return failure(); 334 335 // Replace yield op in the old assuming op's body and move the entire region 336 // to the new assuming op. 337 rewriter.setInsertionPointToEnd(body); 338 auto newYieldOp = 339 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 340 rewriter.setInsertionPoint(op); 341 auto newOp = rewriter.create<AssumingOp>( 342 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 343 newOp.getDoRegion().takeBody(op.getDoRegion()); 344 345 // Use the new results to replace the previously used ones. 346 SmallVector<Value, 4> replacementValues; 347 auto src = newOp.getResults().begin(); 348 for (auto it : op.getResults()) { 349 if (it.getUses().empty()) 350 replacementValues.push_back(nullptr); 351 else 352 replacementValues.push_back(*src++); 353 } 354 rewriter.replaceOp(op, replacementValues); 355 return success(); 356 } 357 }; 358 } // namespace 359 360 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 361 MLIRContext *context) { 362 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 363 } 364 365 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 366 void AssumingOp::getSuccessorRegions( 367 Optional<unsigned> index, ArrayRef<Attribute> operands, 368 SmallVectorImpl<RegionSuccessor> ®ions) { 369 // AssumingOp has unconditional control flow into the region and back to the 370 // parent, so return the correct RegionSuccessor purely based on the index 371 // being None or 0. 372 if (index.hasValue()) { 373 regions.push_back(RegionSuccessor(getResults())); 374 return; 375 } 376 377 regions.push_back(RegionSuccessor(&getDoRegion())); 378 } 379 380 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 381 PatternRewriter &rewriter) { 382 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 383 auto *assumingBlock = op.getBody(); 384 auto initPosition = rewriter.getInsertionPoint(); 385 auto *blockAfterAssuming = 386 rewriter.splitBlock(blockBeforeAssuming, initPosition); 387 388 // Remove the AssumingOp and AssumingYieldOp. 389 auto &yieldOp = assumingBlock->back(); 390 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 391 rewriter.replaceOp(op, yieldOp.getOperands()); 392 rewriter.eraseOp(&yieldOp); 393 394 // Merge blocks together as there was no branching behavior from the 395 // AssumingOp. 396 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 397 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 398 } 399 400 void AssumingOp::build( 401 OpBuilder &builder, OperationState &result, Value witness, 402 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 403 404 result.addOperands(witness); 405 Region *bodyRegion = result.addRegion(); 406 bodyRegion->push_back(new Block); 407 Block &bodyBlock = bodyRegion->front(); 408 409 // Build body. 410 OpBuilder::InsertionGuard guard(builder); 411 builder.setInsertionPointToStart(&bodyBlock); 412 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 413 builder.create<AssumingYieldOp>(result.location, yieldValues); 414 415 SmallVector<Type, 2> assumingTypes; 416 for (Value v : yieldValues) 417 assumingTypes.push_back(v.getType()); 418 result.addTypes(assumingTypes); 419 } 420 421 //===----------------------------------------------------------------------===// 422 // AddOp 423 //===----------------------------------------------------------------------===// 424 425 LogicalResult mlir::shape::AddOp::inferReturnTypes( 426 MLIRContext *context, Optional<Location> location, ValueRange operands, 427 DictionaryAttr attributes, RegionRange regions, 428 SmallVectorImpl<Type> &inferredReturnTypes) { 429 if (operands[0].getType().isa<SizeType>() || 430 operands[1].getType().isa<SizeType>()) 431 inferredReturnTypes.assign({SizeType::get(context)}); 432 else 433 inferredReturnTypes.assign({IndexType::get(context)}); 434 return success(); 435 } 436 437 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 438 // SizeType is compatible with IndexType. 439 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 440 } 441 442 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 443 // add(x, 0) -> x 444 if (matchPattern(getRhs(), m_Zero())) 445 return getLhs(); 446 447 return constFoldBinaryOp<IntegerAttr>(operands, 448 [](APInt a, APInt b) { return a + b; }); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // AssumingAllOp 453 //===----------------------------------------------------------------------===// 454 455 namespace { 456 struct AssumingAllToCstrEqCanonicalization 457 : public OpRewritePattern<AssumingAllOp> { 458 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 459 460 LogicalResult matchAndRewrite(AssumingAllOp op, 461 PatternRewriter &rewriter) const override { 462 SmallVector<Value, 8> shapes; 463 for (Value w : op.getInputs()) { 464 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 465 if (!cstrEqOp) 466 return failure(); 467 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 468 return llvm::is_contained(shapes, s); 469 }); 470 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 471 return failure(); 472 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 473 } 474 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 475 return success(); 476 } 477 }; 478 479 template <typename OpTy> 480 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 481 using OpRewritePattern<OpTy>::OpRewritePattern; 482 483 LogicalResult matchAndRewrite(OpTy op, 484 PatternRewriter &rewriter) const override { 485 // Find unique operands. 486 SmallVector<Value, 2> unique; 487 for (Value v : op.getOperands()) { 488 if (!llvm::is_contained(unique, v)) 489 unique.push_back(v); 490 } 491 492 // Reduce op to equivalent with unique operands. 493 if (unique.size() < op.getNumOperands()) { 494 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 495 op->getAttrs()); 496 return success(); 497 } 498 499 return failure(); 500 } 501 }; 502 } // namespace 503 504 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 505 MLIRContext *context) { 506 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 507 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 508 } 509 510 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 511 // Iterate in reverse to first handle all constant operands. They are 512 // guaranteed to be the tail of the inputs because this is commutative. 513 for (int idx = operands.size() - 1; idx >= 0; idx--) { 514 Attribute a = operands[idx]; 515 // Cannot fold if any inputs are not constant; 516 if (!a) 517 return nullptr; 518 519 // We do not need to keep statically known values after handling them in 520 // this method. 521 getOperation()->eraseOperand(idx); 522 523 // Always false if any input is statically known false 524 if (!a.cast<BoolAttr>().getValue()) 525 return a; 526 } 527 // If this is reached, all inputs were statically known passing. 528 return BoolAttr::get(getContext(), true); 529 } 530 531 static LogicalResult verify(AssumingAllOp op) { 532 // Ensure that AssumingAllOp contains at least one operand 533 if (op.getNumOperands() == 0) 534 return op.emitOpError("no operands specified"); 535 536 return success(); 537 } 538 539 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 540 ValueRange inputs) { 541 build(b, state, b.getType<WitnessType>(), inputs); 542 } 543 544 //===----------------------------------------------------------------------===// 545 // BroadcastOp 546 //===----------------------------------------------------------------------===// 547 548 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 549 if (getShapes().size() == 1) { 550 // Otherwise, we need a cast which would be a canonicalization, not folding. 551 if (getShapes().front().getType() != getType()) 552 return nullptr; 553 return getShapes().front(); 554 } 555 556 // TODO: Support folding with more than 2 input shapes 557 if (getShapes().size() > 2) 558 return nullptr; 559 560 if (!operands[0] || !operands[1]) 561 return nullptr; 562 auto lhsShape = llvm::to_vector<6>( 563 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 564 auto rhsShape = llvm::to_vector<6>( 565 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 566 SmallVector<int64_t, 6> resultShape; 567 568 // If the shapes are not compatible, we can't fold it. 569 // TODO: Fold to an "error". 570 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 571 return nullptr; 572 573 Builder builder(getContext()); 574 return builder.getIndexTensorAttr(resultShape); 575 } 576 577 static LogicalResult verify(BroadcastOp op) { 578 return verifyShapeOrExtentTensorOp(op); 579 } 580 581 namespace { 582 template <typename OpTy> 583 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 584 using OpRewritePattern<OpTy>::OpRewritePattern; 585 586 LogicalResult matchAndRewrite(OpTy op, 587 PatternRewriter &rewriter) const override { 588 auto isPotentiallyNonEmptyShape = [](Value shape) { 589 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 590 if (extentTensorTy.getDimSize(0) == 0) 591 return false; 592 } 593 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 594 if (constShape.getShape().empty()) 595 return false; 596 } 597 return true; 598 }; 599 auto newOperands = llvm::to_vector<8>( 600 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 601 602 // Reduce op to equivalent without empty shape operands. 603 if (newOperands.size() < op.getNumOperands()) { 604 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 605 op->getAttrs()); 606 return success(); 607 } 608 609 return failure(); 610 } 611 }; 612 613 struct BroadcastForwardSingleOperandPattern 614 : public OpRewritePattern<BroadcastOp> { 615 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 616 617 LogicalResult matchAndRewrite(BroadcastOp op, 618 PatternRewriter &rewriter) const override { 619 if (op.getNumOperands() != 1) 620 return failure(); 621 Value replacement = op.getShapes().front(); 622 623 // Insert cast if needed. 624 if (replacement.getType() != op.getType()) { 625 auto loc = op.getLoc(); 626 if (op.getType().isa<ShapeType>()) { 627 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 628 } else { 629 assert(!op.getType().isa<ShapeType>() && 630 !replacement.getType().isa<ShapeType>() && 631 "expect extent tensor cast"); 632 replacement = 633 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 634 } 635 } 636 637 rewriter.replaceOp(op, replacement); 638 return success(); 639 } 640 }; 641 642 struct BroadcastFoldConstantOperandsPattern 643 : public OpRewritePattern<BroadcastOp> { 644 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 645 646 LogicalResult matchAndRewrite(BroadcastOp op, 647 PatternRewriter &rewriter) const override { 648 SmallVector<int64_t, 8> foldedConstantShape; 649 SmallVector<Value, 8> newShapeOperands; 650 for (Value shape : op.getShapes()) { 651 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 652 SmallVector<int64_t, 8> newFoldedConstantShape; 653 if (OpTrait::util::getBroadcastedShape( 654 foldedConstantShape, 655 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 656 newFoldedConstantShape)) { 657 foldedConstantShape = newFoldedConstantShape; 658 continue; 659 } 660 } 661 newShapeOperands.push_back(shape); 662 } 663 664 // Need at least two constant operands to fold anything. 665 if (op.getNumOperands() - newShapeOperands.size() < 2) 666 return failure(); 667 668 auto foldedConstantOperandsTy = RankedTensorType::get( 669 {static_cast<int64_t>(foldedConstantShape.size())}, 670 rewriter.getIndexType()); 671 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 672 op.getLoc(), foldedConstantOperandsTy, 673 rewriter.getIndexTensorAttr(foldedConstantShape))); 674 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 675 newShapeOperands); 676 return success(); 677 } 678 }; 679 680 template <typename OpTy> 681 struct CanonicalizeCastExtentTensorOperandsPattern 682 : public OpRewritePattern<OpTy> { 683 using OpRewritePattern<OpTy>::OpRewritePattern; 684 685 LogicalResult matchAndRewrite(OpTy op, 686 PatternRewriter &rewriter) const override { 687 // Canonicalize operands. 688 bool anyChange = false; 689 auto canonicalizeOperand = [&](Value operand) { 690 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 691 // Only eliminate the cast if it holds no shape information. 692 bool isInformationLoosingCast = 693 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 694 if (isInformationLoosingCast) { 695 anyChange = true; 696 return castOp.source(); 697 } 698 } 699 return operand; 700 }; 701 auto newOperands = llvm::to_vector<8>( 702 llvm::map_range(op.getOperands(), canonicalizeOperand)); 703 704 // Rewrite op if any change required. 705 if (!anyChange) 706 return failure(); 707 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 708 return success(); 709 } 710 }; 711 712 struct BroadcastConcretizeResultTypePattern 713 : public OpRewritePattern<BroadcastOp> { 714 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 715 716 LogicalResult matchAndRewrite(BroadcastOp op, 717 PatternRewriter &rewriter) const override { 718 // Only concretize dynamic extent tensor result types. 719 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 720 if (!resultTy || !resultTy.isDynamicDim(0)) 721 return failure(); 722 723 // Infer resulting shape rank if possible. 724 int64_t maxRank = 0; 725 for (Value shape : op.getShapes()) { 726 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 727 // Cannot infer resulting shape rank if any operand is dynamically 728 // ranked. 729 if (extentTensorTy.isDynamicDim(0)) 730 return failure(); 731 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 732 } 733 } 734 735 auto newOp = rewriter.create<BroadcastOp>( 736 op.getLoc(), getExtentTensorType(getContext(), maxRank), 737 op.getShapes()); 738 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 739 return success(); 740 } 741 }; 742 } // namespace 743 744 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 745 MLIRContext *context) { 746 patterns.add<BroadcastConcretizeResultTypePattern, 747 BroadcastFoldConstantOperandsPattern, 748 BroadcastForwardSingleOperandPattern, 749 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 750 RemoveDuplicateOperandsPattern<BroadcastOp>, 751 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 752 } 753 754 //===----------------------------------------------------------------------===// 755 // ConcatOp 756 //===----------------------------------------------------------------------===// 757 758 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 759 if (!operands[0] || !operands[1]) 760 return nullptr; 761 auto lhsShape = llvm::to_vector<6>( 762 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 763 auto rhsShape = llvm::to_vector<6>( 764 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 765 SmallVector<int64_t, 6> resultShape; 766 resultShape.append(lhsShape.begin(), lhsShape.end()); 767 resultShape.append(rhsShape.begin(), rhsShape.end()); 768 Builder builder(getContext()); 769 return builder.getIndexTensorAttr(resultShape); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // ConstShapeOp 774 //===----------------------------------------------------------------------===// 775 776 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 777 p << " "; 778 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 779 p << "["; 780 interleaveComma(op.getShape().getValues<int64_t>(), p, 781 [&](int64_t i) { p << i; }); 782 p << "] : "; 783 p.printType(op.getType()); 784 } 785 786 static ParseResult parseConstShapeOp(OpAsmParser &parser, 787 OperationState &result) { 788 if (parser.parseOptionalAttrDict(result.attributes)) 789 return failure(); 790 // We piggy-back on ArrayAttr parsing, though we don't internally store the 791 // shape as an ArrayAttr. 792 // TODO: Implement custom parser and maybe make syntax a bit more concise. 793 Attribute extentsRaw; 794 NamedAttrList dummy; 795 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 796 return failure(); 797 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 798 if (!extentsArray) 799 return failure(); 800 SmallVector<int64_t, 6> ints; 801 for (Attribute extent : extentsArray) { 802 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 803 if (!attr) 804 return failure(); 805 ints.push_back(attr.getInt()); 806 } 807 Builder &builder = parser.getBuilder(); 808 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 809 Type resultTy; 810 if (parser.parseColonType(resultTy)) 811 return failure(); 812 result.types.push_back(resultTy); 813 return success(); 814 } 815 816 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 817 818 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 819 MLIRContext *context) { 820 patterns.add<TensorCastConstShape>(context); 821 } 822 823 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 824 MLIRContext *context, Optional<Location> location, ValueRange operands, 825 DictionaryAttr attributes, RegionRange regions, 826 SmallVectorImpl<Type> &inferredReturnTypes) { 827 Builder b(context); 828 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 829 if (!shape) 830 return emitOptionalError(location, "missing shape attribute"); 831 inferredReturnTypes.assign({RankedTensorType::get( 832 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 833 return success(); 834 } 835 836 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 837 TypeRange r) { 838 if (l.size() != 1 || r.size() != 1) 839 return false; 840 841 Type lhs = l.front(); 842 Type rhs = r.front(); 843 844 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 845 // Shape type is compatible with all other valid return types. 846 return true; 847 return lhs == rhs; 848 } 849 850 //===----------------------------------------------------------------------===// 851 // CstrBroadcastableOp 852 //===----------------------------------------------------------------------===// 853 854 void CstrBroadcastableOp::getCanonicalizationPatterns( 855 RewritePatternSet &patterns, MLIRContext *context) { 856 // Canonicalization patterns have overlap with the considerations during 857 // folding in case additional shape information is inferred at some point that 858 // does not result in folding. 859 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 860 CstrBroadcastableEqOps, 861 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 862 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 863 } 864 865 // Return true if there is exactly one attribute not representing a scalar 866 // broadcast. 867 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 868 bool nonScalarSeen = false; 869 for (Attribute a : attributes) { 870 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 871 if (nonScalarSeen) 872 return false; 873 nonScalarSeen = true; 874 } 875 } 876 return true; 877 } 878 879 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 880 // No broadcasting is needed if all operands but one are scalar. 881 if (hasAtMostSingleNonScalar(operands)) 882 return BoolAttr::get(getContext(), true); 883 884 if ([&] { 885 SmallVector<SmallVector<int64_t, 6>, 6> extents; 886 for (const auto &operand : operands) { 887 if (!operand) 888 return false; 889 extents.push_back(llvm::to_vector<6>( 890 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 891 } 892 return OpTrait::util::staticallyKnownBroadcastable(extents); 893 }()) 894 return BoolAttr::get(getContext(), true); 895 896 // Lastly, see if folding can be completed based on what constraints are known 897 // on the input shapes. 898 if ([&] { 899 SmallVector<SmallVector<int64_t, 6>, 6> extents; 900 for (auto shapeValue : getShapes()) { 901 extents.emplace_back(); 902 if (failed(getShapeVec(shapeValue, extents.back()))) 903 return false; 904 } 905 return OpTrait::util::staticallyKnownBroadcastable(extents); 906 }()) 907 return BoolAttr::get(getContext(), true); 908 909 // Because a failing witness result here represents an eventual assertion 910 // failure, we do not replace it with a constant witness. 911 return nullptr; 912 } 913 914 static LogicalResult verify(CstrBroadcastableOp op) { 915 // Ensure that AssumingAllOp contains at least one operand 916 if (op.getNumOperands() < 2) 917 return op.emitOpError("required at least 2 input shapes"); 918 return success(); 919 } 920 921 //===----------------------------------------------------------------------===// 922 // CstrEqOp 923 //===----------------------------------------------------------------------===// 924 925 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 926 MLIRContext *context) { 927 // If inputs are equal, return passing witness 928 patterns.add<CstrEqEqOps>(context); 929 } 930 931 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 932 if (llvm::all_of(operands, 933 [&](Attribute a) { return a && a == operands[0]; })) 934 return BoolAttr::get(getContext(), true); 935 936 // Because a failing witness result here represents an eventual assertion 937 // failure, we do not try to replace it with a constant witness. Similarly, we 938 // cannot if there are any non-const inputs. 939 return nullptr; 940 } 941 942 //===----------------------------------------------------------------------===// 943 // ConstSizeOp 944 //===----------------------------------------------------------------------===// 945 946 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 947 int64_t value) { 948 build(builder, result, builder.getIndexAttr(value)); 949 } 950 951 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 952 953 void ConstSizeOp::getAsmResultNames( 954 llvm::function_ref<void(Value, StringRef)> setNameFn) { 955 SmallString<4> buffer; 956 llvm::raw_svector_ostream os(buffer); 957 os << "c" << getValue(); 958 setNameFn(getResult(), os.str()); 959 } 960 961 //===----------------------------------------------------------------------===// 962 // ConstWitnessOp 963 //===----------------------------------------------------------------------===// 964 965 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 966 return getPassingAttr(); 967 } 968 969 //===----------------------------------------------------------------------===// 970 // CstrRequireOp 971 //===----------------------------------------------------------------------===// 972 973 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 974 return operands[0]; 975 } 976 977 //===----------------------------------------------------------------------===// 978 // DivOp 979 //===----------------------------------------------------------------------===// 980 981 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 982 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 983 if (!lhs) 984 return nullptr; 985 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 986 if (!rhs) 987 return nullptr; 988 989 // Division in APInt does not follow floor(lhs, rhs) when the result is 990 // negative. Rather, APInt rounds toward zero. 991 APInt quotient, remainder; 992 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 993 if (quotient.isNegative() && !remainder.isNullValue()) { 994 quotient -= 1; 995 } 996 997 Type indexTy = IndexType::get(getContext()); 998 return IntegerAttr::get(indexTy, quotient); 999 } 1000 1001 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1002 MLIRContext *context, Optional<Location> location, ValueRange operands, 1003 DictionaryAttr attributes, RegionRange regions, 1004 SmallVectorImpl<Type> &inferredReturnTypes) { 1005 if (operands[0].getType().isa<SizeType>() || 1006 operands[1].getType().isa<SizeType>()) 1007 inferredReturnTypes.assign({SizeType::get(context)}); 1008 else 1009 inferredReturnTypes.assign({IndexType::get(context)}); 1010 return success(); 1011 } 1012 1013 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1014 // SizeType is compatible with IndexType. 1015 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1016 } 1017 1018 //===----------------------------------------------------------------------===// 1019 // ShapeEqOp 1020 //===----------------------------------------------------------------------===// 1021 1022 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1023 bool allSame = true; 1024 if (!operands.empty() && !operands[0]) 1025 return {}; 1026 for (Attribute operand : operands.drop_front(1)) { 1027 if (!operand) 1028 return {}; 1029 allSame = allSame && operand == operands[0]; 1030 } 1031 return BoolAttr::get(getContext(), allSame); 1032 } 1033 1034 //===----------------------------------------------------------------------===// 1035 // IndexToSizeOp 1036 //===----------------------------------------------------------------------===// 1037 1038 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1039 // Constant values of both types, `shape.size` and `index`, are represented as 1040 // `IntegerAttr`s which makes constant folding simple. 1041 if (Attribute arg = operands[0]) 1042 return arg; 1043 return {}; 1044 } 1045 1046 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1047 MLIRContext *context) { 1048 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // FromExtentsOp 1053 //===----------------------------------------------------------------------===// 1054 1055 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1056 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1057 return nullptr; 1058 SmallVector<int64_t, 6> extents; 1059 for (auto attr : operands) 1060 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1061 Builder builder(getContext()); 1062 return builder.getIndexTensorAttr(extents); 1063 } 1064 1065 //===----------------------------------------------------------------------===// 1066 // FunctionLibraryOp 1067 //===----------------------------------------------------------------------===// 1068 1069 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1070 StringRef name) { 1071 result.attributes.push_back(builder.getNamedAttr( 1072 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1073 } 1074 1075 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1076 auto attr = getMapping() 1077 .get(op->getName().getIdentifier()) 1078 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1079 if (!attr) 1080 return nullptr; 1081 return lookupSymbol<FuncOp>(attr); 1082 } 1083 1084 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1085 OperationState &result) { 1086 // Parse the op name. 1087 StringAttr nameAttr; 1088 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1089 result.attributes)) 1090 return failure(); 1091 1092 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1093 return failure(); 1094 1095 auto *bodyRegion = result.addRegion(); 1096 if (parser.parseRegion(*bodyRegion)) 1097 return failure(); 1098 1099 if (parser.parseKeyword("mapping")) 1100 return failure(); 1101 1102 DictionaryAttr mappingAttr; 1103 if (parser.parseAttribute(mappingAttr, 1104 parser.getBuilder().getType<NoneType>(), "mapping", 1105 result.attributes)) 1106 return failure(); 1107 return success(); 1108 } 1109 1110 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1111 p << ' '; 1112 p.printSymbolName(op.getName()); 1113 p.printOptionalAttrDictWithKeyword( 1114 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1115 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1116 /*printBlockTerminators=*/false); 1117 p << " mapping "; 1118 p.printAttributeWithoutType(op.getMappingAttr()); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // GetExtentOp 1123 //===----------------------------------------------------------------------===// 1124 1125 Optional<int64_t> GetExtentOp::getConstantDim() { 1126 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1127 return constSizeOp.getValue().getLimitedValue(); 1128 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1129 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1130 return llvm::None; 1131 } 1132 1133 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1134 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1135 if (!elements) 1136 return nullptr; 1137 Optional<int64_t> dim = getConstantDim(); 1138 if (!dim.hasValue()) 1139 return nullptr; 1140 if (dim.getValue() >= elements.getNumElements()) 1141 return nullptr; 1142 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1143 } 1144 1145 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1146 int64_t dim) { 1147 auto loc = result.location; 1148 auto dimAttr = builder.getIndexAttr(dim); 1149 if (shape.getType().isa<ShapeType>()) { 1150 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1151 build(builder, result, builder.getType<SizeType>(), shape, dim); 1152 } else { 1153 Value dim = 1154 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1155 build(builder, result, builder.getIndexType(), shape, dim); 1156 } 1157 } 1158 1159 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1160 MLIRContext *context, Optional<Location> location, ValueRange operands, 1161 DictionaryAttr attributes, RegionRange regions, 1162 SmallVectorImpl<Type> &inferredReturnTypes) { 1163 inferredReturnTypes.assign({IndexType::get(context)}); 1164 return success(); 1165 } 1166 1167 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1168 TypeRange r) { 1169 // SizeType is compatible with IndexType. 1170 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1171 } 1172 1173 //===----------------------------------------------------------------------===// 1174 // IsBroadcastableOp 1175 //===----------------------------------------------------------------------===// 1176 1177 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1178 MLIRContext *context) { 1179 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1180 } 1181 1182 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1183 // Can always broadcast fewer than two shapes. 1184 if (operands.size() < 2) { 1185 return BoolAttr::get(getContext(), true); 1186 } 1187 1188 return nullptr; 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // MeetOp 1193 //===----------------------------------------------------------------------===// 1194 1195 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1196 MLIRContext *context, Optional<Location> location, ValueRange operands, 1197 DictionaryAttr attributes, RegionRange regions, 1198 SmallVectorImpl<Type> &inferredReturnTypes) { 1199 inferredReturnTypes.assign({operands[0].getType()}); 1200 return success(); 1201 } 1202 1203 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1204 if (l.size() != 1 || r.size() != 1) 1205 return false; 1206 if (l == r) 1207 return true; 1208 1209 Type lhs = l.front(); 1210 Type rhs = r.front(); 1211 1212 if (lhs != rhs) 1213 return false; 1214 1215 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1216 return true; 1217 1218 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1219 return true; 1220 return false; 1221 } 1222 1223 //===----------------------------------------------------------------------===// 1224 // RankOp 1225 //===----------------------------------------------------------------------===// 1226 1227 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1228 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1229 if (!shape) 1230 return {}; 1231 int64_t rank = shape.getNumElements(); 1232 Builder builder(getContext()); 1233 return builder.getIndexAttr(rank); 1234 } 1235 1236 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1237 /// Constant folding fails in cases where only the rank is constant, not the 1238 /// shape itself. 1239 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1240 /// 1241 /// Example: 1242 /// 1243 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1244 /// %rank = shape.rank %shape 1245 /// 1246 /// becomes 1247 /// 1248 /// %rank = shape.const_size 3 1249 1250 namespace { 1251 struct RankShapeOfCanonicalizationPattern 1252 : public OpRewritePattern<shape::RankOp> { 1253 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1254 1255 LogicalResult matchAndRewrite(shape::RankOp op, 1256 PatternRewriter &rewriter) const override { 1257 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1258 if (!shapeOfOp) 1259 return failure(); 1260 auto rankedTensorType = 1261 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1262 if (!rankedTensorType) 1263 return failure(); 1264 int64_t rank = rankedTensorType.getRank(); 1265 if (op.getType().isa<IndexType>()) { 1266 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1267 rank); 1268 } else if (op.getType().isa<shape::SizeType>()) { 1269 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1270 } else { 1271 return failure(); 1272 } 1273 return success(); 1274 } 1275 }; 1276 } // namespace 1277 1278 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1279 MLIRContext *context) { 1280 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1281 } 1282 1283 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1284 MLIRContext *context, Optional<Location> location, ValueRange operands, 1285 DictionaryAttr attributes, RegionRange regions, 1286 SmallVectorImpl<Type> &inferredReturnTypes) { 1287 if (operands[0].getType().isa<ShapeType>()) 1288 inferredReturnTypes.assign({SizeType::get(context)}); 1289 else 1290 inferredReturnTypes.assign({IndexType::get(context)}); 1291 return success(); 1292 } 1293 1294 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1295 // SizeType is compatible with IndexType. 1296 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // NumElementsOp 1301 //===----------------------------------------------------------------------===// 1302 1303 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1304 1305 // Fold only when argument constant. 1306 Attribute shape = operands[0]; 1307 if (!shape) 1308 return {}; 1309 1310 APInt product(64, 1); 1311 for (auto value : shape.cast<DenseIntElementsAttr>()) 1312 product *= value; 1313 Builder builder(getContext()); 1314 return builder.getIndexAttr(product.getLimitedValue()); 1315 } 1316 1317 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1318 MLIRContext *context, Optional<Location> location, ValueRange operands, 1319 DictionaryAttr attributes, RegionRange regions, 1320 SmallVectorImpl<Type> &inferredReturnTypes) { 1321 if (operands[0].getType().isa<ShapeType>()) 1322 inferredReturnTypes.assign({SizeType::get(context)}); 1323 else 1324 inferredReturnTypes.assign({IndexType::get(context)}); 1325 return success(); 1326 } 1327 1328 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1329 TypeRange r) { 1330 // SizeType is compatible with IndexType. 1331 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1332 } 1333 1334 //===----------------------------------------------------------------------===// 1335 // MaxOp 1336 //===----------------------------------------------------------------------===// 1337 1338 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1339 // If operands are equal, just propagate one. 1340 if (getLhs() == getRhs()) 1341 return getLhs(); 1342 return nullptr; 1343 } 1344 1345 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1346 MLIRContext *context, Optional<Location> location, ValueRange operands, 1347 DictionaryAttr attributes, RegionRange regions, 1348 SmallVectorImpl<Type> &inferredReturnTypes) { 1349 if (operands[0].getType() == operands[1].getType()) 1350 inferredReturnTypes.assign({operands[0].getType()}); 1351 else 1352 inferredReturnTypes.assign({SizeType::get(context)}); 1353 return success(); 1354 } 1355 1356 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1357 if (l.size() != 1 || r.size() != 1) 1358 return false; 1359 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1360 return true; 1361 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1362 return true; 1363 return false; 1364 } 1365 1366 //===----------------------------------------------------------------------===// 1367 // MinOp 1368 //===----------------------------------------------------------------------===// 1369 1370 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1371 // If operands are equal, just propagate one. 1372 if (getLhs() == getRhs()) 1373 return getLhs(); 1374 return nullptr; 1375 } 1376 1377 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1378 MLIRContext *context, Optional<Location> location, ValueRange operands, 1379 DictionaryAttr attributes, RegionRange regions, 1380 SmallVectorImpl<Type> &inferredReturnTypes) { 1381 if (operands[0].getType() == operands[1].getType()) 1382 inferredReturnTypes.assign({operands[0].getType()}); 1383 else 1384 inferredReturnTypes.assign({SizeType::get(context)}); 1385 return success(); 1386 } 1387 1388 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1389 if (l.size() != 1 || r.size() != 1) 1390 return false; 1391 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1392 return true; 1393 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1394 return true; 1395 return false; 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // MulOp 1400 //===----------------------------------------------------------------------===// 1401 1402 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1403 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1404 if (!lhs) 1405 return nullptr; 1406 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1407 if (!rhs) 1408 return nullptr; 1409 APInt folded = lhs.getValue() * rhs.getValue(); 1410 Type indexTy = IndexType::get(getContext()); 1411 return IntegerAttr::get(indexTy, folded); 1412 } 1413 1414 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1415 MLIRContext *context, Optional<Location> location, ValueRange operands, 1416 DictionaryAttr attributes, RegionRange regions, 1417 SmallVectorImpl<Type> &inferredReturnTypes) { 1418 if (operands[0].getType().isa<SizeType>() || 1419 operands[1].getType().isa<SizeType>()) 1420 inferredReturnTypes.assign({SizeType::get(context)}); 1421 else 1422 inferredReturnTypes.assign({IndexType::get(context)}); 1423 return success(); 1424 } 1425 1426 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1427 // SizeType is compatible with IndexType. 1428 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1429 } 1430 //===----------------------------------------------------------------------===// 1431 // ShapeOfOp 1432 //===----------------------------------------------------------------------===// 1433 1434 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1435 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1436 if (!type || !type.hasStaticShape()) 1437 return nullptr; 1438 Builder builder(getContext()); 1439 return builder.getIndexTensorAttr(type.getShape()); 1440 } 1441 1442 namespace { 1443 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1444 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1445 1446 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1447 PatternRewriter &rewriter) const override { 1448 if (!op.getArg().getType().isa<ShapedType>()) 1449 return failure(); 1450 if (op.getType().isa<ShapedType>()) 1451 return failure(); 1452 1453 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1454 op.getArg()); 1455 return success(); 1456 } 1457 }; 1458 1459 // Canonicalize 1460 // ``` 1461 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1462 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1463 // ``` 1464 // to 1465 // ``` 1466 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1467 // ``` 1468 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1469 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1470 1471 LogicalResult matchAndRewrite(tensor::CastOp op, 1472 PatternRewriter &rewriter) const override { 1473 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1474 if (!ty || ty.getRank() != 1) 1475 return failure(); 1476 1477 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1478 if (!shapeOfOp) 1479 return failure(); 1480 1481 // Argument type must be ranked and must not conflict. 1482 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1483 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1484 return failure(); 1485 1486 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1487 return success(); 1488 } 1489 }; 1490 } // namespace 1491 1492 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1493 MLIRContext *context) { 1494 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1495 ExtractFromShapeOfExtentTensor>(context); 1496 } 1497 1498 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1499 MLIRContext *context, Optional<Location> location, ValueRange operands, 1500 DictionaryAttr attributes, RegionRange regions, 1501 SmallVectorImpl<Type> &inferredReturnTypes) { 1502 if (operands[0].getType().isa<ValueShapeType>()) 1503 inferredReturnTypes.assign({ShapeType::get(context)}); 1504 else { 1505 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1506 int64_t rank = 1507 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1508 Type indexTy = IndexType::get(context); 1509 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1510 inferredReturnTypes.assign({extentTensorTy}); 1511 } 1512 return success(); 1513 } 1514 1515 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1516 if (l.size() != 1 || r.size() != 1) 1517 return false; 1518 if (l == r) 1519 return true; 1520 1521 Type lhs = l.front(); 1522 Type rhs = r.front(); 1523 1524 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1525 return false; 1526 1527 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1528 // Shape type is compatible with all other valid return types. 1529 return true; 1530 1531 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1532 return true; 1533 return false; 1534 } 1535 1536 //===----------------------------------------------------------------------===// 1537 // SizeToIndexOp 1538 //===----------------------------------------------------------------------===// 1539 1540 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1541 // Constant values of both types, `shape.size` and `index`, are represented as 1542 // `IntegerAttr`s which makes constant folding simple. 1543 if (Attribute arg = operands[0]) 1544 return arg; 1545 return impl::foldCastOp(*this); 1546 } 1547 1548 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1549 MLIRContext *context) { 1550 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1551 } 1552 1553 //===----------------------------------------------------------------------===// 1554 // YieldOp 1555 //===----------------------------------------------------------------------===// 1556 1557 static LogicalResult verify(shape::YieldOp op) { 1558 auto *parentOp = op->getParentOp(); 1559 auto results = parentOp->getResults(); 1560 auto operands = op.getOperands(); 1561 1562 if (parentOp->getNumResults() != op.getNumOperands()) 1563 return op.emitOpError() << "number of operands does not match number of " 1564 "results of its parent"; 1565 for (auto e : llvm::zip(results, operands)) 1566 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1567 return op.emitOpError() 1568 << "types mismatch between yield op and its parent"; 1569 1570 return success(); 1571 } 1572 1573 //===----------------------------------------------------------------------===// 1574 // SplitAtOp 1575 //===----------------------------------------------------------------------===// 1576 1577 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1578 SmallVectorImpl<OpFoldResult> &results) { 1579 if (!operands[0] || !operands[1]) 1580 return failure(); 1581 auto shapeVec = llvm::to_vector<6>( 1582 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1583 auto shape = llvm::makeArrayRef(shapeVec); 1584 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1585 // Verify that the split point is in the correct range. 1586 // TODO: Constant fold to an "error". 1587 int64_t rank = shape.size(); 1588 if (!(-rank <= splitPoint && splitPoint <= rank)) 1589 return failure(); 1590 if (splitPoint < 0) 1591 splitPoint += shape.size(); 1592 Builder builder(operands[0].getContext()); 1593 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1594 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1595 return success(); 1596 } 1597 1598 //===----------------------------------------------------------------------===// 1599 // ToExtentTensorOp 1600 //===----------------------------------------------------------------------===// 1601 1602 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1603 if (!operands[0]) 1604 return impl::foldCastOp(*this); 1605 Builder builder(getContext()); 1606 auto shape = llvm::to_vector<6>( 1607 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1608 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1609 builder.getIndexType()); 1610 return DenseIntElementsAttr::get(type, shape); 1611 } 1612 1613 //===----------------------------------------------------------------------===// 1614 // ReduceOp 1615 //===----------------------------------------------------------------------===// 1616 1617 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1618 ValueRange initVals) { 1619 result.addOperands(shape); 1620 result.addOperands(initVals); 1621 1622 Region *bodyRegion = result.addRegion(); 1623 bodyRegion->push_back(new Block); 1624 Block &bodyBlock = bodyRegion->front(); 1625 bodyBlock.addArgument(builder.getIndexType()); 1626 1627 Type elementType; 1628 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1629 elementType = tensorType.getElementType(); 1630 else 1631 elementType = SizeType::get(builder.getContext()); 1632 bodyBlock.addArgument(elementType); 1633 1634 for (Type initValType : initVals.getTypes()) { 1635 bodyBlock.addArgument(initValType); 1636 result.addTypes(initValType); 1637 } 1638 } 1639 1640 static LogicalResult verify(ReduceOp op) { 1641 // Verify block arg types. 1642 Block &block = op.getRegion().front(); 1643 1644 // The block takes index, extent, and aggregated values as arguments. 1645 auto blockArgsCount = op.getInitVals().size() + 2; 1646 if (block.getNumArguments() != blockArgsCount) 1647 return op.emitOpError() << "ReduceOp body is expected to have " 1648 << blockArgsCount << " arguments"; 1649 1650 // The first block argument is the index and must always be of type `index`. 1651 if (!block.getArgument(0).getType().isa<IndexType>()) 1652 return op.emitOpError( 1653 "argument 0 of ReduceOp body is expected to be of IndexType"); 1654 1655 // The second block argument is the extent and must be of type `size` or 1656 // `index`, depending on whether the reduce operation is applied to a shape or 1657 // to an extent tensor. 1658 Type extentTy = block.getArgument(1).getType(); 1659 if (op.getShape().getType().isa<ShapeType>()) { 1660 if (!extentTy.isa<SizeType>()) 1661 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1662 "SizeType if the ReduceOp operates on a ShapeType"); 1663 } else { 1664 if (!extentTy.isa<IndexType>()) 1665 return op.emitOpError( 1666 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1667 "ReduceOp operates on an extent tensor"); 1668 } 1669 1670 for (auto type : llvm::enumerate(op.getInitVals())) 1671 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1672 return op.emitOpError() 1673 << "type mismatch between argument " << type.index() + 2 1674 << " of ReduceOp body and initial value " << type.index(); 1675 return success(); 1676 } 1677 1678 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1679 // Parse operands. 1680 SmallVector<OpAsmParser::OperandType, 3> operands; 1681 Type shapeOrExtentTensorType; 1682 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1683 OpAsmParser::Delimiter::Paren) || 1684 parser.parseColonType(shapeOrExtentTensorType) || 1685 parser.parseOptionalArrowTypeList(result.types)) 1686 return failure(); 1687 1688 // Resolve operands. 1689 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1690 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1691 result.operands) || 1692 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1693 result.operands)) 1694 return failure(); 1695 1696 // Parse the body. 1697 Region *body = result.addRegion(); 1698 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1699 return failure(); 1700 1701 // Parse attributes. 1702 if (parser.parseOptionalAttrDict(result.attributes)) 1703 return failure(); 1704 1705 return success(); 1706 } 1707 1708 static void print(OpAsmPrinter &p, ReduceOp op) { 1709 p << '(' << op.getShape() << ", " << op.getInitVals() 1710 << ") : " << op.getShape().getType(); 1711 p.printOptionalArrowTypeList(op.getResultTypes()); 1712 p.printRegion(op.getRegion()); 1713 p.printOptionalAttrDict(op->getAttrs()); 1714 } 1715 1716 #define GET_OP_CLASSES 1717 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1718