1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/SetOperations.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::shape; 31 32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 33 34 namespace { 35 #include "ShapeCanonicalization.inc" 36 } // namespace 37 38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 39 return RankedTensorType::get({rank}, IndexType::get(ctx)); 40 } 41 42 bool shape::isExtentTensorType(Type type) { 43 auto ranked = type.dyn_cast<RankedTensorType>(); 44 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 45 } 46 47 LogicalResult shape::getShapeVec(Value input, 48 SmallVectorImpl<int64_t> &shapeValues) { 49 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 50 auto type = inputOp.getArg().getType().cast<ShapedType>(); 51 if (!type.hasRank()) 52 return failure(); 53 llvm::append_range(shapeValues, type.getShape()); 54 return success(); 55 } 56 DenseIntElementsAttr attr; 57 if (matchPattern(input, m_Constant(&attr))) { 58 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 59 return success(); 60 } 61 return failure(); 62 } 63 64 static bool isErrorPropagationPossible(TypeRange operandTypes) { 65 return llvm::any_of(operandTypes, [](Type ty) { 66 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 67 }); 68 } 69 70 static LogicalResult verifySizeOrIndexOp(Operation *op) { 71 assert(op != nullptr && op->getNumResults() == 1); 72 Type resultTy = op->getResultTypes().front(); 73 if (isErrorPropagationPossible(op->getOperandTypes())) { 74 if (!resultTy.isa<SizeType>()) 75 return op->emitOpError() 76 << "if at least one of the operands can hold error values then " 77 "the result must be of type `size` to propagate them"; 78 } 79 return success(); 80 } 81 82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 83 assert(op != nullptr && op->getNumResults() == 1); 84 Type resultTy = op->getResultTypes().front(); 85 if (isErrorPropagationPossible(op->getOperandTypes())) { 86 if (!resultTy.isa<ShapeType>()) 87 return op->emitOpError() 88 << "if at least one of the operands can hold error values then " 89 "the result must be of type `shape` to propagate them"; 90 } 91 return success(); 92 } 93 94 template <typename... Ty> 95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 96 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 97 } 98 99 template <typename... Ty, typename... ranges> 100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 101 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // InlinerInterface 106 //===----------------------------------------------------------------------===// 107 108 namespace { 109 /// This class defines the interface for inlining shape dialect ops. 110 struct ShapeInlinerInterface : public DialectInlinerInterface { 111 using DialectInlinerInterface::DialectInlinerInterface; 112 113 // Returns true if the given region 'src' can be inlined into the region 114 // 'dest' that is attached to an operation registered to the current dialect. 115 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 116 BlockAndValueMapping &) const final { 117 return true; 118 } 119 120 // Returns true if the given operation 'op', that is registered to this 121 // dialect, can be inlined into the region 'dest' that is attached to an 122 // operation registered to the current dialect. 123 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 124 BlockAndValueMapping &) const final { 125 return true; 126 } 127 }; 128 } // namespace 129 130 void ShapeDialect::initialize() { 131 addOperations< 132 #define GET_OP_LIST 133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 134 >(); 135 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 136 addInterfaces<ShapeInlinerInterface>(); 137 // Allow unknown operations during prototyping and testing. As the dialect is 138 // still evolving it makes it simple to start with an unregistered ops and 139 // try different variants before actually defining the op. 140 allowUnknownOperations(); 141 } 142 143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 144 Attribute value, Type type, 145 Location loc) { 146 if (type.isa<ShapeType>() || isExtentTensorType(type)) 147 return builder.create<ConstShapeOp>(loc, type, 148 value.cast<DenseIntElementsAttr>()); 149 if (type.isa<SizeType>()) 150 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 151 if (type.isa<WitnessType>()) 152 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 153 if (arith::ConstantOp::isBuildableWith(value, type)) 154 return builder.create<arith::ConstantOp>(loc, type, value); 155 return nullptr; 156 } 157 158 /// Parse a type registered to this dialect. 159 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 160 StringRef keyword; 161 if (parser.parseKeyword(&keyword)) 162 return Type(); 163 164 if (keyword == "shape") 165 return ShapeType::get(getContext()); 166 if (keyword == "size") 167 return SizeType::get(getContext()); 168 if (keyword == "value_shape") 169 return ValueShapeType::get(getContext()); 170 if (keyword == "witness") 171 return WitnessType::get(getContext()); 172 173 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 174 return Type(); 175 } 176 177 /// Print a type registered to this dialect. 178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 179 TypeSwitch<Type>(type) 180 .Case<ShapeType>([&](Type) { os << "shape"; }) 181 .Case<SizeType>([&](Type) { os << "size"; }) 182 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 183 .Case<WitnessType>([&](Type) { os << "witness"; }) 184 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 185 } 186 187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 188 NamedAttribute attribute) { 189 // Verify shape.lib attribute. 190 if (attribute.getName() == "shape.lib") { 191 if (!op->hasTrait<OpTrait::SymbolTable>()) 192 return op->emitError( 193 "shape.lib attribute may only be on op implementing SymbolTable"); 194 195 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 196 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 197 if (!symbol) 198 return op->emitError("shape function library ") 199 << symbolRef << " not found"; 200 return isa<shape::FunctionLibraryOp>(symbol) 201 ? success() 202 : op->emitError() 203 << symbolRef << " required to be shape function library"; 204 } 205 206 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 207 // Verify all entries are function libraries and mappings in libraries 208 // refer to unique ops. 209 DenseSet<StringAttr> key; 210 for (auto it : arr) { 211 if (!it.isa<SymbolRefAttr>()) 212 return op->emitError( 213 "only SymbolRefAttr allowed in shape.lib attribute array"); 214 215 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 216 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 217 if (!shapeFnLib) 218 return op->emitError() 219 << it << " does not refer to FunctionLibraryOp"; 220 for (auto mapping : shapeFnLib.getMapping()) { 221 if (!key.insert(mapping.getName()).second) { 222 return op->emitError("only one op to shape mapping allowed, found " 223 "multiple for `") 224 << mapping.getName() << "`"; 225 } 226 } 227 } 228 return success(); 229 } 230 231 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 232 "allowed as shape.lib attribute"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AnyOp 239 //===----------------------------------------------------------------------===// 240 241 // TODO: Canonicalization should be implemented for shapes that can be 242 // determined through mixtures of the known dimensions of the inputs. 243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 244 // Only the last operand is checked because AnyOp is commutative. 245 if (operands.back()) 246 return operands.back(); 247 248 return nullptr; 249 } 250 251 //===----------------------------------------------------------------------===// 252 // AssumingOp 253 //===----------------------------------------------------------------------===// 254 255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 256 result.regions.reserve(1); 257 Region *doRegion = result.addRegion(); 258 259 auto &builder = parser.getBuilder(); 260 OpAsmParser::OperandType cond; 261 if (parser.parseOperand(cond) || 262 parser.resolveOperand(cond, builder.getType<WitnessType>(), 263 result.operands)) 264 return failure(); 265 266 // Parse optional results type list. 267 if (parser.parseOptionalArrowTypeList(result.types)) 268 return failure(); 269 270 // Parse the region and add a terminator if elided. 271 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 272 return failure(); 273 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 274 275 // Parse the optional attribute list. 276 if (parser.parseOptionalAttrDict(result.attributes)) 277 return failure(); 278 return success(); 279 } 280 281 void AssumingOp::print(OpAsmPrinter &p) { 282 bool yieldsResults = !getResults().empty(); 283 284 p << " " << getWitness(); 285 if (yieldsResults) 286 p << " -> (" << getResultTypes() << ")"; 287 p << ' '; 288 p.printRegion(getDoRegion(), 289 /*printEntryBlockArgs=*/false, 290 /*printBlockTerminators=*/yieldsResults); 291 p.printOptionalAttrDict((*this)->getAttrs()); 292 } 293 294 namespace { 295 // Removes AssumingOp with a passing witness and inlines the region. 296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 297 using OpRewritePattern<AssumingOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(AssumingOp op, 300 PatternRewriter &rewriter) const override { 301 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 302 if (!witness || !witness.getPassingAttr()) 303 return failure(); 304 305 AssumingOp::inlineRegionIntoParent(op, rewriter); 306 return success(); 307 } 308 }; 309 310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 311 using OpRewritePattern<AssumingOp>::OpRewritePattern; 312 313 LogicalResult matchAndRewrite(AssumingOp op, 314 PatternRewriter &rewriter) const override { 315 Block *body = op.getBody(); 316 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 317 318 // Find used values. 319 SmallVector<Value, 4> newYieldOperands; 320 Value opResult, yieldOperand; 321 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 322 std::tie(opResult, yieldOperand) = it; 323 if (!opResult.getUses().empty()) { 324 newYieldOperands.push_back(yieldOperand); 325 } 326 } 327 328 // Rewrite only if redundant results exist. 329 if (newYieldOperands.size() == yieldOp->getNumOperands()) 330 return failure(); 331 332 // Replace yield op in the old assuming op's body and move the entire region 333 // to the new assuming op. 334 rewriter.setInsertionPointToEnd(body); 335 auto newYieldOp = 336 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 337 rewriter.setInsertionPoint(op); 338 auto newOp = rewriter.create<AssumingOp>( 339 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 340 newOp.getDoRegion().takeBody(op.getDoRegion()); 341 342 // Use the new results to replace the previously used ones. 343 SmallVector<Value, 4> replacementValues; 344 auto src = newOp.getResults().begin(); 345 for (auto it : op.getResults()) { 346 if (it.getUses().empty()) 347 replacementValues.push_back(nullptr); 348 else 349 replacementValues.push_back(*src++); 350 } 351 rewriter.replaceOp(op, replacementValues); 352 return success(); 353 } 354 }; 355 } // namespace 356 357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 358 MLIRContext *context) { 359 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 360 } 361 362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 363 void AssumingOp::getSuccessorRegions( 364 Optional<unsigned> index, ArrayRef<Attribute> operands, 365 SmallVectorImpl<RegionSuccessor> ®ions) { 366 // AssumingOp has unconditional control flow into the region and back to the 367 // parent, so return the correct RegionSuccessor purely based on the index 368 // being None or 0. 369 if (index.hasValue()) { 370 regions.push_back(RegionSuccessor(getResults())); 371 return; 372 } 373 374 regions.push_back(RegionSuccessor(&getDoRegion())); 375 } 376 377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 378 PatternRewriter &rewriter) { 379 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 380 auto *assumingBlock = op.getBody(); 381 auto initPosition = rewriter.getInsertionPoint(); 382 auto *blockAfterAssuming = 383 rewriter.splitBlock(blockBeforeAssuming, initPosition); 384 385 // Remove the AssumingOp and AssumingYieldOp. 386 auto &yieldOp = assumingBlock->back(); 387 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 388 rewriter.replaceOp(op, yieldOp.getOperands()); 389 rewriter.eraseOp(&yieldOp); 390 391 // Merge blocks together as there was no branching behavior from the 392 // AssumingOp. 393 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 394 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 395 } 396 397 void AssumingOp::build( 398 OpBuilder &builder, OperationState &result, Value witness, 399 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 400 401 result.addOperands(witness); 402 Region *bodyRegion = result.addRegion(); 403 bodyRegion->push_back(new Block); 404 Block &bodyBlock = bodyRegion->front(); 405 406 // Build body. 407 OpBuilder::InsertionGuard guard(builder); 408 builder.setInsertionPointToStart(&bodyBlock); 409 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 410 builder.create<AssumingYieldOp>(result.location, yieldValues); 411 412 SmallVector<Type, 2> assumingTypes; 413 for (Value v : yieldValues) 414 assumingTypes.push_back(v.getType()); 415 result.addTypes(assumingTypes); 416 } 417 418 //===----------------------------------------------------------------------===// 419 // AddOp 420 //===----------------------------------------------------------------------===// 421 422 LogicalResult mlir::shape::AddOp::inferReturnTypes( 423 MLIRContext *context, Optional<Location> location, ValueRange operands, 424 DictionaryAttr attributes, RegionRange regions, 425 SmallVectorImpl<Type> &inferredReturnTypes) { 426 if (operands[0].getType().isa<SizeType>() || 427 operands[1].getType().isa<SizeType>()) 428 inferredReturnTypes.assign({SizeType::get(context)}); 429 else 430 inferredReturnTypes.assign({IndexType::get(context)}); 431 return success(); 432 } 433 434 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 435 // SizeType is compatible with IndexType. 436 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 437 } 438 439 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 440 // add(x, 0) -> x 441 if (matchPattern(getRhs(), m_Zero())) 442 return getLhs(); 443 444 return constFoldBinaryOp<IntegerAttr>( 445 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 446 } 447 448 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 449 450 //===----------------------------------------------------------------------===// 451 // AssumingAllOp 452 //===----------------------------------------------------------------------===// 453 454 namespace { 455 456 // Merge multiple `shape.assuming_all` operations together. 457 // 458 // %0 = shape.assuming_all %w0, %w1 459 // %1 = shape.assuming_all %w2, %0 460 // 461 // to: 462 // 463 // %0 = shape.assuming_all %w0, %w2, %w2 464 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 465 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 466 467 LogicalResult matchAndRewrite(AssumingAllOp op, 468 PatternRewriter &rewriter) const override { 469 SmallVector<Value> operands; 470 471 for (Value operand : op.getInputs()) { 472 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>()) 473 operands.append(assumeAll.operand_begin(), assumeAll->operand_end()); 474 else 475 operands.push_back(operand); 476 } 477 478 // We didn't find any other `assuming_all` ops to merge with. 479 if (operands.size() == op.getNumOperands()) 480 return failure(); 481 482 // Replace with a new `assuming_all` operation with merged constraints. 483 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 484 return success(); 485 } 486 }; 487 488 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 489 // are subsumed by others. 490 // 491 // %0 = shape.cstr_broadcastable %shape0, %shape1 492 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 493 // 494 // %2 = shape.cstr_broadcastable %shape3, %shape4 495 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 496 // 497 // %4 = shape.assuming_all %0, %1, %2, %3 498 // 499 // to: 500 // 501 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 502 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 503 // %2 = shape.assuming_all %0, %1 504 // 505 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 506 // shapes [0, 1] are broadcastable too, and can be removed from the list of 507 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 508 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 509 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 510 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 511 512 LogicalResult matchAndRewrite(AssumingAllOp op, 513 PatternRewriter &rewriter) const override { 514 // Collect all `CstrBroadcastableOp` operands first. 515 SetVector<CstrBroadcastableOp> operands; 516 for (Value operand : op.getInputs()) { 517 // TODO: Apply this optimization if some of the witnesses are not 518 // produced by the `cstr_broadcastable`. 519 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 520 if (!broadcastable) 521 return failure(); 522 523 operands.insert(broadcastable); 524 } 525 526 // Skip trivial `assuming_all` operations. 527 if (operands.size() <= 1) 528 return failure(); 529 530 // Collect shapes checked by `cstr_broadcastable` operands. 531 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 532 for (auto cstr : operands) { 533 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end()); 534 shapes.emplace_back(cstr, std::move(shapesSet)); 535 } 536 537 // Sort by the number of shape operands (larger to smaller). 538 llvm::sort(shapes, [](auto a, auto b) { 539 return a.first.getNumOperands() > b.first.getNumOperands(); 540 }); 541 542 // We start from the `cst_broadcastable` operations with largest number of 543 // shape operands, and remove redundant `cst_broadcastable` operations. We 544 // do this until we find a set of `cst_broadcastable` operations with 545 // non-overlapping constraints. 546 SmallVector<CstrBroadcastableOp> markedForErase; 547 548 for (unsigned i = 0; i < shapes.size(); ++i) { 549 auto isSubset = [&](auto pair) { 550 return llvm::set_is_subset(pair.second, shapes[i].second); 551 }; 552 553 // Keep redundant `cstr_broadcastable` operations to be erased. 554 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 555 for (auto *it0 = it; it0 < shapes.end(); ++it0) 556 markedForErase.push_back(it0->first); 557 shapes.erase(it, shapes.end()); 558 } 559 560 // We didn't find any operands that could be removed. 561 if (markedForErase.empty()) 562 return failure(); 563 564 // Collect non-overlapping `cst_broadcastable` constraints. 565 SmallVector<Value> uniqueConstraints; 566 for (auto &shape : shapes) 567 uniqueConstraints.push_back(shape.first.getResult()); 568 569 // Replace with a new `assuming_all` operation ... 570 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints); 571 572 // ... and maybe erase `cstr_broadcastable` ops without uses. 573 for (auto &op : markedForErase) 574 if (op->use_empty()) 575 rewriter.eraseOp(op); 576 577 return success(); 578 } 579 }; 580 581 struct AssumingAllToCstrEqCanonicalization 582 : public OpRewritePattern<AssumingAllOp> { 583 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(AssumingAllOp op, 586 PatternRewriter &rewriter) const override { 587 SmallVector<Value, 8> shapes; 588 for (Value w : op.getInputs()) { 589 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 590 if (!cstrEqOp) 591 return failure(); 592 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 593 return llvm::is_contained(shapes, s); 594 }); 595 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 596 return failure(); 597 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 598 } 599 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 600 return success(); 601 } 602 }; 603 604 template <typename OpTy> 605 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 606 using OpRewritePattern<OpTy>::OpRewritePattern; 607 608 LogicalResult matchAndRewrite(OpTy op, 609 PatternRewriter &rewriter) const override { 610 // Find unique operands. 611 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 612 613 // Reduce op to equivalent with unique operands. 614 if (unique.size() < op.getNumOperands()) { 615 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 616 unique.takeVector(), op->getAttrs()); 617 return success(); 618 } 619 620 return failure(); 621 } 622 }; 623 } // namespace 624 625 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 626 MLIRContext *context) { 627 patterns 628 .add<MergeAssumingAllOps, AssumingAllOneOp, 629 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 630 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 631 } 632 633 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 634 // Iterate in reverse to first handle all constant operands. They are 635 // guaranteed to be the tail of the inputs because this is commutative. 636 for (int idx = operands.size() - 1; idx >= 0; idx--) { 637 Attribute a = operands[idx]; 638 // Cannot fold if any inputs are not constant; 639 if (!a) 640 return nullptr; 641 642 // We do not need to keep statically known values after handling them in 643 // this method. 644 getOperation()->eraseOperand(idx); 645 646 // Always false if any input is statically known false 647 if (!a.cast<BoolAttr>().getValue()) 648 return a; 649 } 650 // If this is reached, all inputs were statically known passing. 651 return BoolAttr::get(getContext(), true); 652 } 653 654 LogicalResult AssumingAllOp::verify() { 655 // Ensure that AssumingAllOp contains at least one operand 656 if (getNumOperands() == 0) 657 return emitOpError("no operands specified"); 658 659 return success(); 660 } 661 662 //===----------------------------------------------------------------------===// 663 // BroadcastOp 664 //===----------------------------------------------------------------------===// 665 666 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 667 if (getShapes().size() == 1) { 668 // Otherwise, we need a cast which would be a canonicalization, not folding. 669 if (getShapes().front().getType() != getType()) 670 return nullptr; 671 return getShapes().front(); 672 } 673 674 // TODO: Support folding with more than 2 input shapes 675 if (getShapes().size() > 2) 676 return nullptr; 677 678 if (!operands[0] || !operands[1]) 679 return nullptr; 680 auto lhsShape = llvm::to_vector<6>( 681 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 682 auto rhsShape = llvm::to_vector<6>( 683 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 684 SmallVector<int64_t, 6> resultShape; 685 686 // If the shapes are not compatible, we can't fold it. 687 // TODO: Fold to an "error". 688 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 689 return nullptr; 690 691 Builder builder(getContext()); 692 return builder.getIndexTensorAttr(resultShape); 693 } 694 695 LogicalResult BroadcastOp::verify() { 696 return verifyShapeOrExtentTensorOp(*this); 697 } 698 699 namespace { 700 template <typename OpTy> 701 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 702 using OpRewritePattern<OpTy>::OpRewritePattern; 703 704 LogicalResult matchAndRewrite(OpTy op, 705 PatternRewriter &rewriter) const override { 706 auto isPotentiallyNonEmptyShape = [](Value shape) { 707 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 708 if (extentTensorTy.getDimSize(0) == 0) 709 return false; 710 } 711 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 712 if (constShape.getShape().empty()) 713 return false; 714 } 715 return true; 716 }; 717 auto newOperands = llvm::to_vector<8>( 718 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 719 720 // Reduce op to equivalent without empty shape operands. 721 if (newOperands.size() < op.getNumOperands()) { 722 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 723 op->getAttrs()); 724 return success(); 725 } 726 727 return failure(); 728 } 729 }; 730 731 struct BroadcastForwardSingleOperandPattern 732 : public OpRewritePattern<BroadcastOp> { 733 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 734 735 LogicalResult matchAndRewrite(BroadcastOp op, 736 PatternRewriter &rewriter) const override { 737 if (op.getNumOperands() != 1) 738 return failure(); 739 Value replacement = op.getShapes().front(); 740 741 // Insert cast if needed. 742 if (replacement.getType() != op.getType()) { 743 auto loc = op.getLoc(); 744 if (op.getType().isa<ShapeType>()) { 745 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 746 } else { 747 assert(!op.getType().isa<ShapeType>() && 748 !replacement.getType().isa<ShapeType>() && 749 "expect extent tensor cast"); 750 replacement = 751 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 752 } 753 } 754 755 rewriter.replaceOp(op, replacement); 756 return success(); 757 } 758 }; 759 760 struct BroadcastFoldConstantOperandsPattern 761 : public OpRewritePattern<BroadcastOp> { 762 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 763 764 LogicalResult matchAndRewrite(BroadcastOp op, 765 PatternRewriter &rewriter) const override { 766 SmallVector<int64_t, 8> foldedConstantShape; 767 SmallVector<Value, 8> newShapeOperands; 768 for (Value shape : op.getShapes()) { 769 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 770 SmallVector<int64_t, 8> newFoldedConstantShape; 771 if (OpTrait::util::getBroadcastedShape( 772 foldedConstantShape, 773 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 774 newFoldedConstantShape)) { 775 foldedConstantShape = newFoldedConstantShape; 776 continue; 777 } 778 } 779 newShapeOperands.push_back(shape); 780 } 781 782 // Need at least two constant operands to fold anything. 783 if (op.getNumOperands() - newShapeOperands.size() < 2) 784 return failure(); 785 786 auto foldedConstantOperandsTy = RankedTensorType::get( 787 {static_cast<int64_t>(foldedConstantShape.size())}, 788 rewriter.getIndexType()); 789 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 790 op.getLoc(), foldedConstantOperandsTy, 791 rewriter.getIndexTensorAttr(foldedConstantShape))); 792 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 793 newShapeOperands); 794 return success(); 795 } 796 }; 797 798 template <typename OpTy> 799 struct CanonicalizeCastExtentTensorOperandsPattern 800 : public OpRewritePattern<OpTy> { 801 using OpRewritePattern<OpTy>::OpRewritePattern; 802 803 LogicalResult matchAndRewrite(OpTy op, 804 PatternRewriter &rewriter) const override { 805 // Canonicalize operands. 806 bool anyChange = false; 807 auto canonicalizeOperand = [&](Value operand) { 808 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 809 // Only eliminate the cast if it holds no shape information. 810 bool isInformationLoosingCast = 811 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 812 if (isInformationLoosingCast) { 813 anyChange = true; 814 return castOp.source(); 815 } 816 } 817 return operand; 818 }; 819 auto newOperands = llvm::to_vector<8>( 820 llvm::map_range(op.getOperands(), canonicalizeOperand)); 821 822 // Rewrite op if any change required. 823 if (!anyChange) 824 return failure(); 825 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 826 return success(); 827 } 828 }; 829 830 struct BroadcastConcretizeResultTypePattern 831 : public OpRewritePattern<BroadcastOp> { 832 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 833 834 LogicalResult matchAndRewrite(BroadcastOp op, 835 PatternRewriter &rewriter) const override { 836 // Only concretize dynamic extent tensor result types. 837 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 838 if (!resultTy || !resultTy.isDynamicDim(0)) 839 return failure(); 840 841 // Infer resulting shape rank if possible. 842 int64_t maxRank = 0; 843 for (Value shape : op.getShapes()) { 844 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 845 // Cannot infer resulting shape rank if any operand is dynamically 846 // ranked. 847 if (extentTensorTy.isDynamicDim(0)) 848 return failure(); 849 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 850 } 851 } 852 853 auto newOp = rewriter.create<BroadcastOp>( 854 op.getLoc(), getExtentTensorType(getContext(), maxRank), 855 op.getShapes()); 856 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 857 return success(); 858 } 859 }; 860 } // namespace 861 862 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 863 MLIRContext *context) { 864 patterns.add<BroadcastConcretizeResultTypePattern, 865 BroadcastFoldConstantOperandsPattern, 866 BroadcastForwardSingleOperandPattern, 867 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 868 RemoveDuplicateOperandsPattern<BroadcastOp>, 869 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 870 } 871 872 //===----------------------------------------------------------------------===// 873 // ConcatOp 874 //===----------------------------------------------------------------------===// 875 876 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 877 if (!operands[0] || !operands[1]) 878 return nullptr; 879 auto lhsShape = llvm::to_vector<6>( 880 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 881 auto rhsShape = llvm::to_vector<6>( 882 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 883 SmallVector<int64_t, 6> resultShape; 884 resultShape.append(lhsShape.begin(), lhsShape.end()); 885 resultShape.append(rhsShape.begin(), rhsShape.end()); 886 Builder builder(getContext()); 887 return builder.getIndexTensorAttr(resultShape); 888 } 889 890 //===----------------------------------------------------------------------===// 891 // ConstShapeOp 892 //===----------------------------------------------------------------------===// 893 894 void ConstShapeOp::print(OpAsmPrinter &p) { 895 p << " "; 896 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 897 p << "["; 898 interleaveComma(getShape().getValues<int64_t>(), p); 899 p << "] : "; 900 p.printType(getType()); 901 } 902 903 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 904 if (parser.parseOptionalAttrDict(result.attributes)) 905 return failure(); 906 // We piggy-back on ArrayAttr parsing, though we don't internally store the 907 // shape as an ArrayAttr. 908 // TODO: Implement custom parser and maybe make syntax a bit more concise. 909 Attribute extentsRaw; 910 NamedAttrList dummy; 911 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 912 return failure(); 913 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 914 if (!extentsArray) 915 return failure(); 916 SmallVector<int64_t, 6> ints; 917 for (Attribute extent : extentsArray) { 918 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 919 if (!attr) 920 return failure(); 921 ints.push_back(attr.getInt()); 922 } 923 Builder &builder = parser.getBuilder(); 924 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 925 Type resultTy; 926 if (parser.parseColonType(resultTy)) 927 return failure(); 928 result.types.push_back(resultTy); 929 return success(); 930 } 931 932 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 933 934 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 935 MLIRContext *context) { 936 patterns.add<TensorCastConstShape>(context); 937 } 938 939 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 940 MLIRContext *context, Optional<Location> location, ValueRange operands, 941 DictionaryAttr attributes, RegionRange regions, 942 SmallVectorImpl<Type> &inferredReturnTypes) { 943 Builder b(context); 944 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 945 if (!shape) 946 return emitOptionalError(location, "missing shape attribute"); 947 inferredReturnTypes.assign({RankedTensorType::get( 948 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 949 return success(); 950 } 951 952 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 953 TypeRange r) { 954 if (l.size() != 1 || r.size() != 1) 955 return false; 956 957 Type lhs = l.front(); 958 Type rhs = r.front(); 959 960 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 961 // Shape type is compatible with all other valid return types. 962 return true; 963 return lhs == rhs; 964 } 965 966 //===----------------------------------------------------------------------===// 967 // CstrBroadcastableOp 968 //===----------------------------------------------------------------------===// 969 970 void CstrBroadcastableOp::getCanonicalizationPatterns( 971 RewritePatternSet &patterns, MLIRContext *context) { 972 // Canonicalization patterns have overlap with the considerations during 973 // folding in case additional shape information is inferred at some point that 974 // does not result in folding. 975 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 976 CstrBroadcastableEqOps, 977 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 978 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 979 } 980 981 // Return true if there is exactly one attribute not representing a scalar 982 // broadcast. 983 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 984 bool nonScalarSeen = false; 985 for (Attribute a : attributes) { 986 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 987 if (nonScalarSeen) 988 return false; 989 nonScalarSeen = true; 990 } 991 } 992 return true; 993 } 994 995 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 996 // No broadcasting is needed if all operands but one are scalar. 997 if (hasAtMostSingleNonScalar(operands)) 998 return BoolAttr::get(getContext(), true); 999 1000 if ([&] { 1001 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1002 for (const auto &operand : operands) { 1003 if (!operand) 1004 return false; 1005 extents.push_back(llvm::to_vector<6>( 1006 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 1007 } 1008 return OpTrait::util::staticallyKnownBroadcastable(extents); 1009 }()) 1010 return BoolAttr::get(getContext(), true); 1011 1012 // Lastly, see if folding can be completed based on what constraints are known 1013 // on the input shapes. 1014 if ([&] { 1015 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1016 for (auto shapeValue : getShapes()) { 1017 extents.emplace_back(); 1018 if (failed(getShapeVec(shapeValue, extents.back()))) 1019 return false; 1020 } 1021 return OpTrait::util::staticallyKnownBroadcastable(extents); 1022 }()) 1023 return BoolAttr::get(getContext(), true); 1024 1025 // Because a failing witness result here represents an eventual assertion 1026 // failure, we do not replace it with a constant witness. 1027 return nullptr; 1028 } 1029 1030 LogicalResult CstrBroadcastableOp::verify() { 1031 // Ensure that CstrBroadcastableOp contains at least two operands 1032 if (getNumOperands() < 2) 1033 return emitOpError("required at least 2 input shapes"); 1034 return success(); 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // CstrEqOp 1039 //===----------------------------------------------------------------------===// 1040 1041 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1042 MLIRContext *context) { 1043 // If inputs are equal, return passing witness 1044 patterns.add<CstrEqEqOps>(context); 1045 } 1046 1047 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1048 if (llvm::all_of(operands, 1049 [&](Attribute a) { return a && a == operands[0]; })) 1050 return BoolAttr::get(getContext(), true); 1051 1052 // Because a failing witness result here represents an eventual assertion 1053 // failure, we do not try to replace it with a constant witness. Similarly, we 1054 // cannot if there are any non-const inputs. 1055 return nullptr; 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // ConstSizeOp 1060 //===----------------------------------------------------------------------===// 1061 1062 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1063 int64_t value) { 1064 build(builder, result, builder.getIndexAttr(value)); 1065 } 1066 1067 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1068 1069 void ConstSizeOp::getAsmResultNames( 1070 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1071 SmallString<4> buffer; 1072 llvm::raw_svector_ostream os(buffer); 1073 os << "c" << getValue(); 1074 setNameFn(getResult(), os.str()); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // ConstWitnessOp 1079 //===----------------------------------------------------------------------===// 1080 1081 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1082 return getPassingAttr(); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // CstrRequireOp 1087 //===----------------------------------------------------------------------===// 1088 1089 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1090 return operands[0]; 1091 } 1092 1093 //===----------------------------------------------------------------------===// 1094 // DivOp 1095 //===----------------------------------------------------------------------===// 1096 1097 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1098 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1099 if (!lhs) 1100 return nullptr; 1101 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1102 if (!rhs) 1103 return nullptr; 1104 1105 // Division in APInt does not follow floor(lhs, rhs) when the result is 1106 // negative. Rather, APInt rounds toward zero. 1107 APInt quotient, remainder; 1108 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1109 if (quotient.isNegative() && !remainder.isNullValue()) { 1110 quotient -= 1; 1111 } 1112 1113 Type indexTy = IndexType::get(getContext()); 1114 return IntegerAttr::get(indexTy, quotient); 1115 } 1116 1117 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1118 MLIRContext *context, Optional<Location> location, ValueRange operands, 1119 DictionaryAttr attributes, RegionRange regions, 1120 SmallVectorImpl<Type> &inferredReturnTypes) { 1121 if (operands[0].getType().isa<SizeType>() || 1122 operands[1].getType().isa<SizeType>()) 1123 inferredReturnTypes.assign({SizeType::get(context)}); 1124 else 1125 inferredReturnTypes.assign({IndexType::get(context)}); 1126 return success(); 1127 } 1128 1129 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1130 // SizeType is compatible with IndexType. 1131 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1132 } 1133 1134 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1135 1136 //===----------------------------------------------------------------------===// 1137 // ShapeEqOp 1138 //===----------------------------------------------------------------------===// 1139 1140 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1141 bool allSame = true; 1142 if (!operands.empty() && !operands[0]) 1143 return {}; 1144 for (Attribute operand : operands.drop_front(1)) { 1145 if (!operand) 1146 return {}; 1147 allSame = allSame && operand == operands[0]; 1148 } 1149 return BoolAttr::get(getContext(), allSame); 1150 } 1151 1152 //===----------------------------------------------------------------------===// 1153 // IndexToSizeOp 1154 //===----------------------------------------------------------------------===// 1155 1156 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1157 // Constant values of both types, `shape.size` and `index`, are represented as 1158 // `IntegerAttr`s which makes constant folding simple. 1159 if (Attribute arg = operands[0]) 1160 return arg; 1161 return {}; 1162 } 1163 1164 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1165 MLIRContext *context) { 1166 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1167 } 1168 1169 //===----------------------------------------------------------------------===// 1170 // FromExtentsOp 1171 //===----------------------------------------------------------------------===// 1172 1173 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1174 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1175 return nullptr; 1176 SmallVector<int64_t, 6> extents; 1177 for (auto attr : operands) 1178 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1179 Builder builder(getContext()); 1180 return builder.getIndexTensorAttr(extents); 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // FunctionLibraryOp 1185 //===----------------------------------------------------------------------===// 1186 1187 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1188 StringRef name) { 1189 result.attributes.push_back(builder.getNamedAttr( 1190 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1191 } 1192 1193 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1194 auto attr = getMapping() 1195 .get(op->getName().getIdentifier()) 1196 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1197 if (!attr) 1198 return nullptr; 1199 return lookupSymbol<FuncOp>(attr); 1200 } 1201 1202 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1203 OperationState &result) { 1204 // Parse the op name. 1205 StringAttr nameAttr; 1206 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1207 result.attributes)) 1208 return failure(); 1209 1210 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1211 return failure(); 1212 1213 auto *bodyRegion = result.addRegion(); 1214 if (parser.parseRegion(*bodyRegion)) 1215 return failure(); 1216 1217 if (parser.parseKeyword("mapping")) 1218 return failure(); 1219 1220 DictionaryAttr mappingAttr; 1221 if (parser.parseAttribute(mappingAttr, 1222 parser.getBuilder().getType<NoneType>(), "mapping", 1223 result.attributes)) 1224 return failure(); 1225 return success(); 1226 } 1227 1228 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1229 p << ' '; 1230 p.printSymbolName(getName()); 1231 p.printOptionalAttrDictWithKeyword( 1232 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1233 p << ' '; 1234 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1235 /*printBlockTerminators=*/false); 1236 p << " mapping "; 1237 p.printAttributeWithoutType(getMappingAttr()); 1238 } 1239 1240 //===----------------------------------------------------------------------===// 1241 // GetExtentOp 1242 //===----------------------------------------------------------------------===// 1243 1244 Optional<int64_t> GetExtentOp::getConstantDim() { 1245 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1246 return constSizeOp.getValue().getLimitedValue(); 1247 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1248 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1249 return llvm::None; 1250 } 1251 1252 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1253 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1254 if (!elements) 1255 return nullptr; 1256 Optional<int64_t> dim = getConstantDim(); 1257 if (!dim.hasValue()) 1258 return nullptr; 1259 if (dim.getValue() >= elements.getNumElements()) 1260 return nullptr; 1261 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1262 } 1263 1264 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1265 int64_t dim) { 1266 auto loc = result.location; 1267 auto dimAttr = builder.getIndexAttr(dim); 1268 if (shape.getType().isa<ShapeType>()) { 1269 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1270 build(builder, result, builder.getType<SizeType>(), shape, dim); 1271 } else { 1272 Value dim = 1273 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1274 build(builder, result, builder.getIndexType(), shape, dim); 1275 } 1276 } 1277 1278 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1279 MLIRContext *context, Optional<Location> location, ValueRange operands, 1280 DictionaryAttr attributes, RegionRange regions, 1281 SmallVectorImpl<Type> &inferredReturnTypes) { 1282 inferredReturnTypes.assign({IndexType::get(context)}); 1283 return success(); 1284 } 1285 1286 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1287 TypeRange r) { 1288 // SizeType is compatible with IndexType. 1289 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1290 } 1291 1292 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1293 1294 //===----------------------------------------------------------------------===// 1295 // IsBroadcastableOp 1296 //===----------------------------------------------------------------------===// 1297 1298 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1299 MLIRContext *context) { 1300 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1301 } 1302 1303 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1304 // Can always broadcast fewer than two shapes. 1305 if (operands.size() < 2) { 1306 return BoolAttr::get(getContext(), true); 1307 } 1308 1309 return nullptr; 1310 } 1311 1312 //===----------------------------------------------------------------------===// 1313 // MeetOp 1314 //===----------------------------------------------------------------------===// 1315 1316 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1317 MLIRContext *context, Optional<Location> location, ValueRange operands, 1318 DictionaryAttr attributes, RegionRange regions, 1319 SmallVectorImpl<Type> &inferredReturnTypes) { 1320 inferredReturnTypes.assign({operands[0].getType()}); 1321 return success(); 1322 } 1323 1324 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1325 if (l.size() != 1 || r.size() != 1) 1326 return false; 1327 if (l == r) 1328 return true; 1329 1330 Type lhs = l.front(); 1331 Type rhs = r.front(); 1332 1333 if (lhs != rhs) 1334 return false; 1335 1336 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1337 return true; 1338 1339 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1340 return true; 1341 return false; 1342 } 1343 1344 //===----------------------------------------------------------------------===// 1345 // RankOp 1346 //===----------------------------------------------------------------------===// 1347 1348 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1349 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1350 if (!shape) 1351 return {}; 1352 int64_t rank = shape.getNumElements(); 1353 Builder builder(getContext()); 1354 return builder.getIndexAttr(rank); 1355 } 1356 1357 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1358 /// Constant folding fails in cases where only the rank is constant, not the 1359 /// shape itself. 1360 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1361 /// 1362 /// Example: 1363 /// 1364 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1365 /// %rank = shape.rank %shape 1366 /// 1367 /// becomes 1368 /// 1369 /// %rank = shape.const_size 3 1370 1371 namespace { 1372 struct RankShapeOfCanonicalizationPattern 1373 : public OpRewritePattern<shape::RankOp> { 1374 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1375 1376 LogicalResult matchAndRewrite(shape::RankOp op, 1377 PatternRewriter &rewriter) const override { 1378 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1379 if (!shapeOfOp) 1380 return failure(); 1381 auto rankedTensorType = 1382 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1383 if (!rankedTensorType) 1384 return failure(); 1385 int64_t rank = rankedTensorType.getRank(); 1386 if (op.getType().isa<IndexType>()) { 1387 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1388 rank); 1389 } else if (op.getType().isa<shape::SizeType>()) { 1390 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1391 } else { 1392 return failure(); 1393 } 1394 return success(); 1395 } 1396 }; 1397 } // namespace 1398 1399 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1400 MLIRContext *context) { 1401 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1402 } 1403 1404 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1405 MLIRContext *context, Optional<Location> location, ValueRange operands, 1406 DictionaryAttr attributes, RegionRange regions, 1407 SmallVectorImpl<Type> &inferredReturnTypes) { 1408 if (operands[0].getType().isa<ShapeType>()) 1409 inferredReturnTypes.assign({SizeType::get(context)}); 1410 else 1411 inferredReturnTypes.assign({IndexType::get(context)}); 1412 return success(); 1413 } 1414 1415 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1416 // SizeType is compatible with IndexType. 1417 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1418 } 1419 1420 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1421 1422 //===----------------------------------------------------------------------===// 1423 // NumElementsOp 1424 //===----------------------------------------------------------------------===// 1425 1426 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1427 1428 // Fold only when argument constant. 1429 Attribute shape = operands[0]; 1430 if (!shape) 1431 return {}; 1432 1433 APInt product(64, 1); 1434 for (auto value : shape.cast<DenseIntElementsAttr>()) 1435 product *= value; 1436 Builder builder(getContext()); 1437 return builder.getIndexAttr(product.getLimitedValue()); 1438 } 1439 1440 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1441 MLIRContext *context, Optional<Location> location, ValueRange operands, 1442 DictionaryAttr attributes, RegionRange regions, 1443 SmallVectorImpl<Type> &inferredReturnTypes) { 1444 if (operands[0].getType().isa<ShapeType>()) 1445 inferredReturnTypes.assign({SizeType::get(context)}); 1446 else 1447 inferredReturnTypes.assign({IndexType::get(context)}); 1448 return success(); 1449 } 1450 1451 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1452 TypeRange r) { 1453 // SizeType is compatible with IndexType. 1454 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1455 } 1456 1457 LogicalResult shape::NumElementsOp::verify() { 1458 return verifySizeOrIndexOp(*this); 1459 } 1460 1461 //===----------------------------------------------------------------------===// 1462 // MaxOp 1463 //===----------------------------------------------------------------------===// 1464 1465 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1466 // If operands are equal, just propagate one. 1467 if (getLhs() == getRhs()) 1468 return getLhs(); 1469 return nullptr; 1470 } 1471 1472 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1473 MLIRContext *context, Optional<Location> location, ValueRange operands, 1474 DictionaryAttr attributes, RegionRange regions, 1475 SmallVectorImpl<Type> &inferredReturnTypes) { 1476 if (operands[0].getType() == operands[1].getType()) 1477 inferredReturnTypes.assign({operands[0].getType()}); 1478 else 1479 inferredReturnTypes.assign({SizeType::get(context)}); 1480 return success(); 1481 } 1482 1483 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1484 if (l.size() != 1 || r.size() != 1) 1485 return false; 1486 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1487 return true; 1488 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1489 return true; 1490 return false; 1491 } 1492 1493 //===----------------------------------------------------------------------===// 1494 // MinOp 1495 //===----------------------------------------------------------------------===// 1496 1497 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1498 // If operands are equal, just propagate one. 1499 if (getLhs() == getRhs()) 1500 return getLhs(); 1501 return nullptr; 1502 } 1503 1504 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1505 MLIRContext *context, Optional<Location> location, ValueRange operands, 1506 DictionaryAttr attributes, RegionRange regions, 1507 SmallVectorImpl<Type> &inferredReturnTypes) { 1508 if (operands[0].getType() == operands[1].getType()) 1509 inferredReturnTypes.assign({operands[0].getType()}); 1510 else 1511 inferredReturnTypes.assign({SizeType::get(context)}); 1512 return success(); 1513 } 1514 1515 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1516 if (l.size() != 1 || r.size() != 1) 1517 return false; 1518 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1519 return true; 1520 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1521 return true; 1522 return false; 1523 } 1524 1525 //===----------------------------------------------------------------------===// 1526 // MulOp 1527 //===----------------------------------------------------------------------===// 1528 1529 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1530 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1531 if (!lhs) 1532 return nullptr; 1533 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1534 if (!rhs) 1535 return nullptr; 1536 APInt folded = lhs.getValue() * rhs.getValue(); 1537 Type indexTy = IndexType::get(getContext()); 1538 return IntegerAttr::get(indexTy, folded); 1539 } 1540 1541 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1542 MLIRContext *context, Optional<Location> location, ValueRange operands, 1543 DictionaryAttr attributes, RegionRange regions, 1544 SmallVectorImpl<Type> &inferredReturnTypes) { 1545 if (operands[0].getType().isa<SizeType>() || 1546 operands[1].getType().isa<SizeType>()) 1547 inferredReturnTypes.assign({SizeType::get(context)}); 1548 else 1549 inferredReturnTypes.assign({IndexType::get(context)}); 1550 return success(); 1551 } 1552 1553 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1554 // SizeType is compatible with IndexType. 1555 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1556 } 1557 1558 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1559 1560 //===----------------------------------------------------------------------===// 1561 // ShapeOfOp 1562 //===----------------------------------------------------------------------===// 1563 1564 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1565 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1566 if (!type || !type.hasStaticShape()) 1567 return nullptr; 1568 Builder builder(getContext()); 1569 return builder.getIndexTensorAttr(type.getShape()); 1570 } 1571 1572 namespace { 1573 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1574 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1575 1576 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1577 PatternRewriter &rewriter) const override { 1578 if (!op.getArg().getType().isa<ShapedType>()) 1579 return failure(); 1580 if (op.getType().isa<ShapedType>()) 1581 return failure(); 1582 1583 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1584 op.getArg()); 1585 return success(); 1586 } 1587 }; 1588 1589 // Canonicalize 1590 // ``` 1591 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1592 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1593 // ``` 1594 // to 1595 // ``` 1596 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1597 // ``` 1598 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1599 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1600 1601 LogicalResult matchAndRewrite(tensor::CastOp op, 1602 PatternRewriter &rewriter) const override { 1603 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1604 if (!ty || ty.getRank() != 1) 1605 return failure(); 1606 1607 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1608 if (!shapeOfOp) 1609 return failure(); 1610 1611 // Argument type must be ranked and must not conflict. 1612 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1613 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1614 return failure(); 1615 1616 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1617 return success(); 1618 } 1619 }; 1620 } // namespace 1621 1622 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1623 MLIRContext *context) { 1624 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1625 ExtractFromShapeOfExtentTensor>(context); 1626 } 1627 1628 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1629 MLIRContext *context, Optional<Location> location, ValueRange operands, 1630 DictionaryAttr attributes, RegionRange regions, 1631 SmallVectorImpl<Type> &inferredReturnTypes) { 1632 if (operands[0].getType().isa<ValueShapeType>()) 1633 inferredReturnTypes.assign({ShapeType::get(context)}); 1634 else { 1635 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1636 int64_t rank = 1637 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1638 Type indexTy = IndexType::get(context); 1639 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1640 inferredReturnTypes.assign({extentTensorTy}); 1641 } 1642 return success(); 1643 } 1644 1645 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1646 if (l.size() != 1 || r.size() != 1) 1647 return false; 1648 if (l == r) 1649 return true; 1650 1651 Type lhs = l.front(); 1652 Type rhs = r.front(); 1653 1654 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1655 return false; 1656 1657 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1658 // Shape type is compatible with all other valid return types. 1659 return true; 1660 1661 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1662 return true; 1663 return false; 1664 } 1665 1666 LogicalResult shape::ShapeOfOp::verify() { 1667 return verifyShapeOrExtentTensorOp(*this); 1668 } 1669 1670 //===----------------------------------------------------------------------===// 1671 // SizeToIndexOp 1672 //===----------------------------------------------------------------------===// 1673 1674 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1675 // Constant values of both types, `shape.size` and `index`, are represented as 1676 // `IntegerAttr`s which makes constant folding simple. 1677 if (Attribute arg = operands[0]) 1678 return arg; 1679 return OpFoldResult(); 1680 } 1681 1682 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1683 MLIRContext *context) { 1684 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1685 } 1686 1687 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1688 if (inputs.size() != 1 || outputs.size() != 1) 1689 return false; 1690 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1691 } 1692 1693 //===----------------------------------------------------------------------===// 1694 // YieldOp 1695 //===----------------------------------------------------------------------===// 1696 1697 LogicalResult shape::YieldOp::verify() { 1698 auto *parentOp = (*this)->getParentOp(); 1699 auto results = parentOp->getResults(); 1700 auto operands = getOperands(); 1701 1702 if (parentOp->getNumResults() != getNumOperands()) 1703 return emitOpError() << "number of operands does not match number of " 1704 "results of its parent"; 1705 for (auto e : llvm::zip(results, operands)) 1706 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1707 return emitOpError() << "types mismatch between yield op and its parent"; 1708 1709 return success(); 1710 } 1711 1712 //===----------------------------------------------------------------------===// 1713 // SplitAtOp 1714 //===----------------------------------------------------------------------===// 1715 1716 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1717 SmallVectorImpl<OpFoldResult> &results) { 1718 if (!operands[0] || !operands[1]) 1719 return failure(); 1720 auto shapeVec = llvm::to_vector<6>( 1721 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1722 auto shape = llvm::makeArrayRef(shapeVec); 1723 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1724 // Verify that the split point is in the correct range. 1725 // TODO: Constant fold to an "error". 1726 int64_t rank = shape.size(); 1727 if (!(-rank <= splitPoint && splitPoint <= rank)) 1728 return failure(); 1729 if (splitPoint < 0) 1730 splitPoint += shape.size(); 1731 Builder builder(operands[0].getContext()); 1732 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1733 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1734 return success(); 1735 } 1736 1737 //===----------------------------------------------------------------------===// 1738 // ToExtentTensorOp 1739 //===----------------------------------------------------------------------===// 1740 1741 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1742 if (!operands[0]) 1743 return OpFoldResult(); 1744 Builder builder(getContext()); 1745 auto shape = llvm::to_vector<6>( 1746 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1747 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1748 builder.getIndexType()); 1749 return DenseIntElementsAttr::get(type, shape); 1750 } 1751 1752 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1753 if (inputs.size() != 1 || outputs.size() != 1) 1754 return false; 1755 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1756 if (!inputTensor.getElementType().isa<IndexType>() || 1757 inputTensor.getRank() != 1) 1758 return false; 1759 } else if (!inputs[0].isa<ShapeType>()) { 1760 return false; 1761 } 1762 1763 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1764 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1765 } 1766 1767 //===----------------------------------------------------------------------===// 1768 // ReduceOp 1769 //===----------------------------------------------------------------------===// 1770 1771 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1772 ValueRange initVals) { 1773 result.addOperands(shape); 1774 result.addOperands(initVals); 1775 1776 Region *bodyRegion = result.addRegion(); 1777 bodyRegion->push_back(new Block); 1778 Block &bodyBlock = bodyRegion->front(); 1779 bodyBlock.addArgument(builder.getIndexType(), result.location); 1780 1781 Type elementType; 1782 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1783 elementType = tensorType.getElementType(); 1784 else 1785 elementType = SizeType::get(builder.getContext()); 1786 bodyBlock.addArgument(elementType, shape.getLoc()); 1787 1788 for (Value initVal : initVals) { 1789 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1790 result.addTypes(initVal.getType()); 1791 } 1792 } 1793 1794 LogicalResult ReduceOp::verify() { 1795 // Verify block arg types. 1796 Block &block = getRegion().front(); 1797 1798 // The block takes index, extent, and aggregated values as arguments. 1799 auto blockArgsCount = getInitVals().size() + 2; 1800 if (block.getNumArguments() != blockArgsCount) 1801 return emitOpError() << "ReduceOp body is expected to have " 1802 << blockArgsCount << " arguments"; 1803 1804 // The first block argument is the index and must always be of type `index`. 1805 if (!block.getArgument(0).getType().isa<IndexType>()) 1806 return emitOpError( 1807 "argument 0 of ReduceOp body is expected to be of IndexType"); 1808 1809 // The second block argument is the extent and must be of type `size` or 1810 // `index`, depending on whether the reduce operation is applied to a shape or 1811 // to an extent tensor. 1812 Type extentTy = block.getArgument(1).getType(); 1813 if (getShape().getType().isa<ShapeType>()) { 1814 if (!extentTy.isa<SizeType>()) 1815 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1816 "SizeType if the ReduceOp operates on a ShapeType"); 1817 } else { 1818 if (!extentTy.isa<IndexType>()) 1819 return emitOpError( 1820 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1821 "ReduceOp operates on an extent tensor"); 1822 } 1823 1824 for (const auto &type : llvm::enumerate(getInitVals())) 1825 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1826 return emitOpError() << "type mismatch between argument " 1827 << type.index() + 2 1828 << " of ReduceOp body and initial value " 1829 << type.index(); 1830 return success(); 1831 } 1832 1833 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1834 // Parse operands. 1835 SmallVector<OpAsmParser::OperandType, 3> operands; 1836 Type shapeOrExtentTensorType; 1837 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1838 OpAsmParser::Delimiter::Paren) || 1839 parser.parseColonType(shapeOrExtentTensorType) || 1840 parser.parseOptionalArrowTypeList(result.types)) 1841 return failure(); 1842 1843 // Resolve operands. 1844 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1845 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1846 result.operands) || 1847 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1848 result.operands)) 1849 return failure(); 1850 1851 // Parse the body. 1852 Region *body = result.addRegion(); 1853 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1854 return failure(); 1855 1856 // Parse attributes. 1857 if (parser.parseOptionalAttrDict(result.attributes)) 1858 return failure(); 1859 1860 return success(); 1861 } 1862 1863 void ReduceOp::print(OpAsmPrinter &p) { 1864 p << '(' << getShape() << ", " << getInitVals() 1865 << ") : " << getShape().getType(); 1866 p.printOptionalArrowTypeList(getResultTypes()); 1867 p << ' '; 1868 p.printRegion(getRegion()); 1869 p.printOptionalAttrDict((*this)->getAttrs()); 1870 } 1871 1872 #define GET_OP_CLASSES 1873 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1874