1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/IR/Shape.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/CommonFolders.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Traits.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/SmallString.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/raw_ostream.h" 26 27 using namespace mlir; 28 using namespace mlir::shape; 29 30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 31 32 namespace { 33 #include "ShapeCanonicalization.inc" 34 } 35 36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 37 return RankedTensorType::get({rank}, IndexType::get(ctx)); 38 } 39 40 bool shape::isExtentTensorType(Type type) { 41 auto ranked = type.dyn_cast<RankedTensorType>(); 42 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 43 } 44 45 LogicalResult shape::getShapeVec(Value input, 46 SmallVectorImpl<int64_t> &shapeValues) { 47 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 48 auto type = inputOp.getArg().getType().dyn_cast<ShapedType>(); 49 if (!type.hasRank()) 50 return failure(); 51 shapeValues = llvm::to_vector<6>(type.getShape()); 52 return success(); 53 } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 54 shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); 55 return success(); 56 } else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 57 shapeValues = llvm::to_vector<6>( 58 inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>()); 59 return success(); 60 } else { 61 return failure(); 62 } 63 } 64 65 static bool isErrorPropagationPossible(TypeRange operandTypes) { 66 return llvm::any_of(operandTypes, [](Type ty) { 67 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 68 }); 69 } 70 71 static LogicalResult verifySizeOrIndexOp(Operation *op) { 72 assert(op != nullptr && op->getNumResults() == 1); 73 Type resultTy = op->getResultTypes().front(); 74 if (isErrorPropagationPossible(op->getOperandTypes())) { 75 if (!resultTy.isa<SizeType>()) 76 return op->emitOpError() 77 << "if at least one of the operands can hold error values then " 78 "the result must be of type `size` to propagate them"; 79 } 80 return success(); 81 } 82 83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 84 assert(op != nullptr && op->getNumResults() == 1); 85 Type resultTy = op->getResultTypes().front(); 86 if (isErrorPropagationPossible(op->getOperandTypes())) { 87 if (!resultTy.isa<ShapeType>()) 88 return op->emitOpError() 89 << "if at least one of the operands can hold error values then " 90 "the result must be of type `shape` to propagate them"; 91 } 92 return success(); 93 } 94 95 template <typename... Ty> 96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 98 } 99 100 template <typename... Ty, typename... ranges> 101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // InlinerInterface 107 //===----------------------------------------------------------------------===// 108 109 namespace { 110 /// This class defines the interface for inlining shape dialect ops. 111 struct ShapeInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 // Returns true if the given region 'src' can be inlined into the region 115 // 'dest' that is attached to an operation registered to the current dialect. 116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 // Returns true if the given operation 'op', that is registered to this 122 // dialect, can be inlined into the region 'dest' that is attached to an 123 // operation registered to the current dialect. 124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 125 BlockAndValueMapping &) const final { 126 return true; 127 } 128 }; 129 } // namespace 130 131 void ShapeDialect::initialize() { 132 addOperations< 133 #define GET_OP_LIST 134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 135 >(); 136 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 137 addInterfaces<ShapeInlinerInterface>(); 138 // Allow unknown operations during prototyping and testing. As the dialect is 139 // still evolving it makes it simple to start with an unregistered ops and 140 // try different variants before actually defining the op. 141 allowUnknownOperations(); 142 } 143 144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 145 Attribute value, Type type, 146 Location loc) { 147 if (type.isa<ShapeType>() || isExtentTensorType(type)) 148 return builder.create<ConstShapeOp>(loc, type, 149 value.cast<DenseIntElementsAttr>()); 150 if (type.isa<SizeType>()) 151 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 152 if (type.isa<WitnessType>()) 153 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 154 if (arith::ConstantOp::isBuildableWith(value, type)) 155 return builder.create<arith::ConstantOp>(loc, type, value); 156 return nullptr; 157 } 158 159 /// Parse a type registered to this dialect. 160 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 161 StringRef keyword; 162 if (parser.parseKeyword(&keyword)) 163 return Type(); 164 165 if (keyword == "shape") 166 return ShapeType::get(getContext()); 167 if (keyword == "size") 168 return SizeType::get(getContext()); 169 if (keyword == "value_shape") 170 return ValueShapeType::get(getContext()); 171 if (keyword == "witness") 172 return WitnessType::get(getContext()); 173 174 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 175 return Type(); 176 } 177 178 /// Print a type registered to this dialect. 179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 180 TypeSwitch<Type>(type) 181 .Case<ShapeType>([&](Type) { os << "shape"; }) 182 .Case<SizeType>([&](Type) { os << "size"; }) 183 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 184 .Case<WitnessType>([&](Type) { os << "witness"; }) 185 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 186 } 187 188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 189 NamedAttribute attribute) { 190 // Verify shape.lib attribute. 191 if (attribute.getName() == "shape.lib") { 192 if (!op->hasTrait<OpTrait::SymbolTable>()) 193 return op->emitError( 194 "shape.lib attribute may only be on op implementing SymbolTable"); 195 196 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 197 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 198 if (!symbol) 199 return op->emitError("shape function library ") 200 << symbolRef << " not found"; 201 return isa<shape::FunctionLibraryOp>(symbol) 202 ? success() 203 : op->emitError() 204 << symbolRef << " required to be shape function library"; 205 } 206 207 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 208 // Verify all entries are function libraries and mappings in libraries 209 // refer to unique ops. 210 DenseSet<StringAttr> key; 211 for (auto it : arr) { 212 if (!it.isa<SymbolRefAttr>()) 213 return op->emitError( 214 "only SymbolRefAttr allowed in shape.lib attribute array"); 215 216 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 217 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 218 if (!shapeFnLib) 219 return op->emitError() 220 << it << " does not refer to FunctionLibraryOp"; 221 for (auto mapping : shapeFnLib.getMapping()) { 222 if (!key.insert(mapping.getName()).second) { 223 return op->emitError("only one op to shape mapping allowed, found " 224 "multiple for `") 225 << mapping.getName() << "`"; 226 } 227 } 228 } 229 return success(); 230 } 231 232 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 233 "allowed as shape.lib attribute"); 234 } 235 return success(); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // AnyOp 240 //===----------------------------------------------------------------------===// 241 242 // TODO: Canonicalization should be implemented for shapes that can be 243 // determined through mixtures of the known dimensions of the inputs. 244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 245 // Only the last operand is checked because AnyOp is commutative. 246 if (operands.back()) 247 return operands.back(); 248 249 return nullptr; 250 } 251 252 //===----------------------------------------------------------------------===// 253 // AssumingOp 254 //===----------------------------------------------------------------------===// 255 256 static ParseResult parseAssumingOp(OpAsmParser &parser, 257 OperationState &result) { 258 result.regions.reserve(1); 259 Region *doRegion = result.addRegion(); 260 261 auto &builder = parser.getBuilder(); 262 OpAsmParser::OperandType cond; 263 if (parser.parseOperand(cond) || 264 parser.resolveOperand(cond, builder.getType<WitnessType>(), 265 result.operands)) 266 return failure(); 267 268 // Parse optional results type list. 269 if (parser.parseOptionalArrowTypeList(result.types)) 270 return failure(); 271 272 // Parse the region and add a terminator if elided. 273 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 274 return failure(); 275 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 276 277 // Parse the optional attribute list. 278 if (parser.parseOptionalAttrDict(result.attributes)) 279 return failure(); 280 return success(); 281 } 282 283 static void print(OpAsmPrinter &p, AssumingOp op) { 284 bool yieldsResults = !op.getResults().empty(); 285 286 p << " " << op.getWitness(); 287 if (yieldsResults) { 288 p << " -> (" << op.getResultTypes() << ")"; 289 } 290 p.printRegion(op.getDoRegion(), 291 /*printEntryBlockArgs=*/false, 292 /*printBlockTerminators=*/yieldsResults); 293 p.printOptionalAttrDict(op->getAttrs()); 294 } 295 296 namespace { 297 // Removes AssumingOp with a passing witness and inlines the region. 298 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 299 using OpRewritePattern<AssumingOp>::OpRewritePattern; 300 301 LogicalResult matchAndRewrite(AssumingOp op, 302 PatternRewriter &rewriter) const override { 303 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 304 if (!witness || !witness.getPassingAttr()) 305 return failure(); 306 307 AssumingOp::inlineRegionIntoParent(op, rewriter); 308 return success(); 309 } 310 }; 311 312 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 313 using OpRewritePattern<AssumingOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(AssumingOp op, 316 PatternRewriter &rewriter) const override { 317 Block *body = op.getBody(); 318 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 319 320 // Find used values. 321 SmallVector<Value, 4> newYieldOperands; 322 Value opResult, yieldOperand; 323 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 324 std::tie(opResult, yieldOperand) = it; 325 if (!opResult.getUses().empty()) { 326 newYieldOperands.push_back(yieldOperand); 327 } 328 } 329 330 // Rewrite only if redundant results exist. 331 if (newYieldOperands.size() == yieldOp->getNumOperands()) 332 return failure(); 333 334 // Replace yield op in the old assuming op's body and move the entire region 335 // to the new assuming op. 336 rewriter.setInsertionPointToEnd(body); 337 auto newYieldOp = 338 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 339 rewriter.setInsertionPoint(op); 340 auto newOp = rewriter.create<AssumingOp>( 341 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 342 newOp.getDoRegion().takeBody(op.getDoRegion()); 343 344 // Use the new results to replace the previously used ones. 345 SmallVector<Value, 4> replacementValues; 346 auto src = newOp.getResults().begin(); 347 for (auto it : op.getResults()) { 348 if (it.getUses().empty()) 349 replacementValues.push_back(nullptr); 350 else 351 replacementValues.push_back(*src++); 352 } 353 rewriter.replaceOp(op, replacementValues); 354 return success(); 355 } 356 }; 357 } // namespace 358 359 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 360 MLIRContext *context) { 361 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 362 } 363 364 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 365 void AssumingOp::getSuccessorRegions( 366 Optional<unsigned> index, ArrayRef<Attribute> operands, 367 SmallVectorImpl<RegionSuccessor> ®ions) { 368 // AssumingOp has unconditional control flow into the region and back to the 369 // parent, so return the correct RegionSuccessor purely based on the index 370 // being None or 0. 371 if (index.hasValue()) { 372 regions.push_back(RegionSuccessor(getResults())); 373 return; 374 } 375 376 regions.push_back(RegionSuccessor(&getDoRegion())); 377 } 378 379 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 380 PatternRewriter &rewriter) { 381 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 382 auto *assumingBlock = op.getBody(); 383 auto initPosition = rewriter.getInsertionPoint(); 384 auto *blockAfterAssuming = 385 rewriter.splitBlock(blockBeforeAssuming, initPosition); 386 387 // Remove the AssumingOp and AssumingYieldOp. 388 auto &yieldOp = assumingBlock->back(); 389 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 390 rewriter.replaceOp(op, yieldOp.getOperands()); 391 rewriter.eraseOp(&yieldOp); 392 393 // Merge blocks together as there was no branching behavior from the 394 // AssumingOp. 395 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 396 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 397 } 398 399 void AssumingOp::build( 400 OpBuilder &builder, OperationState &result, Value witness, 401 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 402 403 result.addOperands(witness); 404 Region *bodyRegion = result.addRegion(); 405 bodyRegion->push_back(new Block); 406 Block &bodyBlock = bodyRegion->front(); 407 408 // Build body. 409 OpBuilder::InsertionGuard guard(builder); 410 builder.setInsertionPointToStart(&bodyBlock); 411 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 412 builder.create<AssumingYieldOp>(result.location, yieldValues); 413 414 SmallVector<Type, 2> assumingTypes; 415 for (Value v : yieldValues) 416 assumingTypes.push_back(v.getType()); 417 result.addTypes(assumingTypes); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // AddOp 422 //===----------------------------------------------------------------------===// 423 424 LogicalResult mlir::shape::AddOp::inferReturnTypes( 425 MLIRContext *context, Optional<Location> location, ValueRange operands, 426 DictionaryAttr attributes, RegionRange regions, 427 SmallVectorImpl<Type> &inferredReturnTypes) { 428 if (operands[0].getType().isa<SizeType>() || 429 operands[1].getType().isa<SizeType>()) 430 inferredReturnTypes.assign({SizeType::get(context)}); 431 else 432 inferredReturnTypes.assign({IndexType::get(context)}); 433 return success(); 434 } 435 436 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 437 // SizeType is compatible with IndexType. 438 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 439 } 440 441 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 442 // add(x, 0) -> x 443 if (matchPattern(getRhs(), m_Zero())) 444 return getLhs(); 445 446 return constFoldBinaryOp<IntegerAttr>(operands, 447 [](APInt a, APInt b) { return a + b; }); 448 } 449 450 //===----------------------------------------------------------------------===// 451 // AssumingAllOp 452 //===----------------------------------------------------------------------===// 453 454 namespace { 455 struct AssumingAllToCstrEqCanonicalization 456 : public OpRewritePattern<AssumingAllOp> { 457 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 458 459 LogicalResult matchAndRewrite(AssumingAllOp op, 460 PatternRewriter &rewriter) const override { 461 SmallVector<Value, 8> shapes; 462 for (Value w : op.getInputs()) { 463 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 464 if (!cstrEqOp) 465 return failure(); 466 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 467 return llvm::is_contained(shapes, s); 468 }); 469 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 470 return failure(); 471 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 472 } 473 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 474 return success(); 475 } 476 }; 477 478 template <typename OpTy> 479 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 480 using OpRewritePattern<OpTy>::OpRewritePattern; 481 482 LogicalResult matchAndRewrite(OpTy op, 483 PatternRewriter &rewriter) const override { 484 // Find unique operands. 485 SmallVector<Value, 2> unique; 486 for (Value v : op.getOperands()) { 487 if (!llvm::is_contained(unique, v)) 488 unique.push_back(v); 489 } 490 491 // Reduce op to equivalent with unique operands. 492 if (unique.size() < op.getNumOperands()) { 493 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique, 494 op->getAttrs()); 495 return success(); 496 } 497 498 return failure(); 499 } 500 }; 501 } // namespace 502 503 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 504 MLIRContext *context) { 505 patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization, 506 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 507 } 508 509 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 510 // Iterate in reverse to first handle all constant operands. They are 511 // guaranteed to be the tail of the inputs because this is commutative. 512 for (int idx = operands.size() - 1; idx >= 0; idx--) { 513 Attribute a = operands[idx]; 514 // Cannot fold if any inputs are not constant; 515 if (!a) 516 return nullptr; 517 518 // We do not need to keep statically known values after handling them in 519 // this method. 520 getOperation()->eraseOperand(idx); 521 522 // Always false if any input is statically known false 523 if (!a.cast<BoolAttr>().getValue()) 524 return a; 525 } 526 // If this is reached, all inputs were statically known passing. 527 return BoolAttr::get(getContext(), true); 528 } 529 530 static LogicalResult verify(AssumingAllOp op) { 531 // Ensure that AssumingAllOp contains at least one operand 532 if (op.getNumOperands() == 0) 533 return op.emitOpError("no operands specified"); 534 535 return success(); 536 } 537 538 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 539 ValueRange inputs) { 540 build(b, state, b.getType<WitnessType>(), inputs); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // BroadcastOp 545 //===----------------------------------------------------------------------===// 546 547 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 548 if (getShapes().size() == 1) { 549 // Otherwise, we need a cast which would be a canonicalization, not folding. 550 if (getShapes().front().getType() != getType()) 551 return nullptr; 552 return getShapes().front(); 553 } 554 555 // TODO: Support folding with more than 2 input shapes 556 if (getShapes().size() > 2) 557 return nullptr; 558 559 if (!operands[0] || !operands[1]) 560 return nullptr; 561 auto lhsShape = llvm::to_vector<6>( 562 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 563 auto rhsShape = llvm::to_vector<6>( 564 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 565 SmallVector<int64_t, 6> resultShape; 566 567 // If the shapes are not compatible, we can't fold it. 568 // TODO: Fold to an "error". 569 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 570 return nullptr; 571 572 Builder builder(getContext()); 573 return builder.getIndexTensorAttr(resultShape); 574 } 575 576 static LogicalResult verify(BroadcastOp op) { 577 return verifyShapeOrExtentTensorOp(op); 578 } 579 580 namespace { 581 template <typename OpTy> 582 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 583 using OpRewritePattern<OpTy>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(OpTy op, 586 PatternRewriter &rewriter) const override { 587 auto isPotentiallyNonEmptyShape = [](Value shape) { 588 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 589 if (extentTensorTy.getDimSize(0) == 0) 590 return false; 591 } 592 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 593 if (constShape.getShape().empty()) 594 return false; 595 } 596 return true; 597 }; 598 auto newOperands = llvm::to_vector<8>( 599 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 600 601 // Reduce op to equivalent without empty shape operands. 602 if (newOperands.size() < op.getNumOperands()) { 603 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 604 op->getAttrs()); 605 return success(); 606 } 607 608 return failure(); 609 } 610 }; 611 612 struct BroadcastForwardSingleOperandPattern 613 : public OpRewritePattern<BroadcastOp> { 614 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 615 616 LogicalResult matchAndRewrite(BroadcastOp op, 617 PatternRewriter &rewriter) const override { 618 if (op.getNumOperands() != 1) 619 return failure(); 620 Value replacement = op.getShapes().front(); 621 622 // Insert cast if needed. 623 if (replacement.getType() != op.getType()) { 624 auto loc = op.getLoc(); 625 if (op.getType().isa<ShapeType>()) { 626 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 627 } else { 628 assert(!op.getType().isa<ShapeType>() && 629 !replacement.getType().isa<ShapeType>() && 630 "expect extent tensor cast"); 631 replacement = 632 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 633 } 634 } 635 636 rewriter.replaceOp(op, replacement); 637 return success(); 638 } 639 }; 640 641 struct BroadcastFoldConstantOperandsPattern 642 : public OpRewritePattern<BroadcastOp> { 643 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 644 645 LogicalResult matchAndRewrite(BroadcastOp op, 646 PatternRewriter &rewriter) const override { 647 SmallVector<int64_t, 8> foldedConstantShape; 648 SmallVector<Value, 8> newShapeOperands; 649 for (Value shape : op.getShapes()) { 650 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 651 SmallVector<int64_t, 8> newFoldedConstantShape; 652 if (OpTrait::util::getBroadcastedShape( 653 foldedConstantShape, 654 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 655 newFoldedConstantShape)) { 656 foldedConstantShape = newFoldedConstantShape; 657 continue; 658 } 659 } 660 newShapeOperands.push_back(shape); 661 } 662 663 // Need at least two constant operands to fold anything. 664 if (op.getNumOperands() - newShapeOperands.size() < 2) 665 return failure(); 666 667 auto foldedConstantOperandsTy = RankedTensorType::get( 668 {static_cast<int64_t>(foldedConstantShape.size())}, 669 rewriter.getIndexType()); 670 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 671 op.getLoc(), foldedConstantOperandsTy, 672 rewriter.getIndexTensorAttr(foldedConstantShape))); 673 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 674 newShapeOperands); 675 return success(); 676 } 677 }; 678 679 template <typename OpTy> 680 struct CanonicalizeCastExtentTensorOperandsPattern 681 : public OpRewritePattern<OpTy> { 682 using OpRewritePattern<OpTy>::OpRewritePattern; 683 684 LogicalResult matchAndRewrite(OpTy op, 685 PatternRewriter &rewriter) const override { 686 // Canonicalize operands. 687 bool anyChange = false; 688 auto canonicalizeOperand = [&](Value operand) { 689 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 690 // Only eliminate the cast if it holds no shape information. 691 bool isInformationLoosingCast = 692 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 693 if (isInformationLoosingCast) { 694 anyChange = true; 695 return castOp.source(); 696 } 697 } 698 return operand; 699 }; 700 auto newOperands = llvm::to_vector<8>( 701 llvm::map_range(op.getOperands(), canonicalizeOperand)); 702 703 // Rewrite op if any change required. 704 if (!anyChange) 705 return failure(); 706 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 707 return success(); 708 } 709 }; 710 711 struct BroadcastConcretizeResultTypePattern 712 : public OpRewritePattern<BroadcastOp> { 713 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 714 715 LogicalResult matchAndRewrite(BroadcastOp op, 716 PatternRewriter &rewriter) const override { 717 // Only concretize dynamic extent tensor result types. 718 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 719 if (!resultTy || !resultTy.isDynamicDim(0)) 720 return failure(); 721 722 // Infer resulting shape rank if possible. 723 int64_t maxRank = 0; 724 for (Value shape : op.getShapes()) { 725 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 726 // Cannot infer resulting shape rank if any operand is dynamically 727 // ranked. 728 if (extentTensorTy.isDynamicDim(0)) 729 return failure(); 730 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 731 } 732 } 733 734 auto newOp = rewriter.create<BroadcastOp>( 735 op.getLoc(), getExtentTensorType(getContext(), maxRank), 736 op.getShapes()); 737 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 738 return success(); 739 } 740 }; 741 } // namespace 742 743 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 744 MLIRContext *context) { 745 patterns.add<BroadcastConcretizeResultTypePattern, 746 BroadcastFoldConstantOperandsPattern, 747 BroadcastForwardSingleOperandPattern, 748 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 749 RemoveDuplicateOperandsPattern<BroadcastOp>, 750 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 751 } 752 753 //===----------------------------------------------------------------------===// 754 // ConcatOp 755 //===----------------------------------------------------------------------===// 756 757 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 758 if (!operands[0] || !operands[1]) 759 return nullptr; 760 auto lhsShape = llvm::to_vector<6>( 761 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 762 auto rhsShape = llvm::to_vector<6>( 763 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 764 SmallVector<int64_t, 6> resultShape; 765 resultShape.append(lhsShape.begin(), lhsShape.end()); 766 resultShape.append(rhsShape.begin(), rhsShape.end()); 767 Builder builder(getContext()); 768 return builder.getIndexTensorAttr(resultShape); 769 } 770 771 //===----------------------------------------------------------------------===// 772 // ConstShapeOp 773 //===----------------------------------------------------------------------===// 774 775 static void print(OpAsmPrinter &p, ConstShapeOp &op) { 776 p << " "; 777 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); 778 p << "["; 779 interleaveComma(op.getShape().getValues<int64_t>(), p, 780 [&](int64_t i) { p << i; }); 781 p << "] : "; 782 p.printType(op.getType()); 783 } 784 785 static ParseResult parseConstShapeOp(OpAsmParser &parser, 786 OperationState &result) { 787 if (parser.parseOptionalAttrDict(result.attributes)) 788 return failure(); 789 // We piggy-back on ArrayAttr parsing, though we don't internally store the 790 // shape as an ArrayAttr. 791 // TODO: Implement custom parser and maybe make syntax a bit more concise. 792 Attribute extentsRaw; 793 NamedAttrList dummy; 794 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 795 return failure(); 796 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 797 if (!extentsArray) 798 return failure(); 799 SmallVector<int64_t, 6> ints; 800 for (Attribute extent : extentsArray) { 801 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 802 if (!attr) 803 return failure(); 804 ints.push_back(attr.getInt()); 805 } 806 Builder &builder = parser.getBuilder(); 807 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 808 Type resultTy; 809 if (parser.parseColonType(resultTy)) 810 return failure(); 811 result.types.push_back(resultTy); 812 return success(); 813 } 814 815 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 816 817 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 818 MLIRContext *context) { 819 patterns.add<TensorCastConstShape>(context); 820 } 821 822 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 823 MLIRContext *context, Optional<Location> location, ValueRange operands, 824 DictionaryAttr attributes, RegionRange regions, 825 SmallVectorImpl<Type> &inferredReturnTypes) { 826 Builder b(context); 827 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 828 if (!shape) 829 return emitOptionalError(location, "missing shape attribute"); 830 inferredReturnTypes.assign({RankedTensorType::get( 831 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 832 return success(); 833 } 834 835 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 836 TypeRange r) { 837 if (l.size() != 1 || r.size() != 1) 838 return false; 839 840 Type lhs = l.front(); 841 Type rhs = r.front(); 842 843 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 844 // Shape type is compatible with all other valid return types. 845 return true; 846 return lhs == rhs; 847 } 848 849 //===----------------------------------------------------------------------===// 850 // CstrBroadcastableOp 851 //===----------------------------------------------------------------------===// 852 853 void CstrBroadcastableOp::getCanonicalizationPatterns( 854 RewritePatternSet &patterns, MLIRContext *context) { 855 // Canonicalization patterns have overlap with the considerations during 856 // folding in case additional shape information is inferred at some point that 857 // does not result in folding. 858 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 859 CstrBroadcastableEqOps, 860 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 861 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 862 } 863 864 // Return true if there is exactly one attribute not representing a scalar 865 // broadcast. 866 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 867 bool nonScalarSeen = false; 868 for (Attribute a : attributes) { 869 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 870 if (nonScalarSeen) 871 return false; 872 nonScalarSeen = true; 873 } 874 } 875 return true; 876 } 877 878 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 879 // No broadcasting is needed if all operands but one are scalar. 880 if (hasAtMostSingleNonScalar(operands)) 881 return BoolAttr::get(getContext(), true); 882 883 if ([&] { 884 SmallVector<SmallVector<int64_t, 6>, 6> extents; 885 for (const auto &operand : operands) { 886 if (!operand) 887 return false; 888 extents.push_back(llvm::to_vector<6>( 889 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 890 } 891 return OpTrait::util::staticallyKnownBroadcastable(extents); 892 }()) 893 return BoolAttr::get(getContext(), true); 894 895 // Lastly, see if folding can be completed based on what constraints are known 896 // on the input shapes. 897 if ([&] { 898 SmallVector<SmallVector<int64_t, 6>, 6> extents; 899 for (auto shapeValue : getShapes()) { 900 extents.emplace_back(); 901 if (failed(getShapeVec(shapeValue, extents.back()))) 902 return false; 903 } 904 return OpTrait::util::staticallyKnownBroadcastable(extents); 905 }()) 906 return BoolAttr::get(getContext(), true); 907 908 // Because a failing witness result here represents an eventual assertion 909 // failure, we do not replace it with a constant witness. 910 return nullptr; 911 } 912 913 static LogicalResult verify(CstrBroadcastableOp op) { 914 // Ensure that AssumingAllOp contains at least one operand 915 if (op.getNumOperands() < 2) 916 return op.emitOpError("required at least 2 input shapes"); 917 return success(); 918 } 919 920 //===----------------------------------------------------------------------===// 921 // CstrEqOp 922 //===----------------------------------------------------------------------===// 923 924 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 925 MLIRContext *context) { 926 // If inputs are equal, return passing witness 927 patterns.add<CstrEqEqOps>(context); 928 } 929 930 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 931 if (llvm::all_of(operands, 932 [&](Attribute a) { return a && a == operands[0]; })) 933 return BoolAttr::get(getContext(), true); 934 935 // Because a failing witness result here represents an eventual assertion 936 // failure, we do not try to replace it with a constant witness. Similarly, we 937 // cannot if there are any non-const inputs. 938 return nullptr; 939 } 940 941 //===----------------------------------------------------------------------===// 942 // ConstSizeOp 943 //===----------------------------------------------------------------------===// 944 945 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 946 int64_t value) { 947 build(builder, result, builder.getIndexAttr(value)); 948 } 949 950 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 951 952 void ConstSizeOp::getAsmResultNames( 953 llvm::function_ref<void(Value, StringRef)> setNameFn) { 954 SmallString<4> buffer; 955 llvm::raw_svector_ostream os(buffer); 956 os << "c" << getValue(); 957 setNameFn(getResult(), os.str()); 958 } 959 960 //===----------------------------------------------------------------------===// 961 // ConstWitnessOp 962 //===----------------------------------------------------------------------===// 963 964 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 965 return getPassingAttr(); 966 } 967 968 //===----------------------------------------------------------------------===// 969 // CstrRequireOp 970 //===----------------------------------------------------------------------===// 971 972 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 973 return operands[0]; 974 } 975 976 //===----------------------------------------------------------------------===// 977 // DivOp 978 //===----------------------------------------------------------------------===// 979 980 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 981 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 982 if (!lhs) 983 return nullptr; 984 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 985 if (!rhs) 986 return nullptr; 987 988 // Division in APInt does not follow floor(lhs, rhs) when the result is 989 // negative. Rather, APInt rounds toward zero. 990 APInt quotient, remainder; 991 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 992 if (quotient.isNegative() && !remainder.isNullValue()) { 993 quotient -= 1; 994 } 995 996 Type indexTy = IndexType::get(getContext()); 997 return IntegerAttr::get(indexTy, quotient); 998 } 999 1000 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1001 MLIRContext *context, Optional<Location> location, ValueRange operands, 1002 DictionaryAttr attributes, RegionRange regions, 1003 SmallVectorImpl<Type> &inferredReturnTypes) { 1004 if (operands[0].getType().isa<SizeType>() || 1005 operands[1].getType().isa<SizeType>()) 1006 inferredReturnTypes.assign({SizeType::get(context)}); 1007 else 1008 inferredReturnTypes.assign({IndexType::get(context)}); 1009 return success(); 1010 } 1011 1012 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1013 // SizeType is compatible with IndexType. 1014 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // ShapeEqOp 1019 //===----------------------------------------------------------------------===// 1020 1021 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1022 bool allSame = true; 1023 if (!operands.empty() && !operands[0]) 1024 return {}; 1025 for (Attribute operand : operands.drop_front(1)) { 1026 if (!operand) 1027 return {}; 1028 allSame = allSame && operand == operands[0]; 1029 } 1030 return BoolAttr::get(getContext(), allSame); 1031 } 1032 1033 //===----------------------------------------------------------------------===// 1034 // IndexToSizeOp 1035 //===----------------------------------------------------------------------===// 1036 1037 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1038 // Constant values of both types, `shape.size` and `index`, are represented as 1039 // `IntegerAttr`s which makes constant folding simple. 1040 if (Attribute arg = operands[0]) 1041 return arg; 1042 return {}; 1043 } 1044 1045 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1046 MLIRContext *context) { 1047 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // FromExtentsOp 1052 //===----------------------------------------------------------------------===// 1053 1054 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1055 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1056 return nullptr; 1057 SmallVector<int64_t, 6> extents; 1058 for (auto attr : operands) 1059 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1060 Builder builder(getContext()); 1061 return builder.getIndexTensorAttr(extents); 1062 } 1063 1064 //===----------------------------------------------------------------------===// 1065 // FunctionLibraryOp 1066 //===----------------------------------------------------------------------===// 1067 1068 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1069 StringRef name) { 1070 result.attributes.push_back(builder.getNamedAttr( 1071 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1072 } 1073 1074 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1075 auto attr = getMapping() 1076 .get(op->getName().getIdentifier()) 1077 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1078 if (!attr) 1079 return nullptr; 1080 return lookupSymbol<FuncOp>(attr); 1081 } 1082 1083 ParseResult parseFunctionLibraryOp(OpAsmParser &parser, 1084 OperationState &result) { 1085 // Parse the op name. 1086 StringAttr nameAttr; 1087 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1088 result.attributes)) 1089 return failure(); 1090 1091 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1092 return failure(); 1093 1094 auto *bodyRegion = result.addRegion(); 1095 if (parser.parseRegion(*bodyRegion)) 1096 return failure(); 1097 1098 if (parser.parseKeyword("mapping")) 1099 return failure(); 1100 1101 DictionaryAttr mappingAttr; 1102 if (parser.parseAttribute(mappingAttr, 1103 parser.getBuilder().getType<NoneType>(), "mapping", 1104 result.attributes)) 1105 return failure(); 1106 return success(); 1107 } 1108 1109 void print(OpAsmPrinter &p, FunctionLibraryOp op) { 1110 p << ' '; 1111 p.printSymbolName(op.getName()); 1112 p.printOptionalAttrDictWithKeyword( 1113 op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); 1114 p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, 1115 /*printBlockTerminators=*/false); 1116 p << " mapping "; 1117 p.printAttributeWithoutType(op.getMappingAttr()); 1118 } 1119 1120 //===----------------------------------------------------------------------===// 1121 // GetExtentOp 1122 //===----------------------------------------------------------------------===// 1123 1124 Optional<int64_t> GetExtentOp::getConstantDim() { 1125 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1126 return constSizeOp.getValue().getLimitedValue(); 1127 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1128 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1129 return llvm::None; 1130 } 1131 1132 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1133 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1134 if (!elements) 1135 return nullptr; 1136 Optional<int64_t> dim = getConstantDim(); 1137 if (!dim.hasValue()) 1138 return nullptr; 1139 if (dim.getValue() >= elements.getNumElements()) 1140 return nullptr; 1141 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1142 } 1143 1144 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1145 int64_t dim) { 1146 auto loc = result.location; 1147 auto dimAttr = builder.getIndexAttr(dim); 1148 if (shape.getType().isa<ShapeType>()) { 1149 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1150 build(builder, result, builder.getType<SizeType>(), shape, dim); 1151 } else { 1152 Value dim = 1153 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1154 build(builder, result, builder.getIndexType(), shape, dim); 1155 } 1156 } 1157 1158 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1159 MLIRContext *context, Optional<Location> location, ValueRange operands, 1160 DictionaryAttr attributes, RegionRange regions, 1161 SmallVectorImpl<Type> &inferredReturnTypes) { 1162 inferredReturnTypes.assign({IndexType::get(context)}); 1163 return success(); 1164 } 1165 1166 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1167 TypeRange r) { 1168 // SizeType is compatible with IndexType. 1169 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1170 } 1171 1172 //===----------------------------------------------------------------------===// 1173 // IsBroadcastableOp 1174 //===----------------------------------------------------------------------===// 1175 1176 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1177 MLIRContext *context) { 1178 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1179 } 1180 1181 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1182 // Can always broadcast fewer than two shapes. 1183 if (operands.size() < 2) { 1184 return BoolAttr::get(getContext(), true); 1185 } 1186 1187 return nullptr; 1188 } 1189 1190 //===----------------------------------------------------------------------===// 1191 // MeetOp 1192 //===----------------------------------------------------------------------===// 1193 1194 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1195 MLIRContext *context, Optional<Location> location, ValueRange operands, 1196 DictionaryAttr attributes, RegionRange regions, 1197 SmallVectorImpl<Type> &inferredReturnTypes) { 1198 inferredReturnTypes.assign({operands[0].getType()}); 1199 return success(); 1200 } 1201 1202 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1203 if (l.size() != 1 || r.size() != 1) 1204 return false; 1205 if (l == r) 1206 return true; 1207 1208 Type lhs = l.front(); 1209 Type rhs = r.front(); 1210 1211 if (lhs != rhs) 1212 return false; 1213 1214 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1215 return true; 1216 1217 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1218 return true; 1219 return false; 1220 } 1221 1222 //===----------------------------------------------------------------------===// 1223 // RankOp 1224 //===----------------------------------------------------------------------===// 1225 1226 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1227 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1228 if (!shape) 1229 return {}; 1230 int64_t rank = shape.getNumElements(); 1231 Builder builder(getContext()); 1232 return builder.getIndexAttr(rank); 1233 } 1234 1235 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1236 /// Constant folding fails in cases where only the rank is constant, not the 1237 /// shape itself. 1238 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1239 /// 1240 /// Example: 1241 /// 1242 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1243 /// %rank = shape.rank %shape 1244 /// 1245 /// becomes 1246 /// 1247 /// %rank = shape.const_size 3 1248 1249 namespace { 1250 struct RankShapeOfCanonicalizationPattern 1251 : public OpRewritePattern<shape::RankOp> { 1252 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1253 1254 LogicalResult matchAndRewrite(shape::RankOp op, 1255 PatternRewriter &rewriter) const override { 1256 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1257 if (!shapeOfOp) 1258 return failure(); 1259 auto rankedTensorType = 1260 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1261 if (!rankedTensorType) 1262 return failure(); 1263 int64_t rank = rankedTensorType.getRank(); 1264 if (op.getType().isa<IndexType>()) { 1265 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1266 rank); 1267 } else if (op.getType().isa<shape::SizeType>()) { 1268 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1269 } else { 1270 return failure(); 1271 } 1272 return success(); 1273 } 1274 }; 1275 } // namespace 1276 1277 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1278 MLIRContext *context) { 1279 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1280 } 1281 1282 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1283 MLIRContext *context, Optional<Location> location, ValueRange operands, 1284 DictionaryAttr attributes, RegionRange regions, 1285 SmallVectorImpl<Type> &inferredReturnTypes) { 1286 if (operands[0].getType().isa<ShapeType>()) 1287 inferredReturnTypes.assign({SizeType::get(context)}); 1288 else 1289 inferredReturnTypes.assign({IndexType::get(context)}); 1290 return success(); 1291 } 1292 1293 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1294 // SizeType is compatible with IndexType. 1295 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1296 } 1297 1298 //===----------------------------------------------------------------------===// 1299 // NumElementsOp 1300 //===----------------------------------------------------------------------===// 1301 1302 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1303 1304 // Fold only when argument constant. 1305 Attribute shape = operands[0]; 1306 if (!shape) 1307 return {}; 1308 1309 APInt product(64, 1); 1310 for (auto value : shape.cast<DenseIntElementsAttr>()) 1311 product *= value; 1312 Builder builder(getContext()); 1313 return builder.getIndexAttr(product.getLimitedValue()); 1314 } 1315 1316 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1317 MLIRContext *context, Optional<Location> location, ValueRange operands, 1318 DictionaryAttr attributes, RegionRange regions, 1319 SmallVectorImpl<Type> &inferredReturnTypes) { 1320 if (operands[0].getType().isa<ShapeType>()) 1321 inferredReturnTypes.assign({SizeType::get(context)}); 1322 else 1323 inferredReturnTypes.assign({IndexType::get(context)}); 1324 return success(); 1325 } 1326 1327 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1328 TypeRange r) { 1329 // SizeType is compatible with IndexType. 1330 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1331 } 1332 1333 //===----------------------------------------------------------------------===// 1334 // MaxOp 1335 //===----------------------------------------------------------------------===// 1336 1337 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1338 // If operands are equal, just propagate one. 1339 if (getLhs() == getRhs()) 1340 return getLhs(); 1341 return nullptr; 1342 } 1343 1344 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1345 MLIRContext *context, Optional<Location> location, ValueRange operands, 1346 DictionaryAttr attributes, RegionRange regions, 1347 SmallVectorImpl<Type> &inferredReturnTypes) { 1348 if (operands[0].getType() == operands[1].getType()) 1349 inferredReturnTypes.assign({operands[0].getType()}); 1350 else 1351 inferredReturnTypes.assign({SizeType::get(context)}); 1352 return success(); 1353 } 1354 1355 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1356 if (l.size() != 1 || r.size() != 1) 1357 return false; 1358 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1359 return true; 1360 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1361 return true; 1362 return false; 1363 } 1364 1365 //===----------------------------------------------------------------------===// 1366 // MinOp 1367 //===----------------------------------------------------------------------===// 1368 1369 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1370 // If operands are equal, just propagate one. 1371 if (getLhs() == getRhs()) 1372 return getLhs(); 1373 return nullptr; 1374 } 1375 1376 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1377 MLIRContext *context, Optional<Location> location, ValueRange operands, 1378 DictionaryAttr attributes, RegionRange regions, 1379 SmallVectorImpl<Type> &inferredReturnTypes) { 1380 if (operands[0].getType() == operands[1].getType()) 1381 inferredReturnTypes.assign({operands[0].getType()}); 1382 else 1383 inferredReturnTypes.assign({SizeType::get(context)}); 1384 return success(); 1385 } 1386 1387 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1388 if (l.size() != 1 || r.size() != 1) 1389 return false; 1390 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1391 return true; 1392 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1393 return true; 1394 return false; 1395 } 1396 1397 //===----------------------------------------------------------------------===// 1398 // MulOp 1399 //===----------------------------------------------------------------------===// 1400 1401 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1402 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1403 if (!lhs) 1404 return nullptr; 1405 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1406 if (!rhs) 1407 return nullptr; 1408 APInt folded = lhs.getValue() * rhs.getValue(); 1409 Type indexTy = IndexType::get(getContext()); 1410 return IntegerAttr::get(indexTy, folded); 1411 } 1412 1413 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1414 MLIRContext *context, Optional<Location> location, ValueRange operands, 1415 DictionaryAttr attributes, RegionRange regions, 1416 SmallVectorImpl<Type> &inferredReturnTypes) { 1417 if (operands[0].getType().isa<SizeType>() || 1418 operands[1].getType().isa<SizeType>()) 1419 inferredReturnTypes.assign({SizeType::get(context)}); 1420 else 1421 inferredReturnTypes.assign({IndexType::get(context)}); 1422 return success(); 1423 } 1424 1425 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1426 // SizeType is compatible with IndexType. 1427 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1428 } 1429 //===----------------------------------------------------------------------===// 1430 // ShapeOfOp 1431 //===----------------------------------------------------------------------===// 1432 1433 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1434 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1435 if (!type || !type.hasStaticShape()) 1436 return nullptr; 1437 Builder builder(getContext()); 1438 return builder.getIndexTensorAttr(type.getShape()); 1439 } 1440 1441 namespace { 1442 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1443 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1444 1445 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1446 PatternRewriter &rewriter) const override { 1447 if (!op.getArg().getType().isa<ShapedType>()) 1448 return failure(); 1449 if (op.getType().isa<ShapedType>()) 1450 return failure(); 1451 1452 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1453 op.getArg()); 1454 return success(); 1455 } 1456 }; 1457 1458 // Canonicalize 1459 // ``` 1460 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1461 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1462 // ``` 1463 // to 1464 // ``` 1465 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1466 // ``` 1467 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1468 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1469 1470 LogicalResult matchAndRewrite(tensor::CastOp op, 1471 PatternRewriter &rewriter) const override { 1472 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1473 if (!ty || ty.getRank() != 1) 1474 return failure(); 1475 1476 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1477 if (!shapeOfOp) 1478 return failure(); 1479 1480 // Argument type must be ranked and must not conflict. 1481 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1482 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1483 return failure(); 1484 1485 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1486 return success(); 1487 } 1488 }; 1489 } // namespace 1490 1491 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1492 MLIRContext *context) { 1493 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1494 ExtractFromShapeOfExtentTensor>(context); 1495 } 1496 1497 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1498 MLIRContext *context, Optional<Location> location, ValueRange operands, 1499 DictionaryAttr attributes, RegionRange regions, 1500 SmallVectorImpl<Type> &inferredReturnTypes) { 1501 if (operands[0].getType().isa<ValueShapeType>()) 1502 inferredReturnTypes.assign({ShapeType::get(context)}); 1503 else { 1504 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1505 int64_t rank = 1506 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1507 Type indexTy = IndexType::get(context); 1508 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1509 inferredReturnTypes.assign({extentTensorTy}); 1510 } 1511 return success(); 1512 } 1513 1514 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1515 if (l.size() != 1 || r.size() != 1) 1516 return false; 1517 if (l == r) 1518 return true; 1519 1520 Type lhs = l.front(); 1521 Type rhs = r.front(); 1522 1523 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1524 return false; 1525 1526 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1527 // Shape type is compatible with all other valid return types. 1528 return true; 1529 1530 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1531 return true; 1532 return false; 1533 } 1534 1535 //===----------------------------------------------------------------------===// 1536 // SizeToIndexOp 1537 //===----------------------------------------------------------------------===// 1538 1539 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1540 // Constant values of both types, `shape.size` and `index`, are represented as 1541 // `IntegerAttr`s which makes constant folding simple. 1542 if (Attribute arg = operands[0]) 1543 return arg; 1544 return impl::foldCastOp(*this); 1545 } 1546 1547 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1548 MLIRContext *context) { 1549 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1550 } 1551 1552 //===----------------------------------------------------------------------===// 1553 // YieldOp 1554 //===----------------------------------------------------------------------===// 1555 1556 static LogicalResult verify(shape::YieldOp op) { 1557 auto *parentOp = op->getParentOp(); 1558 auto results = parentOp->getResults(); 1559 auto operands = op.getOperands(); 1560 1561 if (parentOp->getNumResults() != op.getNumOperands()) 1562 return op.emitOpError() << "number of operands does not match number of " 1563 "results of its parent"; 1564 for (auto e : llvm::zip(results, operands)) 1565 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1566 return op.emitOpError() 1567 << "types mismatch between yield op and its parent"; 1568 1569 return success(); 1570 } 1571 1572 //===----------------------------------------------------------------------===// 1573 // SplitAtOp 1574 //===----------------------------------------------------------------------===// 1575 1576 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1577 SmallVectorImpl<OpFoldResult> &results) { 1578 if (!operands[0] || !operands[1]) 1579 return failure(); 1580 auto shapeVec = llvm::to_vector<6>( 1581 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1582 auto shape = llvm::makeArrayRef(shapeVec); 1583 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1584 // Verify that the split point is in the correct range. 1585 // TODO: Constant fold to an "error". 1586 int64_t rank = shape.size(); 1587 if (!(-rank <= splitPoint && splitPoint <= rank)) 1588 return failure(); 1589 if (splitPoint < 0) 1590 splitPoint += shape.size(); 1591 Builder builder(operands[0].getContext()); 1592 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1593 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1594 return success(); 1595 } 1596 1597 //===----------------------------------------------------------------------===// 1598 // ToExtentTensorOp 1599 //===----------------------------------------------------------------------===// 1600 1601 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1602 if (!operands[0]) 1603 return impl::foldCastOp(*this); 1604 Builder builder(getContext()); 1605 auto shape = llvm::to_vector<6>( 1606 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1607 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1608 builder.getIndexType()); 1609 return DenseIntElementsAttr::get(type, shape); 1610 } 1611 1612 //===----------------------------------------------------------------------===// 1613 // ReduceOp 1614 //===----------------------------------------------------------------------===// 1615 1616 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1617 ValueRange initVals) { 1618 result.addOperands(shape); 1619 result.addOperands(initVals); 1620 1621 Region *bodyRegion = result.addRegion(); 1622 bodyRegion->push_back(new Block); 1623 Block &bodyBlock = bodyRegion->front(); 1624 bodyBlock.addArgument(builder.getIndexType()); 1625 1626 Type elementType; 1627 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1628 elementType = tensorType.getElementType(); 1629 else 1630 elementType = SizeType::get(builder.getContext()); 1631 bodyBlock.addArgument(elementType); 1632 1633 for (Type initValType : initVals.getTypes()) { 1634 bodyBlock.addArgument(initValType); 1635 result.addTypes(initValType); 1636 } 1637 } 1638 1639 static LogicalResult verify(ReduceOp op) { 1640 // Verify block arg types. 1641 Block &block = op.getRegion().front(); 1642 1643 // The block takes index, extent, and aggregated values as arguments. 1644 auto blockArgsCount = op.getInitVals().size() + 2; 1645 if (block.getNumArguments() != blockArgsCount) 1646 return op.emitOpError() << "ReduceOp body is expected to have " 1647 << blockArgsCount << " arguments"; 1648 1649 // The first block argument is the index and must always be of type `index`. 1650 if (!block.getArgument(0).getType().isa<IndexType>()) 1651 return op.emitOpError( 1652 "argument 0 of ReduceOp body is expected to be of IndexType"); 1653 1654 // The second block argument is the extent and must be of type `size` or 1655 // `index`, depending on whether the reduce operation is applied to a shape or 1656 // to an extent tensor. 1657 Type extentTy = block.getArgument(1).getType(); 1658 if (op.getShape().getType().isa<ShapeType>()) { 1659 if (!extentTy.isa<SizeType>()) 1660 return op.emitOpError("argument 1 of ReduceOp body is expected to be of " 1661 "SizeType if the ReduceOp operates on a ShapeType"); 1662 } else { 1663 if (!extentTy.isa<IndexType>()) 1664 return op.emitOpError( 1665 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1666 "ReduceOp operates on an extent tensor"); 1667 } 1668 1669 for (auto type : llvm::enumerate(op.getInitVals())) 1670 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1671 return op.emitOpError() 1672 << "type mismatch between argument " << type.index() + 2 1673 << " of ReduceOp body and initial value " << type.index(); 1674 return success(); 1675 } 1676 1677 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { 1678 // Parse operands. 1679 SmallVector<OpAsmParser::OperandType, 3> operands; 1680 Type shapeOrExtentTensorType; 1681 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1682 OpAsmParser::Delimiter::Paren) || 1683 parser.parseColonType(shapeOrExtentTensorType) || 1684 parser.parseOptionalArrowTypeList(result.types)) 1685 return failure(); 1686 1687 // Resolve operands. 1688 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1689 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1690 result.operands) || 1691 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1692 result.operands)) 1693 return failure(); 1694 1695 // Parse the body. 1696 Region *body = result.addRegion(); 1697 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1698 return failure(); 1699 1700 // Parse attributes. 1701 if (parser.parseOptionalAttrDict(result.attributes)) 1702 return failure(); 1703 1704 return success(); 1705 } 1706 1707 static void print(OpAsmPrinter &p, ReduceOp op) { 1708 p << '(' << op.getShape() << ", " << op.getInitVals() 1709 << ") : " << op.getShape().getType(); 1710 p.printOptionalArrowTypeList(op.getResultTypes()); 1711 p.printRegion(op.getRegion()); 1712 p.printOptionalAttrDict(op->getAttrs()); 1713 } 1714 1715 #define GET_OP_CLASSES 1716 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1717