1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/SetOperations.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::shape; 31 32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 33 34 namespace { 35 #include "ShapeCanonicalization.inc" 36 } // namespace 37 38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 39 return RankedTensorType::get({rank}, IndexType::get(ctx)); 40 } 41 42 bool shape::isExtentTensorType(Type type) { 43 auto ranked = type.dyn_cast<RankedTensorType>(); 44 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 45 } 46 47 LogicalResult shape::getShapeVec(Value input, 48 SmallVectorImpl<int64_t> &shapeValues) { 49 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 50 auto type = inputOp.getArg().getType().cast<ShapedType>(); 51 if (!type.hasRank()) 52 return failure(); 53 llvm::append_range(shapeValues, type.getShape()); 54 return success(); 55 } 56 DenseIntElementsAttr attr; 57 if (matchPattern(input, m_Constant(&attr))) { 58 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 59 return success(); 60 } 61 return failure(); 62 } 63 64 static bool isErrorPropagationPossible(TypeRange operandTypes) { 65 return llvm::any_of(operandTypes, [](Type ty) { 66 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 67 }); 68 } 69 70 static LogicalResult verifySizeOrIndexOp(Operation *op) { 71 assert(op != nullptr && op->getNumResults() == 1); 72 Type resultTy = op->getResultTypes().front(); 73 if (isErrorPropagationPossible(op->getOperandTypes())) { 74 if (!resultTy.isa<SizeType>()) 75 return op->emitOpError() 76 << "if at least one of the operands can hold error values then " 77 "the result must be of type `size` to propagate them"; 78 } 79 return success(); 80 } 81 82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 83 assert(op != nullptr && op->getNumResults() == 1); 84 Type resultTy = op->getResultTypes().front(); 85 if (isErrorPropagationPossible(op->getOperandTypes())) { 86 if (!resultTy.isa<ShapeType>()) 87 return op->emitOpError() 88 << "if at least one of the operands can hold error values then " 89 "the result must be of type `shape` to propagate them"; 90 } 91 return success(); 92 } 93 94 template <typename... Ty> 95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 96 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 97 } 98 99 template <typename... Ty, typename... ranges> 100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 101 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // InlinerInterface 106 //===----------------------------------------------------------------------===// 107 108 namespace { 109 /// This class defines the interface for inlining shape dialect ops. 110 struct ShapeInlinerInterface : public DialectInlinerInterface { 111 using DialectInlinerInterface::DialectInlinerInterface; 112 113 // Returns true if the given region 'src' can be inlined into the region 114 // 'dest' that is attached to an operation registered to the current dialect. 115 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 116 BlockAndValueMapping &) const final { 117 return true; 118 } 119 120 // Returns true if the given operation 'op', that is registered to this 121 // dialect, can be inlined into the region 'dest' that is attached to an 122 // operation registered to the current dialect. 123 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 124 BlockAndValueMapping &) const final { 125 return true; 126 } 127 }; 128 } // namespace 129 130 void ShapeDialect::initialize() { 131 addOperations< 132 #define GET_OP_LIST 133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 134 >(); 135 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 136 addInterfaces<ShapeInlinerInterface>(); 137 // Allow unknown operations during prototyping and testing. As the dialect is 138 // still evolving it makes it simple to start with an unregistered ops and 139 // try different variants before actually defining the op. 140 allowUnknownOperations(); 141 } 142 143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 144 Attribute value, Type type, 145 Location loc) { 146 if (type.isa<ShapeType>() || isExtentTensorType(type)) 147 return builder.create<ConstShapeOp>(loc, type, 148 value.cast<DenseIntElementsAttr>()); 149 if (type.isa<SizeType>()) 150 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 151 if (type.isa<WitnessType>()) 152 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 153 if (arith::ConstantOp::isBuildableWith(value, type)) 154 return builder.create<arith::ConstantOp>(loc, type, value); 155 return nullptr; 156 } 157 158 /// Parse a type registered to this dialect. 159 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 160 StringRef keyword; 161 if (parser.parseKeyword(&keyword)) 162 return Type(); 163 164 if (keyword == "shape") 165 return ShapeType::get(getContext()); 166 if (keyword == "size") 167 return SizeType::get(getContext()); 168 if (keyword == "value_shape") 169 return ValueShapeType::get(getContext()); 170 if (keyword == "witness") 171 return WitnessType::get(getContext()); 172 173 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 174 return Type(); 175 } 176 177 /// Print a type registered to this dialect. 178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 179 TypeSwitch<Type>(type) 180 .Case<ShapeType>([&](Type) { os << "shape"; }) 181 .Case<SizeType>([&](Type) { os << "size"; }) 182 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 183 .Case<WitnessType>([&](Type) { os << "witness"; }) 184 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 185 } 186 187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 188 NamedAttribute attribute) { 189 // Verify shape.lib attribute. 190 if (attribute.getName() == "shape.lib") { 191 if (!op->hasTrait<OpTrait::SymbolTable>()) 192 return op->emitError( 193 "shape.lib attribute may only be on op implementing SymbolTable"); 194 195 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 196 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 197 if (!symbol) 198 return op->emitError("shape function library ") 199 << symbolRef << " not found"; 200 return isa<shape::FunctionLibraryOp>(symbol) 201 ? success() 202 : op->emitError() 203 << symbolRef << " required to be shape function library"; 204 } 205 206 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 207 // Verify all entries are function libraries and mappings in libraries 208 // refer to unique ops. 209 DenseSet<StringAttr> key; 210 for (auto it : arr) { 211 if (!it.isa<SymbolRefAttr>()) 212 return op->emitError( 213 "only SymbolRefAttr allowed in shape.lib attribute array"); 214 215 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 216 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 217 if (!shapeFnLib) 218 return op->emitError() 219 << it << " does not refer to FunctionLibraryOp"; 220 for (auto mapping : shapeFnLib.getMapping()) { 221 if (!key.insert(mapping.getName()).second) { 222 return op->emitError("only one op to shape mapping allowed, found " 223 "multiple for `") 224 << mapping.getName() << "`"; 225 } 226 } 227 } 228 return success(); 229 } 230 231 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 232 "allowed as shape.lib attribute"); 233 } 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // AnyOp 239 //===----------------------------------------------------------------------===// 240 241 // TODO: Canonicalization should be implemented for shapes that can be 242 // determined through mixtures of the known dimensions of the inputs. 243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 244 // Only the last operand is checked because AnyOp is commutative. 245 if (operands.back()) 246 return operands.back(); 247 248 return nullptr; 249 } 250 251 //===----------------------------------------------------------------------===// 252 // AssumingOp 253 //===----------------------------------------------------------------------===// 254 255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 256 result.regions.reserve(1); 257 Region *doRegion = result.addRegion(); 258 259 auto &builder = parser.getBuilder(); 260 OpAsmParser::OperandType cond; 261 if (parser.parseOperand(cond) || 262 parser.resolveOperand(cond, builder.getType<WitnessType>(), 263 result.operands)) 264 return failure(); 265 266 // Parse optional results type list. 267 if (parser.parseOptionalArrowTypeList(result.types)) 268 return failure(); 269 270 // Parse the region and add a terminator if elided. 271 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 272 return failure(); 273 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 274 275 // Parse the optional attribute list. 276 if (parser.parseOptionalAttrDict(result.attributes)) 277 return failure(); 278 return success(); 279 } 280 281 void AssumingOp::print(OpAsmPrinter &p) { 282 bool yieldsResults = !getResults().empty(); 283 284 p << " " << getWitness(); 285 if (yieldsResults) 286 p << " -> (" << getResultTypes() << ")"; 287 p << ' '; 288 p.printRegion(getDoRegion(), 289 /*printEntryBlockArgs=*/false, 290 /*printBlockTerminators=*/yieldsResults); 291 p.printOptionalAttrDict((*this)->getAttrs()); 292 } 293 294 namespace { 295 // Removes AssumingOp with a passing witness and inlines the region. 296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 297 using OpRewritePattern<AssumingOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(AssumingOp op, 300 PatternRewriter &rewriter) const override { 301 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 302 if (!witness || !witness.getPassingAttr()) 303 return failure(); 304 305 AssumingOp::inlineRegionIntoParent(op, rewriter); 306 return success(); 307 } 308 }; 309 310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 311 using OpRewritePattern<AssumingOp>::OpRewritePattern; 312 313 LogicalResult matchAndRewrite(AssumingOp op, 314 PatternRewriter &rewriter) const override { 315 Block *body = op.getBody(); 316 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 317 318 // Find used values. 319 SmallVector<Value, 4> newYieldOperands; 320 Value opResult, yieldOperand; 321 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 322 std::tie(opResult, yieldOperand) = it; 323 if (!opResult.getUses().empty()) { 324 newYieldOperands.push_back(yieldOperand); 325 } 326 } 327 328 // Rewrite only if redundant results exist. 329 if (newYieldOperands.size() == yieldOp->getNumOperands()) 330 return failure(); 331 332 // Replace yield op in the old assuming op's body and move the entire region 333 // to the new assuming op. 334 rewriter.setInsertionPointToEnd(body); 335 auto newYieldOp = 336 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 337 rewriter.setInsertionPoint(op); 338 auto newOp = rewriter.create<AssumingOp>( 339 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 340 newOp.getDoRegion().takeBody(op.getDoRegion()); 341 342 // Use the new results to replace the previously used ones. 343 SmallVector<Value, 4> replacementValues; 344 auto src = newOp.getResults().begin(); 345 for (auto it : op.getResults()) { 346 if (it.getUses().empty()) 347 replacementValues.push_back(nullptr); 348 else 349 replacementValues.push_back(*src++); 350 } 351 rewriter.replaceOp(op, replacementValues); 352 return success(); 353 } 354 }; 355 } // namespace 356 357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 358 MLIRContext *context) { 359 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 360 } 361 362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 363 void AssumingOp::getSuccessorRegions( 364 Optional<unsigned> index, ArrayRef<Attribute> operands, 365 SmallVectorImpl<RegionSuccessor> ®ions) { 366 // AssumingOp has unconditional control flow into the region and back to the 367 // parent, so return the correct RegionSuccessor purely based on the index 368 // being None or 0. 369 if (index.hasValue()) { 370 regions.push_back(RegionSuccessor(getResults())); 371 return; 372 } 373 374 regions.push_back(RegionSuccessor(&getDoRegion())); 375 } 376 377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 378 PatternRewriter &rewriter) { 379 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 380 auto *assumingBlock = op.getBody(); 381 auto initPosition = rewriter.getInsertionPoint(); 382 auto *blockAfterAssuming = 383 rewriter.splitBlock(blockBeforeAssuming, initPosition); 384 385 // Remove the AssumingOp and AssumingYieldOp. 386 auto &yieldOp = assumingBlock->back(); 387 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 388 rewriter.replaceOp(op, yieldOp.getOperands()); 389 rewriter.eraseOp(&yieldOp); 390 391 // Merge blocks together as there was no branching behavior from the 392 // AssumingOp. 393 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 394 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 395 } 396 397 void AssumingOp::build( 398 OpBuilder &builder, OperationState &result, Value witness, 399 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 400 401 result.addOperands(witness); 402 Region *bodyRegion = result.addRegion(); 403 bodyRegion->push_back(new Block); 404 Block &bodyBlock = bodyRegion->front(); 405 406 // Build body. 407 OpBuilder::InsertionGuard guard(builder); 408 builder.setInsertionPointToStart(&bodyBlock); 409 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 410 builder.create<AssumingYieldOp>(result.location, yieldValues); 411 412 SmallVector<Type, 2> assumingTypes; 413 for (Value v : yieldValues) 414 assumingTypes.push_back(v.getType()); 415 result.addTypes(assumingTypes); 416 } 417 418 LogicalResult AssumingOp::verify() { 419 return RegionBranchOpInterface::verifyTypes(*this); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // AddOp 424 //===----------------------------------------------------------------------===// 425 426 LogicalResult mlir::shape::AddOp::inferReturnTypes( 427 MLIRContext *context, Optional<Location> location, ValueRange operands, 428 DictionaryAttr attributes, RegionRange regions, 429 SmallVectorImpl<Type> &inferredReturnTypes) { 430 if (operands[0].getType().isa<SizeType>() || 431 operands[1].getType().isa<SizeType>()) 432 inferredReturnTypes.assign({SizeType::get(context)}); 433 else 434 inferredReturnTypes.assign({IndexType::get(context)}); 435 return success(); 436 } 437 438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 439 // SizeType is compatible with IndexType. 440 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 441 } 442 443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 444 // add(x, 0) -> x 445 if (matchPattern(getRhs(), m_Zero())) 446 return getLhs(); 447 448 return constFoldBinaryOp<IntegerAttr>( 449 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 450 } 451 452 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 453 454 //===----------------------------------------------------------------------===// 455 // AssumingAllOp 456 //===----------------------------------------------------------------------===// 457 458 namespace { 459 460 // Merge multiple `shape.assuming_all` operations together. 461 // 462 // %0 = shape.assuming_all %w0, %w1 463 // %1 = shape.assuming_all %w2, %0 464 // 465 // to: 466 // 467 // %0 = shape.assuming_all %w0, %w2, %w2 468 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 469 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 470 471 LogicalResult matchAndRewrite(AssumingAllOp op, 472 PatternRewriter &rewriter) const override { 473 SmallVector<Value> operands; 474 475 for (Value operand : op.getInputs()) { 476 if (auto assume_all = operand.getDefiningOp<AssumingAllOp>()) 477 operands.append(assume_all.operand_begin(), assume_all->operand_end()); 478 else 479 operands.push_back(operand); 480 } 481 482 // We didn't find any other `assuming_all` ops to merge with. 483 if (operands.size() == op.getNumOperands()) 484 return failure(); 485 486 // Replace with a new `assuming_all` operation with merged constraints. 487 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 488 return success(); 489 } 490 }; 491 492 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 493 // are subsumed by others. 494 // 495 // %0 = shape.cstr_broadcastable %shape0, %shape1 496 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 497 // 498 // %2 = shape.cstr_broadcastable %shape3, %shape4 499 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 500 // 501 // %4 = shape.assuming_all %0, %1, %2, %3 502 // 503 // to: 504 // 505 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 506 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 507 // %2 = shape.assuming_all %0, %1 508 // 509 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 510 // shapes [0, 1] are broadcastable too, and can be removed from the list of 511 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 512 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 513 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 514 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 515 516 LogicalResult matchAndRewrite(AssumingAllOp op, 517 PatternRewriter &rewriter) const override { 518 // Collect all `CstrBroadcastableOp` operands first. 519 SetVector<CstrBroadcastableOp> operands; 520 for (Value operand : op.getInputs()) { 521 // TODO: Apply this optimization if some of the witnesses are not 522 // produced by the `cstr_broadcastable`. 523 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 524 if (!broadcastable) 525 return failure(); 526 527 operands.insert(broadcastable); 528 } 529 530 // Skip trivial `assuming_all` operations. 531 if (operands.size() <= 1) 532 return failure(); 533 534 // Collect shapes checked by `cstr_broadcastable` operands. 535 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 536 for (auto cstr : operands) { 537 DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end()); 538 shapes.emplace_back(cstr, std::move(shapes_set)); 539 } 540 541 // Sort by the number of shape operands (larger to smaller). 542 llvm::sort(shapes, [](auto a, auto b) { 543 return a.first.getNumOperands() > b.first.getNumOperands(); 544 }); 545 546 // We start from the `cst_broadcastable` operations with largest number of 547 // shape operands, and remove redundant `cst_broadcastable` operations. We 548 // do this until we find a set of `cst_broadcastable` operations with 549 // non-overlapping constraints. 550 SmallVector<CstrBroadcastableOp> marked_for_erase; 551 552 for (unsigned i = 0; i < shapes.size(); ++i) { 553 auto isSubset = [&](auto pair) { 554 return llvm::set_is_subset(pair.second, shapes[i].second); 555 }; 556 557 // Keep redundant `cstr_broadcastable` operations to be erased. 558 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 559 for (auto *it0 = it; it0 < shapes.end(); ++it0) 560 marked_for_erase.push_back(it0->first); 561 shapes.erase(it, shapes.end()); 562 } 563 564 // We didn't find any operands that could be removed. 565 if (marked_for_erase.empty()) 566 return failure(); 567 568 // Collect non-overlapping `cst_broadcastable` constraints. 569 SmallVector<Value> unique_constraints; 570 for (auto &shape : shapes) 571 unique_constraints.push_back(shape.first.getResult()); 572 573 // Replace with a new `assuming_all` operation ... 574 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints); 575 576 // ... and maybe erase `cstr_broadcastable` ops without uses. 577 for (auto &op : marked_for_erase) 578 if (op->use_empty()) 579 rewriter.eraseOp(op); 580 581 return success(); 582 } 583 }; 584 585 struct AssumingAllToCstrEqCanonicalization 586 : public OpRewritePattern<AssumingAllOp> { 587 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 588 589 LogicalResult matchAndRewrite(AssumingAllOp op, 590 PatternRewriter &rewriter) const override { 591 SmallVector<Value, 8> shapes; 592 for (Value w : op.getInputs()) { 593 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 594 if (!cstrEqOp) 595 return failure(); 596 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 597 return llvm::is_contained(shapes, s); 598 }); 599 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 600 return failure(); 601 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 602 } 603 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 604 return success(); 605 } 606 }; 607 608 template <typename OpTy> 609 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 610 using OpRewritePattern<OpTy>::OpRewritePattern; 611 612 LogicalResult matchAndRewrite(OpTy op, 613 PatternRewriter &rewriter) const override { 614 // Find unique operands. 615 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 616 617 // Reduce op to equivalent with unique operands. 618 if (unique.size() < op.getNumOperands()) { 619 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 620 unique.takeVector(), op->getAttrs()); 621 return success(); 622 } 623 624 return failure(); 625 } 626 }; 627 } // namespace 628 629 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 630 MLIRContext *context) { 631 patterns 632 .add<MergeAssumingAllOps, AssumingAllOneOp, 633 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 634 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 635 } 636 637 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 638 // Iterate in reverse to first handle all constant operands. They are 639 // guaranteed to be the tail of the inputs because this is commutative. 640 for (int idx = operands.size() - 1; idx >= 0; idx--) { 641 Attribute a = operands[idx]; 642 // Cannot fold if any inputs are not constant; 643 if (!a) 644 return nullptr; 645 646 // We do not need to keep statically known values after handling them in 647 // this method. 648 getOperation()->eraseOperand(idx); 649 650 // Always false if any input is statically known false 651 if (!a.cast<BoolAttr>().getValue()) 652 return a; 653 } 654 // If this is reached, all inputs were statically known passing. 655 return BoolAttr::get(getContext(), true); 656 } 657 658 LogicalResult AssumingAllOp::verify() { 659 // Ensure that AssumingAllOp contains at least one operand 660 if (getNumOperands() == 0) 661 return emitOpError("no operands specified"); 662 663 return success(); 664 } 665 666 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 667 ValueRange inputs) { 668 build(b, state, b.getType<WitnessType>(), inputs); 669 } 670 671 //===----------------------------------------------------------------------===// 672 // BroadcastOp 673 //===----------------------------------------------------------------------===// 674 675 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 676 if (getShapes().size() == 1) { 677 // Otherwise, we need a cast which would be a canonicalization, not folding. 678 if (getShapes().front().getType() != getType()) 679 return nullptr; 680 return getShapes().front(); 681 } 682 683 // TODO: Support folding with more than 2 input shapes 684 if (getShapes().size() > 2) 685 return nullptr; 686 687 if (!operands[0] || !operands[1]) 688 return nullptr; 689 auto lhsShape = llvm::to_vector<6>( 690 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 691 auto rhsShape = llvm::to_vector<6>( 692 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 693 SmallVector<int64_t, 6> resultShape; 694 695 // If the shapes are not compatible, we can't fold it. 696 // TODO: Fold to an "error". 697 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 698 return nullptr; 699 700 Builder builder(getContext()); 701 return builder.getIndexTensorAttr(resultShape); 702 } 703 704 LogicalResult BroadcastOp::verify() { 705 return verifyShapeOrExtentTensorOp(*this); 706 } 707 708 namespace { 709 template <typename OpTy> 710 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 711 using OpRewritePattern<OpTy>::OpRewritePattern; 712 713 LogicalResult matchAndRewrite(OpTy op, 714 PatternRewriter &rewriter) const override { 715 auto isPotentiallyNonEmptyShape = [](Value shape) { 716 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 717 if (extentTensorTy.getDimSize(0) == 0) 718 return false; 719 } 720 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 721 if (constShape.getShape().empty()) 722 return false; 723 } 724 return true; 725 }; 726 auto newOperands = llvm::to_vector<8>( 727 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 728 729 // Reduce op to equivalent without empty shape operands. 730 if (newOperands.size() < op.getNumOperands()) { 731 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 732 op->getAttrs()); 733 return success(); 734 } 735 736 return failure(); 737 } 738 }; 739 740 struct BroadcastForwardSingleOperandPattern 741 : public OpRewritePattern<BroadcastOp> { 742 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 743 744 LogicalResult matchAndRewrite(BroadcastOp op, 745 PatternRewriter &rewriter) const override { 746 if (op.getNumOperands() != 1) 747 return failure(); 748 Value replacement = op.getShapes().front(); 749 750 // Insert cast if needed. 751 if (replacement.getType() != op.getType()) { 752 auto loc = op.getLoc(); 753 if (op.getType().isa<ShapeType>()) { 754 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 755 } else { 756 assert(!op.getType().isa<ShapeType>() && 757 !replacement.getType().isa<ShapeType>() && 758 "expect extent tensor cast"); 759 replacement = 760 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 761 } 762 } 763 764 rewriter.replaceOp(op, replacement); 765 return success(); 766 } 767 }; 768 769 struct BroadcastFoldConstantOperandsPattern 770 : public OpRewritePattern<BroadcastOp> { 771 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 772 773 LogicalResult matchAndRewrite(BroadcastOp op, 774 PatternRewriter &rewriter) const override { 775 SmallVector<int64_t, 8> foldedConstantShape; 776 SmallVector<Value, 8> newShapeOperands; 777 for (Value shape : op.getShapes()) { 778 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 779 SmallVector<int64_t, 8> newFoldedConstantShape; 780 if (OpTrait::util::getBroadcastedShape( 781 foldedConstantShape, 782 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 783 newFoldedConstantShape)) { 784 foldedConstantShape = newFoldedConstantShape; 785 continue; 786 } 787 } 788 newShapeOperands.push_back(shape); 789 } 790 791 // Need at least two constant operands to fold anything. 792 if (op.getNumOperands() - newShapeOperands.size() < 2) 793 return failure(); 794 795 auto foldedConstantOperandsTy = RankedTensorType::get( 796 {static_cast<int64_t>(foldedConstantShape.size())}, 797 rewriter.getIndexType()); 798 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 799 op.getLoc(), foldedConstantOperandsTy, 800 rewriter.getIndexTensorAttr(foldedConstantShape))); 801 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 802 newShapeOperands); 803 return success(); 804 } 805 }; 806 807 template <typename OpTy> 808 struct CanonicalizeCastExtentTensorOperandsPattern 809 : public OpRewritePattern<OpTy> { 810 using OpRewritePattern<OpTy>::OpRewritePattern; 811 812 LogicalResult matchAndRewrite(OpTy op, 813 PatternRewriter &rewriter) const override { 814 // Canonicalize operands. 815 bool anyChange = false; 816 auto canonicalizeOperand = [&](Value operand) { 817 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 818 // Only eliminate the cast if it holds no shape information. 819 bool isInformationLoosingCast = 820 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 821 if (isInformationLoosingCast) { 822 anyChange = true; 823 return castOp.source(); 824 } 825 } 826 return operand; 827 }; 828 auto newOperands = llvm::to_vector<8>( 829 llvm::map_range(op.getOperands(), canonicalizeOperand)); 830 831 // Rewrite op if any change required. 832 if (!anyChange) 833 return failure(); 834 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 835 return success(); 836 } 837 }; 838 839 struct BroadcastConcretizeResultTypePattern 840 : public OpRewritePattern<BroadcastOp> { 841 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 842 843 LogicalResult matchAndRewrite(BroadcastOp op, 844 PatternRewriter &rewriter) const override { 845 // Only concretize dynamic extent tensor result types. 846 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 847 if (!resultTy || !resultTy.isDynamicDim(0)) 848 return failure(); 849 850 // Infer resulting shape rank if possible. 851 int64_t maxRank = 0; 852 for (Value shape : op.getShapes()) { 853 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 854 // Cannot infer resulting shape rank if any operand is dynamically 855 // ranked. 856 if (extentTensorTy.isDynamicDim(0)) 857 return failure(); 858 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 859 } 860 } 861 862 auto newOp = rewriter.create<BroadcastOp>( 863 op.getLoc(), getExtentTensorType(getContext(), maxRank), 864 op.getShapes()); 865 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 866 return success(); 867 } 868 }; 869 } // namespace 870 871 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 872 MLIRContext *context) { 873 patterns.add<BroadcastConcretizeResultTypePattern, 874 BroadcastFoldConstantOperandsPattern, 875 BroadcastForwardSingleOperandPattern, 876 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 877 RemoveDuplicateOperandsPattern<BroadcastOp>, 878 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 879 } 880 881 //===----------------------------------------------------------------------===// 882 // ConcatOp 883 //===----------------------------------------------------------------------===// 884 885 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 886 if (!operands[0] || !operands[1]) 887 return nullptr; 888 auto lhsShape = llvm::to_vector<6>( 889 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 890 auto rhsShape = llvm::to_vector<6>( 891 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 892 SmallVector<int64_t, 6> resultShape; 893 resultShape.append(lhsShape.begin(), lhsShape.end()); 894 resultShape.append(rhsShape.begin(), rhsShape.end()); 895 Builder builder(getContext()); 896 return builder.getIndexTensorAttr(resultShape); 897 } 898 899 //===----------------------------------------------------------------------===// 900 // ConstShapeOp 901 //===----------------------------------------------------------------------===// 902 903 void ConstShapeOp::print(OpAsmPrinter &p) { 904 p << " "; 905 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 906 p << "["; 907 interleaveComma(getShape().getValues<int64_t>(), p); 908 p << "] : "; 909 p.printType(getType()); 910 } 911 912 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 913 if (parser.parseOptionalAttrDict(result.attributes)) 914 return failure(); 915 // We piggy-back on ArrayAttr parsing, though we don't internally store the 916 // shape as an ArrayAttr. 917 // TODO: Implement custom parser and maybe make syntax a bit more concise. 918 Attribute extentsRaw; 919 NamedAttrList dummy; 920 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 921 return failure(); 922 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 923 if (!extentsArray) 924 return failure(); 925 SmallVector<int64_t, 6> ints; 926 for (Attribute extent : extentsArray) { 927 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 928 if (!attr) 929 return failure(); 930 ints.push_back(attr.getInt()); 931 } 932 Builder &builder = parser.getBuilder(); 933 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 934 Type resultTy; 935 if (parser.parseColonType(resultTy)) 936 return failure(); 937 result.types.push_back(resultTy); 938 return success(); 939 } 940 941 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 942 943 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 944 MLIRContext *context) { 945 patterns.add<TensorCastConstShape>(context); 946 } 947 948 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 949 MLIRContext *context, Optional<Location> location, ValueRange operands, 950 DictionaryAttr attributes, RegionRange regions, 951 SmallVectorImpl<Type> &inferredReturnTypes) { 952 Builder b(context); 953 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 954 if (!shape) 955 return emitOptionalError(location, "missing shape attribute"); 956 inferredReturnTypes.assign({RankedTensorType::get( 957 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 958 return success(); 959 } 960 961 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 962 TypeRange r) { 963 if (l.size() != 1 || r.size() != 1) 964 return false; 965 966 Type lhs = l.front(); 967 Type rhs = r.front(); 968 969 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 970 // Shape type is compatible with all other valid return types. 971 return true; 972 return lhs == rhs; 973 } 974 975 //===----------------------------------------------------------------------===// 976 // CstrBroadcastableOp 977 //===----------------------------------------------------------------------===// 978 979 void CstrBroadcastableOp::getCanonicalizationPatterns( 980 RewritePatternSet &patterns, MLIRContext *context) { 981 // Canonicalization patterns have overlap with the considerations during 982 // folding in case additional shape information is inferred at some point that 983 // does not result in folding. 984 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 985 CstrBroadcastableEqOps, 986 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 987 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 988 } 989 990 // Return true if there is exactly one attribute not representing a scalar 991 // broadcast. 992 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 993 bool nonScalarSeen = false; 994 for (Attribute a : attributes) { 995 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 996 if (nonScalarSeen) 997 return false; 998 nonScalarSeen = true; 999 } 1000 } 1001 return true; 1002 } 1003 1004 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1005 // No broadcasting is needed if all operands but one are scalar. 1006 if (hasAtMostSingleNonScalar(operands)) 1007 return BoolAttr::get(getContext(), true); 1008 1009 if ([&] { 1010 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1011 for (const auto &operand : operands) { 1012 if (!operand) 1013 return false; 1014 extents.push_back(llvm::to_vector<6>( 1015 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 1016 } 1017 return OpTrait::util::staticallyKnownBroadcastable(extents); 1018 }()) 1019 return BoolAttr::get(getContext(), true); 1020 1021 // Lastly, see if folding can be completed based on what constraints are known 1022 // on the input shapes. 1023 if ([&] { 1024 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1025 for (auto shapeValue : getShapes()) { 1026 extents.emplace_back(); 1027 if (failed(getShapeVec(shapeValue, extents.back()))) 1028 return false; 1029 } 1030 return OpTrait::util::staticallyKnownBroadcastable(extents); 1031 }()) 1032 return BoolAttr::get(getContext(), true); 1033 1034 // Because a failing witness result here represents an eventual assertion 1035 // failure, we do not replace it with a constant witness. 1036 return nullptr; 1037 } 1038 1039 LogicalResult CstrBroadcastableOp::verify() { 1040 // Ensure that CstrBroadcastableOp contains at least two operands 1041 if (getNumOperands() < 2) 1042 return emitOpError("required at least 2 input shapes"); 1043 return success(); 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // CstrEqOp 1048 //===----------------------------------------------------------------------===// 1049 1050 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1051 MLIRContext *context) { 1052 // If inputs are equal, return passing witness 1053 patterns.add<CstrEqEqOps>(context); 1054 } 1055 1056 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1057 if (llvm::all_of(operands, 1058 [&](Attribute a) { return a && a == operands[0]; })) 1059 return BoolAttr::get(getContext(), true); 1060 1061 // Because a failing witness result here represents an eventual assertion 1062 // failure, we do not try to replace it with a constant witness. Similarly, we 1063 // cannot if there are any non-const inputs. 1064 return nullptr; 1065 } 1066 1067 //===----------------------------------------------------------------------===// 1068 // ConstSizeOp 1069 //===----------------------------------------------------------------------===// 1070 1071 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1072 int64_t value) { 1073 build(builder, result, builder.getIndexAttr(value)); 1074 } 1075 1076 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1077 1078 void ConstSizeOp::getAsmResultNames( 1079 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1080 SmallString<4> buffer; 1081 llvm::raw_svector_ostream os(buffer); 1082 os << "c" << getValue(); 1083 setNameFn(getResult(), os.str()); 1084 } 1085 1086 //===----------------------------------------------------------------------===// 1087 // ConstWitnessOp 1088 //===----------------------------------------------------------------------===// 1089 1090 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1091 return getPassingAttr(); 1092 } 1093 1094 //===----------------------------------------------------------------------===// 1095 // CstrRequireOp 1096 //===----------------------------------------------------------------------===// 1097 1098 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1099 return operands[0]; 1100 } 1101 1102 //===----------------------------------------------------------------------===// 1103 // DivOp 1104 //===----------------------------------------------------------------------===// 1105 1106 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1107 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1108 if (!lhs) 1109 return nullptr; 1110 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1111 if (!rhs) 1112 return nullptr; 1113 1114 // Division in APInt does not follow floor(lhs, rhs) when the result is 1115 // negative. Rather, APInt rounds toward zero. 1116 APInt quotient, remainder; 1117 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1118 if (quotient.isNegative() && !remainder.isNullValue()) { 1119 quotient -= 1; 1120 } 1121 1122 Type indexTy = IndexType::get(getContext()); 1123 return IntegerAttr::get(indexTy, quotient); 1124 } 1125 1126 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1127 MLIRContext *context, Optional<Location> location, ValueRange operands, 1128 DictionaryAttr attributes, RegionRange regions, 1129 SmallVectorImpl<Type> &inferredReturnTypes) { 1130 if (operands[0].getType().isa<SizeType>() || 1131 operands[1].getType().isa<SizeType>()) 1132 inferredReturnTypes.assign({SizeType::get(context)}); 1133 else 1134 inferredReturnTypes.assign({IndexType::get(context)}); 1135 return success(); 1136 } 1137 1138 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1139 // SizeType is compatible with IndexType. 1140 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1141 } 1142 1143 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1144 1145 //===----------------------------------------------------------------------===// 1146 // ShapeEqOp 1147 //===----------------------------------------------------------------------===// 1148 1149 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1150 bool allSame = true; 1151 if (!operands.empty() && !operands[0]) 1152 return {}; 1153 for (Attribute operand : operands.drop_front(1)) { 1154 if (!operand) 1155 return {}; 1156 allSame = allSame && operand == operands[0]; 1157 } 1158 return BoolAttr::get(getContext(), allSame); 1159 } 1160 1161 //===----------------------------------------------------------------------===// 1162 // IndexToSizeOp 1163 //===----------------------------------------------------------------------===// 1164 1165 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1166 // Constant values of both types, `shape.size` and `index`, are represented as 1167 // `IntegerAttr`s which makes constant folding simple. 1168 if (Attribute arg = operands[0]) 1169 return arg; 1170 return {}; 1171 } 1172 1173 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1174 MLIRContext *context) { 1175 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1176 } 1177 1178 //===----------------------------------------------------------------------===// 1179 // FromExtentsOp 1180 //===----------------------------------------------------------------------===// 1181 1182 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1183 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1184 return nullptr; 1185 SmallVector<int64_t, 6> extents; 1186 for (auto attr : operands) 1187 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1188 Builder builder(getContext()); 1189 return builder.getIndexTensorAttr(extents); 1190 } 1191 1192 //===----------------------------------------------------------------------===// 1193 // FunctionLibraryOp 1194 //===----------------------------------------------------------------------===// 1195 1196 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1197 StringRef name) { 1198 result.attributes.push_back(builder.getNamedAttr( 1199 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1200 } 1201 1202 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1203 auto attr = getMapping() 1204 .get(op->getName().getIdentifier()) 1205 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1206 if (!attr) 1207 return nullptr; 1208 return lookupSymbol<FuncOp>(attr); 1209 } 1210 1211 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1212 OperationState &result) { 1213 // Parse the op name. 1214 StringAttr nameAttr; 1215 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1216 result.attributes)) 1217 return failure(); 1218 1219 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1220 return failure(); 1221 1222 auto *bodyRegion = result.addRegion(); 1223 if (parser.parseRegion(*bodyRegion)) 1224 return failure(); 1225 1226 if (parser.parseKeyword("mapping")) 1227 return failure(); 1228 1229 DictionaryAttr mappingAttr; 1230 if (parser.parseAttribute(mappingAttr, 1231 parser.getBuilder().getType<NoneType>(), "mapping", 1232 result.attributes)) 1233 return failure(); 1234 return success(); 1235 } 1236 1237 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1238 p << ' '; 1239 p.printSymbolName(getName()); 1240 p.printOptionalAttrDictWithKeyword( 1241 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1242 p << ' '; 1243 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1244 /*printBlockTerminators=*/false); 1245 p << " mapping "; 1246 p.printAttributeWithoutType(getMappingAttr()); 1247 } 1248 1249 //===----------------------------------------------------------------------===// 1250 // GetExtentOp 1251 //===----------------------------------------------------------------------===// 1252 1253 Optional<int64_t> GetExtentOp::getConstantDim() { 1254 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1255 return constSizeOp.getValue().getLimitedValue(); 1256 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1257 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1258 return llvm::None; 1259 } 1260 1261 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1262 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1263 if (!elements) 1264 return nullptr; 1265 Optional<int64_t> dim = getConstantDim(); 1266 if (!dim.hasValue()) 1267 return nullptr; 1268 if (dim.getValue() >= elements.getNumElements()) 1269 return nullptr; 1270 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1271 } 1272 1273 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1274 int64_t dim) { 1275 auto loc = result.location; 1276 auto dimAttr = builder.getIndexAttr(dim); 1277 if (shape.getType().isa<ShapeType>()) { 1278 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1279 build(builder, result, builder.getType<SizeType>(), shape, dim); 1280 } else { 1281 Value dim = 1282 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1283 build(builder, result, builder.getIndexType(), shape, dim); 1284 } 1285 } 1286 1287 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1288 MLIRContext *context, Optional<Location> location, ValueRange operands, 1289 DictionaryAttr attributes, RegionRange regions, 1290 SmallVectorImpl<Type> &inferredReturnTypes) { 1291 inferredReturnTypes.assign({IndexType::get(context)}); 1292 return success(); 1293 } 1294 1295 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1296 TypeRange r) { 1297 // SizeType is compatible with IndexType. 1298 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1299 } 1300 1301 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1302 1303 //===----------------------------------------------------------------------===// 1304 // IsBroadcastableOp 1305 //===----------------------------------------------------------------------===// 1306 1307 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1308 MLIRContext *context) { 1309 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1310 } 1311 1312 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1313 // Can always broadcast fewer than two shapes. 1314 if (operands.size() < 2) { 1315 return BoolAttr::get(getContext(), true); 1316 } 1317 1318 return nullptr; 1319 } 1320 1321 //===----------------------------------------------------------------------===// 1322 // MeetOp 1323 //===----------------------------------------------------------------------===// 1324 1325 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1326 MLIRContext *context, Optional<Location> location, ValueRange operands, 1327 DictionaryAttr attributes, RegionRange regions, 1328 SmallVectorImpl<Type> &inferredReturnTypes) { 1329 inferredReturnTypes.assign({operands[0].getType()}); 1330 return success(); 1331 } 1332 1333 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1334 if (l.size() != 1 || r.size() != 1) 1335 return false; 1336 if (l == r) 1337 return true; 1338 1339 Type lhs = l.front(); 1340 Type rhs = r.front(); 1341 1342 if (lhs != rhs) 1343 return false; 1344 1345 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1346 return true; 1347 1348 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1349 return true; 1350 return false; 1351 } 1352 1353 //===----------------------------------------------------------------------===// 1354 // RankOp 1355 //===----------------------------------------------------------------------===// 1356 1357 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1358 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1359 if (!shape) 1360 return {}; 1361 int64_t rank = shape.getNumElements(); 1362 Builder builder(getContext()); 1363 return builder.getIndexAttr(rank); 1364 } 1365 1366 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1367 /// Constant folding fails in cases where only the rank is constant, not the 1368 /// shape itself. 1369 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1370 /// 1371 /// Example: 1372 /// 1373 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1374 /// %rank = shape.rank %shape 1375 /// 1376 /// becomes 1377 /// 1378 /// %rank = shape.const_size 3 1379 1380 namespace { 1381 struct RankShapeOfCanonicalizationPattern 1382 : public OpRewritePattern<shape::RankOp> { 1383 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1384 1385 LogicalResult matchAndRewrite(shape::RankOp op, 1386 PatternRewriter &rewriter) const override { 1387 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1388 if (!shapeOfOp) 1389 return failure(); 1390 auto rankedTensorType = 1391 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1392 if (!rankedTensorType) 1393 return failure(); 1394 int64_t rank = rankedTensorType.getRank(); 1395 if (op.getType().isa<IndexType>()) { 1396 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1397 rank); 1398 } else if (op.getType().isa<shape::SizeType>()) { 1399 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1400 } else { 1401 return failure(); 1402 } 1403 return success(); 1404 } 1405 }; 1406 } // namespace 1407 1408 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1409 MLIRContext *context) { 1410 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1411 } 1412 1413 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1414 MLIRContext *context, Optional<Location> location, ValueRange operands, 1415 DictionaryAttr attributes, RegionRange regions, 1416 SmallVectorImpl<Type> &inferredReturnTypes) { 1417 if (operands[0].getType().isa<ShapeType>()) 1418 inferredReturnTypes.assign({SizeType::get(context)}); 1419 else 1420 inferredReturnTypes.assign({IndexType::get(context)}); 1421 return success(); 1422 } 1423 1424 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1425 // SizeType is compatible with IndexType. 1426 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1427 } 1428 1429 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1430 1431 //===----------------------------------------------------------------------===// 1432 // NumElementsOp 1433 //===----------------------------------------------------------------------===// 1434 1435 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1436 1437 // Fold only when argument constant. 1438 Attribute shape = operands[0]; 1439 if (!shape) 1440 return {}; 1441 1442 APInt product(64, 1); 1443 for (auto value : shape.cast<DenseIntElementsAttr>()) 1444 product *= value; 1445 Builder builder(getContext()); 1446 return builder.getIndexAttr(product.getLimitedValue()); 1447 } 1448 1449 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1450 MLIRContext *context, Optional<Location> location, ValueRange operands, 1451 DictionaryAttr attributes, RegionRange regions, 1452 SmallVectorImpl<Type> &inferredReturnTypes) { 1453 if (operands[0].getType().isa<ShapeType>()) 1454 inferredReturnTypes.assign({SizeType::get(context)}); 1455 else 1456 inferredReturnTypes.assign({IndexType::get(context)}); 1457 return success(); 1458 } 1459 1460 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1461 TypeRange r) { 1462 // SizeType is compatible with IndexType. 1463 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1464 } 1465 1466 LogicalResult shape::NumElementsOp::verify() { 1467 return verifySizeOrIndexOp(*this); 1468 } 1469 1470 //===----------------------------------------------------------------------===// 1471 // MaxOp 1472 //===----------------------------------------------------------------------===// 1473 1474 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1475 // If operands are equal, just propagate one. 1476 if (getLhs() == getRhs()) 1477 return getLhs(); 1478 return nullptr; 1479 } 1480 1481 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1482 MLIRContext *context, Optional<Location> location, ValueRange operands, 1483 DictionaryAttr attributes, RegionRange regions, 1484 SmallVectorImpl<Type> &inferredReturnTypes) { 1485 if (operands[0].getType() == operands[1].getType()) 1486 inferredReturnTypes.assign({operands[0].getType()}); 1487 else 1488 inferredReturnTypes.assign({SizeType::get(context)}); 1489 return success(); 1490 } 1491 1492 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1493 if (l.size() != 1 || r.size() != 1) 1494 return false; 1495 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1496 return true; 1497 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1498 return true; 1499 return false; 1500 } 1501 1502 //===----------------------------------------------------------------------===// 1503 // MinOp 1504 //===----------------------------------------------------------------------===// 1505 1506 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1507 // If operands are equal, just propagate one. 1508 if (getLhs() == getRhs()) 1509 return getLhs(); 1510 return nullptr; 1511 } 1512 1513 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1514 MLIRContext *context, Optional<Location> location, ValueRange operands, 1515 DictionaryAttr attributes, RegionRange regions, 1516 SmallVectorImpl<Type> &inferredReturnTypes) { 1517 if (operands[0].getType() == operands[1].getType()) 1518 inferredReturnTypes.assign({operands[0].getType()}); 1519 else 1520 inferredReturnTypes.assign({SizeType::get(context)}); 1521 return success(); 1522 } 1523 1524 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1525 if (l.size() != 1 || r.size() != 1) 1526 return false; 1527 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1528 return true; 1529 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1530 return true; 1531 return false; 1532 } 1533 1534 //===----------------------------------------------------------------------===// 1535 // MulOp 1536 //===----------------------------------------------------------------------===// 1537 1538 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1539 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1540 if (!lhs) 1541 return nullptr; 1542 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1543 if (!rhs) 1544 return nullptr; 1545 APInt folded = lhs.getValue() * rhs.getValue(); 1546 Type indexTy = IndexType::get(getContext()); 1547 return IntegerAttr::get(indexTy, folded); 1548 } 1549 1550 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1551 MLIRContext *context, Optional<Location> location, ValueRange operands, 1552 DictionaryAttr attributes, RegionRange regions, 1553 SmallVectorImpl<Type> &inferredReturnTypes) { 1554 if (operands[0].getType().isa<SizeType>() || 1555 operands[1].getType().isa<SizeType>()) 1556 inferredReturnTypes.assign({SizeType::get(context)}); 1557 else 1558 inferredReturnTypes.assign({IndexType::get(context)}); 1559 return success(); 1560 } 1561 1562 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1563 // SizeType is compatible with IndexType. 1564 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1565 } 1566 1567 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1568 1569 //===----------------------------------------------------------------------===// 1570 // ShapeOfOp 1571 //===----------------------------------------------------------------------===// 1572 1573 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1574 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1575 if (!type || !type.hasStaticShape()) 1576 return nullptr; 1577 Builder builder(getContext()); 1578 return builder.getIndexTensorAttr(type.getShape()); 1579 } 1580 1581 namespace { 1582 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1583 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1584 1585 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1586 PatternRewriter &rewriter) const override { 1587 if (!op.getArg().getType().isa<ShapedType>()) 1588 return failure(); 1589 if (op.getType().isa<ShapedType>()) 1590 return failure(); 1591 1592 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1593 op.getArg()); 1594 return success(); 1595 } 1596 }; 1597 1598 // Canonicalize 1599 // ``` 1600 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1601 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1602 // ``` 1603 // to 1604 // ``` 1605 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1606 // ``` 1607 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1608 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1609 1610 LogicalResult matchAndRewrite(tensor::CastOp op, 1611 PatternRewriter &rewriter) const override { 1612 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1613 if (!ty || ty.getRank() != 1) 1614 return failure(); 1615 1616 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1617 if (!shapeOfOp) 1618 return failure(); 1619 1620 // Argument type must be ranked and must not conflict. 1621 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1622 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1623 return failure(); 1624 1625 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1626 return success(); 1627 } 1628 }; 1629 } // namespace 1630 1631 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1632 MLIRContext *context) { 1633 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1634 ExtractFromShapeOfExtentTensor>(context); 1635 } 1636 1637 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1638 MLIRContext *context, Optional<Location> location, ValueRange operands, 1639 DictionaryAttr attributes, RegionRange regions, 1640 SmallVectorImpl<Type> &inferredReturnTypes) { 1641 if (operands[0].getType().isa<ValueShapeType>()) 1642 inferredReturnTypes.assign({ShapeType::get(context)}); 1643 else { 1644 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1645 int64_t rank = 1646 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1647 Type indexTy = IndexType::get(context); 1648 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1649 inferredReturnTypes.assign({extentTensorTy}); 1650 } 1651 return success(); 1652 } 1653 1654 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1655 if (l.size() != 1 || r.size() != 1) 1656 return false; 1657 if (l == r) 1658 return true; 1659 1660 Type lhs = l.front(); 1661 Type rhs = r.front(); 1662 1663 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1664 return false; 1665 1666 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1667 // Shape type is compatible with all other valid return types. 1668 return true; 1669 1670 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1671 return true; 1672 return false; 1673 } 1674 1675 LogicalResult shape::ShapeOfOp::verify() { 1676 return verifyShapeOrExtentTensorOp(*this); 1677 } 1678 1679 //===----------------------------------------------------------------------===// 1680 // SizeToIndexOp 1681 //===----------------------------------------------------------------------===// 1682 1683 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1684 // Constant values of both types, `shape.size` and `index`, are represented as 1685 // `IntegerAttr`s which makes constant folding simple. 1686 if (Attribute arg = operands[0]) 1687 return arg; 1688 return OpFoldResult(); 1689 } 1690 1691 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1692 MLIRContext *context) { 1693 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1694 } 1695 1696 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1697 if (inputs.size() != 1 || outputs.size() != 1) 1698 return false; 1699 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1700 } 1701 1702 //===----------------------------------------------------------------------===// 1703 // YieldOp 1704 //===----------------------------------------------------------------------===// 1705 1706 LogicalResult shape::YieldOp::verify() { 1707 auto *parentOp = (*this)->getParentOp(); 1708 auto results = parentOp->getResults(); 1709 auto operands = getOperands(); 1710 1711 if (parentOp->getNumResults() != getNumOperands()) 1712 return emitOpError() << "number of operands does not match number of " 1713 "results of its parent"; 1714 for (auto e : llvm::zip(results, operands)) 1715 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1716 return emitOpError() << "types mismatch between yield op and its parent"; 1717 1718 return success(); 1719 } 1720 1721 //===----------------------------------------------------------------------===// 1722 // SplitAtOp 1723 //===----------------------------------------------------------------------===// 1724 1725 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1726 SmallVectorImpl<OpFoldResult> &results) { 1727 if (!operands[0] || !operands[1]) 1728 return failure(); 1729 auto shapeVec = llvm::to_vector<6>( 1730 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1731 auto shape = llvm::makeArrayRef(shapeVec); 1732 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1733 // Verify that the split point is in the correct range. 1734 // TODO: Constant fold to an "error". 1735 int64_t rank = shape.size(); 1736 if (!(-rank <= splitPoint && splitPoint <= rank)) 1737 return failure(); 1738 if (splitPoint < 0) 1739 splitPoint += shape.size(); 1740 Builder builder(operands[0].getContext()); 1741 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1742 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1743 return success(); 1744 } 1745 1746 //===----------------------------------------------------------------------===// 1747 // ToExtentTensorOp 1748 //===----------------------------------------------------------------------===// 1749 1750 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1751 if (!operands[0]) 1752 return OpFoldResult(); 1753 Builder builder(getContext()); 1754 auto shape = llvm::to_vector<6>( 1755 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1756 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1757 builder.getIndexType()); 1758 return DenseIntElementsAttr::get(type, shape); 1759 } 1760 1761 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1762 if (inputs.size() != 1 || outputs.size() != 1) 1763 return false; 1764 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1765 if (!inputTensor.getElementType().isa<IndexType>() || 1766 inputTensor.getRank() != 1) 1767 return false; 1768 } else if (!inputs[0].isa<ShapeType>()) { 1769 return false; 1770 } 1771 1772 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1773 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1774 } 1775 1776 //===----------------------------------------------------------------------===// 1777 // ReduceOp 1778 //===----------------------------------------------------------------------===// 1779 1780 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1781 ValueRange initVals) { 1782 result.addOperands(shape); 1783 result.addOperands(initVals); 1784 1785 Region *bodyRegion = result.addRegion(); 1786 bodyRegion->push_back(new Block); 1787 Block &bodyBlock = bodyRegion->front(); 1788 bodyBlock.addArgument(builder.getIndexType(), result.location); 1789 1790 Type elementType; 1791 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1792 elementType = tensorType.getElementType(); 1793 else 1794 elementType = SizeType::get(builder.getContext()); 1795 bodyBlock.addArgument(elementType, shape.getLoc()); 1796 1797 for (Value initVal : initVals) { 1798 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1799 result.addTypes(initVal.getType()); 1800 } 1801 } 1802 1803 LogicalResult ReduceOp::verify() { 1804 // Verify block arg types. 1805 Block &block = getRegion().front(); 1806 1807 // The block takes index, extent, and aggregated values as arguments. 1808 auto blockArgsCount = getInitVals().size() + 2; 1809 if (block.getNumArguments() != blockArgsCount) 1810 return emitOpError() << "ReduceOp body is expected to have " 1811 << blockArgsCount << " arguments"; 1812 1813 // The first block argument is the index and must always be of type `index`. 1814 if (!block.getArgument(0).getType().isa<IndexType>()) 1815 return emitOpError( 1816 "argument 0 of ReduceOp body is expected to be of IndexType"); 1817 1818 // The second block argument is the extent and must be of type `size` or 1819 // `index`, depending on whether the reduce operation is applied to a shape or 1820 // to an extent tensor. 1821 Type extentTy = block.getArgument(1).getType(); 1822 if (getShape().getType().isa<ShapeType>()) { 1823 if (!extentTy.isa<SizeType>()) 1824 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1825 "SizeType if the ReduceOp operates on a ShapeType"); 1826 } else { 1827 if (!extentTy.isa<IndexType>()) 1828 return emitOpError( 1829 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1830 "ReduceOp operates on an extent tensor"); 1831 } 1832 1833 for (const auto &type : llvm::enumerate(getInitVals())) 1834 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1835 return emitOpError() << "type mismatch between argument " 1836 << type.index() + 2 1837 << " of ReduceOp body and initial value " 1838 << type.index(); 1839 return success(); 1840 } 1841 1842 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1843 // Parse operands. 1844 SmallVector<OpAsmParser::OperandType, 3> operands; 1845 Type shapeOrExtentTensorType; 1846 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1847 OpAsmParser::Delimiter::Paren) || 1848 parser.parseColonType(shapeOrExtentTensorType) || 1849 parser.parseOptionalArrowTypeList(result.types)) 1850 return failure(); 1851 1852 // Resolve operands. 1853 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1854 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1855 result.operands) || 1856 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1857 result.operands)) 1858 return failure(); 1859 1860 // Parse the body. 1861 Region *body = result.addRegion(); 1862 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1863 return failure(); 1864 1865 // Parse attributes. 1866 if (parser.parseOptionalAttrDict(result.attributes)) 1867 return failure(); 1868 1869 return success(); 1870 } 1871 1872 void ReduceOp::print(OpAsmPrinter &p) { 1873 p << '(' << getShape() << ", " << getInitVals() 1874 << ") : " << getShape().getType(); 1875 p.printOptionalArrowTypeList(getResultTypes()); 1876 p << ' '; 1877 p.printRegion(getRegion()); 1878 p.printOptionalAttrDict((*this)->getAttrs()); 1879 } 1880 1881 #define GET_OP_CLASSES 1882 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1883