1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/SetOperations.h" 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace mlir; 31 using namespace mlir::shape; 32 33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 34 35 namespace { 36 #include "ShapeCanonicalization.inc" 37 } // namespace 38 39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 40 return RankedTensorType::get({rank}, IndexType::get(ctx)); 41 } 42 43 bool shape::isExtentTensorType(Type type) { 44 auto ranked = type.dyn_cast<RankedTensorType>(); 45 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 46 } 47 48 LogicalResult shape::getShapeVec(Value input, 49 SmallVectorImpl<int64_t> &shapeValues) { 50 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 51 auto type = inputOp.getArg().getType().cast<ShapedType>(); 52 if (!type.hasRank()) 53 return failure(); 54 llvm::append_range(shapeValues, type.getShape()); 55 return success(); 56 } 57 DenseIntElementsAttr attr; 58 if (matchPattern(input, m_Constant(&attr))) { 59 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 60 return success(); 61 } 62 return failure(); 63 } 64 65 static bool isErrorPropagationPossible(TypeRange operandTypes) { 66 return llvm::any_of(operandTypes, [](Type ty) { 67 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 68 }); 69 } 70 71 static LogicalResult verifySizeOrIndexOp(Operation *op) { 72 assert(op != nullptr && op->getNumResults() == 1); 73 Type resultTy = op->getResultTypes().front(); 74 if (isErrorPropagationPossible(op->getOperandTypes())) { 75 if (!resultTy.isa<SizeType>()) 76 return op->emitOpError() 77 << "if at least one of the operands can hold error values then " 78 "the result must be of type `size` to propagate them"; 79 } 80 return success(); 81 } 82 83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 84 assert(op != nullptr && op->getNumResults() == 1); 85 Type resultTy = op->getResultTypes().front(); 86 if (isErrorPropagationPossible(op->getOperandTypes())) { 87 if (!resultTy.isa<ShapeType>()) 88 return op->emitOpError() 89 << "if at least one of the operands can hold error values then " 90 "the result must be of type `shape` to propagate them"; 91 } 92 return success(); 93 } 94 95 template <typename... Ty> 96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 98 } 99 100 template <typename... Ty, typename... ranges> 101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // InlinerInterface 107 //===----------------------------------------------------------------------===// 108 109 namespace { 110 /// This class defines the interface for inlining shape dialect ops. 111 struct ShapeInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 // Returns true if the given region 'src' can be inlined into the region 115 // 'dest' that is attached to an operation registered to the current dialect. 116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 // Returns true if the given operation 'op', that is registered to this 122 // dialect, can be inlined into the region 'dest' that is attached to an 123 // operation registered to the current dialect. 124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 125 BlockAndValueMapping &) const final { 126 return true; 127 } 128 }; 129 } // namespace 130 131 void ShapeDialect::initialize() { 132 addOperations< 133 #define GET_OP_LIST 134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 135 >(); 136 addTypes< 137 #define GET_TYPEDEF_LIST 138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" 139 >(); 140 addInterfaces<ShapeInlinerInterface>(); 141 // Allow unknown operations during prototyping and testing. As the dialect is 142 // still evolving it makes it simple to start with an unregistered ops and 143 // try different variants before actually defining the op. 144 allowUnknownOperations(); 145 } 146 147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 148 Attribute value, Type type, 149 Location loc) { 150 if (type.isa<ShapeType>() || isExtentTensorType(type)) 151 return builder.create<ConstShapeOp>(loc, type, 152 value.cast<DenseIntElementsAttr>()); 153 if (type.isa<SizeType>()) 154 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 155 if (type.isa<WitnessType>()) 156 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 157 if (arith::ConstantOp::isBuildableWith(value, type)) 158 return builder.create<arith::ConstantOp>(loc, type, value); 159 return nullptr; 160 } 161 162 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 163 NamedAttribute attribute) { 164 // Verify shape.lib attribute. 165 if (attribute.getName() == "shape.lib") { 166 if (!op->hasTrait<OpTrait::SymbolTable>()) 167 return op->emitError( 168 "shape.lib attribute may only be on op implementing SymbolTable"); 169 170 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 171 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 172 if (!symbol) 173 return op->emitError("shape function library ") 174 << symbolRef << " not found"; 175 return isa<shape::FunctionLibraryOp>(symbol) 176 ? success() 177 : op->emitError() 178 << symbolRef << " required to be shape function library"; 179 } 180 181 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 182 // Verify all entries are function libraries and mappings in libraries 183 // refer to unique ops. 184 DenseSet<StringAttr> key; 185 for (auto it : arr) { 186 if (!it.isa<SymbolRefAttr>()) 187 return op->emitError( 188 "only SymbolRefAttr allowed in shape.lib attribute array"); 189 190 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 191 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 192 if (!shapeFnLib) 193 return op->emitError() 194 << it << " does not refer to FunctionLibraryOp"; 195 for (auto mapping : shapeFnLib.getMapping()) { 196 if (!key.insert(mapping.getName()).second) { 197 return op->emitError("only one op to shape mapping allowed, found " 198 "multiple for `") 199 << mapping.getName() << "`"; 200 } 201 } 202 } 203 return success(); 204 } 205 206 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 207 "allowed as shape.lib attribute"); 208 } 209 return success(); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // AnyOp 214 //===----------------------------------------------------------------------===// 215 216 // TODO: Canonicalization should be implemented for shapes that can be 217 // determined through mixtures of the known dimensions of the inputs. 218 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 219 // Only the last operand is checked because AnyOp is commutative. 220 if (operands.back()) 221 return operands.back(); 222 223 return nullptr; 224 } 225 226 //===----------------------------------------------------------------------===// 227 // AssumingOp 228 //===----------------------------------------------------------------------===// 229 230 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 231 result.regions.reserve(1); 232 Region *doRegion = result.addRegion(); 233 234 auto &builder = parser.getBuilder(); 235 OpAsmParser::UnresolvedOperand cond; 236 if (parser.parseOperand(cond) || 237 parser.resolveOperand(cond, builder.getType<WitnessType>(), 238 result.operands)) 239 return failure(); 240 241 // Parse optional results type list. 242 if (parser.parseOptionalArrowTypeList(result.types)) 243 return failure(); 244 245 // Parse the region and add a terminator if elided. 246 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 247 return failure(); 248 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 249 250 // Parse the optional attribute list. 251 if (parser.parseOptionalAttrDict(result.attributes)) 252 return failure(); 253 return success(); 254 } 255 256 void AssumingOp::print(OpAsmPrinter &p) { 257 bool yieldsResults = !getResults().empty(); 258 259 p << " " << getWitness(); 260 if (yieldsResults) 261 p << " -> (" << getResultTypes() << ")"; 262 p << ' '; 263 p.printRegion(getDoRegion(), 264 /*printEntryBlockArgs=*/false, 265 /*printBlockTerminators=*/yieldsResults); 266 p.printOptionalAttrDict((*this)->getAttrs()); 267 } 268 269 namespace { 270 // Removes AssumingOp with a passing witness and inlines the region. 271 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 272 using OpRewritePattern<AssumingOp>::OpRewritePattern; 273 274 LogicalResult matchAndRewrite(AssumingOp op, 275 PatternRewriter &rewriter) const override { 276 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 277 if (!witness || !witness.getPassingAttr()) 278 return failure(); 279 280 AssumingOp::inlineRegionIntoParent(op, rewriter); 281 return success(); 282 } 283 }; 284 285 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 286 using OpRewritePattern<AssumingOp>::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(AssumingOp op, 289 PatternRewriter &rewriter) const override { 290 Block *body = op.getBody(); 291 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 292 293 // Find used values. 294 SmallVector<Value, 4> newYieldOperands; 295 Value opResult, yieldOperand; 296 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 297 std::tie(opResult, yieldOperand) = it; 298 if (!opResult.getUses().empty()) { 299 newYieldOperands.push_back(yieldOperand); 300 } 301 } 302 303 // Rewrite only if redundant results exist. 304 if (newYieldOperands.size() == yieldOp->getNumOperands()) 305 return failure(); 306 307 // Replace yield op in the old assuming op's body and move the entire region 308 // to the new assuming op. 309 rewriter.setInsertionPointToEnd(body); 310 auto newYieldOp = 311 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 312 rewriter.setInsertionPoint(op); 313 auto newOp = rewriter.create<AssumingOp>( 314 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 315 newOp.getDoRegion().takeBody(op.getDoRegion()); 316 317 // Use the new results to replace the previously used ones. 318 SmallVector<Value, 4> replacementValues; 319 auto src = newOp.getResults().begin(); 320 for (auto it : op.getResults()) { 321 if (it.getUses().empty()) 322 replacementValues.push_back(nullptr); 323 else 324 replacementValues.push_back(*src++); 325 } 326 rewriter.replaceOp(op, replacementValues); 327 return success(); 328 } 329 }; 330 } // namespace 331 332 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 333 MLIRContext *context) { 334 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 335 } 336 337 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 338 void AssumingOp::getSuccessorRegions( 339 Optional<unsigned> index, ArrayRef<Attribute> operands, 340 SmallVectorImpl<RegionSuccessor> ®ions) { 341 // AssumingOp has unconditional control flow into the region and back to the 342 // parent, so return the correct RegionSuccessor purely based on the index 343 // being None or 0. 344 if (index) { 345 regions.push_back(RegionSuccessor(getResults())); 346 return; 347 } 348 349 regions.push_back(RegionSuccessor(&getDoRegion())); 350 } 351 352 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 353 PatternRewriter &rewriter) { 354 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 355 auto *assumingBlock = op.getBody(); 356 auto initPosition = rewriter.getInsertionPoint(); 357 auto *blockAfterAssuming = 358 rewriter.splitBlock(blockBeforeAssuming, initPosition); 359 360 // Remove the AssumingOp and AssumingYieldOp. 361 auto &yieldOp = assumingBlock->back(); 362 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 363 rewriter.replaceOp(op, yieldOp.getOperands()); 364 rewriter.eraseOp(&yieldOp); 365 366 // Merge blocks together as there was no branching behavior from the 367 // AssumingOp. 368 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 369 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 370 } 371 372 void AssumingOp::build( 373 OpBuilder &builder, OperationState &result, Value witness, 374 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 375 376 result.addOperands(witness); 377 Region *bodyRegion = result.addRegion(); 378 bodyRegion->push_back(new Block); 379 Block &bodyBlock = bodyRegion->front(); 380 381 // Build body. 382 OpBuilder::InsertionGuard guard(builder); 383 builder.setInsertionPointToStart(&bodyBlock); 384 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 385 builder.create<AssumingYieldOp>(result.location, yieldValues); 386 387 SmallVector<Type, 2> assumingTypes; 388 for (Value v : yieldValues) 389 assumingTypes.push_back(v.getType()); 390 result.addTypes(assumingTypes); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // AddOp 395 //===----------------------------------------------------------------------===// 396 397 LogicalResult mlir::shape::AddOp::inferReturnTypes( 398 MLIRContext *context, Optional<Location> location, ValueRange operands, 399 DictionaryAttr attributes, RegionRange regions, 400 SmallVectorImpl<Type> &inferredReturnTypes) { 401 if (operands[0].getType().isa<SizeType>() || 402 operands[1].getType().isa<SizeType>()) 403 inferredReturnTypes.assign({SizeType::get(context)}); 404 else 405 inferredReturnTypes.assign({IndexType::get(context)}); 406 return success(); 407 } 408 409 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 410 // SizeType is compatible with IndexType. 411 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 412 } 413 414 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 415 // add(x, 0) -> x 416 if (matchPattern(getRhs(), m_Zero())) 417 return getLhs(); 418 419 return constFoldBinaryOp<IntegerAttr>( 420 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 421 } 422 423 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 424 425 //===----------------------------------------------------------------------===// 426 // AssumingAllOp 427 //===----------------------------------------------------------------------===// 428 429 namespace { 430 431 // Merge multiple `shape.assuming_all` operations together. 432 // 433 // %0 = shape.assuming_all %w0, %w1 434 // %1 = shape.assuming_all %w2, %0 435 // 436 // to: 437 // 438 // %0 = shape.assuming_all %w0, %w2, %w2 439 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 440 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 441 442 LogicalResult matchAndRewrite(AssumingAllOp op, 443 PatternRewriter &rewriter) const override { 444 SmallVector<Value> operands; 445 446 for (Value operand : op.getInputs()) { 447 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>()) 448 operands.append(assumeAll.operand_begin(), assumeAll->operand_end()); 449 else 450 operands.push_back(operand); 451 } 452 453 // We didn't find any other `assuming_all` ops to merge with. 454 if (operands.size() == op.getNumOperands()) 455 return failure(); 456 457 // Replace with a new `assuming_all` operation with merged constraints. 458 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 459 return success(); 460 } 461 }; 462 463 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 464 // are subsumed by others. 465 // 466 // %0 = shape.cstr_broadcastable %shape0, %shape1 467 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 468 // 469 // %2 = shape.cstr_broadcastable %shape3, %shape4 470 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 471 // 472 // %4 = shape.assuming_all %0, %1, %2, %3 473 // 474 // to: 475 // 476 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 477 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 478 // %2 = shape.assuming_all %0, %1 479 // 480 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 481 // shapes [0, 1] are broadcastable too, and can be removed from the list of 482 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 483 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 484 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 485 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 486 487 LogicalResult matchAndRewrite(AssumingAllOp op, 488 PatternRewriter &rewriter) const override { 489 // Collect all `CstrBroadcastableOp` operands first. 490 SetVector<CstrBroadcastableOp> operands; 491 for (Value operand : op.getInputs()) { 492 // TODO: Apply this optimization if some of the witnesses are not 493 // produced by the `cstr_broadcastable`. 494 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 495 if (!broadcastable) 496 return failure(); 497 498 operands.insert(broadcastable); 499 } 500 501 // Skip trivial `assuming_all` operations. 502 if (operands.size() <= 1) 503 return failure(); 504 505 // Collect shapes checked by `cstr_broadcastable` operands. 506 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 507 for (auto cstr : operands) { 508 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end()); 509 shapes.emplace_back(cstr, std::move(shapesSet)); 510 } 511 512 // Sort by the number of shape operands (larger to smaller). 513 llvm::sort(shapes, [](auto a, auto b) { 514 return a.first.getNumOperands() > b.first.getNumOperands(); 515 }); 516 517 // We start from the `cst_broadcastable` operations with largest number of 518 // shape operands, and remove redundant `cst_broadcastable` operations. We 519 // do this until we find a set of `cst_broadcastable` operations with 520 // non-overlapping constraints. 521 SmallVector<CstrBroadcastableOp> markedForErase; 522 523 for (unsigned i = 0; i < shapes.size(); ++i) { 524 auto isSubset = [&](auto pair) { 525 return llvm::set_is_subset(pair.second, shapes[i].second); 526 }; 527 528 // Keep redundant `cstr_broadcastable` operations to be erased. 529 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 530 for (auto *it0 = it; it0 < shapes.end(); ++it0) 531 markedForErase.push_back(it0->first); 532 shapes.erase(it, shapes.end()); 533 } 534 535 // We didn't find any operands that could be removed. 536 if (markedForErase.empty()) 537 return failure(); 538 539 // Collect non-overlapping `cst_broadcastable` constraints. 540 SmallVector<Value> uniqueConstraints; 541 for (auto &shape : shapes) 542 uniqueConstraints.push_back(shape.first.getResult()); 543 544 // Replace with a new `assuming_all` operation ... 545 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints); 546 547 // ... and maybe erase `cstr_broadcastable` ops without uses. 548 for (auto &op : markedForErase) 549 if (op->use_empty()) 550 rewriter.eraseOp(op); 551 552 return success(); 553 } 554 }; 555 556 struct AssumingAllToCstrEqCanonicalization 557 : public OpRewritePattern<AssumingAllOp> { 558 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 559 560 LogicalResult matchAndRewrite(AssumingAllOp op, 561 PatternRewriter &rewriter) const override { 562 SmallVector<Value, 8> shapes; 563 for (Value w : op.getInputs()) { 564 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 565 if (!cstrEqOp) 566 return failure(); 567 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 568 return llvm::is_contained(shapes, s); 569 }); 570 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 571 return failure(); 572 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 573 } 574 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 575 return success(); 576 } 577 }; 578 579 template <typename OpTy> 580 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 581 using OpRewritePattern<OpTy>::OpRewritePattern; 582 583 LogicalResult matchAndRewrite(OpTy op, 584 PatternRewriter &rewriter) const override { 585 // Find unique operands. 586 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 587 588 // Reduce op to equivalent with unique operands. 589 if (unique.size() < op.getNumOperands()) { 590 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 591 unique.takeVector(), op->getAttrs()); 592 return success(); 593 } 594 595 return failure(); 596 } 597 }; 598 } // namespace 599 600 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 601 MLIRContext *context) { 602 patterns 603 .add<MergeAssumingAllOps, AssumingAllOneOp, 604 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 605 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 606 } 607 608 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 609 // Iterate in reverse to first handle all constant operands. They are 610 // guaranteed to be the tail of the inputs because this is commutative. 611 for (int idx = operands.size() - 1; idx >= 0; idx--) { 612 Attribute a = operands[idx]; 613 // Cannot fold if any inputs are not constant; 614 if (!a) 615 return nullptr; 616 617 // We do not need to keep statically known values after handling them in 618 // this method. 619 getOperation()->eraseOperand(idx); 620 621 // Always false if any input is statically known false 622 if (!a.cast<BoolAttr>().getValue()) 623 return a; 624 } 625 // If this is reached, all inputs were statically known passing. 626 return BoolAttr::get(getContext(), true); 627 } 628 629 LogicalResult AssumingAllOp::verify() { 630 // Ensure that AssumingAllOp contains at least one operand 631 if (getNumOperands() == 0) 632 return emitOpError("no operands specified"); 633 634 return success(); 635 } 636 637 //===----------------------------------------------------------------------===// 638 // BroadcastOp 639 //===----------------------------------------------------------------------===// 640 641 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 642 if (getShapes().size() == 1) { 643 // Otherwise, we need a cast which would be a canonicalization, not folding. 644 if (getShapes().front().getType() != getType()) 645 return nullptr; 646 return getShapes().front(); 647 } 648 649 // TODO: Support folding with more than 2 input shapes 650 if (getShapes().size() > 2) 651 return nullptr; 652 653 if (!operands[0] || !operands[1]) 654 return nullptr; 655 auto lhsShape = llvm::to_vector<6>( 656 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 657 auto rhsShape = llvm::to_vector<6>( 658 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 659 SmallVector<int64_t, 6> resultShape; 660 661 // If the shapes are not compatible, we can't fold it. 662 // TODO: Fold to an "error". 663 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 664 return nullptr; 665 666 Builder builder(getContext()); 667 return builder.getIndexTensorAttr(resultShape); 668 } 669 670 LogicalResult BroadcastOp::verify() { 671 return verifyShapeOrExtentTensorOp(*this); 672 } 673 674 namespace { 675 template <typename OpTy> 676 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 677 using OpRewritePattern<OpTy>::OpRewritePattern; 678 679 LogicalResult matchAndRewrite(OpTy op, 680 PatternRewriter &rewriter) const override { 681 auto isPotentiallyNonEmptyShape = [](Value shape) { 682 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 683 if (extentTensorTy.getDimSize(0) == 0) 684 return false; 685 } 686 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 687 if (constShape.getShape().empty()) 688 return false; 689 } 690 return true; 691 }; 692 auto newOperands = llvm::to_vector<8>( 693 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 694 695 // Reduce op to equivalent without empty shape operands. 696 if (newOperands.size() < op.getNumOperands()) { 697 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 698 op->getAttrs()); 699 return success(); 700 } 701 702 return failure(); 703 } 704 }; 705 706 struct BroadcastForwardSingleOperandPattern 707 : public OpRewritePattern<BroadcastOp> { 708 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 709 710 LogicalResult matchAndRewrite(BroadcastOp op, 711 PatternRewriter &rewriter) const override { 712 if (op.getNumOperands() != 1) 713 return failure(); 714 Value replacement = op.getShapes().front(); 715 716 // Insert cast if needed. 717 if (replacement.getType() != op.getType()) { 718 auto loc = op.getLoc(); 719 if (op.getType().isa<ShapeType>()) { 720 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 721 } else { 722 assert(!op.getType().isa<ShapeType>() && 723 !replacement.getType().isa<ShapeType>() && 724 "expect extent tensor cast"); 725 replacement = 726 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 727 } 728 } 729 730 rewriter.replaceOp(op, replacement); 731 return success(); 732 } 733 }; 734 735 struct BroadcastFoldConstantOperandsPattern 736 : public OpRewritePattern<BroadcastOp> { 737 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 738 739 LogicalResult matchAndRewrite(BroadcastOp op, 740 PatternRewriter &rewriter) const override { 741 SmallVector<int64_t, 8> foldedConstantShape; 742 SmallVector<Value, 8> newShapeOperands; 743 for (Value shape : op.getShapes()) { 744 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 745 SmallVector<int64_t, 8> newFoldedConstantShape; 746 if (OpTrait::util::getBroadcastedShape( 747 foldedConstantShape, 748 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 749 newFoldedConstantShape)) { 750 foldedConstantShape = newFoldedConstantShape; 751 continue; 752 } 753 } 754 newShapeOperands.push_back(shape); 755 } 756 757 // Need at least two constant operands to fold anything. 758 if (op.getNumOperands() - newShapeOperands.size() < 2) 759 return failure(); 760 761 auto foldedConstantOperandsTy = RankedTensorType::get( 762 {static_cast<int64_t>(foldedConstantShape.size())}, 763 rewriter.getIndexType()); 764 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 765 op.getLoc(), foldedConstantOperandsTy, 766 rewriter.getIndexTensorAttr(foldedConstantShape))); 767 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 768 newShapeOperands); 769 return success(); 770 } 771 }; 772 773 template <typename OpTy> 774 struct CanonicalizeCastExtentTensorOperandsPattern 775 : public OpRewritePattern<OpTy> { 776 using OpRewritePattern<OpTy>::OpRewritePattern; 777 778 LogicalResult matchAndRewrite(OpTy op, 779 PatternRewriter &rewriter) const override { 780 // Canonicalize operands. 781 bool anyChange = false; 782 auto canonicalizeOperand = [&](Value operand) { 783 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 784 // Only eliminate the cast if it holds no shape information. 785 bool isInformationLoosingCast = 786 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 787 if (isInformationLoosingCast) { 788 anyChange = true; 789 return castOp.getSource(); 790 } 791 } 792 return operand; 793 }; 794 auto newOperands = llvm::to_vector<8>( 795 llvm::map_range(op.getOperands(), canonicalizeOperand)); 796 797 // Rewrite op if any change required. 798 if (!anyChange) 799 return failure(); 800 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 801 return success(); 802 } 803 }; 804 805 struct BroadcastConcretizeResultTypePattern 806 : public OpRewritePattern<BroadcastOp> { 807 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 808 809 LogicalResult matchAndRewrite(BroadcastOp op, 810 PatternRewriter &rewriter) const override { 811 // Only concretize dynamic extent tensor result types. 812 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 813 if (!resultTy || !resultTy.isDynamicDim(0)) 814 return failure(); 815 816 // Infer resulting shape rank if possible. 817 int64_t maxRank = 0; 818 for (Value shape : op.getShapes()) { 819 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 820 // Cannot infer resulting shape rank if any operand is dynamically 821 // ranked. 822 if (extentTensorTy.isDynamicDim(0)) 823 return failure(); 824 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 825 } 826 } 827 828 auto newOp = rewriter.create<BroadcastOp>( 829 op.getLoc(), getExtentTensorType(getContext(), maxRank), 830 op.getShapes()); 831 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 832 return success(); 833 } 834 }; 835 } // namespace 836 837 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 838 MLIRContext *context) { 839 patterns.add<BroadcastConcretizeResultTypePattern, 840 BroadcastFoldConstantOperandsPattern, 841 BroadcastForwardSingleOperandPattern, 842 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 843 RemoveDuplicateOperandsPattern<BroadcastOp>, 844 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 845 } 846 847 //===----------------------------------------------------------------------===// 848 // ConcatOp 849 //===----------------------------------------------------------------------===// 850 851 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 852 if (!operands[0] || !operands[1]) 853 return nullptr; 854 auto lhsShape = llvm::to_vector<6>( 855 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 856 auto rhsShape = llvm::to_vector<6>( 857 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 858 SmallVector<int64_t, 6> resultShape; 859 resultShape.append(lhsShape.begin(), lhsShape.end()); 860 resultShape.append(rhsShape.begin(), rhsShape.end()); 861 Builder builder(getContext()); 862 return builder.getIndexTensorAttr(resultShape); 863 } 864 865 //===----------------------------------------------------------------------===// 866 // ConstShapeOp 867 //===----------------------------------------------------------------------===// 868 869 void ConstShapeOp::print(OpAsmPrinter &p) { 870 p << " "; 871 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 872 p << "["; 873 interleaveComma(getShape().getValues<int64_t>(), p); 874 p << "] : "; 875 p.printType(getType()); 876 } 877 878 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 879 if (parser.parseOptionalAttrDict(result.attributes)) 880 return failure(); 881 // We piggy-back on ArrayAttr parsing, though we don't internally store the 882 // shape as an ArrayAttr. 883 // TODO: Implement custom parser and maybe make syntax a bit more concise. 884 Attribute extentsRaw; 885 NamedAttrList dummy; 886 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 887 return failure(); 888 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 889 if (!extentsArray) 890 return failure(); 891 SmallVector<int64_t, 6> ints; 892 for (Attribute extent : extentsArray) { 893 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 894 if (!attr) 895 return failure(); 896 ints.push_back(attr.getInt()); 897 } 898 Builder &builder = parser.getBuilder(); 899 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 900 Type resultTy; 901 if (parser.parseColonType(resultTy)) 902 return failure(); 903 result.types.push_back(resultTy); 904 return success(); 905 } 906 907 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 908 909 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 910 MLIRContext *context) { 911 patterns.add<TensorCastConstShape>(context); 912 } 913 914 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 915 MLIRContext *context, Optional<Location> location, ValueRange operands, 916 DictionaryAttr attributes, RegionRange regions, 917 SmallVectorImpl<Type> &inferredReturnTypes) { 918 Builder b(context); 919 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 920 if (!shape) 921 return emitOptionalError(location, "missing shape attribute"); 922 inferredReturnTypes.assign({RankedTensorType::get( 923 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 924 return success(); 925 } 926 927 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 928 TypeRange r) { 929 if (l.size() != 1 || r.size() != 1) 930 return false; 931 932 Type lhs = l.front(); 933 Type rhs = r.front(); 934 935 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 936 // Shape type is compatible with all other valid return types. 937 return true; 938 return lhs == rhs; 939 } 940 941 //===----------------------------------------------------------------------===// 942 // CstrBroadcastableOp 943 //===----------------------------------------------------------------------===// 944 945 void CstrBroadcastableOp::getCanonicalizationPatterns( 946 RewritePatternSet &patterns, MLIRContext *context) { 947 // Canonicalization patterns have overlap with the considerations during 948 // folding in case additional shape information is inferred at some point that 949 // does not result in folding. 950 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 951 CstrBroadcastableEqOps, 952 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 953 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 954 } 955 956 // Return true if there is exactly one attribute not representing a scalar 957 // broadcast. 958 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 959 bool nonScalarSeen = false; 960 for (Attribute a : attributes) { 961 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 962 if (nonScalarSeen) 963 return false; 964 nonScalarSeen = true; 965 } 966 } 967 return true; 968 } 969 970 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 971 // No broadcasting is needed if all operands but one are scalar. 972 if (hasAtMostSingleNonScalar(operands)) 973 return BoolAttr::get(getContext(), true); 974 975 if ([&] { 976 SmallVector<SmallVector<int64_t, 6>, 6> extents; 977 for (const auto &operand : operands) { 978 if (!operand) 979 return false; 980 extents.push_back(llvm::to_vector<6>( 981 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 982 } 983 return OpTrait::util::staticallyKnownBroadcastable(extents); 984 }()) 985 return BoolAttr::get(getContext(), true); 986 987 // Lastly, see if folding can be completed based on what constraints are known 988 // on the input shapes. 989 if ([&] { 990 SmallVector<SmallVector<int64_t, 6>, 6> extents; 991 for (auto shapeValue : getShapes()) { 992 extents.emplace_back(); 993 if (failed(getShapeVec(shapeValue, extents.back()))) 994 return false; 995 } 996 return OpTrait::util::staticallyKnownBroadcastable(extents); 997 }()) 998 return BoolAttr::get(getContext(), true); 999 1000 // Because a failing witness result here represents an eventual assertion 1001 // failure, we do not replace it with a constant witness. 1002 return nullptr; 1003 } 1004 1005 LogicalResult CstrBroadcastableOp::verify() { 1006 // Ensure that CstrBroadcastableOp contains at least two operands 1007 if (getNumOperands() < 2) 1008 return emitOpError("required at least 2 input shapes"); 1009 return success(); 1010 } 1011 1012 //===----------------------------------------------------------------------===// 1013 // CstrEqOp 1014 //===----------------------------------------------------------------------===// 1015 1016 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1017 MLIRContext *context) { 1018 // If inputs are equal, return passing witness 1019 patterns.add<CstrEqEqOps>(context); 1020 } 1021 1022 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1023 if (llvm::all_of(operands, 1024 [&](Attribute a) { return a && a == operands[0]; })) 1025 return BoolAttr::get(getContext(), true); 1026 1027 // Because a failing witness result here represents an eventual assertion 1028 // failure, we do not try to replace it with a constant witness. Similarly, we 1029 // cannot if there are any non-const inputs. 1030 return nullptr; 1031 } 1032 1033 //===----------------------------------------------------------------------===// 1034 // ConstSizeOp 1035 //===----------------------------------------------------------------------===// 1036 1037 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1038 int64_t value) { 1039 build(builder, result, builder.getIndexAttr(value)); 1040 } 1041 1042 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1043 1044 void ConstSizeOp::getAsmResultNames( 1045 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1046 SmallString<4> buffer; 1047 llvm::raw_svector_ostream os(buffer); 1048 os << "c" << getValue(); 1049 setNameFn(getResult(), os.str()); 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // ConstWitnessOp 1054 //===----------------------------------------------------------------------===// 1055 1056 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1057 return getPassingAttr(); 1058 } 1059 1060 //===----------------------------------------------------------------------===// 1061 // CstrRequireOp 1062 //===----------------------------------------------------------------------===// 1063 1064 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1065 return operands[0]; 1066 } 1067 1068 //===----------------------------------------------------------------------===// 1069 // DivOp 1070 //===----------------------------------------------------------------------===// 1071 1072 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1073 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1074 if (!lhs) 1075 return nullptr; 1076 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1077 if (!rhs) 1078 return nullptr; 1079 1080 // Division in APInt does not follow floor(lhs, rhs) when the result is 1081 // negative. Rather, APInt rounds toward zero. 1082 APInt quotient, remainder; 1083 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1084 if (quotient.isNegative() && !remainder.isNullValue()) { 1085 quotient -= 1; 1086 } 1087 1088 Type indexTy = IndexType::get(getContext()); 1089 return IntegerAttr::get(indexTy, quotient); 1090 } 1091 1092 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1093 MLIRContext *context, Optional<Location> location, ValueRange operands, 1094 DictionaryAttr attributes, RegionRange regions, 1095 SmallVectorImpl<Type> &inferredReturnTypes) { 1096 if (operands[0].getType().isa<SizeType>() || 1097 operands[1].getType().isa<SizeType>()) 1098 inferredReturnTypes.assign({SizeType::get(context)}); 1099 else 1100 inferredReturnTypes.assign({IndexType::get(context)}); 1101 return success(); 1102 } 1103 1104 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1105 // SizeType is compatible with IndexType. 1106 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1107 } 1108 1109 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1110 1111 //===----------------------------------------------------------------------===// 1112 // ShapeEqOp 1113 //===----------------------------------------------------------------------===// 1114 1115 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1116 bool allSame = true; 1117 if (!operands.empty() && !operands[0]) 1118 return {}; 1119 for (Attribute operand : operands.drop_front(1)) { 1120 if (!operand) 1121 return {}; 1122 allSame = allSame && operand == operands[0]; 1123 } 1124 return BoolAttr::get(getContext(), allSame); 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // IndexToSizeOp 1129 //===----------------------------------------------------------------------===// 1130 1131 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1132 // Constant values of both types, `shape.size` and `index`, are represented as 1133 // `IntegerAttr`s which makes constant folding simple. 1134 if (Attribute arg = operands[0]) 1135 return arg; 1136 return {}; 1137 } 1138 1139 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1140 MLIRContext *context) { 1141 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1142 } 1143 1144 //===----------------------------------------------------------------------===// 1145 // FromExtentsOp 1146 //===----------------------------------------------------------------------===// 1147 1148 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1149 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1150 return nullptr; 1151 SmallVector<int64_t, 6> extents; 1152 for (auto attr : operands) 1153 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1154 Builder builder(getContext()); 1155 return builder.getIndexTensorAttr(extents); 1156 } 1157 1158 //===----------------------------------------------------------------------===// 1159 // FunctionLibraryOp 1160 //===----------------------------------------------------------------------===// 1161 1162 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1163 StringRef name) { 1164 result.attributes.push_back(builder.getNamedAttr( 1165 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1166 } 1167 1168 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1169 auto attr = getMapping() 1170 .get(op->getName().getIdentifier()) 1171 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1172 if (!attr) 1173 return nullptr; 1174 return lookupSymbol<FuncOp>(attr); 1175 } 1176 1177 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1178 OperationState &result) { 1179 // Parse the op name. 1180 StringAttr nameAttr; 1181 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1182 result.attributes)) 1183 return failure(); 1184 1185 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1186 return failure(); 1187 1188 auto *bodyRegion = result.addRegion(); 1189 if (parser.parseRegion(*bodyRegion)) 1190 return failure(); 1191 1192 if (parser.parseKeyword("mapping")) 1193 return failure(); 1194 1195 DictionaryAttr mappingAttr; 1196 if (parser.parseAttribute(mappingAttr, 1197 parser.getBuilder().getType<NoneType>(), "mapping", 1198 result.attributes)) 1199 return failure(); 1200 return success(); 1201 } 1202 1203 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1204 p << ' '; 1205 p.printSymbolName(getName()); 1206 p.printOptionalAttrDictWithKeyword( 1207 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1208 p << ' '; 1209 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1210 /*printBlockTerminators=*/false); 1211 p << " mapping "; 1212 p.printAttributeWithoutType(getMappingAttr()); 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // FuncOp 1217 //===----------------------------------------------------------------------===// 1218 1219 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 1220 auto buildFuncType = 1221 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 1222 function_interface_impl::VariadicFlag, 1223 std::string &) { return builder.getFunctionType(argTypes, results); }; 1224 1225 return function_interface_impl::parseFunctionOp( 1226 parser, result, /*allowVariadic=*/false, buildFuncType); 1227 } 1228 1229 void FuncOp::print(OpAsmPrinter &p) { 1230 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 1231 } 1232 1233 //===----------------------------------------------------------------------===// 1234 // GetExtentOp 1235 //===----------------------------------------------------------------------===// 1236 1237 Optional<int64_t> GetExtentOp::getConstantDim() { 1238 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1239 return constSizeOp.getValue().getLimitedValue(); 1240 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1241 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1242 return llvm::None; 1243 } 1244 1245 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1246 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1247 if (!elements) 1248 return nullptr; 1249 Optional<int64_t> dim = getConstantDim(); 1250 if (!dim.has_value()) 1251 return nullptr; 1252 if (dim.value() >= elements.getNumElements()) 1253 return nullptr; 1254 return elements.getValues<Attribute>()[(uint64_t)dim.value()]; 1255 } 1256 1257 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1258 int64_t dim) { 1259 auto loc = result.location; 1260 auto dimAttr = builder.getIndexAttr(dim); 1261 if (shape.getType().isa<ShapeType>()) { 1262 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1263 build(builder, result, builder.getType<SizeType>(), shape, dim); 1264 } else { 1265 Value dim = 1266 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1267 build(builder, result, builder.getIndexType(), shape, dim); 1268 } 1269 } 1270 1271 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1272 MLIRContext *context, Optional<Location> location, ValueRange operands, 1273 DictionaryAttr attributes, RegionRange regions, 1274 SmallVectorImpl<Type> &inferredReturnTypes) { 1275 inferredReturnTypes.assign({IndexType::get(context)}); 1276 return success(); 1277 } 1278 1279 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1280 TypeRange r) { 1281 // SizeType is compatible with IndexType. 1282 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1283 } 1284 1285 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1286 1287 //===----------------------------------------------------------------------===// 1288 // IsBroadcastableOp 1289 //===----------------------------------------------------------------------===// 1290 1291 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1292 MLIRContext *context) { 1293 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1294 } 1295 1296 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1297 // Can always broadcast fewer than two shapes. 1298 if (operands.size() < 2) { 1299 return BoolAttr::get(getContext(), true); 1300 } 1301 1302 return nullptr; 1303 } 1304 1305 //===----------------------------------------------------------------------===// 1306 // MeetOp 1307 //===----------------------------------------------------------------------===// 1308 1309 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1310 MLIRContext *context, Optional<Location> location, ValueRange operands, 1311 DictionaryAttr attributes, RegionRange regions, 1312 SmallVectorImpl<Type> &inferredReturnTypes) { 1313 inferredReturnTypes.assign({operands[0].getType()}); 1314 return success(); 1315 } 1316 1317 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1318 if (l.size() != 1 || r.size() != 1) 1319 return false; 1320 if (l == r) 1321 return true; 1322 1323 Type lhs = l.front(); 1324 Type rhs = r.front(); 1325 1326 if (lhs != rhs) 1327 return false; 1328 1329 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1330 return true; 1331 1332 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1333 return true; 1334 return false; 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // RankOp 1339 //===----------------------------------------------------------------------===// 1340 1341 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1342 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1343 if (!shape) 1344 return {}; 1345 int64_t rank = shape.getNumElements(); 1346 Builder builder(getContext()); 1347 return builder.getIndexAttr(rank); 1348 } 1349 1350 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1351 /// Constant folding fails in cases where only the rank is constant, not the 1352 /// shape itself. 1353 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1354 /// 1355 /// Example: 1356 /// 1357 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1358 /// %rank = shape.rank %shape 1359 /// 1360 /// becomes 1361 /// 1362 /// %rank = shape.const_size 3 1363 1364 namespace { 1365 struct RankShapeOfCanonicalizationPattern 1366 : public OpRewritePattern<shape::RankOp> { 1367 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1368 1369 LogicalResult matchAndRewrite(shape::RankOp op, 1370 PatternRewriter &rewriter) const override { 1371 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1372 if (!shapeOfOp) 1373 return failure(); 1374 auto rankedTensorType = 1375 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1376 if (!rankedTensorType) 1377 return failure(); 1378 int64_t rank = rankedTensorType.getRank(); 1379 if (op.getType().isa<IndexType>()) { 1380 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1381 rank); 1382 } else if (op.getType().isa<shape::SizeType>()) { 1383 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1384 } else { 1385 return failure(); 1386 } 1387 return success(); 1388 } 1389 }; 1390 } // namespace 1391 1392 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1393 MLIRContext *context) { 1394 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1395 } 1396 1397 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1398 MLIRContext *context, Optional<Location> location, ValueRange operands, 1399 DictionaryAttr attributes, RegionRange regions, 1400 SmallVectorImpl<Type> &inferredReturnTypes) { 1401 if (operands[0].getType().isa<ShapeType>()) 1402 inferredReturnTypes.assign({SizeType::get(context)}); 1403 else 1404 inferredReturnTypes.assign({IndexType::get(context)}); 1405 return success(); 1406 } 1407 1408 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1409 // SizeType is compatible with IndexType. 1410 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1411 } 1412 1413 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1414 1415 //===----------------------------------------------------------------------===// 1416 // NumElementsOp 1417 //===----------------------------------------------------------------------===// 1418 1419 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1420 1421 // Fold only when argument constant. 1422 Attribute shape = operands[0]; 1423 if (!shape) 1424 return {}; 1425 1426 APInt product(64, 1); 1427 for (auto value : shape.cast<DenseIntElementsAttr>()) 1428 product *= value; 1429 Builder builder(getContext()); 1430 return builder.getIndexAttr(product.getLimitedValue()); 1431 } 1432 1433 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1434 MLIRContext *context, Optional<Location> location, ValueRange operands, 1435 DictionaryAttr attributes, RegionRange regions, 1436 SmallVectorImpl<Type> &inferredReturnTypes) { 1437 if (operands[0].getType().isa<ShapeType>()) 1438 inferredReturnTypes.assign({SizeType::get(context)}); 1439 else 1440 inferredReturnTypes.assign({IndexType::get(context)}); 1441 return success(); 1442 } 1443 1444 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1445 TypeRange r) { 1446 // SizeType is compatible with IndexType. 1447 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1448 } 1449 1450 LogicalResult shape::NumElementsOp::verify() { 1451 return verifySizeOrIndexOp(*this); 1452 } 1453 1454 //===----------------------------------------------------------------------===// 1455 // MaxOp 1456 //===----------------------------------------------------------------------===// 1457 1458 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1459 // If operands are equal, just propagate one. 1460 if (getLhs() == getRhs()) 1461 return getLhs(); 1462 return nullptr; 1463 } 1464 1465 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1466 MLIRContext *context, Optional<Location> location, ValueRange operands, 1467 DictionaryAttr attributes, RegionRange regions, 1468 SmallVectorImpl<Type> &inferredReturnTypes) { 1469 if (operands[0].getType() == operands[1].getType()) 1470 inferredReturnTypes.assign({operands[0].getType()}); 1471 else 1472 inferredReturnTypes.assign({SizeType::get(context)}); 1473 return success(); 1474 } 1475 1476 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1477 if (l.size() != 1 || r.size() != 1) 1478 return false; 1479 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1480 return true; 1481 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1482 return true; 1483 return false; 1484 } 1485 1486 //===----------------------------------------------------------------------===// 1487 // MinOp 1488 //===----------------------------------------------------------------------===// 1489 1490 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1491 // If operands are equal, just propagate one. 1492 if (getLhs() == getRhs()) 1493 return getLhs(); 1494 return nullptr; 1495 } 1496 1497 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1498 MLIRContext *context, Optional<Location> location, ValueRange operands, 1499 DictionaryAttr attributes, RegionRange regions, 1500 SmallVectorImpl<Type> &inferredReturnTypes) { 1501 if (operands[0].getType() == operands[1].getType()) 1502 inferredReturnTypes.assign({operands[0].getType()}); 1503 else 1504 inferredReturnTypes.assign({SizeType::get(context)}); 1505 return success(); 1506 } 1507 1508 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1509 if (l.size() != 1 || r.size() != 1) 1510 return false; 1511 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1512 return true; 1513 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1514 return true; 1515 return false; 1516 } 1517 1518 //===----------------------------------------------------------------------===// 1519 // MulOp 1520 //===----------------------------------------------------------------------===// 1521 1522 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1523 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1524 if (!lhs) 1525 return nullptr; 1526 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1527 if (!rhs) 1528 return nullptr; 1529 APInt folded = lhs.getValue() * rhs.getValue(); 1530 Type indexTy = IndexType::get(getContext()); 1531 return IntegerAttr::get(indexTy, folded); 1532 } 1533 1534 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1535 MLIRContext *context, Optional<Location> location, ValueRange operands, 1536 DictionaryAttr attributes, RegionRange regions, 1537 SmallVectorImpl<Type> &inferredReturnTypes) { 1538 if (operands[0].getType().isa<SizeType>() || 1539 operands[1].getType().isa<SizeType>()) 1540 inferredReturnTypes.assign({SizeType::get(context)}); 1541 else 1542 inferredReturnTypes.assign({IndexType::get(context)}); 1543 return success(); 1544 } 1545 1546 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1547 // SizeType is compatible with IndexType. 1548 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1549 } 1550 1551 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1552 1553 //===----------------------------------------------------------------------===// 1554 // ShapeOfOp 1555 //===----------------------------------------------------------------------===// 1556 1557 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1558 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1559 if (!type || !type.hasStaticShape()) 1560 return nullptr; 1561 Builder builder(getContext()); 1562 return builder.getIndexTensorAttr(type.getShape()); 1563 } 1564 1565 namespace { 1566 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1567 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1568 1569 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1570 PatternRewriter &rewriter) const override { 1571 if (!op.getArg().getType().isa<ShapedType>()) 1572 return failure(); 1573 if (op.getType().isa<ShapedType>()) 1574 return failure(); 1575 1576 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1577 op.getArg()); 1578 return success(); 1579 } 1580 }; 1581 1582 // Canonicalize 1583 // ``` 1584 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1585 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1586 // ``` 1587 // to 1588 // ``` 1589 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1590 // ``` 1591 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1592 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1593 1594 LogicalResult matchAndRewrite(tensor::CastOp op, 1595 PatternRewriter &rewriter) const override { 1596 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1597 if (!ty || ty.getRank() != 1) 1598 return failure(); 1599 1600 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>(); 1601 if (!shapeOfOp) 1602 return failure(); 1603 1604 // Argument type must be ranked and must not conflict. 1605 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1606 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1607 return failure(); 1608 1609 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1610 return success(); 1611 } 1612 }; 1613 } // namespace 1614 1615 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1616 MLIRContext *context) { 1617 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1618 ExtractFromShapeOfExtentTensor>(context); 1619 } 1620 1621 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1622 MLIRContext *context, Optional<Location> location, ValueRange operands, 1623 DictionaryAttr attributes, RegionRange regions, 1624 SmallVectorImpl<Type> &inferredReturnTypes) { 1625 if (operands[0].getType().isa<ValueShapeType>()) 1626 inferredReturnTypes.assign({ShapeType::get(context)}); 1627 else { 1628 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1629 int64_t rank = 1630 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1631 Type indexTy = IndexType::get(context); 1632 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1633 inferredReturnTypes.assign({extentTensorTy}); 1634 } 1635 return success(); 1636 } 1637 1638 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1639 if (l.size() != 1 || r.size() != 1) 1640 return false; 1641 if (l == r) 1642 return true; 1643 1644 Type lhs = l.front(); 1645 Type rhs = r.front(); 1646 1647 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1648 return false; 1649 1650 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1651 // Shape type is compatible with all other valid return types. 1652 return true; 1653 1654 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1655 return true; 1656 return false; 1657 } 1658 1659 LogicalResult shape::ShapeOfOp::verify() { 1660 return verifyShapeOrExtentTensorOp(*this); 1661 } 1662 1663 //===----------------------------------------------------------------------===// 1664 // SizeToIndexOp 1665 //===----------------------------------------------------------------------===// 1666 1667 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1668 // Constant values of both types, `shape.size` and `index`, are represented as 1669 // `IntegerAttr`s which makes constant folding simple. 1670 if (Attribute arg = operands[0]) 1671 return arg; 1672 return OpFoldResult(); 1673 } 1674 1675 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1676 MLIRContext *context) { 1677 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1678 } 1679 1680 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1681 if (inputs.size() != 1 || outputs.size() != 1) 1682 return false; 1683 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1684 } 1685 1686 //===----------------------------------------------------------------------===// 1687 // YieldOp 1688 //===----------------------------------------------------------------------===// 1689 1690 LogicalResult shape::YieldOp::verify() { 1691 auto *parentOp = (*this)->getParentOp(); 1692 auto results = parentOp->getResults(); 1693 auto operands = getOperands(); 1694 1695 if (parentOp->getNumResults() != getNumOperands()) 1696 return emitOpError() << "number of operands does not match number of " 1697 "results of its parent"; 1698 for (auto e : llvm::zip(results, operands)) 1699 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1700 return emitOpError() << "types mismatch between yield op and its parent"; 1701 1702 return success(); 1703 } 1704 1705 //===----------------------------------------------------------------------===// 1706 // SplitAtOp 1707 //===----------------------------------------------------------------------===// 1708 1709 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1710 SmallVectorImpl<OpFoldResult> &results) { 1711 if (!operands[0] || !operands[1]) 1712 return failure(); 1713 auto shapeVec = llvm::to_vector<6>( 1714 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1715 auto shape = llvm::makeArrayRef(shapeVec); 1716 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1717 // Verify that the split point is in the correct range. 1718 // TODO: Constant fold to an "error". 1719 int64_t rank = shape.size(); 1720 if (-rank > splitPoint || splitPoint > rank) 1721 return failure(); 1722 if (splitPoint < 0) 1723 splitPoint += shape.size(); 1724 Builder builder(operands[0].getContext()); 1725 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1726 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1727 return success(); 1728 } 1729 1730 //===----------------------------------------------------------------------===// 1731 // ToExtentTensorOp 1732 //===----------------------------------------------------------------------===// 1733 1734 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1735 if (!operands[0]) 1736 return OpFoldResult(); 1737 Builder builder(getContext()); 1738 auto shape = llvm::to_vector<6>( 1739 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1740 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1741 builder.getIndexType()); 1742 return DenseIntElementsAttr::get(type, shape); 1743 } 1744 1745 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1746 if (inputs.size() != 1 || outputs.size() != 1) 1747 return false; 1748 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1749 if (!inputTensor.getElementType().isa<IndexType>() || 1750 inputTensor.getRank() != 1) 1751 return false; 1752 } else if (!inputs[0].isa<ShapeType>()) { 1753 return false; 1754 } 1755 1756 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1757 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1758 } 1759 1760 //===----------------------------------------------------------------------===// 1761 // ReduceOp 1762 //===----------------------------------------------------------------------===// 1763 1764 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1765 ValueRange initVals) { 1766 result.addOperands(shape); 1767 result.addOperands(initVals); 1768 1769 Region *bodyRegion = result.addRegion(); 1770 bodyRegion->push_back(new Block); 1771 Block &bodyBlock = bodyRegion->front(); 1772 bodyBlock.addArgument(builder.getIndexType(), result.location); 1773 1774 Type elementType; 1775 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1776 elementType = tensorType.getElementType(); 1777 else 1778 elementType = SizeType::get(builder.getContext()); 1779 bodyBlock.addArgument(elementType, shape.getLoc()); 1780 1781 for (Value initVal : initVals) { 1782 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1783 result.addTypes(initVal.getType()); 1784 } 1785 } 1786 1787 LogicalResult ReduceOp::verify() { 1788 // Verify block arg types. 1789 Block &block = getRegion().front(); 1790 1791 // The block takes index, extent, and aggregated values as arguments. 1792 auto blockArgsCount = getInitVals().size() + 2; 1793 if (block.getNumArguments() != blockArgsCount) 1794 return emitOpError() << "ReduceOp body is expected to have " 1795 << blockArgsCount << " arguments"; 1796 1797 // The first block argument is the index and must always be of type `index`. 1798 if (!block.getArgument(0).getType().isa<IndexType>()) 1799 return emitOpError( 1800 "argument 0 of ReduceOp body is expected to be of IndexType"); 1801 1802 // The second block argument is the extent and must be of type `size` or 1803 // `index`, depending on whether the reduce operation is applied to a shape or 1804 // to an extent tensor. 1805 Type extentTy = block.getArgument(1).getType(); 1806 if (getShape().getType().isa<ShapeType>()) { 1807 if (!extentTy.isa<SizeType>()) 1808 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1809 "SizeType if the ReduceOp operates on a ShapeType"); 1810 } else { 1811 if (!extentTy.isa<IndexType>()) 1812 return emitOpError( 1813 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1814 "ReduceOp operates on an extent tensor"); 1815 } 1816 1817 for (const auto &type : llvm::enumerate(getInitVals())) 1818 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1819 return emitOpError() << "type mismatch between argument " 1820 << type.index() + 2 1821 << " of ReduceOp body and initial value " 1822 << type.index(); 1823 return success(); 1824 } 1825 1826 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1827 // Parse operands. 1828 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1829 Type shapeOrExtentTensorType; 1830 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1831 OpAsmParser::Delimiter::Paren) || 1832 parser.parseColonType(shapeOrExtentTensorType) || 1833 parser.parseOptionalArrowTypeList(result.types)) 1834 return failure(); 1835 1836 // Resolve operands. 1837 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1838 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1839 result.operands) || 1840 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1841 result.operands)) 1842 return failure(); 1843 1844 // Parse the body. 1845 Region *body = result.addRegion(); 1846 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1847 return failure(); 1848 1849 // Parse attributes. 1850 if (parser.parseOptionalAttrDict(result.attributes)) 1851 return failure(); 1852 1853 return success(); 1854 } 1855 1856 void ReduceOp::print(OpAsmPrinter &p) { 1857 p << '(' << getShape() << ", " << getInitVals() 1858 << ") : " << getShape().getType(); 1859 p.printOptionalArrowTypeList(getResultTypes()); 1860 p << ' '; 1861 p.printRegion(getRegion()); 1862 p.printOptionalAttrDict((*this)->getAttrs()); 1863 } 1864 1865 #define GET_OP_CLASSES 1866 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1867 1868 #define GET_TYPEDEF_CLASSES 1869 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" 1870