1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/SetOperations.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::shape; 31 32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 33 34 namespace { 35 #include "ShapeCanonicalization.inc" 36 } // namespace 37 38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 39 return RankedTensorType::get({rank}, IndexType::get(ctx)); 40 } 41 42 bool shape::isExtentTensorType(Type type) { 43 auto ranked = type.dyn_cast<RankedTensorType>(); 44 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 45 } 46 47 LogicalResult shape::getShapeVec(Value input, 48 SmallVectorImpl<int64_t> &shapeValues) { 49 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 50 auto type = inputOp.getArg().getType().cast<ShapedType>(); 51 if (!type.hasRank()) 52 return failure(); 53 llvm::append_range(shapeValues, type.getShape()); 54 return success(); 55 } 56 DenseIntElementsAttr attr; 57 if (matchPattern(input, m_Constant(&attr))) { 58 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 59 return success(); 60 } 61 return failure(); 62 } 63 64 static bool isErrorPropagationPossible(TypeRange operandTypes) { 65 return llvm::any_of(operandTypes, [](Type ty) { 66 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 67 }); 68 } 69 70 static LogicalResult verifySizeOrIndexOp(Operation *op) { 71 assert(op != nullptr && op->getNumResults() == 1); 72 Type resultTy = op->getResultTypes().front(); 73 if (isErrorPropagationPossible(op->getOperandTypes())) { 74 if (!resultTy.isa<SizeType>()) 75 return op->emitOpError() 76 << "if at least one of the operands can hold error values then " 77 "the result must be of type `size` to propagate them"; 78 } 79 return success(); 80 } 81 82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 83 assert(op != nullptr && op->getNumResults() == 1); 84 Type resultTy = op->getResultTypes().front(); 85 if (isErrorPropagationPossible(op->getOperandTypes())) { 86 if (!resultTy.isa<ShapeType>()) 87 return op->emitOpError() 88 << "if at least one of the operands can hold error values then " 89 "the result must be of type `shape` to propagate them"; 90 } 91 return success(); 92 } 93 94 template <typename... Ty> 95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 96 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 97 } 98 99 template <typename... Ty, typename... ranges> 100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 101 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // InlinerInterface 106 //===----------------------------------------------------------------------===// 107 108 namespace { 109 /// This class defines the interface for inlining shape dialect ops. 110 struct ShapeInlinerInterface : public DialectInlinerInterface { 111 using DialectInlinerInterface::DialectInlinerInterface; 112 113 // Returns true if the given region 'src' can be inlined into the region 114 // 'dest' that is attached to an operation registered to the current dialect. 115 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 116 BlockAndValueMapping &) const final { 117 return true; 118 } 119 120 // Returns true if the given operation 'op', that is registered to this 121 // dialect, can be inlined into the region 'dest' that is attached to an 122 // operation registered to the current dialect. 123 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 124 BlockAndValueMapping &) const final { 125 return true; 126 } 127 }; 128 } // namespace 129 130 void ShapeDialect::initialize() { 131 addOperations< 132 #define GET_OP_LIST 133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 134 >(); 135 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 136 addInterfaces<ShapeInlinerInterface>(); 137 // Allow unknown operations during prototyping and testing. As the dialect is 138 // still evolving it makes it simple to start with an unregistered ops and 139 // try different variants before actually defining the op. 140 allowUnknownOperations(); 141 } 142 143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 144 Attribute value, Type type, 145 Location loc) { 146 if (type.isa<ShapeType>() || isExtentTensorType(type)) 147 return builder.create<ConstShapeOp>(loc, type, 148 value.cast<DenseIntElementsAttr>()); 149 if (type.isa<SizeType>()) 150 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 151 if (type.isa<WitnessType>()) 152 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 153 if (arith::ConstantOp::isBuildableWith(value, type)) 154 return builder.create<arith::ConstantOp>(loc, type, value); 155 return nullptr; 156 } 157 158 /// Parse a type registered to this dialect. 159 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 160 StringRef keyword; 161 if (parser.parseKeyword(&keyword)) 162 return Type(); 163 164 if (keyword == "shape") 165 return ShapeType::get(getContext()); 166 if (keyword == "size") 167 return SizeType::get(getContext()); 168 if (keyword == "value_shape") 169 return ValueShapeType::get(getContext()); 170 if (keyword == "witness") 171 return WitnessType::get(getContext()); 172 173 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 174 return Type(); 175 } 176 177 /// Print a type registered to this dialect. 178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 179 TypeSwitch<Type>(type) 180 .Case<ShapeType>([&](Type) { os << "shape"; }) 181 .Case<SizeType>([&](Type) { os << "size"; }) 182 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 183 .Case<WitnessType>([&](Type) { os << "witness"; }) 184 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 185 } 186 187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 188 NamedAttribute attribute) { 189 // Verify shape.lib attribute. 190 if (attribute.getName() == "shape.lib") { 191 if (!op->hasTrait<OpTrait::SymbolTable>()) 192 return op->emitError( 193 "shape.lib attribute may only be on op implementing SymbolTable"); 194 195 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 196 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 197 if (!symbol) 198 return op->emitError("shape function library ") 199 << symbolRef << " not found"; 200 return isa<shape::FunctionLibraryOp>(symbol) 201 ? success() 202 : op->emitError() 203 << symbolRef << " required to be shape function library"; 204 } 205 206 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 207 // Verify all entries are function libraries and mappings in libraries 208 // refer to unique ops. 209 DenseSet<StringAttr> key; 210 for (auto it : arr) { 211 if (!it.isa<SymbolRefAttr>()) 212 return op->emitError( 213 "only SymbolRefAttr allowed in shape.lib attribute array"); 214 215 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 216 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 217 if (!shapeFnLib) 218 return op->emitError() 219 << it << " does not refer to FunctionLibraryOp"; 220 for (auto mapping : shapeFnLib.getMapping()) { 221 if (!key.insert(mapping.getName()).second) { 222 return op->emitError("only one op to shape mapping allowed, found " 223 "multiple for `") 224 << mapping.getName() << "`"; 225 } 226 } 227 } 228 return success(); 229 } 230 231 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 232 "allowed as shape.lib attribute"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AnyOp 239 //===----------------------------------------------------------------------===// 240 241 // TODO: Canonicalization should be implemented for shapes that can be 242 // determined through mixtures of the known dimensions of the inputs. 243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 244 // Only the last operand is checked because AnyOp is commutative. 245 if (operands.back()) 246 return operands.back(); 247 248 return nullptr; 249 } 250 251 //===----------------------------------------------------------------------===// 252 // AssumingOp 253 //===----------------------------------------------------------------------===// 254 255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 256 result.regions.reserve(1); 257 Region *doRegion = result.addRegion(); 258 259 auto &builder = parser.getBuilder(); 260 OpAsmParser::OperandType cond; 261 if (parser.parseOperand(cond) || 262 parser.resolveOperand(cond, builder.getType<WitnessType>(), 263 result.operands)) 264 return failure(); 265 266 // Parse optional results type list. 267 if (parser.parseOptionalArrowTypeList(result.types)) 268 return failure(); 269 270 // Parse the region and add a terminator if elided. 271 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 272 return failure(); 273 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 274 275 // Parse the optional attribute list. 276 if (parser.parseOptionalAttrDict(result.attributes)) 277 return failure(); 278 return success(); 279 } 280 281 void AssumingOp::print(OpAsmPrinter &p) { 282 bool yieldsResults = !getResults().empty(); 283 284 p << " " << getWitness(); 285 if (yieldsResults) 286 p << " -> (" << getResultTypes() << ")"; 287 p << ' '; 288 p.printRegion(getDoRegion(), 289 /*printEntryBlockArgs=*/false, 290 /*printBlockTerminators=*/yieldsResults); 291 p.printOptionalAttrDict((*this)->getAttrs()); 292 } 293 294 namespace { 295 // Removes AssumingOp with a passing witness and inlines the region. 296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 297 using OpRewritePattern<AssumingOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(AssumingOp op, 300 PatternRewriter &rewriter) const override { 301 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 302 if (!witness || !witness.getPassingAttr()) 303 return failure(); 304 305 AssumingOp::inlineRegionIntoParent(op, rewriter); 306 return success(); 307 } 308 }; 309 310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 311 using OpRewritePattern<AssumingOp>::OpRewritePattern; 312 313 LogicalResult matchAndRewrite(AssumingOp op, 314 PatternRewriter &rewriter) const override { 315 Block *body = op.getBody(); 316 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 317 318 // Find used values. 319 SmallVector<Value, 4> newYieldOperands; 320 Value opResult, yieldOperand; 321 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 322 std::tie(opResult, yieldOperand) = it; 323 if (!opResult.getUses().empty()) { 324 newYieldOperands.push_back(yieldOperand); 325 } 326 } 327 328 // Rewrite only if redundant results exist. 329 if (newYieldOperands.size() == yieldOp->getNumOperands()) 330 return failure(); 331 332 // Replace yield op in the old assuming op's body and move the entire region 333 // to the new assuming op. 334 rewriter.setInsertionPointToEnd(body); 335 auto newYieldOp = 336 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 337 rewriter.setInsertionPoint(op); 338 auto newOp = rewriter.create<AssumingOp>( 339 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 340 newOp.getDoRegion().takeBody(op.getDoRegion()); 341 342 // Use the new results to replace the previously used ones. 343 SmallVector<Value, 4> replacementValues; 344 auto src = newOp.getResults().begin(); 345 for (auto it : op.getResults()) { 346 if (it.getUses().empty()) 347 replacementValues.push_back(nullptr); 348 else 349 replacementValues.push_back(*src++); 350 } 351 rewriter.replaceOp(op, replacementValues); 352 return success(); 353 } 354 }; 355 } // namespace 356 357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 358 MLIRContext *context) { 359 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 360 } 361 362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 363 void AssumingOp::getSuccessorRegions( 364 Optional<unsigned> index, ArrayRef<Attribute> operands, 365 SmallVectorImpl<RegionSuccessor> ®ions) { 366 // AssumingOp has unconditional control flow into the region and back to the 367 // parent, so return the correct RegionSuccessor purely based on the index 368 // being None or 0. 369 if (index.hasValue()) { 370 regions.push_back(RegionSuccessor(getResults())); 371 return; 372 } 373 374 regions.push_back(RegionSuccessor(&getDoRegion())); 375 } 376 377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 378 PatternRewriter &rewriter) { 379 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 380 auto *assumingBlock = op.getBody(); 381 auto initPosition = rewriter.getInsertionPoint(); 382 auto *blockAfterAssuming = 383 rewriter.splitBlock(blockBeforeAssuming, initPosition); 384 385 // Remove the AssumingOp and AssumingYieldOp. 386 auto &yieldOp = assumingBlock->back(); 387 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 388 rewriter.replaceOp(op, yieldOp.getOperands()); 389 rewriter.eraseOp(&yieldOp); 390 391 // Merge blocks together as there was no branching behavior from the 392 // AssumingOp. 393 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 394 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 395 } 396 397 void AssumingOp::build( 398 OpBuilder &builder, OperationState &result, Value witness, 399 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 400 401 result.addOperands(witness); 402 Region *bodyRegion = result.addRegion(); 403 bodyRegion->push_back(new Block); 404 Block &bodyBlock = bodyRegion->front(); 405 406 // Build body. 407 OpBuilder::InsertionGuard guard(builder); 408 builder.setInsertionPointToStart(&bodyBlock); 409 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 410 builder.create<AssumingYieldOp>(result.location, yieldValues); 411 412 SmallVector<Type, 2> assumingTypes; 413 for (Value v : yieldValues) 414 assumingTypes.push_back(v.getType()); 415 result.addTypes(assumingTypes); 416 } 417 418 LogicalResult AssumingOp::verify() { 419 return RegionBranchOpInterface::verifyTypes(*this); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // AddOp 424 //===----------------------------------------------------------------------===// 425 426 LogicalResult mlir::shape::AddOp::inferReturnTypes( 427 MLIRContext *context, Optional<Location> location, ValueRange operands, 428 DictionaryAttr attributes, RegionRange regions, 429 SmallVectorImpl<Type> &inferredReturnTypes) { 430 if (operands[0].getType().isa<SizeType>() || 431 operands[1].getType().isa<SizeType>()) 432 inferredReturnTypes.assign({SizeType::get(context)}); 433 else 434 inferredReturnTypes.assign({IndexType::get(context)}); 435 return success(); 436 } 437 438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 439 // SizeType is compatible with IndexType. 440 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 441 } 442 443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 444 // add(x, 0) -> x 445 if (matchPattern(getRhs(), m_Zero())) 446 return getLhs(); 447 448 return constFoldBinaryOp<IntegerAttr>( 449 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 450 } 451 452 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 453 454 //===----------------------------------------------------------------------===// 455 // AssumingAllOp 456 //===----------------------------------------------------------------------===// 457 458 namespace { 459 460 // Merge multiple `shape.assuming_all` operations together. 461 // 462 // %0 = shape.assuming_all %w0, %w1 463 // %1 = shape.assuming_all %w2, %0 464 // 465 // to: 466 // 467 // %0 = shape.assuming_all %w0, %w2, %w2 468 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 469 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 470 471 LogicalResult matchAndRewrite(AssumingAllOp op, 472 PatternRewriter &rewriter) const override { 473 SmallVector<Value> operands; 474 475 for (Value operand : op.getInputs()) { 476 if (auto assume_all = operand.getDefiningOp<AssumingAllOp>()) 477 operands.append(assume_all.operand_begin(), assume_all->operand_end()); 478 else 479 operands.push_back(operand); 480 } 481 482 // We didn't find any other `assuming_all` ops to merge with. 483 if (operands.size() == op.getNumOperands()) 484 return failure(); 485 486 // Replace with a new `assuming_all` operation with merged constraints. 487 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 488 return success(); 489 } 490 }; 491 492 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 493 // are subsumed by others. 494 // 495 // %0 = shape.cstr_broadcastable %shape0, %shape1 496 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 497 // 498 // %2 = shape.cstr_broadcastable %shape3, %shape4 499 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 500 // 501 // %4 = shape.assuming_all %0, %1, %2, %3 502 // 503 // to: 504 // 505 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 506 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 507 // %2 = shape.assuming_all %0, %1 508 // 509 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 510 // shapes [0, 1] are broadcastable too, and can be removed from the list of 511 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 512 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 513 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 514 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 515 516 LogicalResult matchAndRewrite(AssumingAllOp op, 517 PatternRewriter &rewriter) const override { 518 // Collect all `CstrBroadcastableOp` operands first. 519 SetVector<CstrBroadcastableOp> operands; 520 for (Value operand : op.getInputs()) { 521 // TODO: Apply this optimization if some of the witnesses are not 522 // produced by the `cstr_broadcastable`. 523 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 524 if (!broadcastable) 525 return failure(); 526 527 operands.insert(broadcastable); 528 } 529 530 // Skip trivial `assuming_all` operations. 531 if (operands.size() <= 1) 532 return failure(); 533 534 // Collect shapes checked by `cstr_broadcastable` operands. 535 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 536 for (auto cstr : operands) { 537 DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end()); 538 shapes.emplace_back(cstr, std::move(shapes_set)); 539 } 540 541 // Sort by the number of shape operands (larger to smaller). 542 llvm::sort(shapes, [](auto a, auto b) { 543 return a.first.getNumOperands() > b.first.getNumOperands(); 544 }); 545 546 // We start from the `cst_broadcastable` operations with largest number of 547 // shape operands, and remove redundant `cst_broadcastable` operations. We 548 // do this until we find a set of `cst_broadcastable` operations with 549 // non-overlapping constraints. 550 SmallVector<CstrBroadcastableOp> marked_for_erase; 551 552 for (unsigned i = 0; i < shapes.size(); ++i) { 553 auto isSubset = [&](auto pair) { 554 return llvm::set_is_subset(pair.second, shapes[i].second); 555 }; 556 557 // Keep redundant `cstr_broadcastable` operations to be erased. 558 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 559 for (auto *it0 = it; it0 < shapes.end(); ++it0) 560 marked_for_erase.push_back(it0->first); 561 shapes.erase(it, shapes.end()); 562 } 563 564 // We didn't find any operands that could be removed. 565 if (marked_for_erase.empty()) 566 return failure(); 567 568 // Collect non-overlapping `cst_broadcastable` constraints. 569 SmallVector<Value> unique_constraints; 570 for (auto &shape : shapes) 571 unique_constraints.push_back(shape.first.getResult()); 572 573 // Replace with a new `assuming_all` operation ... 574 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints); 575 576 // ... and maybe erase `cstr_broadcastable` ops without uses. 577 for (auto &op : marked_for_erase) 578 if (op->use_empty()) 579 rewriter.eraseOp(op); 580 581 return success(); 582 } 583 }; 584 585 struct AssumingAllToCstrEqCanonicalization 586 : public OpRewritePattern<AssumingAllOp> { 587 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 588 589 LogicalResult matchAndRewrite(AssumingAllOp op, 590 PatternRewriter &rewriter) const override { 591 SmallVector<Value, 8> shapes; 592 for (Value w : op.getInputs()) { 593 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 594 if (!cstrEqOp) 595 return failure(); 596 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 597 return llvm::is_contained(shapes, s); 598 }); 599 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 600 return failure(); 601 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 602 } 603 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 604 return success(); 605 } 606 }; 607 608 template <typename OpTy> 609 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 610 using OpRewritePattern<OpTy>::OpRewritePattern; 611 612 LogicalResult matchAndRewrite(OpTy op, 613 PatternRewriter &rewriter) const override { 614 // Find unique operands. 615 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 616 617 // Reduce op to equivalent with unique operands. 618 if (unique.size() < op.getNumOperands()) { 619 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 620 unique.takeVector(), op->getAttrs()); 621 return success(); 622 } 623 624 return failure(); 625 } 626 }; 627 } // namespace 628 629 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 630 MLIRContext *context) { 631 patterns 632 .add<MergeAssumingAllOps, AssumingAllOneOp, 633 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 634 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 635 } 636 637 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 638 // Iterate in reverse to first handle all constant operands. They are 639 // guaranteed to be the tail of the inputs because this is commutative. 640 for (int idx = operands.size() - 1; idx >= 0; idx--) { 641 Attribute a = operands[idx]; 642 // Cannot fold if any inputs are not constant; 643 if (!a) 644 return nullptr; 645 646 // We do not need to keep statically known values after handling them in 647 // this method. 648 getOperation()->eraseOperand(idx); 649 650 // Always false if any input is statically known false 651 if (!a.cast<BoolAttr>().getValue()) 652 return a; 653 } 654 // If this is reached, all inputs were statically known passing. 655 return BoolAttr::get(getContext(), true); 656 } 657 658 LogicalResult AssumingAllOp::verify() { 659 // Ensure that AssumingAllOp contains at least one operand 660 if (getNumOperands() == 0) 661 return emitOpError("no operands specified"); 662 663 return success(); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // BroadcastOp 668 //===----------------------------------------------------------------------===// 669 670 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 671 if (getShapes().size() == 1) { 672 // Otherwise, we need a cast which would be a canonicalization, not folding. 673 if (getShapes().front().getType() != getType()) 674 return nullptr; 675 return getShapes().front(); 676 } 677 678 // TODO: Support folding with more than 2 input shapes 679 if (getShapes().size() > 2) 680 return nullptr; 681 682 if (!operands[0] || !operands[1]) 683 return nullptr; 684 auto lhsShape = llvm::to_vector<6>( 685 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 686 auto rhsShape = llvm::to_vector<6>( 687 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 688 SmallVector<int64_t, 6> resultShape; 689 690 // If the shapes are not compatible, we can't fold it. 691 // TODO: Fold to an "error". 692 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 693 return nullptr; 694 695 Builder builder(getContext()); 696 return builder.getIndexTensorAttr(resultShape); 697 } 698 699 LogicalResult BroadcastOp::verify() { 700 return verifyShapeOrExtentTensorOp(*this); 701 } 702 703 namespace { 704 template <typename OpTy> 705 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 706 using OpRewritePattern<OpTy>::OpRewritePattern; 707 708 LogicalResult matchAndRewrite(OpTy op, 709 PatternRewriter &rewriter) const override { 710 auto isPotentiallyNonEmptyShape = [](Value shape) { 711 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 712 if (extentTensorTy.getDimSize(0) == 0) 713 return false; 714 } 715 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 716 if (constShape.getShape().empty()) 717 return false; 718 } 719 return true; 720 }; 721 auto newOperands = llvm::to_vector<8>( 722 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 723 724 // Reduce op to equivalent without empty shape operands. 725 if (newOperands.size() < op.getNumOperands()) { 726 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 727 op->getAttrs()); 728 return success(); 729 } 730 731 return failure(); 732 } 733 }; 734 735 struct BroadcastForwardSingleOperandPattern 736 : public OpRewritePattern<BroadcastOp> { 737 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 738 739 LogicalResult matchAndRewrite(BroadcastOp op, 740 PatternRewriter &rewriter) const override { 741 if (op.getNumOperands() != 1) 742 return failure(); 743 Value replacement = op.getShapes().front(); 744 745 // Insert cast if needed. 746 if (replacement.getType() != op.getType()) { 747 auto loc = op.getLoc(); 748 if (op.getType().isa<ShapeType>()) { 749 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 750 } else { 751 assert(!op.getType().isa<ShapeType>() && 752 !replacement.getType().isa<ShapeType>() && 753 "expect extent tensor cast"); 754 replacement = 755 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 756 } 757 } 758 759 rewriter.replaceOp(op, replacement); 760 return success(); 761 } 762 }; 763 764 struct BroadcastFoldConstantOperandsPattern 765 : public OpRewritePattern<BroadcastOp> { 766 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 767 768 LogicalResult matchAndRewrite(BroadcastOp op, 769 PatternRewriter &rewriter) const override { 770 SmallVector<int64_t, 8> foldedConstantShape; 771 SmallVector<Value, 8> newShapeOperands; 772 for (Value shape : op.getShapes()) { 773 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 774 SmallVector<int64_t, 8> newFoldedConstantShape; 775 if (OpTrait::util::getBroadcastedShape( 776 foldedConstantShape, 777 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 778 newFoldedConstantShape)) { 779 foldedConstantShape = newFoldedConstantShape; 780 continue; 781 } 782 } 783 newShapeOperands.push_back(shape); 784 } 785 786 // Need at least two constant operands to fold anything. 787 if (op.getNumOperands() - newShapeOperands.size() < 2) 788 return failure(); 789 790 auto foldedConstantOperandsTy = RankedTensorType::get( 791 {static_cast<int64_t>(foldedConstantShape.size())}, 792 rewriter.getIndexType()); 793 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 794 op.getLoc(), foldedConstantOperandsTy, 795 rewriter.getIndexTensorAttr(foldedConstantShape))); 796 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 797 newShapeOperands); 798 return success(); 799 } 800 }; 801 802 template <typename OpTy> 803 struct CanonicalizeCastExtentTensorOperandsPattern 804 : public OpRewritePattern<OpTy> { 805 using OpRewritePattern<OpTy>::OpRewritePattern; 806 807 LogicalResult matchAndRewrite(OpTy op, 808 PatternRewriter &rewriter) const override { 809 // Canonicalize operands. 810 bool anyChange = false; 811 auto canonicalizeOperand = [&](Value operand) { 812 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 813 // Only eliminate the cast if it holds no shape information. 814 bool isInformationLoosingCast = 815 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 816 if (isInformationLoosingCast) { 817 anyChange = true; 818 return castOp.source(); 819 } 820 } 821 return operand; 822 }; 823 auto newOperands = llvm::to_vector<8>( 824 llvm::map_range(op.getOperands(), canonicalizeOperand)); 825 826 // Rewrite op if any change required. 827 if (!anyChange) 828 return failure(); 829 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 830 return success(); 831 } 832 }; 833 834 struct BroadcastConcretizeResultTypePattern 835 : public OpRewritePattern<BroadcastOp> { 836 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 837 838 LogicalResult matchAndRewrite(BroadcastOp op, 839 PatternRewriter &rewriter) const override { 840 // Only concretize dynamic extent tensor result types. 841 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 842 if (!resultTy || !resultTy.isDynamicDim(0)) 843 return failure(); 844 845 // Infer resulting shape rank if possible. 846 int64_t maxRank = 0; 847 for (Value shape : op.getShapes()) { 848 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 849 // Cannot infer resulting shape rank if any operand is dynamically 850 // ranked. 851 if (extentTensorTy.isDynamicDim(0)) 852 return failure(); 853 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 854 } 855 } 856 857 auto newOp = rewriter.create<BroadcastOp>( 858 op.getLoc(), getExtentTensorType(getContext(), maxRank), 859 op.getShapes()); 860 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 861 return success(); 862 } 863 }; 864 } // namespace 865 866 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 867 MLIRContext *context) { 868 patterns.add<BroadcastConcretizeResultTypePattern, 869 BroadcastFoldConstantOperandsPattern, 870 BroadcastForwardSingleOperandPattern, 871 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 872 RemoveDuplicateOperandsPattern<BroadcastOp>, 873 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 874 } 875 876 //===----------------------------------------------------------------------===// 877 // ConcatOp 878 //===----------------------------------------------------------------------===// 879 880 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 881 if (!operands[0] || !operands[1]) 882 return nullptr; 883 auto lhsShape = llvm::to_vector<6>( 884 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 885 auto rhsShape = llvm::to_vector<6>( 886 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 887 SmallVector<int64_t, 6> resultShape; 888 resultShape.append(lhsShape.begin(), lhsShape.end()); 889 resultShape.append(rhsShape.begin(), rhsShape.end()); 890 Builder builder(getContext()); 891 return builder.getIndexTensorAttr(resultShape); 892 } 893 894 //===----------------------------------------------------------------------===// 895 // ConstShapeOp 896 //===----------------------------------------------------------------------===// 897 898 void ConstShapeOp::print(OpAsmPrinter &p) { 899 p << " "; 900 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 901 p << "["; 902 interleaveComma(getShape().getValues<int64_t>(), p); 903 p << "] : "; 904 p.printType(getType()); 905 } 906 907 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 908 if (parser.parseOptionalAttrDict(result.attributes)) 909 return failure(); 910 // We piggy-back on ArrayAttr parsing, though we don't internally store the 911 // shape as an ArrayAttr. 912 // TODO: Implement custom parser and maybe make syntax a bit more concise. 913 Attribute extentsRaw; 914 NamedAttrList dummy; 915 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 916 return failure(); 917 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 918 if (!extentsArray) 919 return failure(); 920 SmallVector<int64_t, 6> ints; 921 for (Attribute extent : extentsArray) { 922 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 923 if (!attr) 924 return failure(); 925 ints.push_back(attr.getInt()); 926 } 927 Builder &builder = parser.getBuilder(); 928 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 929 Type resultTy; 930 if (parser.parseColonType(resultTy)) 931 return failure(); 932 result.types.push_back(resultTy); 933 return success(); 934 } 935 936 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 937 938 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 939 MLIRContext *context) { 940 patterns.add<TensorCastConstShape>(context); 941 } 942 943 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 944 MLIRContext *context, Optional<Location> location, ValueRange operands, 945 DictionaryAttr attributes, RegionRange regions, 946 SmallVectorImpl<Type> &inferredReturnTypes) { 947 Builder b(context); 948 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 949 if (!shape) 950 return emitOptionalError(location, "missing shape attribute"); 951 inferredReturnTypes.assign({RankedTensorType::get( 952 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 953 return success(); 954 } 955 956 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 957 TypeRange r) { 958 if (l.size() != 1 || r.size() != 1) 959 return false; 960 961 Type lhs = l.front(); 962 Type rhs = r.front(); 963 964 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 965 // Shape type is compatible with all other valid return types. 966 return true; 967 return lhs == rhs; 968 } 969 970 //===----------------------------------------------------------------------===// 971 // CstrBroadcastableOp 972 //===----------------------------------------------------------------------===// 973 974 void CstrBroadcastableOp::getCanonicalizationPatterns( 975 RewritePatternSet &patterns, MLIRContext *context) { 976 // Canonicalization patterns have overlap with the considerations during 977 // folding in case additional shape information is inferred at some point that 978 // does not result in folding. 979 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 980 CstrBroadcastableEqOps, 981 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 982 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 983 } 984 985 // Return true if there is exactly one attribute not representing a scalar 986 // broadcast. 987 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 988 bool nonScalarSeen = false; 989 for (Attribute a : attributes) { 990 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 991 if (nonScalarSeen) 992 return false; 993 nonScalarSeen = true; 994 } 995 } 996 return true; 997 } 998 999 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1000 // No broadcasting is needed if all operands but one are scalar. 1001 if (hasAtMostSingleNonScalar(operands)) 1002 return BoolAttr::get(getContext(), true); 1003 1004 if ([&] { 1005 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1006 for (const auto &operand : operands) { 1007 if (!operand) 1008 return false; 1009 extents.push_back(llvm::to_vector<6>( 1010 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 1011 } 1012 return OpTrait::util::staticallyKnownBroadcastable(extents); 1013 }()) 1014 return BoolAttr::get(getContext(), true); 1015 1016 // Lastly, see if folding can be completed based on what constraints are known 1017 // on the input shapes. 1018 if ([&] { 1019 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1020 for (auto shapeValue : getShapes()) { 1021 extents.emplace_back(); 1022 if (failed(getShapeVec(shapeValue, extents.back()))) 1023 return false; 1024 } 1025 return OpTrait::util::staticallyKnownBroadcastable(extents); 1026 }()) 1027 return BoolAttr::get(getContext(), true); 1028 1029 // Because a failing witness result here represents an eventual assertion 1030 // failure, we do not replace it with a constant witness. 1031 return nullptr; 1032 } 1033 1034 LogicalResult CstrBroadcastableOp::verify() { 1035 // Ensure that CstrBroadcastableOp contains at least two operands 1036 if (getNumOperands() < 2) 1037 return emitOpError("required at least 2 input shapes"); 1038 return success(); 1039 } 1040 1041 //===----------------------------------------------------------------------===// 1042 // CstrEqOp 1043 //===----------------------------------------------------------------------===// 1044 1045 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1046 MLIRContext *context) { 1047 // If inputs are equal, return passing witness 1048 patterns.add<CstrEqEqOps>(context); 1049 } 1050 1051 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1052 if (llvm::all_of(operands, 1053 [&](Attribute a) { return a && a == operands[0]; })) 1054 return BoolAttr::get(getContext(), true); 1055 1056 // Because a failing witness result here represents an eventual assertion 1057 // failure, we do not try to replace it with a constant witness. Similarly, we 1058 // cannot if there are any non-const inputs. 1059 return nullptr; 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // ConstSizeOp 1064 //===----------------------------------------------------------------------===// 1065 1066 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1067 int64_t value) { 1068 build(builder, result, builder.getIndexAttr(value)); 1069 } 1070 1071 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1072 1073 void ConstSizeOp::getAsmResultNames( 1074 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1075 SmallString<4> buffer; 1076 llvm::raw_svector_ostream os(buffer); 1077 os << "c" << getValue(); 1078 setNameFn(getResult(), os.str()); 1079 } 1080 1081 //===----------------------------------------------------------------------===// 1082 // ConstWitnessOp 1083 //===----------------------------------------------------------------------===// 1084 1085 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1086 return getPassingAttr(); 1087 } 1088 1089 //===----------------------------------------------------------------------===// 1090 // CstrRequireOp 1091 //===----------------------------------------------------------------------===// 1092 1093 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1094 return operands[0]; 1095 } 1096 1097 //===----------------------------------------------------------------------===// 1098 // DivOp 1099 //===----------------------------------------------------------------------===// 1100 1101 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1102 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1103 if (!lhs) 1104 return nullptr; 1105 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1106 if (!rhs) 1107 return nullptr; 1108 1109 // Division in APInt does not follow floor(lhs, rhs) when the result is 1110 // negative. Rather, APInt rounds toward zero. 1111 APInt quotient, remainder; 1112 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1113 if (quotient.isNegative() && !remainder.isNullValue()) { 1114 quotient -= 1; 1115 } 1116 1117 Type indexTy = IndexType::get(getContext()); 1118 return IntegerAttr::get(indexTy, quotient); 1119 } 1120 1121 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1122 MLIRContext *context, Optional<Location> location, ValueRange operands, 1123 DictionaryAttr attributes, RegionRange regions, 1124 SmallVectorImpl<Type> &inferredReturnTypes) { 1125 if (operands[0].getType().isa<SizeType>() || 1126 operands[1].getType().isa<SizeType>()) 1127 inferredReturnTypes.assign({SizeType::get(context)}); 1128 else 1129 inferredReturnTypes.assign({IndexType::get(context)}); 1130 return success(); 1131 } 1132 1133 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1134 // SizeType is compatible with IndexType. 1135 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1136 } 1137 1138 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1139 1140 //===----------------------------------------------------------------------===// 1141 // ShapeEqOp 1142 //===----------------------------------------------------------------------===// 1143 1144 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1145 bool allSame = true; 1146 if (!operands.empty() && !operands[0]) 1147 return {}; 1148 for (Attribute operand : operands.drop_front(1)) { 1149 if (!operand) 1150 return {}; 1151 allSame = allSame && operand == operands[0]; 1152 } 1153 return BoolAttr::get(getContext(), allSame); 1154 } 1155 1156 //===----------------------------------------------------------------------===// 1157 // IndexToSizeOp 1158 //===----------------------------------------------------------------------===// 1159 1160 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1161 // Constant values of both types, `shape.size` and `index`, are represented as 1162 // `IntegerAttr`s which makes constant folding simple. 1163 if (Attribute arg = operands[0]) 1164 return arg; 1165 return {}; 1166 } 1167 1168 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1169 MLIRContext *context) { 1170 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1171 } 1172 1173 //===----------------------------------------------------------------------===// 1174 // FromExtentsOp 1175 //===----------------------------------------------------------------------===// 1176 1177 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1178 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1179 return nullptr; 1180 SmallVector<int64_t, 6> extents; 1181 for (auto attr : operands) 1182 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1183 Builder builder(getContext()); 1184 return builder.getIndexTensorAttr(extents); 1185 } 1186 1187 //===----------------------------------------------------------------------===// 1188 // FunctionLibraryOp 1189 //===----------------------------------------------------------------------===// 1190 1191 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1192 StringRef name) { 1193 result.attributes.push_back(builder.getNamedAttr( 1194 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1195 } 1196 1197 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1198 auto attr = getMapping() 1199 .get(op->getName().getIdentifier()) 1200 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1201 if (!attr) 1202 return nullptr; 1203 return lookupSymbol<FuncOp>(attr); 1204 } 1205 1206 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1207 OperationState &result) { 1208 // Parse the op name. 1209 StringAttr nameAttr; 1210 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1211 result.attributes)) 1212 return failure(); 1213 1214 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1215 return failure(); 1216 1217 auto *bodyRegion = result.addRegion(); 1218 if (parser.parseRegion(*bodyRegion)) 1219 return failure(); 1220 1221 if (parser.parseKeyword("mapping")) 1222 return failure(); 1223 1224 DictionaryAttr mappingAttr; 1225 if (parser.parseAttribute(mappingAttr, 1226 parser.getBuilder().getType<NoneType>(), "mapping", 1227 result.attributes)) 1228 return failure(); 1229 return success(); 1230 } 1231 1232 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1233 p << ' '; 1234 p.printSymbolName(getName()); 1235 p.printOptionalAttrDictWithKeyword( 1236 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1237 p << ' '; 1238 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1239 /*printBlockTerminators=*/false); 1240 p << " mapping "; 1241 p.printAttributeWithoutType(getMappingAttr()); 1242 } 1243 1244 //===----------------------------------------------------------------------===// 1245 // GetExtentOp 1246 //===----------------------------------------------------------------------===// 1247 1248 Optional<int64_t> GetExtentOp::getConstantDim() { 1249 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1250 return constSizeOp.getValue().getLimitedValue(); 1251 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1252 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1253 return llvm::None; 1254 } 1255 1256 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1257 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1258 if (!elements) 1259 return nullptr; 1260 Optional<int64_t> dim = getConstantDim(); 1261 if (!dim.hasValue()) 1262 return nullptr; 1263 if (dim.getValue() >= elements.getNumElements()) 1264 return nullptr; 1265 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1266 } 1267 1268 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1269 int64_t dim) { 1270 auto loc = result.location; 1271 auto dimAttr = builder.getIndexAttr(dim); 1272 if (shape.getType().isa<ShapeType>()) { 1273 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1274 build(builder, result, builder.getType<SizeType>(), shape, dim); 1275 } else { 1276 Value dim = 1277 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1278 build(builder, result, builder.getIndexType(), shape, dim); 1279 } 1280 } 1281 1282 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1283 MLIRContext *context, Optional<Location> location, ValueRange operands, 1284 DictionaryAttr attributes, RegionRange regions, 1285 SmallVectorImpl<Type> &inferredReturnTypes) { 1286 inferredReturnTypes.assign({IndexType::get(context)}); 1287 return success(); 1288 } 1289 1290 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1291 TypeRange r) { 1292 // SizeType is compatible with IndexType. 1293 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1294 } 1295 1296 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1297 1298 //===----------------------------------------------------------------------===// 1299 // IsBroadcastableOp 1300 //===----------------------------------------------------------------------===// 1301 1302 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1303 MLIRContext *context) { 1304 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1305 } 1306 1307 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1308 // Can always broadcast fewer than two shapes. 1309 if (operands.size() < 2) { 1310 return BoolAttr::get(getContext(), true); 1311 } 1312 1313 return nullptr; 1314 } 1315 1316 //===----------------------------------------------------------------------===// 1317 // MeetOp 1318 //===----------------------------------------------------------------------===// 1319 1320 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1321 MLIRContext *context, Optional<Location> location, ValueRange operands, 1322 DictionaryAttr attributes, RegionRange regions, 1323 SmallVectorImpl<Type> &inferredReturnTypes) { 1324 inferredReturnTypes.assign({operands[0].getType()}); 1325 return success(); 1326 } 1327 1328 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1329 if (l.size() != 1 || r.size() != 1) 1330 return false; 1331 if (l == r) 1332 return true; 1333 1334 Type lhs = l.front(); 1335 Type rhs = r.front(); 1336 1337 if (lhs != rhs) 1338 return false; 1339 1340 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1341 return true; 1342 1343 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1344 return true; 1345 return false; 1346 } 1347 1348 //===----------------------------------------------------------------------===// 1349 // RankOp 1350 //===----------------------------------------------------------------------===// 1351 1352 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1353 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1354 if (!shape) 1355 return {}; 1356 int64_t rank = shape.getNumElements(); 1357 Builder builder(getContext()); 1358 return builder.getIndexAttr(rank); 1359 } 1360 1361 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1362 /// Constant folding fails in cases where only the rank is constant, not the 1363 /// shape itself. 1364 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1365 /// 1366 /// Example: 1367 /// 1368 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1369 /// %rank = shape.rank %shape 1370 /// 1371 /// becomes 1372 /// 1373 /// %rank = shape.const_size 3 1374 1375 namespace { 1376 struct RankShapeOfCanonicalizationPattern 1377 : public OpRewritePattern<shape::RankOp> { 1378 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1379 1380 LogicalResult matchAndRewrite(shape::RankOp op, 1381 PatternRewriter &rewriter) const override { 1382 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1383 if (!shapeOfOp) 1384 return failure(); 1385 auto rankedTensorType = 1386 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1387 if (!rankedTensorType) 1388 return failure(); 1389 int64_t rank = rankedTensorType.getRank(); 1390 if (op.getType().isa<IndexType>()) { 1391 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1392 rank); 1393 } else if (op.getType().isa<shape::SizeType>()) { 1394 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1395 } else { 1396 return failure(); 1397 } 1398 return success(); 1399 } 1400 }; 1401 } // namespace 1402 1403 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1404 MLIRContext *context) { 1405 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1406 } 1407 1408 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1409 MLIRContext *context, Optional<Location> location, ValueRange operands, 1410 DictionaryAttr attributes, RegionRange regions, 1411 SmallVectorImpl<Type> &inferredReturnTypes) { 1412 if (operands[0].getType().isa<ShapeType>()) 1413 inferredReturnTypes.assign({SizeType::get(context)}); 1414 else 1415 inferredReturnTypes.assign({IndexType::get(context)}); 1416 return success(); 1417 } 1418 1419 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1420 // SizeType is compatible with IndexType. 1421 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1422 } 1423 1424 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1425 1426 //===----------------------------------------------------------------------===// 1427 // NumElementsOp 1428 //===----------------------------------------------------------------------===// 1429 1430 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1431 1432 // Fold only when argument constant. 1433 Attribute shape = operands[0]; 1434 if (!shape) 1435 return {}; 1436 1437 APInt product(64, 1); 1438 for (auto value : shape.cast<DenseIntElementsAttr>()) 1439 product *= value; 1440 Builder builder(getContext()); 1441 return builder.getIndexAttr(product.getLimitedValue()); 1442 } 1443 1444 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1445 MLIRContext *context, Optional<Location> location, ValueRange operands, 1446 DictionaryAttr attributes, RegionRange regions, 1447 SmallVectorImpl<Type> &inferredReturnTypes) { 1448 if (operands[0].getType().isa<ShapeType>()) 1449 inferredReturnTypes.assign({SizeType::get(context)}); 1450 else 1451 inferredReturnTypes.assign({IndexType::get(context)}); 1452 return success(); 1453 } 1454 1455 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1456 TypeRange r) { 1457 // SizeType is compatible with IndexType. 1458 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1459 } 1460 1461 LogicalResult shape::NumElementsOp::verify() { 1462 return verifySizeOrIndexOp(*this); 1463 } 1464 1465 //===----------------------------------------------------------------------===// 1466 // MaxOp 1467 //===----------------------------------------------------------------------===// 1468 1469 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1470 // If operands are equal, just propagate one. 1471 if (getLhs() == getRhs()) 1472 return getLhs(); 1473 return nullptr; 1474 } 1475 1476 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1477 MLIRContext *context, Optional<Location> location, ValueRange operands, 1478 DictionaryAttr attributes, RegionRange regions, 1479 SmallVectorImpl<Type> &inferredReturnTypes) { 1480 if (operands[0].getType() == operands[1].getType()) 1481 inferredReturnTypes.assign({operands[0].getType()}); 1482 else 1483 inferredReturnTypes.assign({SizeType::get(context)}); 1484 return success(); 1485 } 1486 1487 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1488 if (l.size() != 1 || r.size() != 1) 1489 return false; 1490 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1491 return true; 1492 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1493 return true; 1494 return false; 1495 } 1496 1497 //===----------------------------------------------------------------------===// 1498 // MinOp 1499 //===----------------------------------------------------------------------===// 1500 1501 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1502 // If operands are equal, just propagate one. 1503 if (getLhs() == getRhs()) 1504 return getLhs(); 1505 return nullptr; 1506 } 1507 1508 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1509 MLIRContext *context, Optional<Location> location, ValueRange operands, 1510 DictionaryAttr attributes, RegionRange regions, 1511 SmallVectorImpl<Type> &inferredReturnTypes) { 1512 if (operands[0].getType() == operands[1].getType()) 1513 inferredReturnTypes.assign({operands[0].getType()}); 1514 else 1515 inferredReturnTypes.assign({SizeType::get(context)}); 1516 return success(); 1517 } 1518 1519 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1520 if (l.size() != 1 || r.size() != 1) 1521 return false; 1522 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1523 return true; 1524 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1525 return true; 1526 return false; 1527 } 1528 1529 //===----------------------------------------------------------------------===// 1530 // MulOp 1531 //===----------------------------------------------------------------------===// 1532 1533 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1534 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1535 if (!lhs) 1536 return nullptr; 1537 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1538 if (!rhs) 1539 return nullptr; 1540 APInt folded = lhs.getValue() * rhs.getValue(); 1541 Type indexTy = IndexType::get(getContext()); 1542 return IntegerAttr::get(indexTy, folded); 1543 } 1544 1545 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1546 MLIRContext *context, Optional<Location> location, ValueRange operands, 1547 DictionaryAttr attributes, RegionRange regions, 1548 SmallVectorImpl<Type> &inferredReturnTypes) { 1549 if (operands[0].getType().isa<SizeType>() || 1550 operands[1].getType().isa<SizeType>()) 1551 inferredReturnTypes.assign({SizeType::get(context)}); 1552 else 1553 inferredReturnTypes.assign({IndexType::get(context)}); 1554 return success(); 1555 } 1556 1557 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1558 // SizeType is compatible with IndexType. 1559 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1560 } 1561 1562 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1563 1564 //===----------------------------------------------------------------------===// 1565 // ShapeOfOp 1566 //===----------------------------------------------------------------------===// 1567 1568 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1569 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1570 if (!type || !type.hasStaticShape()) 1571 return nullptr; 1572 Builder builder(getContext()); 1573 return builder.getIndexTensorAttr(type.getShape()); 1574 } 1575 1576 namespace { 1577 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1578 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1579 1580 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1581 PatternRewriter &rewriter) const override { 1582 if (!op.getArg().getType().isa<ShapedType>()) 1583 return failure(); 1584 if (op.getType().isa<ShapedType>()) 1585 return failure(); 1586 1587 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1588 op.getArg()); 1589 return success(); 1590 } 1591 }; 1592 1593 // Canonicalize 1594 // ``` 1595 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1596 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1597 // ``` 1598 // to 1599 // ``` 1600 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1601 // ``` 1602 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1603 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1604 1605 LogicalResult matchAndRewrite(tensor::CastOp op, 1606 PatternRewriter &rewriter) const override { 1607 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1608 if (!ty || ty.getRank() != 1) 1609 return failure(); 1610 1611 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1612 if (!shapeOfOp) 1613 return failure(); 1614 1615 // Argument type must be ranked and must not conflict. 1616 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1617 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1618 return failure(); 1619 1620 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1621 return success(); 1622 } 1623 }; 1624 } // namespace 1625 1626 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1627 MLIRContext *context) { 1628 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1629 ExtractFromShapeOfExtentTensor>(context); 1630 } 1631 1632 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1633 MLIRContext *context, Optional<Location> location, ValueRange operands, 1634 DictionaryAttr attributes, RegionRange regions, 1635 SmallVectorImpl<Type> &inferredReturnTypes) { 1636 if (operands[0].getType().isa<ValueShapeType>()) 1637 inferredReturnTypes.assign({ShapeType::get(context)}); 1638 else { 1639 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1640 int64_t rank = 1641 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1642 Type indexTy = IndexType::get(context); 1643 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1644 inferredReturnTypes.assign({extentTensorTy}); 1645 } 1646 return success(); 1647 } 1648 1649 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1650 if (l.size() != 1 || r.size() != 1) 1651 return false; 1652 if (l == r) 1653 return true; 1654 1655 Type lhs = l.front(); 1656 Type rhs = r.front(); 1657 1658 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1659 return false; 1660 1661 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1662 // Shape type is compatible with all other valid return types. 1663 return true; 1664 1665 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1666 return true; 1667 return false; 1668 } 1669 1670 LogicalResult shape::ShapeOfOp::verify() { 1671 return verifyShapeOrExtentTensorOp(*this); 1672 } 1673 1674 //===----------------------------------------------------------------------===// 1675 // SizeToIndexOp 1676 //===----------------------------------------------------------------------===// 1677 1678 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1679 // Constant values of both types, `shape.size` and `index`, are represented as 1680 // `IntegerAttr`s which makes constant folding simple. 1681 if (Attribute arg = operands[0]) 1682 return arg; 1683 return OpFoldResult(); 1684 } 1685 1686 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1687 MLIRContext *context) { 1688 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1689 } 1690 1691 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1692 if (inputs.size() != 1 || outputs.size() != 1) 1693 return false; 1694 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1695 } 1696 1697 //===----------------------------------------------------------------------===// 1698 // YieldOp 1699 //===----------------------------------------------------------------------===// 1700 1701 LogicalResult shape::YieldOp::verify() { 1702 auto *parentOp = (*this)->getParentOp(); 1703 auto results = parentOp->getResults(); 1704 auto operands = getOperands(); 1705 1706 if (parentOp->getNumResults() != getNumOperands()) 1707 return emitOpError() << "number of operands does not match number of " 1708 "results of its parent"; 1709 for (auto e : llvm::zip(results, operands)) 1710 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1711 return emitOpError() << "types mismatch between yield op and its parent"; 1712 1713 return success(); 1714 } 1715 1716 //===----------------------------------------------------------------------===// 1717 // SplitAtOp 1718 //===----------------------------------------------------------------------===// 1719 1720 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1721 SmallVectorImpl<OpFoldResult> &results) { 1722 if (!operands[0] || !operands[1]) 1723 return failure(); 1724 auto shapeVec = llvm::to_vector<6>( 1725 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1726 auto shape = llvm::makeArrayRef(shapeVec); 1727 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1728 // Verify that the split point is in the correct range. 1729 // TODO: Constant fold to an "error". 1730 int64_t rank = shape.size(); 1731 if (!(-rank <= splitPoint && splitPoint <= rank)) 1732 return failure(); 1733 if (splitPoint < 0) 1734 splitPoint += shape.size(); 1735 Builder builder(operands[0].getContext()); 1736 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1737 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1738 return success(); 1739 } 1740 1741 //===----------------------------------------------------------------------===// 1742 // ToExtentTensorOp 1743 //===----------------------------------------------------------------------===// 1744 1745 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1746 if (!operands[0]) 1747 return OpFoldResult(); 1748 Builder builder(getContext()); 1749 auto shape = llvm::to_vector<6>( 1750 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1751 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1752 builder.getIndexType()); 1753 return DenseIntElementsAttr::get(type, shape); 1754 } 1755 1756 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1757 if (inputs.size() != 1 || outputs.size() != 1) 1758 return false; 1759 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1760 if (!inputTensor.getElementType().isa<IndexType>() || 1761 inputTensor.getRank() != 1) 1762 return false; 1763 } else if (!inputs[0].isa<ShapeType>()) { 1764 return false; 1765 } 1766 1767 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1768 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1769 } 1770 1771 //===----------------------------------------------------------------------===// 1772 // ReduceOp 1773 //===----------------------------------------------------------------------===// 1774 1775 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1776 ValueRange initVals) { 1777 result.addOperands(shape); 1778 result.addOperands(initVals); 1779 1780 Region *bodyRegion = result.addRegion(); 1781 bodyRegion->push_back(new Block); 1782 Block &bodyBlock = bodyRegion->front(); 1783 bodyBlock.addArgument(builder.getIndexType(), result.location); 1784 1785 Type elementType; 1786 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1787 elementType = tensorType.getElementType(); 1788 else 1789 elementType = SizeType::get(builder.getContext()); 1790 bodyBlock.addArgument(elementType, shape.getLoc()); 1791 1792 for (Value initVal : initVals) { 1793 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1794 result.addTypes(initVal.getType()); 1795 } 1796 } 1797 1798 LogicalResult ReduceOp::verify() { 1799 // Verify block arg types. 1800 Block &block = getRegion().front(); 1801 1802 // The block takes index, extent, and aggregated values as arguments. 1803 auto blockArgsCount = getInitVals().size() + 2; 1804 if (block.getNumArguments() != blockArgsCount) 1805 return emitOpError() << "ReduceOp body is expected to have " 1806 << blockArgsCount << " arguments"; 1807 1808 // The first block argument is the index and must always be of type `index`. 1809 if (!block.getArgument(0).getType().isa<IndexType>()) 1810 return emitOpError( 1811 "argument 0 of ReduceOp body is expected to be of IndexType"); 1812 1813 // The second block argument is the extent and must be of type `size` or 1814 // `index`, depending on whether the reduce operation is applied to a shape or 1815 // to an extent tensor. 1816 Type extentTy = block.getArgument(1).getType(); 1817 if (getShape().getType().isa<ShapeType>()) { 1818 if (!extentTy.isa<SizeType>()) 1819 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1820 "SizeType if the ReduceOp operates on a ShapeType"); 1821 } else { 1822 if (!extentTy.isa<IndexType>()) 1823 return emitOpError( 1824 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1825 "ReduceOp operates on an extent tensor"); 1826 } 1827 1828 for (const auto &type : llvm::enumerate(getInitVals())) 1829 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1830 return emitOpError() << "type mismatch between argument " 1831 << type.index() + 2 1832 << " of ReduceOp body and initial value " 1833 << type.index(); 1834 return success(); 1835 } 1836 1837 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1838 // Parse operands. 1839 SmallVector<OpAsmParser::OperandType, 3> operands; 1840 Type shapeOrExtentTensorType; 1841 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1842 OpAsmParser::Delimiter::Paren) || 1843 parser.parseColonType(shapeOrExtentTensorType) || 1844 parser.parseOptionalArrowTypeList(result.types)) 1845 return failure(); 1846 1847 // Resolve operands. 1848 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1849 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1850 result.operands) || 1851 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1852 result.operands)) 1853 return failure(); 1854 1855 // Parse the body. 1856 Region *body = result.addRegion(); 1857 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1858 return failure(); 1859 1860 // Parse attributes. 1861 if (parser.parseOptionalAttrDict(result.attributes)) 1862 return failure(); 1863 1864 return success(); 1865 } 1866 1867 void ReduceOp::print(OpAsmPrinter &p) { 1868 p << '(' << getShape() << ", " << getInitVals() 1869 << ") : " << getShape().getType(); 1870 p.printOptionalArrowTypeList(getResultTypes()); 1871 p << ' '; 1872 p.printRegion(getRegion()); 1873 p.printOptionalAttrDict((*this)->getAttrs()); 1874 } 1875 1876 #define GET_OP_CLASSES 1877 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1878