1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 #include "mlir/Transforms/InliningUtils.h" 25 #include "llvm/ADT/SetOperations.h" 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace mlir; 31 using namespace mlir::shape; 32 33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 34 35 namespace { 36 #include "ShapeCanonicalization.inc" 37 } // namespace 38 39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 40 return RankedTensorType::get({rank}, IndexType::get(ctx)); 41 } 42 43 bool shape::isExtentTensorType(Type type) { 44 auto ranked = type.dyn_cast<RankedTensorType>(); 45 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 46 } 47 48 LogicalResult shape::getShapeVec(Value input, 49 SmallVectorImpl<int64_t> &shapeValues) { 50 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 51 auto type = inputOp.getArg().getType().cast<ShapedType>(); 52 if (!type.hasRank()) 53 return failure(); 54 llvm::append_range(shapeValues, type.getShape()); 55 return success(); 56 } 57 DenseIntElementsAttr attr; 58 if (matchPattern(input, m_Constant(&attr))) { 59 llvm::append_range(shapeValues, attr.getValues<int64_t>()); 60 return success(); 61 } 62 return failure(); 63 } 64 65 static bool isErrorPropagationPossible(TypeRange operandTypes) { 66 return llvm::any_of(operandTypes, [](Type ty) { 67 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 68 }); 69 } 70 71 static LogicalResult verifySizeOrIndexOp(Operation *op) { 72 assert(op != nullptr && op->getNumResults() == 1); 73 Type resultTy = op->getResultTypes().front(); 74 if (isErrorPropagationPossible(op->getOperandTypes())) { 75 if (!resultTy.isa<SizeType>()) 76 return op->emitOpError() 77 << "if at least one of the operands can hold error values then " 78 "the result must be of type `size` to propagate them"; 79 } 80 return success(); 81 } 82 83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 84 assert(op != nullptr && op->getNumResults() == 1); 85 Type resultTy = op->getResultTypes().front(); 86 if (isErrorPropagationPossible(op->getOperandTypes())) { 87 if (!resultTy.isa<ShapeType>()) 88 return op->emitOpError() 89 << "if at least one of the operands can hold error values then " 90 "the result must be of type `shape` to propagate them"; 91 } 92 return success(); 93 } 94 95 template <typename... Ty> 96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 98 } 99 100 template <typename... Ty, typename... ranges> 101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // InlinerInterface 107 //===----------------------------------------------------------------------===// 108 109 namespace { 110 /// This class defines the interface for inlining shape dialect ops. 111 struct ShapeInlinerInterface : public DialectInlinerInterface { 112 using DialectInlinerInterface::DialectInlinerInterface; 113 114 // Returns true if the given region 'src' can be inlined into the region 115 // 'dest' that is attached to an operation registered to the current dialect. 116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 117 BlockAndValueMapping &) const final { 118 return true; 119 } 120 121 // Returns true if the given operation 'op', that is registered to this 122 // dialect, can be inlined into the region 'dest' that is attached to an 123 // operation registered to the current dialect. 124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 125 BlockAndValueMapping &) const final { 126 return true; 127 } 128 }; 129 } // namespace 130 131 void ShapeDialect::initialize() { 132 addOperations< 133 #define GET_OP_LIST 134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 135 >(); 136 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 137 addInterfaces<ShapeInlinerInterface>(); 138 // Allow unknown operations during prototyping and testing. As the dialect is 139 // still evolving it makes it simple to start with an unregistered ops and 140 // try different variants before actually defining the op. 141 allowUnknownOperations(); 142 } 143 144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 145 Attribute value, Type type, 146 Location loc) { 147 if (type.isa<ShapeType>() || isExtentTensorType(type)) 148 return builder.create<ConstShapeOp>(loc, type, 149 value.cast<DenseIntElementsAttr>()); 150 if (type.isa<SizeType>()) 151 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 152 if (type.isa<WitnessType>()) 153 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 154 if (arith::ConstantOp::isBuildableWith(value, type)) 155 return builder.create<arith::ConstantOp>(loc, type, value); 156 return nullptr; 157 } 158 159 /// Parse a type registered to this dialect. 160 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 161 StringRef keyword; 162 if (parser.parseKeyword(&keyword)) 163 return Type(); 164 165 if (keyword == "shape") 166 return ShapeType::get(getContext()); 167 if (keyword == "size") 168 return SizeType::get(getContext()); 169 if (keyword == "value_shape") 170 return ValueShapeType::get(getContext()); 171 if (keyword == "witness") 172 return WitnessType::get(getContext()); 173 174 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 175 return Type(); 176 } 177 178 /// Print a type registered to this dialect. 179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 180 TypeSwitch<Type>(type) 181 .Case<ShapeType>([&](Type) { os << "shape"; }) 182 .Case<SizeType>([&](Type) { os << "size"; }) 183 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 184 .Case<WitnessType>([&](Type) { os << "witness"; }) 185 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 186 } 187 188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 189 NamedAttribute attribute) { 190 // Verify shape.lib attribute. 191 if (attribute.getName() == "shape.lib") { 192 if (!op->hasTrait<OpTrait::SymbolTable>()) 193 return op->emitError( 194 "shape.lib attribute may only be on op implementing SymbolTable"); 195 196 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 197 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 198 if (!symbol) 199 return op->emitError("shape function library ") 200 << symbolRef << " not found"; 201 return isa<shape::FunctionLibraryOp>(symbol) 202 ? success() 203 : op->emitError() 204 << symbolRef << " required to be shape function library"; 205 } 206 207 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 208 // Verify all entries are function libraries and mappings in libraries 209 // refer to unique ops. 210 DenseSet<StringAttr> key; 211 for (auto it : arr) { 212 if (!it.isa<SymbolRefAttr>()) 213 return op->emitError( 214 "only SymbolRefAttr allowed in shape.lib attribute array"); 215 216 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 217 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 218 if (!shapeFnLib) 219 return op->emitError() 220 << it << " does not refer to FunctionLibraryOp"; 221 for (auto mapping : shapeFnLib.getMapping()) { 222 if (!key.insert(mapping.getName()).second) { 223 return op->emitError("only one op to shape mapping allowed, found " 224 "multiple for `") 225 << mapping.getName() << "`"; 226 } 227 } 228 } 229 return success(); 230 } 231 232 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 233 "allowed as shape.lib attribute"); 234 } 235 return success(); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // AnyOp 240 //===----------------------------------------------------------------------===// 241 242 // TODO: Canonicalization should be implemented for shapes that can be 243 // determined through mixtures of the known dimensions of the inputs. 244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 245 // Only the last operand is checked because AnyOp is commutative. 246 if (operands.back()) 247 return operands.back(); 248 249 return nullptr; 250 } 251 252 //===----------------------------------------------------------------------===// 253 // AssumingOp 254 //===----------------------------------------------------------------------===// 255 256 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 257 result.regions.reserve(1); 258 Region *doRegion = result.addRegion(); 259 260 auto &builder = parser.getBuilder(); 261 OpAsmParser::UnresolvedOperand cond; 262 if (parser.parseOperand(cond) || 263 parser.resolveOperand(cond, builder.getType<WitnessType>(), 264 result.operands)) 265 return failure(); 266 267 // Parse optional results type list. 268 if (parser.parseOptionalArrowTypeList(result.types)) 269 return failure(); 270 271 // Parse the region and add a terminator if elided. 272 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 273 return failure(); 274 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 275 276 // Parse the optional attribute list. 277 if (parser.parseOptionalAttrDict(result.attributes)) 278 return failure(); 279 return success(); 280 } 281 282 void AssumingOp::print(OpAsmPrinter &p) { 283 bool yieldsResults = !getResults().empty(); 284 285 p << " " << getWitness(); 286 if (yieldsResults) 287 p << " -> (" << getResultTypes() << ")"; 288 p << ' '; 289 p.printRegion(getDoRegion(), 290 /*printEntryBlockArgs=*/false, 291 /*printBlockTerminators=*/yieldsResults); 292 p.printOptionalAttrDict((*this)->getAttrs()); 293 } 294 295 namespace { 296 // Removes AssumingOp with a passing witness and inlines the region. 297 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 298 using OpRewritePattern<AssumingOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(AssumingOp op, 301 PatternRewriter &rewriter) const override { 302 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 303 if (!witness || !witness.getPassingAttr()) 304 return failure(); 305 306 AssumingOp::inlineRegionIntoParent(op, rewriter); 307 return success(); 308 } 309 }; 310 311 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 312 using OpRewritePattern<AssumingOp>::OpRewritePattern; 313 314 LogicalResult matchAndRewrite(AssumingOp op, 315 PatternRewriter &rewriter) const override { 316 Block *body = op.getBody(); 317 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 318 319 // Find used values. 320 SmallVector<Value, 4> newYieldOperands; 321 Value opResult, yieldOperand; 322 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 323 std::tie(opResult, yieldOperand) = it; 324 if (!opResult.getUses().empty()) { 325 newYieldOperands.push_back(yieldOperand); 326 } 327 } 328 329 // Rewrite only if redundant results exist. 330 if (newYieldOperands.size() == yieldOp->getNumOperands()) 331 return failure(); 332 333 // Replace yield op in the old assuming op's body and move the entire region 334 // to the new assuming op. 335 rewriter.setInsertionPointToEnd(body); 336 auto newYieldOp = 337 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 338 rewriter.setInsertionPoint(op); 339 auto newOp = rewriter.create<AssumingOp>( 340 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 341 newOp.getDoRegion().takeBody(op.getDoRegion()); 342 343 // Use the new results to replace the previously used ones. 344 SmallVector<Value, 4> replacementValues; 345 auto src = newOp.getResults().begin(); 346 for (auto it : op.getResults()) { 347 if (it.getUses().empty()) 348 replacementValues.push_back(nullptr); 349 else 350 replacementValues.push_back(*src++); 351 } 352 rewriter.replaceOp(op, replacementValues); 353 return success(); 354 } 355 }; 356 } // namespace 357 358 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 359 MLIRContext *context) { 360 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 361 } 362 363 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 364 void AssumingOp::getSuccessorRegions( 365 Optional<unsigned> index, ArrayRef<Attribute> operands, 366 SmallVectorImpl<RegionSuccessor> ®ions) { 367 // AssumingOp has unconditional control flow into the region and back to the 368 // parent, so return the correct RegionSuccessor purely based on the index 369 // being None or 0. 370 if (index.hasValue()) { 371 regions.push_back(RegionSuccessor(getResults())); 372 return; 373 } 374 375 regions.push_back(RegionSuccessor(&getDoRegion())); 376 } 377 378 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 379 PatternRewriter &rewriter) { 380 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 381 auto *assumingBlock = op.getBody(); 382 auto initPosition = rewriter.getInsertionPoint(); 383 auto *blockAfterAssuming = 384 rewriter.splitBlock(blockBeforeAssuming, initPosition); 385 386 // Remove the AssumingOp and AssumingYieldOp. 387 auto &yieldOp = assumingBlock->back(); 388 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 389 rewriter.replaceOp(op, yieldOp.getOperands()); 390 rewriter.eraseOp(&yieldOp); 391 392 // Merge blocks together as there was no branching behavior from the 393 // AssumingOp. 394 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 395 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 396 } 397 398 void AssumingOp::build( 399 OpBuilder &builder, OperationState &result, Value witness, 400 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 401 402 result.addOperands(witness); 403 Region *bodyRegion = result.addRegion(); 404 bodyRegion->push_back(new Block); 405 Block &bodyBlock = bodyRegion->front(); 406 407 // Build body. 408 OpBuilder::InsertionGuard guard(builder); 409 builder.setInsertionPointToStart(&bodyBlock); 410 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 411 builder.create<AssumingYieldOp>(result.location, yieldValues); 412 413 SmallVector<Type, 2> assumingTypes; 414 for (Value v : yieldValues) 415 assumingTypes.push_back(v.getType()); 416 result.addTypes(assumingTypes); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // AddOp 421 //===----------------------------------------------------------------------===// 422 423 LogicalResult mlir::shape::AddOp::inferReturnTypes( 424 MLIRContext *context, Optional<Location> location, ValueRange operands, 425 DictionaryAttr attributes, RegionRange regions, 426 SmallVectorImpl<Type> &inferredReturnTypes) { 427 if (operands[0].getType().isa<SizeType>() || 428 operands[1].getType().isa<SizeType>()) 429 inferredReturnTypes.assign({SizeType::get(context)}); 430 else 431 inferredReturnTypes.assign({IndexType::get(context)}); 432 return success(); 433 } 434 435 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 436 // SizeType is compatible with IndexType. 437 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 438 } 439 440 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 441 // add(x, 0) -> x 442 if (matchPattern(getRhs(), m_Zero())) 443 return getLhs(); 444 445 return constFoldBinaryOp<IntegerAttr>( 446 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 447 } 448 449 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 450 451 //===----------------------------------------------------------------------===// 452 // AssumingAllOp 453 //===----------------------------------------------------------------------===// 454 455 namespace { 456 457 // Merge multiple `shape.assuming_all` operations together. 458 // 459 // %0 = shape.assuming_all %w0, %w1 460 // %1 = shape.assuming_all %w2, %0 461 // 462 // to: 463 // 464 // %0 = shape.assuming_all %w0, %w2, %w2 465 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 466 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 467 468 LogicalResult matchAndRewrite(AssumingAllOp op, 469 PatternRewriter &rewriter) const override { 470 SmallVector<Value> operands; 471 472 for (Value operand : op.getInputs()) { 473 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>()) 474 operands.append(assumeAll.operand_begin(), assumeAll->operand_end()); 475 else 476 operands.push_back(operand); 477 } 478 479 // We didn't find any other `assuming_all` ops to merge with. 480 if (operands.size() == op.getNumOperands()) 481 return failure(); 482 483 // Replace with a new `assuming_all` operation with merged constraints. 484 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 485 return success(); 486 } 487 }; 488 489 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 490 // are subsumed by others. 491 // 492 // %0 = shape.cstr_broadcastable %shape0, %shape1 493 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 494 // 495 // %2 = shape.cstr_broadcastable %shape3, %shape4 496 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 497 // 498 // %4 = shape.assuming_all %0, %1, %2, %3 499 // 500 // to: 501 // 502 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 503 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 504 // %2 = shape.assuming_all %0, %1 505 // 506 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 507 // shapes [0, 1] are broadcastable too, and can be removed from the list of 508 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 509 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 510 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 511 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 512 513 LogicalResult matchAndRewrite(AssumingAllOp op, 514 PatternRewriter &rewriter) const override { 515 // Collect all `CstrBroadcastableOp` operands first. 516 SetVector<CstrBroadcastableOp> operands; 517 for (Value operand : op.getInputs()) { 518 // TODO: Apply this optimization if some of the witnesses are not 519 // produced by the `cstr_broadcastable`. 520 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 521 if (!broadcastable) 522 return failure(); 523 524 operands.insert(broadcastable); 525 } 526 527 // Skip trivial `assuming_all` operations. 528 if (operands.size() <= 1) 529 return failure(); 530 531 // Collect shapes checked by `cstr_broadcastable` operands. 532 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 533 for (auto cstr : operands) { 534 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end()); 535 shapes.emplace_back(cstr, std::move(shapesSet)); 536 } 537 538 // Sort by the number of shape operands (larger to smaller). 539 llvm::sort(shapes, [](auto a, auto b) { 540 return a.first.getNumOperands() > b.first.getNumOperands(); 541 }); 542 543 // We start from the `cst_broadcastable` operations with largest number of 544 // shape operands, and remove redundant `cst_broadcastable` operations. We 545 // do this until we find a set of `cst_broadcastable` operations with 546 // non-overlapping constraints. 547 SmallVector<CstrBroadcastableOp> markedForErase; 548 549 for (unsigned i = 0; i < shapes.size(); ++i) { 550 auto isSubset = [&](auto pair) { 551 return llvm::set_is_subset(pair.second, shapes[i].second); 552 }; 553 554 // Keep redundant `cstr_broadcastable` operations to be erased. 555 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 556 for (auto *it0 = it; it0 < shapes.end(); ++it0) 557 markedForErase.push_back(it0->first); 558 shapes.erase(it, shapes.end()); 559 } 560 561 // We didn't find any operands that could be removed. 562 if (markedForErase.empty()) 563 return failure(); 564 565 // Collect non-overlapping `cst_broadcastable` constraints. 566 SmallVector<Value> uniqueConstraints; 567 for (auto &shape : shapes) 568 uniqueConstraints.push_back(shape.first.getResult()); 569 570 // Replace with a new `assuming_all` operation ... 571 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints); 572 573 // ... and maybe erase `cstr_broadcastable` ops without uses. 574 for (auto &op : markedForErase) 575 if (op->use_empty()) 576 rewriter.eraseOp(op); 577 578 return success(); 579 } 580 }; 581 582 struct AssumingAllToCstrEqCanonicalization 583 : public OpRewritePattern<AssumingAllOp> { 584 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 585 586 LogicalResult matchAndRewrite(AssumingAllOp op, 587 PatternRewriter &rewriter) const override { 588 SmallVector<Value, 8> shapes; 589 for (Value w : op.getInputs()) { 590 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 591 if (!cstrEqOp) 592 return failure(); 593 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 594 return llvm::is_contained(shapes, s); 595 }); 596 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 597 return failure(); 598 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 599 } 600 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 601 return success(); 602 } 603 }; 604 605 template <typename OpTy> 606 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 607 using OpRewritePattern<OpTy>::OpRewritePattern; 608 609 LogicalResult matchAndRewrite(OpTy op, 610 PatternRewriter &rewriter) const override { 611 // Find unique operands. 612 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 613 614 // Reduce op to equivalent with unique operands. 615 if (unique.size() < op.getNumOperands()) { 616 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 617 unique.takeVector(), op->getAttrs()); 618 return success(); 619 } 620 621 return failure(); 622 } 623 }; 624 } // namespace 625 626 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 627 MLIRContext *context) { 628 patterns 629 .add<MergeAssumingAllOps, AssumingAllOneOp, 630 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 631 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 632 } 633 634 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 635 // Iterate in reverse to first handle all constant operands. They are 636 // guaranteed to be the tail of the inputs because this is commutative. 637 for (int idx = operands.size() - 1; idx >= 0; idx--) { 638 Attribute a = operands[idx]; 639 // Cannot fold if any inputs are not constant; 640 if (!a) 641 return nullptr; 642 643 // We do not need to keep statically known values after handling them in 644 // this method. 645 getOperation()->eraseOperand(idx); 646 647 // Always false if any input is statically known false 648 if (!a.cast<BoolAttr>().getValue()) 649 return a; 650 } 651 // If this is reached, all inputs were statically known passing. 652 return BoolAttr::get(getContext(), true); 653 } 654 655 LogicalResult AssumingAllOp::verify() { 656 // Ensure that AssumingAllOp contains at least one operand 657 if (getNumOperands() == 0) 658 return emitOpError("no operands specified"); 659 660 return success(); 661 } 662 663 //===----------------------------------------------------------------------===// 664 // BroadcastOp 665 //===----------------------------------------------------------------------===// 666 667 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 668 if (getShapes().size() == 1) { 669 // Otherwise, we need a cast which would be a canonicalization, not folding. 670 if (getShapes().front().getType() != getType()) 671 return nullptr; 672 return getShapes().front(); 673 } 674 675 // TODO: Support folding with more than 2 input shapes 676 if (getShapes().size() > 2) 677 return nullptr; 678 679 if (!operands[0] || !operands[1]) 680 return nullptr; 681 auto lhsShape = llvm::to_vector<6>( 682 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 683 auto rhsShape = llvm::to_vector<6>( 684 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 685 SmallVector<int64_t, 6> resultShape; 686 687 // If the shapes are not compatible, we can't fold it. 688 // TODO: Fold to an "error". 689 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 690 return nullptr; 691 692 Builder builder(getContext()); 693 return builder.getIndexTensorAttr(resultShape); 694 } 695 696 LogicalResult BroadcastOp::verify() { 697 return verifyShapeOrExtentTensorOp(*this); 698 } 699 700 namespace { 701 template <typename OpTy> 702 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 703 using OpRewritePattern<OpTy>::OpRewritePattern; 704 705 LogicalResult matchAndRewrite(OpTy op, 706 PatternRewriter &rewriter) const override { 707 auto isPotentiallyNonEmptyShape = [](Value shape) { 708 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 709 if (extentTensorTy.getDimSize(0) == 0) 710 return false; 711 } 712 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 713 if (constShape.getShape().empty()) 714 return false; 715 } 716 return true; 717 }; 718 auto newOperands = llvm::to_vector<8>( 719 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 720 721 // Reduce op to equivalent without empty shape operands. 722 if (newOperands.size() < op.getNumOperands()) { 723 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 724 op->getAttrs()); 725 return success(); 726 } 727 728 return failure(); 729 } 730 }; 731 732 struct BroadcastForwardSingleOperandPattern 733 : public OpRewritePattern<BroadcastOp> { 734 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 735 736 LogicalResult matchAndRewrite(BroadcastOp op, 737 PatternRewriter &rewriter) const override { 738 if (op.getNumOperands() != 1) 739 return failure(); 740 Value replacement = op.getShapes().front(); 741 742 // Insert cast if needed. 743 if (replacement.getType() != op.getType()) { 744 auto loc = op.getLoc(); 745 if (op.getType().isa<ShapeType>()) { 746 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 747 } else { 748 assert(!op.getType().isa<ShapeType>() && 749 !replacement.getType().isa<ShapeType>() && 750 "expect extent tensor cast"); 751 replacement = 752 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 753 } 754 } 755 756 rewriter.replaceOp(op, replacement); 757 return success(); 758 } 759 }; 760 761 struct BroadcastFoldConstantOperandsPattern 762 : public OpRewritePattern<BroadcastOp> { 763 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 764 765 LogicalResult matchAndRewrite(BroadcastOp op, 766 PatternRewriter &rewriter) const override { 767 SmallVector<int64_t, 8> foldedConstantShape; 768 SmallVector<Value, 8> newShapeOperands; 769 for (Value shape : op.getShapes()) { 770 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 771 SmallVector<int64_t, 8> newFoldedConstantShape; 772 if (OpTrait::util::getBroadcastedShape( 773 foldedConstantShape, 774 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 775 newFoldedConstantShape)) { 776 foldedConstantShape = newFoldedConstantShape; 777 continue; 778 } 779 } 780 newShapeOperands.push_back(shape); 781 } 782 783 // Need at least two constant operands to fold anything. 784 if (op.getNumOperands() - newShapeOperands.size() < 2) 785 return failure(); 786 787 auto foldedConstantOperandsTy = RankedTensorType::get( 788 {static_cast<int64_t>(foldedConstantShape.size())}, 789 rewriter.getIndexType()); 790 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 791 op.getLoc(), foldedConstantOperandsTy, 792 rewriter.getIndexTensorAttr(foldedConstantShape))); 793 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 794 newShapeOperands); 795 return success(); 796 } 797 }; 798 799 template <typename OpTy> 800 struct CanonicalizeCastExtentTensorOperandsPattern 801 : public OpRewritePattern<OpTy> { 802 using OpRewritePattern<OpTy>::OpRewritePattern; 803 804 LogicalResult matchAndRewrite(OpTy op, 805 PatternRewriter &rewriter) const override { 806 // Canonicalize operands. 807 bool anyChange = false; 808 auto canonicalizeOperand = [&](Value operand) { 809 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 810 // Only eliminate the cast if it holds no shape information. 811 bool isInformationLoosingCast = 812 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 813 if (isInformationLoosingCast) { 814 anyChange = true; 815 return castOp.source(); 816 } 817 } 818 return operand; 819 }; 820 auto newOperands = llvm::to_vector<8>( 821 llvm::map_range(op.getOperands(), canonicalizeOperand)); 822 823 // Rewrite op if any change required. 824 if (!anyChange) 825 return failure(); 826 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 827 return success(); 828 } 829 }; 830 831 struct BroadcastConcretizeResultTypePattern 832 : public OpRewritePattern<BroadcastOp> { 833 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 834 835 LogicalResult matchAndRewrite(BroadcastOp op, 836 PatternRewriter &rewriter) const override { 837 // Only concretize dynamic extent tensor result types. 838 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 839 if (!resultTy || !resultTy.isDynamicDim(0)) 840 return failure(); 841 842 // Infer resulting shape rank if possible. 843 int64_t maxRank = 0; 844 for (Value shape : op.getShapes()) { 845 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 846 // Cannot infer resulting shape rank if any operand is dynamically 847 // ranked. 848 if (extentTensorTy.isDynamicDim(0)) 849 return failure(); 850 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 851 } 852 } 853 854 auto newOp = rewriter.create<BroadcastOp>( 855 op.getLoc(), getExtentTensorType(getContext(), maxRank), 856 op.getShapes()); 857 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 858 return success(); 859 } 860 }; 861 } // namespace 862 863 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 864 MLIRContext *context) { 865 patterns.add<BroadcastConcretizeResultTypePattern, 866 BroadcastFoldConstantOperandsPattern, 867 BroadcastForwardSingleOperandPattern, 868 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 869 RemoveDuplicateOperandsPattern<BroadcastOp>, 870 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // ConcatOp 875 //===----------------------------------------------------------------------===// 876 877 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 878 if (!operands[0] || !operands[1]) 879 return nullptr; 880 auto lhsShape = llvm::to_vector<6>( 881 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 882 auto rhsShape = llvm::to_vector<6>( 883 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 884 SmallVector<int64_t, 6> resultShape; 885 resultShape.append(lhsShape.begin(), lhsShape.end()); 886 resultShape.append(rhsShape.begin(), rhsShape.end()); 887 Builder builder(getContext()); 888 return builder.getIndexTensorAttr(resultShape); 889 } 890 891 //===----------------------------------------------------------------------===// 892 // ConstShapeOp 893 //===----------------------------------------------------------------------===// 894 895 void ConstShapeOp::print(OpAsmPrinter &p) { 896 p << " "; 897 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 898 p << "["; 899 interleaveComma(getShape().getValues<int64_t>(), p); 900 p << "] : "; 901 p.printType(getType()); 902 } 903 904 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 905 if (parser.parseOptionalAttrDict(result.attributes)) 906 return failure(); 907 // We piggy-back on ArrayAttr parsing, though we don't internally store the 908 // shape as an ArrayAttr. 909 // TODO: Implement custom parser and maybe make syntax a bit more concise. 910 Attribute extentsRaw; 911 NamedAttrList dummy; 912 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 913 return failure(); 914 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 915 if (!extentsArray) 916 return failure(); 917 SmallVector<int64_t, 6> ints; 918 for (Attribute extent : extentsArray) { 919 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 920 if (!attr) 921 return failure(); 922 ints.push_back(attr.getInt()); 923 } 924 Builder &builder = parser.getBuilder(); 925 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 926 Type resultTy; 927 if (parser.parseColonType(resultTy)) 928 return failure(); 929 result.types.push_back(resultTy); 930 return success(); 931 } 932 933 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 934 935 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 936 MLIRContext *context) { 937 patterns.add<TensorCastConstShape>(context); 938 } 939 940 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 941 MLIRContext *context, Optional<Location> location, ValueRange operands, 942 DictionaryAttr attributes, RegionRange regions, 943 SmallVectorImpl<Type> &inferredReturnTypes) { 944 Builder b(context); 945 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 946 if (!shape) 947 return emitOptionalError(location, "missing shape attribute"); 948 inferredReturnTypes.assign({RankedTensorType::get( 949 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 950 return success(); 951 } 952 953 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 954 TypeRange r) { 955 if (l.size() != 1 || r.size() != 1) 956 return false; 957 958 Type lhs = l.front(); 959 Type rhs = r.front(); 960 961 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 962 // Shape type is compatible with all other valid return types. 963 return true; 964 return lhs == rhs; 965 } 966 967 //===----------------------------------------------------------------------===// 968 // CstrBroadcastableOp 969 //===----------------------------------------------------------------------===// 970 971 void CstrBroadcastableOp::getCanonicalizationPatterns( 972 RewritePatternSet &patterns, MLIRContext *context) { 973 // Canonicalization patterns have overlap with the considerations during 974 // folding in case additional shape information is inferred at some point that 975 // does not result in folding. 976 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 977 CstrBroadcastableEqOps, 978 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 979 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 980 } 981 982 // Return true if there is exactly one attribute not representing a scalar 983 // broadcast. 984 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 985 bool nonScalarSeen = false; 986 for (Attribute a : attributes) { 987 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 988 if (nonScalarSeen) 989 return false; 990 nonScalarSeen = true; 991 } 992 } 993 return true; 994 } 995 996 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 997 // No broadcasting is needed if all operands but one are scalar. 998 if (hasAtMostSingleNonScalar(operands)) 999 return BoolAttr::get(getContext(), true); 1000 1001 if ([&] { 1002 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1003 for (const auto &operand : operands) { 1004 if (!operand) 1005 return false; 1006 extents.push_back(llvm::to_vector<6>( 1007 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 1008 } 1009 return OpTrait::util::staticallyKnownBroadcastable(extents); 1010 }()) 1011 return BoolAttr::get(getContext(), true); 1012 1013 // Lastly, see if folding can be completed based on what constraints are known 1014 // on the input shapes. 1015 if ([&] { 1016 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1017 for (auto shapeValue : getShapes()) { 1018 extents.emplace_back(); 1019 if (failed(getShapeVec(shapeValue, extents.back()))) 1020 return false; 1021 } 1022 return OpTrait::util::staticallyKnownBroadcastable(extents); 1023 }()) 1024 return BoolAttr::get(getContext(), true); 1025 1026 // Because a failing witness result here represents an eventual assertion 1027 // failure, we do not replace it with a constant witness. 1028 return nullptr; 1029 } 1030 1031 LogicalResult CstrBroadcastableOp::verify() { 1032 // Ensure that CstrBroadcastableOp contains at least two operands 1033 if (getNumOperands() < 2) 1034 return emitOpError("required at least 2 input shapes"); 1035 return success(); 1036 } 1037 1038 //===----------------------------------------------------------------------===// 1039 // CstrEqOp 1040 //===----------------------------------------------------------------------===// 1041 1042 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1043 MLIRContext *context) { 1044 // If inputs are equal, return passing witness 1045 patterns.add<CstrEqEqOps>(context); 1046 } 1047 1048 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1049 if (llvm::all_of(operands, 1050 [&](Attribute a) { return a && a == operands[0]; })) 1051 return BoolAttr::get(getContext(), true); 1052 1053 // Because a failing witness result here represents an eventual assertion 1054 // failure, we do not try to replace it with a constant witness. Similarly, we 1055 // cannot if there are any non-const inputs. 1056 return nullptr; 1057 } 1058 1059 //===----------------------------------------------------------------------===// 1060 // ConstSizeOp 1061 //===----------------------------------------------------------------------===// 1062 1063 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1064 int64_t value) { 1065 build(builder, result, builder.getIndexAttr(value)); 1066 } 1067 1068 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1069 1070 void ConstSizeOp::getAsmResultNames( 1071 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1072 SmallString<4> buffer; 1073 llvm::raw_svector_ostream os(buffer); 1074 os << "c" << getValue(); 1075 setNameFn(getResult(), os.str()); 1076 } 1077 1078 //===----------------------------------------------------------------------===// 1079 // ConstWitnessOp 1080 //===----------------------------------------------------------------------===// 1081 1082 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1083 return getPassingAttr(); 1084 } 1085 1086 //===----------------------------------------------------------------------===// 1087 // CstrRequireOp 1088 //===----------------------------------------------------------------------===// 1089 1090 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1091 return operands[0]; 1092 } 1093 1094 //===----------------------------------------------------------------------===// 1095 // DivOp 1096 //===----------------------------------------------------------------------===// 1097 1098 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1099 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1100 if (!lhs) 1101 return nullptr; 1102 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1103 if (!rhs) 1104 return nullptr; 1105 1106 // Division in APInt does not follow floor(lhs, rhs) when the result is 1107 // negative. Rather, APInt rounds toward zero. 1108 APInt quotient, remainder; 1109 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1110 if (quotient.isNegative() && !remainder.isNullValue()) { 1111 quotient -= 1; 1112 } 1113 1114 Type indexTy = IndexType::get(getContext()); 1115 return IntegerAttr::get(indexTy, quotient); 1116 } 1117 1118 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1119 MLIRContext *context, Optional<Location> location, ValueRange operands, 1120 DictionaryAttr attributes, RegionRange regions, 1121 SmallVectorImpl<Type> &inferredReturnTypes) { 1122 if (operands[0].getType().isa<SizeType>() || 1123 operands[1].getType().isa<SizeType>()) 1124 inferredReturnTypes.assign({SizeType::get(context)}); 1125 else 1126 inferredReturnTypes.assign({IndexType::get(context)}); 1127 return success(); 1128 } 1129 1130 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1131 // SizeType is compatible with IndexType. 1132 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1133 } 1134 1135 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1136 1137 //===----------------------------------------------------------------------===// 1138 // ShapeEqOp 1139 //===----------------------------------------------------------------------===// 1140 1141 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1142 bool allSame = true; 1143 if (!operands.empty() && !operands[0]) 1144 return {}; 1145 for (Attribute operand : operands.drop_front(1)) { 1146 if (!operand) 1147 return {}; 1148 allSame = allSame && operand == operands[0]; 1149 } 1150 return BoolAttr::get(getContext(), allSame); 1151 } 1152 1153 //===----------------------------------------------------------------------===// 1154 // IndexToSizeOp 1155 //===----------------------------------------------------------------------===// 1156 1157 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1158 // Constant values of both types, `shape.size` and `index`, are represented as 1159 // `IntegerAttr`s which makes constant folding simple. 1160 if (Attribute arg = operands[0]) 1161 return arg; 1162 return {}; 1163 } 1164 1165 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1166 MLIRContext *context) { 1167 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1168 } 1169 1170 //===----------------------------------------------------------------------===// 1171 // FromExtentsOp 1172 //===----------------------------------------------------------------------===// 1173 1174 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1175 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1176 return nullptr; 1177 SmallVector<int64_t, 6> extents; 1178 for (auto attr : operands) 1179 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1180 Builder builder(getContext()); 1181 return builder.getIndexTensorAttr(extents); 1182 } 1183 1184 //===----------------------------------------------------------------------===// 1185 // FunctionLibraryOp 1186 //===----------------------------------------------------------------------===// 1187 1188 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1189 StringRef name) { 1190 result.attributes.push_back(builder.getNamedAttr( 1191 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1192 } 1193 1194 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1195 auto attr = getMapping() 1196 .get(op->getName().getIdentifier()) 1197 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1198 if (!attr) 1199 return nullptr; 1200 return lookupSymbol<FuncOp>(attr); 1201 } 1202 1203 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1204 OperationState &result) { 1205 // Parse the op name. 1206 StringAttr nameAttr; 1207 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1208 result.attributes)) 1209 return failure(); 1210 1211 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1212 return failure(); 1213 1214 auto *bodyRegion = result.addRegion(); 1215 if (parser.parseRegion(*bodyRegion)) 1216 return failure(); 1217 1218 if (parser.parseKeyword("mapping")) 1219 return failure(); 1220 1221 DictionaryAttr mappingAttr; 1222 if (parser.parseAttribute(mappingAttr, 1223 parser.getBuilder().getType<NoneType>(), "mapping", 1224 result.attributes)) 1225 return failure(); 1226 return success(); 1227 } 1228 1229 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1230 p << ' '; 1231 p.printSymbolName(getName()); 1232 p.printOptionalAttrDictWithKeyword( 1233 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1234 p << ' '; 1235 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1236 /*printBlockTerminators=*/false); 1237 p << " mapping "; 1238 p.printAttributeWithoutType(getMappingAttr()); 1239 } 1240 1241 //===----------------------------------------------------------------------===// 1242 // FuncOp 1243 //===----------------------------------------------------------------------===// 1244 1245 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 1246 auto buildFuncType = 1247 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 1248 function_interface_impl::VariadicFlag, 1249 std::string &) { return builder.getFunctionType(argTypes, results); }; 1250 1251 return function_interface_impl::parseFunctionOp( 1252 parser, result, /*allowVariadic=*/false, buildFuncType); 1253 } 1254 1255 void FuncOp::print(OpAsmPrinter &p) { 1256 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 1257 } 1258 1259 //===----------------------------------------------------------------------===// 1260 // GetExtentOp 1261 //===----------------------------------------------------------------------===// 1262 1263 Optional<int64_t> GetExtentOp::getConstantDim() { 1264 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1265 return constSizeOp.getValue().getLimitedValue(); 1266 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1267 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1268 return llvm::None; 1269 } 1270 1271 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1272 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1273 if (!elements) 1274 return nullptr; 1275 Optional<int64_t> dim = getConstantDim(); 1276 if (!dim.hasValue()) 1277 return nullptr; 1278 if (dim.getValue() >= elements.getNumElements()) 1279 return nullptr; 1280 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1281 } 1282 1283 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1284 int64_t dim) { 1285 auto loc = result.location; 1286 auto dimAttr = builder.getIndexAttr(dim); 1287 if (shape.getType().isa<ShapeType>()) { 1288 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1289 build(builder, result, builder.getType<SizeType>(), shape, dim); 1290 } else { 1291 Value dim = 1292 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1293 build(builder, result, builder.getIndexType(), shape, dim); 1294 } 1295 } 1296 1297 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1298 MLIRContext *context, Optional<Location> location, ValueRange operands, 1299 DictionaryAttr attributes, RegionRange regions, 1300 SmallVectorImpl<Type> &inferredReturnTypes) { 1301 inferredReturnTypes.assign({IndexType::get(context)}); 1302 return success(); 1303 } 1304 1305 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1306 TypeRange r) { 1307 // SizeType is compatible with IndexType. 1308 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1309 } 1310 1311 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1312 1313 //===----------------------------------------------------------------------===// 1314 // IsBroadcastableOp 1315 //===----------------------------------------------------------------------===// 1316 1317 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1318 MLIRContext *context) { 1319 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1320 } 1321 1322 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1323 // Can always broadcast fewer than two shapes. 1324 if (operands.size() < 2) { 1325 return BoolAttr::get(getContext(), true); 1326 } 1327 1328 return nullptr; 1329 } 1330 1331 //===----------------------------------------------------------------------===// 1332 // MeetOp 1333 //===----------------------------------------------------------------------===// 1334 1335 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1336 MLIRContext *context, Optional<Location> location, ValueRange operands, 1337 DictionaryAttr attributes, RegionRange regions, 1338 SmallVectorImpl<Type> &inferredReturnTypes) { 1339 inferredReturnTypes.assign({operands[0].getType()}); 1340 return success(); 1341 } 1342 1343 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1344 if (l.size() != 1 || r.size() != 1) 1345 return false; 1346 if (l == r) 1347 return true; 1348 1349 Type lhs = l.front(); 1350 Type rhs = r.front(); 1351 1352 if (lhs != rhs) 1353 return false; 1354 1355 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1356 return true; 1357 1358 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1359 return true; 1360 return false; 1361 } 1362 1363 //===----------------------------------------------------------------------===// 1364 // RankOp 1365 //===----------------------------------------------------------------------===// 1366 1367 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1368 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1369 if (!shape) 1370 return {}; 1371 int64_t rank = shape.getNumElements(); 1372 Builder builder(getContext()); 1373 return builder.getIndexAttr(rank); 1374 } 1375 1376 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1377 /// Constant folding fails in cases where only the rank is constant, not the 1378 /// shape itself. 1379 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1380 /// 1381 /// Example: 1382 /// 1383 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1384 /// %rank = shape.rank %shape 1385 /// 1386 /// becomes 1387 /// 1388 /// %rank = shape.const_size 3 1389 1390 namespace { 1391 struct RankShapeOfCanonicalizationPattern 1392 : public OpRewritePattern<shape::RankOp> { 1393 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1394 1395 LogicalResult matchAndRewrite(shape::RankOp op, 1396 PatternRewriter &rewriter) const override { 1397 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1398 if (!shapeOfOp) 1399 return failure(); 1400 auto rankedTensorType = 1401 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1402 if (!rankedTensorType) 1403 return failure(); 1404 int64_t rank = rankedTensorType.getRank(); 1405 if (op.getType().isa<IndexType>()) { 1406 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1407 rank); 1408 } else if (op.getType().isa<shape::SizeType>()) { 1409 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1410 } else { 1411 return failure(); 1412 } 1413 return success(); 1414 } 1415 }; 1416 } // namespace 1417 1418 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1419 MLIRContext *context) { 1420 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1421 } 1422 1423 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1424 MLIRContext *context, Optional<Location> location, ValueRange operands, 1425 DictionaryAttr attributes, RegionRange regions, 1426 SmallVectorImpl<Type> &inferredReturnTypes) { 1427 if (operands[0].getType().isa<ShapeType>()) 1428 inferredReturnTypes.assign({SizeType::get(context)}); 1429 else 1430 inferredReturnTypes.assign({IndexType::get(context)}); 1431 return success(); 1432 } 1433 1434 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1435 // SizeType is compatible with IndexType. 1436 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1437 } 1438 1439 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1440 1441 //===----------------------------------------------------------------------===// 1442 // NumElementsOp 1443 //===----------------------------------------------------------------------===// 1444 1445 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1446 1447 // Fold only when argument constant. 1448 Attribute shape = operands[0]; 1449 if (!shape) 1450 return {}; 1451 1452 APInt product(64, 1); 1453 for (auto value : shape.cast<DenseIntElementsAttr>()) 1454 product *= value; 1455 Builder builder(getContext()); 1456 return builder.getIndexAttr(product.getLimitedValue()); 1457 } 1458 1459 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1460 MLIRContext *context, Optional<Location> location, ValueRange operands, 1461 DictionaryAttr attributes, RegionRange regions, 1462 SmallVectorImpl<Type> &inferredReturnTypes) { 1463 if (operands[0].getType().isa<ShapeType>()) 1464 inferredReturnTypes.assign({SizeType::get(context)}); 1465 else 1466 inferredReturnTypes.assign({IndexType::get(context)}); 1467 return success(); 1468 } 1469 1470 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1471 TypeRange r) { 1472 // SizeType is compatible with IndexType. 1473 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1474 } 1475 1476 LogicalResult shape::NumElementsOp::verify() { 1477 return verifySizeOrIndexOp(*this); 1478 } 1479 1480 //===----------------------------------------------------------------------===// 1481 // MaxOp 1482 //===----------------------------------------------------------------------===// 1483 1484 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1485 // If operands are equal, just propagate one. 1486 if (getLhs() == getRhs()) 1487 return getLhs(); 1488 return nullptr; 1489 } 1490 1491 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1492 MLIRContext *context, Optional<Location> location, ValueRange operands, 1493 DictionaryAttr attributes, RegionRange regions, 1494 SmallVectorImpl<Type> &inferredReturnTypes) { 1495 if (operands[0].getType() == operands[1].getType()) 1496 inferredReturnTypes.assign({operands[0].getType()}); 1497 else 1498 inferredReturnTypes.assign({SizeType::get(context)}); 1499 return success(); 1500 } 1501 1502 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1503 if (l.size() != 1 || r.size() != 1) 1504 return false; 1505 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1506 return true; 1507 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1508 return true; 1509 return false; 1510 } 1511 1512 //===----------------------------------------------------------------------===// 1513 // MinOp 1514 //===----------------------------------------------------------------------===// 1515 1516 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1517 // If operands are equal, just propagate one. 1518 if (getLhs() == getRhs()) 1519 return getLhs(); 1520 return nullptr; 1521 } 1522 1523 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1524 MLIRContext *context, Optional<Location> location, ValueRange operands, 1525 DictionaryAttr attributes, RegionRange regions, 1526 SmallVectorImpl<Type> &inferredReturnTypes) { 1527 if (operands[0].getType() == operands[1].getType()) 1528 inferredReturnTypes.assign({operands[0].getType()}); 1529 else 1530 inferredReturnTypes.assign({SizeType::get(context)}); 1531 return success(); 1532 } 1533 1534 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1535 if (l.size() != 1 || r.size() != 1) 1536 return false; 1537 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1538 return true; 1539 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1540 return true; 1541 return false; 1542 } 1543 1544 //===----------------------------------------------------------------------===// 1545 // MulOp 1546 //===----------------------------------------------------------------------===// 1547 1548 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1549 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1550 if (!lhs) 1551 return nullptr; 1552 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1553 if (!rhs) 1554 return nullptr; 1555 APInt folded = lhs.getValue() * rhs.getValue(); 1556 Type indexTy = IndexType::get(getContext()); 1557 return IntegerAttr::get(indexTy, folded); 1558 } 1559 1560 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1561 MLIRContext *context, Optional<Location> location, ValueRange operands, 1562 DictionaryAttr attributes, RegionRange regions, 1563 SmallVectorImpl<Type> &inferredReturnTypes) { 1564 if (operands[0].getType().isa<SizeType>() || 1565 operands[1].getType().isa<SizeType>()) 1566 inferredReturnTypes.assign({SizeType::get(context)}); 1567 else 1568 inferredReturnTypes.assign({IndexType::get(context)}); 1569 return success(); 1570 } 1571 1572 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1573 // SizeType is compatible with IndexType. 1574 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1575 } 1576 1577 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1578 1579 //===----------------------------------------------------------------------===// 1580 // ShapeOfOp 1581 //===----------------------------------------------------------------------===// 1582 1583 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1584 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1585 if (!type || !type.hasStaticShape()) 1586 return nullptr; 1587 Builder builder(getContext()); 1588 return builder.getIndexTensorAttr(type.getShape()); 1589 } 1590 1591 namespace { 1592 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1593 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1594 1595 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1596 PatternRewriter &rewriter) const override { 1597 if (!op.getArg().getType().isa<ShapedType>()) 1598 return failure(); 1599 if (op.getType().isa<ShapedType>()) 1600 return failure(); 1601 1602 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1603 op.getArg()); 1604 return success(); 1605 } 1606 }; 1607 1608 // Canonicalize 1609 // ``` 1610 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1611 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1612 // ``` 1613 // to 1614 // ``` 1615 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1616 // ``` 1617 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1618 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1619 1620 LogicalResult matchAndRewrite(tensor::CastOp op, 1621 PatternRewriter &rewriter) const override { 1622 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1623 if (!ty || ty.getRank() != 1) 1624 return failure(); 1625 1626 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1627 if (!shapeOfOp) 1628 return failure(); 1629 1630 // Argument type must be ranked and must not conflict. 1631 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1632 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1633 return failure(); 1634 1635 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1636 return success(); 1637 } 1638 }; 1639 } // namespace 1640 1641 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1642 MLIRContext *context) { 1643 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1644 ExtractFromShapeOfExtentTensor>(context); 1645 } 1646 1647 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1648 MLIRContext *context, Optional<Location> location, ValueRange operands, 1649 DictionaryAttr attributes, RegionRange regions, 1650 SmallVectorImpl<Type> &inferredReturnTypes) { 1651 if (operands[0].getType().isa<ValueShapeType>()) 1652 inferredReturnTypes.assign({ShapeType::get(context)}); 1653 else { 1654 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1655 int64_t rank = 1656 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1657 Type indexTy = IndexType::get(context); 1658 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1659 inferredReturnTypes.assign({extentTensorTy}); 1660 } 1661 return success(); 1662 } 1663 1664 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1665 if (l.size() != 1 || r.size() != 1) 1666 return false; 1667 if (l == r) 1668 return true; 1669 1670 Type lhs = l.front(); 1671 Type rhs = r.front(); 1672 1673 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1674 return false; 1675 1676 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1677 // Shape type is compatible with all other valid return types. 1678 return true; 1679 1680 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1681 return true; 1682 return false; 1683 } 1684 1685 LogicalResult shape::ShapeOfOp::verify() { 1686 return verifyShapeOrExtentTensorOp(*this); 1687 } 1688 1689 //===----------------------------------------------------------------------===// 1690 // SizeToIndexOp 1691 //===----------------------------------------------------------------------===// 1692 1693 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1694 // Constant values of both types, `shape.size` and `index`, are represented as 1695 // `IntegerAttr`s which makes constant folding simple. 1696 if (Attribute arg = operands[0]) 1697 return arg; 1698 return OpFoldResult(); 1699 } 1700 1701 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1702 MLIRContext *context) { 1703 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1704 } 1705 1706 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1707 if (inputs.size() != 1 || outputs.size() != 1) 1708 return false; 1709 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1710 } 1711 1712 //===----------------------------------------------------------------------===// 1713 // YieldOp 1714 //===----------------------------------------------------------------------===// 1715 1716 LogicalResult shape::YieldOp::verify() { 1717 auto *parentOp = (*this)->getParentOp(); 1718 auto results = parentOp->getResults(); 1719 auto operands = getOperands(); 1720 1721 if (parentOp->getNumResults() != getNumOperands()) 1722 return emitOpError() << "number of operands does not match number of " 1723 "results of its parent"; 1724 for (auto e : llvm::zip(results, operands)) 1725 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1726 return emitOpError() << "types mismatch between yield op and its parent"; 1727 1728 return success(); 1729 } 1730 1731 //===----------------------------------------------------------------------===// 1732 // SplitAtOp 1733 //===----------------------------------------------------------------------===// 1734 1735 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1736 SmallVectorImpl<OpFoldResult> &results) { 1737 if (!operands[0] || !operands[1]) 1738 return failure(); 1739 auto shapeVec = llvm::to_vector<6>( 1740 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1741 auto shape = llvm::makeArrayRef(shapeVec); 1742 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1743 // Verify that the split point is in the correct range. 1744 // TODO: Constant fold to an "error". 1745 int64_t rank = shape.size(); 1746 if (!(-rank <= splitPoint && splitPoint <= rank)) 1747 return failure(); 1748 if (splitPoint < 0) 1749 splitPoint += shape.size(); 1750 Builder builder(operands[0].getContext()); 1751 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1752 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1753 return success(); 1754 } 1755 1756 //===----------------------------------------------------------------------===// 1757 // ToExtentTensorOp 1758 //===----------------------------------------------------------------------===// 1759 1760 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1761 if (!operands[0]) 1762 return OpFoldResult(); 1763 Builder builder(getContext()); 1764 auto shape = llvm::to_vector<6>( 1765 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1766 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1767 builder.getIndexType()); 1768 return DenseIntElementsAttr::get(type, shape); 1769 } 1770 1771 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1772 if (inputs.size() != 1 || outputs.size() != 1) 1773 return false; 1774 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1775 if (!inputTensor.getElementType().isa<IndexType>() || 1776 inputTensor.getRank() != 1) 1777 return false; 1778 } else if (!inputs[0].isa<ShapeType>()) { 1779 return false; 1780 } 1781 1782 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1783 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1784 } 1785 1786 //===----------------------------------------------------------------------===// 1787 // ReduceOp 1788 //===----------------------------------------------------------------------===// 1789 1790 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1791 ValueRange initVals) { 1792 result.addOperands(shape); 1793 result.addOperands(initVals); 1794 1795 Region *bodyRegion = result.addRegion(); 1796 bodyRegion->push_back(new Block); 1797 Block &bodyBlock = bodyRegion->front(); 1798 bodyBlock.addArgument(builder.getIndexType(), result.location); 1799 1800 Type elementType; 1801 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1802 elementType = tensorType.getElementType(); 1803 else 1804 elementType = SizeType::get(builder.getContext()); 1805 bodyBlock.addArgument(elementType, shape.getLoc()); 1806 1807 for (Value initVal : initVals) { 1808 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1809 result.addTypes(initVal.getType()); 1810 } 1811 } 1812 1813 LogicalResult ReduceOp::verify() { 1814 // Verify block arg types. 1815 Block &block = getRegion().front(); 1816 1817 // The block takes index, extent, and aggregated values as arguments. 1818 auto blockArgsCount = getInitVals().size() + 2; 1819 if (block.getNumArguments() != blockArgsCount) 1820 return emitOpError() << "ReduceOp body is expected to have " 1821 << blockArgsCount << " arguments"; 1822 1823 // The first block argument is the index and must always be of type `index`. 1824 if (!block.getArgument(0).getType().isa<IndexType>()) 1825 return emitOpError( 1826 "argument 0 of ReduceOp body is expected to be of IndexType"); 1827 1828 // The second block argument is the extent and must be of type `size` or 1829 // `index`, depending on whether the reduce operation is applied to a shape or 1830 // to an extent tensor. 1831 Type extentTy = block.getArgument(1).getType(); 1832 if (getShape().getType().isa<ShapeType>()) { 1833 if (!extentTy.isa<SizeType>()) 1834 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1835 "SizeType if the ReduceOp operates on a ShapeType"); 1836 } else { 1837 if (!extentTy.isa<IndexType>()) 1838 return emitOpError( 1839 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1840 "ReduceOp operates on an extent tensor"); 1841 } 1842 1843 for (const auto &type : llvm::enumerate(getInitVals())) 1844 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1845 return emitOpError() << "type mismatch between argument " 1846 << type.index() + 2 1847 << " of ReduceOp body and initial value " 1848 << type.index(); 1849 return success(); 1850 } 1851 1852 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1853 // Parse operands. 1854 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; 1855 Type shapeOrExtentTensorType; 1856 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1857 OpAsmParser::Delimiter::Paren) || 1858 parser.parseColonType(shapeOrExtentTensorType) || 1859 parser.parseOptionalArrowTypeList(result.types)) 1860 return failure(); 1861 1862 // Resolve operands. 1863 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1864 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1865 result.operands) || 1866 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1867 result.operands)) 1868 return failure(); 1869 1870 // Parse the body. 1871 Region *body = result.addRegion(); 1872 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1873 return failure(); 1874 1875 // Parse attributes. 1876 if (parser.parseOptionalAttrDict(result.attributes)) 1877 return failure(); 1878 1879 return success(); 1880 } 1881 1882 void ReduceOp::print(OpAsmPrinter &p) { 1883 p << '(' << getShape() << ", " << getInitVals() 1884 << ") : " << getShape().getType(); 1885 p.printOptionalArrowTypeList(getResultTypes()); 1886 p << ' '; 1887 p.printRegion(getRegion()); 1888 p.printOptionalAttrDict((*this)->getAttrs()); 1889 } 1890 1891 #define GET_OP_CLASSES 1892 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1893