1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/CommonFolders.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Traits.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/SetOperations.h" 25 #include "llvm/ADT/SmallString.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace mlir; 30 using namespace mlir::shape; 31 32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 33 34 namespace { 35 #include "ShapeCanonicalization.inc" 36 } // namespace 37 38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { 39 return RankedTensorType::get({rank}, IndexType::get(ctx)); 40 } 41 42 bool shape::isExtentTensorType(Type type) { 43 auto ranked = type.dyn_cast<RankedTensorType>(); 44 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); 45 } 46 47 LogicalResult shape::getShapeVec(Value input, 48 SmallVectorImpl<int64_t> &shapeValues) { 49 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { 50 auto type = inputOp.getArg().getType().dyn_cast<ShapedType>(); 51 if (!type.hasRank()) 52 return failure(); 53 shapeValues = llvm::to_vector<6>(type.getShape()); 54 return success(); 55 } 56 if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { 57 shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); 58 return success(); 59 } 60 if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { 61 shapeValues = llvm::to_vector<6>( 62 inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>()); 63 return success(); 64 } 65 return failure(); 66 } 67 68 static bool isErrorPropagationPossible(TypeRange operandTypes) { 69 return llvm::any_of(operandTypes, [](Type ty) { 70 return ty.isa<SizeType, ShapeType, ValueShapeType>(); 71 }); 72 } 73 74 static LogicalResult verifySizeOrIndexOp(Operation *op) { 75 assert(op != nullptr && op->getNumResults() == 1); 76 Type resultTy = op->getResultTypes().front(); 77 if (isErrorPropagationPossible(op->getOperandTypes())) { 78 if (!resultTy.isa<SizeType>()) 79 return op->emitOpError() 80 << "if at least one of the operands can hold error values then " 81 "the result must be of type `size` to propagate them"; 82 } 83 return success(); 84 } 85 86 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { 87 assert(op != nullptr && op->getNumResults() == 1); 88 Type resultTy = op->getResultTypes().front(); 89 if (isErrorPropagationPossible(op->getOperandTypes())) { 90 if (!resultTy.isa<ShapeType>()) 91 return op->emitOpError() 92 << "if at least one of the operands can hold error values then " 93 "the result must be of type `shape` to propagate them"; 94 } 95 return success(); 96 } 97 98 template <typename... Ty> 99 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { 100 return typeRange.size() == 1 && typeRange.front().isa<Ty...>(); 101 } 102 103 template <typename... Ty, typename... ranges> 104 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { 105 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); 106 } 107 108 //===----------------------------------------------------------------------===// 109 // InlinerInterface 110 //===----------------------------------------------------------------------===// 111 112 namespace { 113 /// This class defines the interface for inlining shape dialect ops. 114 struct ShapeInlinerInterface : public DialectInlinerInterface { 115 using DialectInlinerInterface::DialectInlinerInterface; 116 117 // Returns true if the given region 'src' can be inlined into the region 118 // 'dest' that is attached to an operation registered to the current dialect. 119 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 120 BlockAndValueMapping &) const final { 121 return true; 122 } 123 124 // Returns true if the given operation 'op', that is registered to this 125 // dialect, can be inlined into the region 'dest' that is attached to an 126 // operation registered to the current dialect. 127 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, 128 BlockAndValueMapping &) const final { 129 return true; 130 } 131 }; 132 } // namespace 133 134 void ShapeDialect::initialize() { 135 addOperations< 136 #define GET_OP_LIST 137 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 138 >(); 139 addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>(); 140 addInterfaces<ShapeInlinerInterface>(); 141 // Allow unknown operations during prototyping and testing. As the dialect is 142 // still evolving it makes it simple to start with an unregistered ops and 143 // try different variants before actually defining the op. 144 allowUnknownOperations(); 145 } 146 147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder, 148 Attribute value, Type type, 149 Location loc) { 150 if (type.isa<ShapeType>() || isExtentTensorType(type)) 151 return builder.create<ConstShapeOp>(loc, type, 152 value.cast<DenseIntElementsAttr>()); 153 if (type.isa<SizeType>()) 154 return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>()); 155 if (type.isa<WitnessType>()) 156 return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>()); 157 if (arith::ConstantOp::isBuildableWith(value, type)) 158 return builder.create<arith::ConstantOp>(loc, type, value); 159 return nullptr; 160 } 161 162 /// Parse a type registered to this dialect. 163 Type ShapeDialect::parseType(DialectAsmParser &parser) const { 164 StringRef keyword; 165 if (parser.parseKeyword(&keyword)) 166 return Type(); 167 168 if (keyword == "shape") 169 return ShapeType::get(getContext()); 170 if (keyword == "size") 171 return SizeType::get(getContext()); 172 if (keyword == "value_shape") 173 return ValueShapeType::get(getContext()); 174 if (keyword == "witness") 175 return WitnessType::get(getContext()); 176 177 parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; 178 return Type(); 179 } 180 181 /// Print a type registered to this dialect. 182 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { 183 TypeSwitch<Type>(type) 184 .Case<ShapeType>([&](Type) { os << "shape"; }) 185 .Case<SizeType>([&](Type) { os << "size"; }) 186 .Case<ValueShapeType>([&](Type) { os << "value_shape"; }) 187 .Case<WitnessType>([&](Type) { os << "witness"; }) 188 .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); 189 } 190 191 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, 192 NamedAttribute attribute) { 193 // Verify shape.lib attribute. 194 if (attribute.getName() == "shape.lib") { 195 if (!op->hasTrait<OpTrait::SymbolTable>()) 196 return op->emitError( 197 "shape.lib attribute may only be on op implementing SymbolTable"); 198 199 if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) { 200 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); 201 if (!symbol) 202 return op->emitError("shape function library ") 203 << symbolRef << " not found"; 204 return isa<shape::FunctionLibraryOp>(symbol) 205 ? success() 206 : op->emitError() 207 << symbolRef << " required to be shape function library"; 208 } 209 210 if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) { 211 // Verify all entries are function libraries and mappings in libraries 212 // refer to unique ops. 213 DenseSet<StringAttr> key; 214 for (auto it : arr) { 215 if (!it.isa<SymbolRefAttr>()) 216 return op->emitError( 217 "only SymbolRefAttr allowed in shape.lib attribute array"); 218 219 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( 220 SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>())); 221 if (!shapeFnLib) 222 return op->emitError() 223 << it << " does not refer to FunctionLibraryOp"; 224 for (auto mapping : shapeFnLib.getMapping()) { 225 if (!key.insert(mapping.getName()).second) { 226 return op->emitError("only one op to shape mapping allowed, found " 227 "multiple for `") 228 << mapping.getName() << "`"; 229 } 230 } 231 } 232 return success(); 233 } 234 235 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " 236 "allowed as shape.lib attribute"); 237 } 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // AnyOp 243 //===----------------------------------------------------------------------===// 244 245 // TODO: Canonicalization should be implemented for shapes that can be 246 // determined through mixtures of the known dimensions of the inputs. 247 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) { 248 // Only the last operand is checked because AnyOp is commutative. 249 if (operands.back()) 250 return operands.back(); 251 252 return nullptr; 253 } 254 255 //===----------------------------------------------------------------------===// 256 // AssumingOp 257 //===----------------------------------------------------------------------===// 258 259 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { 260 result.regions.reserve(1); 261 Region *doRegion = result.addRegion(); 262 263 auto &builder = parser.getBuilder(); 264 OpAsmParser::OperandType cond; 265 if (parser.parseOperand(cond) || 266 parser.resolveOperand(cond, builder.getType<WitnessType>(), 267 result.operands)) 268 return failure(); 269 270 // Parse optional results type list. 271 if (parser.parseOptionalArrowTypeList(result.types)) 272 return failure(); 273 274 // Parse the region and add a terminator if elided. 275 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) 276 return failure(); 277 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); 278 279 // Parse the optional attribute list. 280 if (parser.parseOptionalAttrDict(result.attributes)) 281 return failure(); 282 return success(); 283 } 284 285 void AssumingOp::print(OpAsmPrinter &p) { 286 bool yieldsResults = !getResults().empty(); 287 288 p << " " << getWitness(); 289 if (yieldsResults) 290 p << " -> (" << getResultTypes() << ")"; 291 p << ' '; 292 p.printRegion(getDoRegion(), 293 /*printEntryBlockArgs=*/false, 294 /*printBlockTerminators=*/yieldsResults); 295 p.printOptionalAttrDict((*this)->getAttrs()); 296 } 297 298 namespace { 299 // Removes AssumingOp with a passing witness and inlines the region. 300 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { 301 using OpRewritePattern<AssumingOp>::OpRewritePattern; 302 303 LogicalResult matchAndRewrite(AssumingOp op, 304 PatternRewriter &rewriter) const override { 305 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); 306 if (!witness || !witness.getPassingAttr()) 307 return failure(); 308 309 AssumingOp::inlineRegionIntoParent(op, rewriter); 310 return success(); 311 } 312 }; 313 314 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { 315 using OpRewritePattern<AssumingOp>::OpRewritePattern; 316 317 LogicalResult matchAndRewrite(AssumingOp op, 318 PatternRewriter &rewriter) const override { 319 Block *body = op.getBody(); 320 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); 321 322 // Find used values. 323 SmallVector<Value, 4> newYieldOperands; 324 Value opResult, yieldOperand; 325 for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) { 326 std::tie(opResult, yieldOperand) = it; 327 if (!opResult.getUses().empty()) { 328 newYieldOperands.push_back(yieldOperand); 329 } 330 } 331 332 // Rewrite only if redundant results exist. 333 if (newYieldOperands.size() == yieldOp->getNumOperands()) 334 return failure(); 335 336 // Replace yield op in the old assuming op's body and move the entire region 337 // to the new assuming op. 338 rewriter.setInsertionPointToEnd(body); 339 auto newYieldOp = 340 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); 341 rewriter.setInsertionPoint(op); 342 auto newOp = rewriter.create<AssumingOp>( 343 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); 344 newOp.getDoRegion().takeBody(op.getDoRegion()); 345 346 // Use the new results to replace the previously used ones. 347 SmallVector<Value, 4> replacementValues; 348 auto src = newOp.getResults().begin(); 349 for (auto it : op.getResults()) { 350 if (it.getUses().empty()) 351 replacementValues.push_back(nullptr); 352 else 353 replacementValues.push_back(*src++); 354 } 355 rewriter.replaceOp(op, replacementValues); 356 return success(); 357 } 358 }; 359 } // namespace 360 361 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 362 MLIRContext *context) { 363 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); 364 } 365 366 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td 367 void AssumingOp::getSuccessorRegions( 368 Optional<unsigned> index, ArrayRef<Attribute> operands, 369 SmallVectorImpl<RegionSuccessor> ®ions) { 370 // AssumingOp has unconditional control flow into the region and back to the 371 // parent, so return the correct RegionSuccessor purely based on the index 372 // being None or 0. 373 if (index.hasValue()) { 374 regions.push_back(RegionSuccessor(getResults())); 375 return; 376 } 377 378 regions.push_back(RegionSuccessor(&getDoRegion())); 379 } 380 381 void AssumingOp::inlineRegionIntoParent(AssumingOp &op, 382 PatternRewriter &rewriter) { 383 auto *blockBeforeAssuming = rewriter.getInsertionBlock(); 384 auto *assumingBlock = op.getBody(); 385 auto initPosition = rewriter.getInsertionPoint(); 386 auto *blockAfterAssuming = 387 rewriter.splitBlock(blockBeforeAssuming, initPosition); 388 389 // Remove the AssumingOp and AssumingYieldOp. 390 auto &yieldOp = assumingBlock->back(); 391 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); 392 rewriter.replaceOp(op, yieldOp.getOperands()); 393 rewriter.eraseOp(&yieldOp); 394 395 // Merge blocks together as there was no branching behavior from the 396 // AssumingOp. 397 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); 398 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); 399 } 400 401 void AssumingOp::build( 402 OpBuilder &builder, OperationState &result, Value witness, 403 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { 404 405 result.addOperands(witness); 406 Region *bodyRegion = result.addRegion(); 407 bodyRegion->push_back(new Block); 408 Block &bodyBlock = bodyRegion->front(); 409 410 // Build body. 411 OpBuilder::InsertionGuard guard(builder); 412 builder.setInsertionPointToStart(&bodyBlock); 413 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); 414 builder.create<AssumingYieldOp>(result.location, yieldValues); 415 416 SmallVector<Type, 2> assumingTypes; 417 for (Value v : yieldValues) 418 assumingTypes.push_back(v.getType()); 419 result.addTypes(assumingTypes); 420 } 421 422 LogicalResult AssumingOp::verify() { 423 return RegionBranchOpInterface::verifyTypes(*this); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // AddOp 428 //===----------------------------------------------------------------------===// 429 430 LogicalResult mlir::shape::AddOp::inferReturnTypes( 431 MLIRContext *context, Optional<Location> location, ValueRange operands, 432 DictionaryAttr attributes, RegionRange regions, 433 SmallVectorImpl<Type> &inferredReturnTypes) { 434 if (operands[0].getType().isa<SizeType>() || 435 operands[1].getType().isa<SizeType>()) 436 inferredReturnTypes.assign({SizeType::get(context)}); 437 else 438 inferredReturnTypes.assign({IndexType::get(context)}); 439 return success(); 440 } 441 442 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 443 // SizeType is compatible with IndexType. 444 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 445 } 446 447 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) { 448 // add(x, 0) -> x 449 if (matchPattern(getRhs(), m_Zero())) 450 return getLhs(); 451 452 return constFoldBinaryOp<IntegerAttr>( 453 operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); 454 } 455 456 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } 457 458 //===----------------------------------------------------------------------===// 459 // AssumingAllOp 460 //===----------------------------------------------------------------------===// 461 462 namespace { 463 464 // Merge multiple `shape.assuming_all` operations together. 465 // 466 // %0 = shape.assuming_all %w0, %w1 467 // %1 = shape.assuming_all %w2, %0 468 // 469 // to: 470 // 471 // %0 = shape.assuming_all %w0, %w2, %w2 472 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { 473 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(AssumingAllOp op, 476 PatternRewriter &rewriter) const override { 477 SmallVector<Value> operands; 478 479 for (Value operand : op.getInputs()) { 480 if (auto assume_all = operand.getDefiningOp<AssumingAllOp>()) 481 operands.append(assume_all.operand_begin(), assume_all->operand_end()); 482 else 483 operands.push_back(operand); 484 } 485 486 // We didn't find any other `assuming_all` ops to merge with. 487 if (operands.size() == op.getNumOperands()) 488 return failure(); 489 490 // Replace with a new `assuming_all` operation with merged constraints. 491 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); 492 return success(); 493 } 494 }; 495 496 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that 497 // are subsumed by others. 498 // 499 // %0 = shape.cstr_broadcastable %shape0, %shape1 500 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 501 // 502 // %2 = shape.cstr_broadcastable %shape3, %shape4 503 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 504 // 505 // %4 = shape.assuming_all %0, %1, %2, %3 506 // 507 // to: 508 // 509 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 510 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 511 // %2 = shape.assuming_all %0, %1 512 // 513 // In this example if shapes [0, 1, 2] are broadcastable, then it means that 514 // shapes [0, 1] are broadcastable too, and can be removed from the list of 515 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't 516 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). 517 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { 518 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 519 520 LogicalResult matchAndRewrite(AssumingAllOp op, 521 PatternRewriter &rewriter) const override { 522 // Collect all `CstrBroadcastableOp` operands first. 523 SetVector<CstrBroadcastableOp> operands; 524 for (Value operand : op.getInputs()) { 525 // TODO: Apply this optimization if some of the witnesses are not 526 // produced by the `cstr_broadcastable`. 527 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); 528 if (!broadcastable) 529 return failure(); 530 531 operands.insert(broadcastable); 532 } 533 534 // Skip trivial `assuming_all` operations. 535 if (operands.size() <= 1) 536 return failure(); 537 538 // Collect shapes checked by `cstr_broadcastable` operands. 539 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; 540 for (auto cstr : operands) { 541 DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end()); 542 shapes.emplace_back(cstr, std::move(shapes_set)); 543 } 544 545 // Sort by the number of shape operands (larger to smaller). 546 llvm::sort(shapes, [](auto a, auto b) { 547 return a.first.getNumOperands() > b.first.getNumOperands(); 548 }); 549 550 // We start from the `cst_broadcastable` operations with largest number of 551 // shape operands, and remove redundant `cst_broadcastable` operations. We 552 // do this until we find a set of `cst_broadcastable` operations with 553 // non-overlapping constraints. 554 SmallVector<CstrBroadcastableOp> marked_for_erase; 555 556 for (unsigned i = 0; i < shapes.size(); ++i) { 557 auto isSubset = [&](auto pair) { 558 return llvm::set_is_subset(pair.second, shapes[i].second); 559 }; 560 561 // Keep redundant `cstr_broadcastable` operations to be erased. 562 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); 563 for (auto *it0 = it; it0 < shapes.end(); ++it0) 564 marked_for_erase.push_back(it0->first); 565 shapes.erase(it, shapes.end()); 566 } 567 568 // We didn't find any operands that could be removed. 569 if (marked_for_erase.empty()) 570 return failure(); 571 572 // Collect non-overlapping `cst_broadcastable` constraints. 573 SmallVector<Value> unique_constraints; 574 for (auto &shape : shapes) 575 unique_constraints.push_back(shape.first.getResult()); 576 577 // Replace with a new `assuming_all` operation ... 578 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints); 579 580 // ... and maybe erase `cstr_broadcastable` ops without uses. 581 for (auto &op : marked_for_erase) 582 if (op->use_empty()) 583 rewriter.eraseOp(op); 584 585 return success(); 586 } 587 }; 588 589 struct AssumingAllToCstrEqCanonicalization 590 : public OpRewritePattern<AssumingAllOp> { 591 using OpRewritePattern<AssumingAllOp>::OpRewritePattern; 592 593 LogicalResult matchAndRewrite(AssumingAllOp op, 594 PatternRewriter &rewriter) const override { 595 SmallVector<Value, 8> shapes; 596 for (Value w : op.getInputs()) { 597 auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); 598 if (!cstrEqOp) 599 return failure(); 600 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { 601 return llvm::is_contained(shapes, s); 602 }); 603 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) 604 return failure(); 605 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); 606 } 607 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); 608 return success(); 609 } 610 }; 611 612 template <typename OpTy> 613 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { 614 using OpRewritePattern<OpTy>::OpRewritePattern; 615 616 LogicalResult matchAndRewrite(OpTy op, 617 PatternRewriter &rewriter) const override { 618 // Find unique operands. 619 SetVector<Value> unique(op.operand_begin(), op.operand_end()); 620 621 // Reduce op to equivalent with unique operands. 622 if (unique.size() < op.getNumOperands()) { 623 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), 624 unique.takeVector(), op->getAttrs()); 625 return success(); 626 } 627 628 return failure(); 629 } 630 }; 631 } // namespace 632 633 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 634 MLIRContext *context) { 635 patterns 636 .add<MergeAssumingAllOps, AssumingAllOneOp, 637 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, 638 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); 639 } 640 641 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) { 642 // Iterate in reverse to first handle all constant operands. They are 643 // guaranteed to be the tail of the inputs because this is commutative. 644 for (int idx = operands.size() - 1; idx >= 0; idx--) { 645 Attribute a = operands[idx]; 646 // Cannot fold if any inputs are not constant; 647 if (!a) 648 return nullptr; 649 650 // We do not need to keep statically known values after handling them in 651 // this method. 652 getOperation()->eraseOperand(idx); 653 654 // Always false if any input is statically known false 655 if (!a.cast<BoolAttr>().getValue()) 656 return a; 657 } 658 // If this is reached, all inputs were statically known passing. 659 return BoolAttr::get(getContext(), true); 660 } 661 662 LogicalResult AssumingAllOp::verify() { 663 // Ensure that AssumingAllOp contains at least one operand 664 if (getNumOperands() == 0) 665 return emitOpError("no operands specified"); 666 667 return success(); 668 } 669 670 void AssumingAllOp::build(OpBuilder &b, OperationState &state, 671 ValueRange inputs) { 672 build(b, state, b.getType<WitnessType>(), inputs); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // BroadcastOp 677 //===----------------------------------------------------------------------===// 678 679 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { 680 if (getShapes().size() == 1) { 681 // Otherwise, we need a cast which would be a canonicalization, not folding. 682 if (getShapes().front().getType() != getType()) 683 return nullptr; 684 return getShapes().front(); 685 } 686 687 // TODO: Support folding with more than 2 input shapes 688 if (getShapes().size() > 2) 689 return nullptr; 690 691 if (!operands[0] || !operands[1]) 692 return nullptr; 693 auto lhsShape = llvm::to_vector<6>( 694 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 695 auto rhsShape = llvm::to_vector<6>( 696 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 697 SmallVector<int64_t, 6> resultShape; 698 699 // If the shapes are not compatible, we can't fold it. 700 // TODO: Fold to an "error". 701 if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) 702 return nullptr; 703 704 Builder builder(getContext()); 705 return builder.getIndexTensorAttr(resultShape); 706 } 707 708 LogicalResult BroadcastOp::verify() { 709 return verifyShapeOrExtentTensorOp(*this); 710 } 711 712 namespace { 713 template <typename OpTy> 714 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { 715 using OpRewritePattern<OpTy>::OpRewritePattern; 716 717 LogicalResult matchAndRewrite(OpTy op, 718 PatternRewriter &rewriter) const override { 719 auto isPotentiallyNonEmptyShape = [](Value shape) { 720 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 721 if (extentTensorTy.getDimSize(0) == 0) 722 return false; 723 } 724 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 725 if (constShape.getShape().empty()) 726 return false; 727 } 728 return true; 729 }; 730 auto newOperands = llvm::to_vector<8>( 731 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); 732 733 // Reduce op to equivalent without empty shape operands. 734 if (newOperands.size() < op.getNumOperands()) { 735 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, 736 op->getAttrs()); 737 return success(); 738 } 739 740 return failure(); 741 } 742 }; 743 744 struct BroadcastForwardSingleOperandPattern 745 : public OpRewritePattern<BroadcastOp> { 746 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 747 748 LogicalResult matchAndRewrite(BroadcastOp op, 749 PatternRewriter &rewriter) const override { 750 if (op.getNumOperands() != 1) 751 return failure(); 752 Value replacement = op.getShapes().front(); 753 754 // Insert cast if needed. 755 if (replacement.getType() != op.getType()) { 756 auto loc = op.getLoc(); 757 if (op.getType().isa<ShapeType>()) { 758 replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); 759 } else { 760 assert(!op.getType().isa<ShapeType>() && 761 !replacement.getType().isa<ShapeType>() && 762 "expect extent tensor cast"); 763 replacement = 764 rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); 765 } 766 } 767 768 rewriter.replaceOp(op, replacement); 769 return success(); 770 } 771 }; 772 773 struct BroadcastFoldConstantOperandsPattern 774 : public OpRewritePattern<BroadcastOp> { 775 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 776 777 LogicalResult matchAndRewrite(BroadcastOp op, 778 PatternRewriter &rewriter) const override { 779 SmallVector<int64_t, 8> foldedConstantShape; 780 SmallVector<Value, 8> newShapeOperands; 781 for (Value shape : op.getShapes()) { 782 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { 783 SmallVector<int64_t, 8> newFoldedConstantShape; 784 if (OpTrait::util::getBroadcastedShape( 785 foldedConstantShape, 786 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), 787 newFoldedConstantShape)) { 788 foldedConstantShape = newFoldedConstantShape; 789 continue; 790 } 791 } 792 newShapeOperands.push_back(shape); 793 } 794 795 // Need at least two constant operands to fold anything. 796 if (op.getNumOperands() - newShapeOperands.size() < 2) 797 return failure(); 798 799 auto foldedConstantOperandsTy = RankedTensorType::get( 800 {static_cast<int64_t>(foldedConstantShape.size())}, 801 rewriter.getIndexType()); 802 newShapeOperands.push_back(rewriter.create<ConstShapeOp>( 803 op.getLoc(), foldedConstantOperandsTy, 804 rewriter.getIndexTensorAttr(foldedConstantShape))); 805 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), 806 newShapeOperands); 807 return success(); 808 } 809 }; 810 811 template <typename OpTy> 812 struct CanonicalizeCastExtentTensorOperandsPattern 813 : public OpRewritePattern<OpTy> { 814 using OpRewritePattern<OpTy>::OpRewritePattern; 815 816 LogicalResult matchAndRewrite(OpTy op, 817 PatternRewriter &rewriter) const override { 818 // Canonicalize operands. 819 bool anyChange = false; 820 auto canonicalizeOperand = [&](Value operand) { 821 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { 822 // Only eliminate the cast if it holds no shape information. 823 bool isInformationLoosingCast = 824 castOp.getType().cast<RankedTensorType>().isDynamicDim(0); 825 if (isInformationLoosingCast) { 826 anyChange = true; 827 return castOp.source(); 828 } 829 } 830 return operand; 831 }; 832 auto newOperands = llvm::to_vector<8>( 833 llvm::map_range(op.getOperands(), canonicalizeOperand)); 834 835 // Rewrite op if any change required. 836 if (!anyChange) 837 return failure(); 838 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); 839 return success(); 840 } 841 }; 842 843 struct BroadcastConcretizeResultTypePattern 844 : public OpRewritePattern<BroadcastOp> { 845 using OpRewritePattern<BroadcastOp>::OpRewritePattern; 846 847 LogicalResult matchAndRewrite(BroadcastOp op, 848 PatternRewriter &rewriter) const override { 849 // Only concretize dynamic extent tensor result types. 850 auto resultTy = op.getType().dyn_cast<RankedTensorType>(); 851 if (!resultTy || !resultTy.isDynamicDim(0)) 852 return failure(); 853 854 // Infer resulting shape rank if possible. 855 int64_t maxRank = 0; 856 for (Value shape : op.getShapes()) { 857 if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { 858 // Cannot infer resulting shape rank if any operand is dynamically 859 // ranked. 860 if (extentTensorTy.isDynamicDim(0)) 861 return failure(); 862 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); 863 } 864 } 865 866 auto newOp = rewriter.create<BroadcastOp>( 867 op.getLoc(), getExtentTensorType(getContext(), maxRank), 868 op.getShapes()); 869 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 870 return success(); 871 } 872 }; 873 } // namespace 874 875 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 876 MLIRContext *context) { 877 patterns.add<BroadcastConcretizeResultTypePattern, 878 BroadcastFoldConstantOperandsPattern, 879 BroadcastForwardSingleOperandPattern, 880 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, 881 RemoveDuplicateOperandsPattern<BroadcastOp>, 882 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); 883 } 884 885 //===----------------------------------------------------------------------===// 886 // ConcatOp 887 //===----------------------------------------------------------------------===// 888 889 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) { 890 if (!operands[0] || !operands[1]) 891 return nullptr; 892 auto lhsShape = llvm::to_vector<6>( 893 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 894 auto rhsShape = llvm::to_vector<6>( 895 operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>()); 896 SmallVector<int64_t, 6> resultShape; 897 resultShape.append(lhsShape.begin(), lhsShape.end()); 898 resultShape.append(rhsShape.begin(), rhsShape.end()); 899 Builder builder(getContext()); 900 return builder.getIndexTensorAttr(resultShape); 901 } 902 903 //===----------------------------------------------------------------------===// 904 // ConstShapeOp 905 //===----------------------------------------------------------------------===// 906 907 void ConstShapeOp::print(OpAsmPrinter &p) { 908 p << " "; 909 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); 910 p << "["; 911 interleaveComma(getShape().getValues<int64_t>(), p); 912 p << "] : "; 913 p.printType(getType()); 914 } 915 916 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { 917 if (parser.parseOptionalAttrDict(result.attributes)) 918 return failure(); 919 // We piggy-back on ArrayAttr parsing, though we don't internally store the 920 // shape as an ArrayAttr. 921 // TODO: Implement custom parser and maybe make syntax a bit more concise. 922 Attribute extentsRaw; 923 NamedAttrList dummy; 924 if (parser.parseAttribute(extentsRaw, "dummy", dummy)) 925 return failure(); 926 auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>(); 927 if (!extentsArray) 928 return failure(); 929 SmallVector<int64_t, 6> ints; 930 for (Attribute extent : extentsArray) { 931 IntegerAttr attr = extent.dyn_cast<IntegerAttr>(); 932 if (!attr) 933 return failure(); 934 ints.push_back(attr.getInt()); 935 } 936 Builder &builder = parser.getBuilder(); 937 result.addAttribute("shape", builder.getIndexTensorAttr(ints)); 938 Type resultTy; 939 if (parser.parseColonType(resultTy)) 940 return failure(); 941 result.types.push_back(resultTy); 942 return success(); 943 } 944 945 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); } 946 947 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 948 MLIRContext *context) { 949 patterns.add<TensorCastConstShape>(context); 950 } 951 952 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( 953 MLIRContext *context, Optional<Location> location, ValueRange operands, 954 DictionaryAttr attributes, RegionRange regions, 955 SmallVectorImpl<Type> &inferredReturnTypes) { 956 Builder b(context); 957 auto shape = attributes.getAs<DenseIntElementsAttr>("shape"); 958 if (!shape) 959 return emitOptionalError(location, "missing shape attribute"); 960 inferredReturnTypes.assign({RankedTensorType::get( 961 {static_cast<int64_t>(shape.size())}, b.getIndexType())}); 962 return success(); 963 } 964 965 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, 966 TypeRange r) { 967 if (l.size() != 1 || r.size() != 1) 968 return false; 969 970 Type lhs = l.front(); 971 Type rhs = r.front(); 972 973 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 974 // Shape type is compatible with all other valid return types. 975 return true; 976 return lhs == rhs; 977 } 978 979 //===----------------------------------------------------------------------===// 980 // CstrBroadcastableOp 981 //===----------------------------------------------------------------------===// 982 983 void CstrBroadcastableOp::getCanonicalizationPatterns( 984 RewritePatternSet &patterns, MLIRContext *context) { 985 // Canonicalization patterns have overlap with the considerations during 986 // folding in case additional shape information is inferred at some point that 987 // does not result in folding. 988 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, 989 CstrBroadcastableEqOps, 990 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, 991 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); 992 } 993 994 // Return true if there is exactly one attribute not representing a scalar 995 // broadcast. 996 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { 997 bool nonScalarSeen = false; 998 for (Attribute a : attributes) { 999 if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) { 1000 if (nonScalarSeen) 1001 return false; 1002 nonScalarSeen = true; 1003 } 1004 } 1005 return true; 1006 } 1007 1008 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1009 // No broadcasting is needed if all operands but one are scalar. 1010 if (hasAtMostSingleNonScalar(operands)) 1011 return BoolAttr::get(getContext(), true); 1012 1013 if ([&] { 1014 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1015 for (const auto &operand : operands) { 1016 if (!operand) 1017 return false; 1018 extents.push_back(llvm::to_vector<6>( 1019 operand.cast<DenseIntElementsAttr>().getValues<int64_t>())); 1020 } 1021 return OpTrait::util::staticallyKnownBroadcastable(extents); 1022 }()) 1023 return BoolAttr::get(getContext(), true); 1024 1025 // Lastly, see if folding can be completed based on what constraints are known 1026 // on the input shapes. 1027 if ([&] { 1028 SmallVector<SmallVector<int64_t, 6>, 6> extents; 1029 for (auto shapeValue : getShapes()) { 1030 extents.emplace_back(); 1031 if (failed(getShapeVec(shapeValue, extents.back()))) 1032 return false; 1033 } 1034 return OpTrait::util::staticallyKnownBroadcastable(extents); 1035 }()) 1036 return BoolAttr::get(getContext(), true); 1037 1038 // Because a failing witness result here represents an eventual assertion 1039 // failure, we do not replace it with a constant witness. 1040 return nullptr; 1041 } 1042 1043 LogicalResult CstrBroadcastableOp::verify() { 1044 // Ensure that CstrBroadcastableOp contains at least two operands 1045 if (getNumOperands() < 2) 1046 return emitOpError("required at least 2 input shapes"); 1047 return success(); 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // CstrEqOp 1052 //===----------------------------------------------------------------------===// 1053 1054 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1055 MLIRContext *context) { 1056 // If inputs are equal, return passing witness 1057 patterns.add<CstrEqEqOps>(context); 1058 } 1059 1060 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) { 1061 if (llvm::all_of(operands, 1062 [&](Attribute a) { return a && a == operands[0]; })) 1063 return BoolAttr::get(getContext(), true); 1064 1065 // Because a failing witness result here represents an eventual assertion 1066 // failure, we do not try to replace it with a constant witness. Similarly, we 1067 // cannot if there are any non-const inputs. 1068 return nullptr; 1069 } 1070 1071 //===----------------------------------------------------------------------===// 1072 // ConstSizeOp 1073 //===----------------------------------------------------------------------===// 1074 1075 void ConstSizeOp::build(OpBuilder &builder, OperationState &result, 1076 int64_t value) { 1077 build(builder, result, builder.getIndexAttr(value)); 1078 } 1079 1080 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); } 1081 1082 void ConstSizeOp::getAsmResultNames( 1083 llvm::function_ref<void(Value, StringRef)> setNameFn) { 1084 SmallString<4> buffer; 1085 llvm::raw_svector_ostream os(buffer); 1086 os << "c" << getValue(); 1087 setNameFn(getResult(), os.str()); 1088 } 1089 1090 //===----------------------------------------------------------------------===// 1091 // ConstWitnessOp 1092 //===----------------------------------------------------------------------===// 1093 1094 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { 1095 return getPassingAttr(); 1096 } 1097 1098 //===----------------------------------------------------------------------===// 1099 // CstrRequireOp 1100 //===----------------------------------------------------------------------===// 1101 1102 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) { 1103 return operands[0]; 1104 } 1105 1106 //===----------------------------------------------------------------------===// 1107 // DivOp 1108 //===----------------------------------------------------------------------===// 1109 1110 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) { 1111 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1112 if (!lhs) 1113 return nullptr; 1114 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1115 if (!rhs) 1116 return nullptr; 1117 1118 // Division in APInt does not follow floor(lhs, rhs) when the result is 1119 // negative. Rather, APInt rounds toward zero. 1120 APInt quotient, remainder; 1121 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); 1122 if (quotient.isNegative() && !remainder.isNullValue()) { 1123 quotient -= 1; 1124 } 1125 1126 Type indexTy = IndexType::get(getContext()); 1127 return IntegerAttr::get(indexTy, quotient); 1128 } 1129 1130 LogicalResult mlir::shape::DivOp::inferReturnTypes( 1131 MLIRContext *context, Optional<Location> location, ValueRange operands, 1132 DictionaryAttr attributes, RegionRange regions, 1133 SmallVectorImpl<Type> &inferredReturnTypes) { 1134 if (operands[0].getType().isa<SizeType>() || 1135 operands[1].getType().isa<SizeType>()) 1136 inferredReturnTypes.assign({SizeType::get(context)}); 1137 else 1138 inferredReturnTypes.assign({IndexType::get(context)}); 1139 return success(); 1140 } 1141 1142 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1143 // SizeType is compatible with IndexType. 1144 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1145 } 1146 1147 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } 1148 1149 //===----------------------------------------------------------------------===// 1150 // ShapeEqOp 1151 //===----------------------------------------------------------------------===// 1152 1153 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) { 1154 bool allSame = true; 1155 if (!operands.empty() && !operands[0]) 1156 return {}; 1157 for (Attribute operand : operands.drop_front(1)) { 1158 if (!operand) 1159 return {}; 1160 allSame = allSame && operand == operands[0]; 1161 } 1162 return BoolAttr::get(getContext(), allSame); 1163 } 1164 1165 //===----------------------------------------------------------------------===// 1166 // IndexToSizeOp 1167 //===----------------------------------------------------------------------===// 1168 1169 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) { 1170 // Constant values of both types, `shape.size` and `index`, are represented as 1171 // `IntegerAttr`s which makes constant folding simple. 1172 if (Attribute arg = operands[0]) 1173 return arg; 1174 return {}; 1175 } 1176 1177 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1178 MLIRContext *context) { 1179 patterns.add<SizeToIndexToSizeCanonicalization>(context); 1180 } 1181 1182 //===----------------------------------------------------------------------===// 1183 // FromExtentsOp 1184 //===----------------------------------------------------------------------===// 1185 1186 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) { 1187 if (llvm::any_of(operands, [](Attribute a) { return !a; })) 1188 return nullptr; 1189 SmallVector<int64_t, 6> extents; 1190 for (auto attr : operands) 1191 extents.push_back(attr.cast<IntegerAttr>().getInt()); 1192 Builder builder(getContext()); 1193 return builder.getIndexTensorAttr(extents); 1194 } 1195 1196 //===----------------------------------------------------------------------===// 1197 // FunctionLibraryOp 1198 //===----------------------------------------------------------------------===// 1199 1200 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, 1201 StringRef name) { 1202 result.attributes.push_back(builder.getNamedAttr( 1203 ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1204 } 1205 1206 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { 1207 auto attr = getMapping() 1208 .get(op->getName().getIdentifier()) 1209 .dyn_cast_or_null<FlatSymbolRefAttr>(); 1210 if (!attr) 1211 return nullptr; 1212 return lookupSymbol<FuncOp>(attr); 1213 } 1214 1215 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, 1216 OperationState &result) { 1217 // Parse the op name. 1218 StringAttr nameAttr; 1219 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1220 result.attributes)) 1221 return failure(); 1222 1223 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1224 return failure(); 1225 1226 auto *bodyRegion = result.addRegion(); 1227 if (parser.parseRegion(*bodyRegion)) 1228 return failure(); 1229 1230 if (parser.parseKeyword("mapping")) 1231 return failure(); 1232 1233 DictionaryAttr mappingAttr; 1234 if (parser.parseAttribute(mappingAttr, 1235 parser.getBuilder().getType<NoneType>(), "mapping", 1236 result.attributes)) 1237 return failure(); 1238 return success(); 1239 } 1240 1241 void FunctionLibraryOp::print(OpAsmPrinter &p) { 1242 p << ' '; 1243 p.printSymbolName(getName()); 1244 p.printOptionalAttrDictWithKeyword( 1245 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); 1246 p << ' '; 1247 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1248 /*printBlockTerminators=*/false); 1249 p << " mapping "; 1250 p.printAttributeWithoutType(getMappingAttr()); 1251 } 1252 1253 //===----------------------------------------------------------------------===// 1254 // GetExtentOp 1255 //===----------------------------------------------------------------------===// 1256 1257 Optional<int64_t> GetExtentOp::getConstantDim() { 1258 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) 1259 return constSizeOp.getValue().getLimitedValue(); 1260 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) 1261 return constantOp.getValue().cast<IntegerAttr>().getInt(); 1262 return llvm::None; 1263 } 1264 1265 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) { 1266 auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1267 if (!elements) 1268 return nullptr; 1269 Optional<int64_t> dim = getConstantDim(); 1270 if (!dim.hasValue()) 1271 return nullptr; 1272 if (dim.getValue() >= elements.getNumElements()) 1273 return nullptr; 1274 return elements.getValues<Attribute>()[(uint64_t)dim.getValue()]; 1275 } 1276 1277 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, 1278 int64_t dim) { 1279 auto loc = result.location; 1280 auto dimAttr = builder.getIndexAttr(dim); 1281 if (shape.getType().isa<ShapeType>()) { 1282 Value dim = builder.create<ConstSizeOp>(loc, dimAttr); 1283 build(builder, result, builder.getType<SizeType>(), shape, dim); 1284 } else { 1285 Value dim = 1286 builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); 1287 build(builder, result, builder.getIndexType(), shape, dim); 1288 } 1289 } 1290 1291 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( 1292 MLIRContext *context, Optional<Location> location, ValueRange operands, 1293 DictionaryAttr attributes, RegionRange regions, 1294 SmallVectorImpl<Type> &inferredReturnTypes) { 1295 inferredReturnTypes.assign({IndexType::get(context)}); 1296 return success(); 1297 } 1298 1299 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, 1300 TypeRange r) { 1301 // SizeType is compatible with IndexType. 1302 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1303 } 1304 1305 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } 1306 1307 //===----------------------------------------------------------------------===// 1308 // IsBroadcastableOp 1309 //===----------------------------------------------------------------------===// 1310 1311 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1312 MLIRContext *context) { 1313 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); 1314 } 1315 1316 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) { 1317 // Can always broadcast fewer than two shapes. 1318 if (operands.size() < 2) { 1319 return BoolAttr::get(getContext(), true); 1320 } 1321 1322 return nullptr; 1323 } 1324 1325 //===----------------------------------------------------------------------===// 1326 // MeetOp 1327 //===----------------------------------------------------------------------===// 1328 1329 LogicalResult mlir::shape::MeetOp::inferReturnTypes( 1330 MLIRContext *context, Optional<Location> location, ValueRange operands, 1331 DictionaryAttr attributes, RegionRange regions, 1332 SmallVectorImpl<Type> &inferredReturnTypes) { 1333 inferredReturnTypes.assign({operands[0].getType()}); 1334 return success(); 1335 } 1336 1337 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1338 if (l.size() != 1 || r.size() != 1) 1339 return false; 1340 if (l == r) 1341 return true; 1342 1343 Type lhs = l.front(); 1344 Type rhs = r.front(); 1345 1346 if (lhs != rhs) 1347 return false; 1348 1349 if (lhs.isa<SizeType>() || lhs.isa<ShapeType>()) 1350 return true; 1351 1352 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1353 return true; 1354 return false; 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // RankOp 1359 //===----------------------------------------------------------------------===// 1360 1361 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) { 1362 auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>(); 1363 if (!shape) 1364 return {}; 1365 int64_t rank = shape.getNumElements(); 1366 Builder builder(getContext()); 1367 return builder.getIndexAttr(rank); 1368 } 1369 1370 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. 1371 /// Constant folding fails in cases where only the rank is constant, not the 1372 /// shape itself. 1373 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. 1374 /// 1375 /// Example: 1376 /// 1377 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> 1378 /// %rank = shape.rank %shape 1379 /// 1380 /// becomes 1381 /// 1382 /// %rank = shape.const_size 3 1383 1384 namespace { 1385 struct RankShapeOfCanonicalizationPattern 1386 : public OpRewritePattern<shape::RankOp> { 1387 using OpRewritePattern<shape::RankOp>::OpRewritePattern; 1388 1389 LogicalResult matchAndRewrite(shape::RankOp op, 1390 PatternRewriter &rewriter) const override { 1391 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); 1392 if (!shapeOfOp) 1393 return failure(); 1394 auto rankedTensorType = 1395 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1396 if (!rankedTensorType) 1397 return failure(); 1398 int64_t rank = rankedTensorType.getRank(); 1399 if (op.getType().isa<IndexType>()) { 1400 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), 1401 rank); 1402 } else if (op.getType().isa<shape::SizeType>()) { 1403 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); 1404 } else { 1405 return failure(); 1406 } 1407 return success(); 1408 } 1409 }; 1410 } // namespace 1411 1412 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1413 MLIRContext *context) { 1414 patterns.add<RankShapeOfCanonicalizationPattern>(context); 1415 } 1416 1417 LogicalResult mlir::shape::RankOp::inferReturnTypes( 1418 MLIRContext *context, Optional<Location> location, ValueRange operands, 1419 DictionaryAttr attributes, RegionRange regions, 1420 SmallVectorImpl<Type> &inferredReturnTypes) { 1421 if (operands[0].getType().isa<ShapeType>()) 1422 inferredReturnTypes.assign({SizeType::get(context)}); 1423 else 1424 inferredReturnTypes.assign({IndexType::get(context)}); 1425 return success(); 1426 } 1427 1428 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1429 // SizeType is compatible with IndexType. 1430 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1431 } 1432 1433 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } 1434 1435 //===----------------------------------------------------------------------===// 1436 // NumElementsOp 1437 //===----------------------------------------------------------------------===// 1438 1439 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) { 1440 1441 // Fold only when argument constant. 1442 Attribute shape = operands[0]; 1443 if (!shape) 1444 return {}; 1445 1446 APInt product(64, 1); 1447 for (auto value : shape.cast<DenseIntElementsAttr>()) 1448 product *= value; 1449 Builder builder(getContext()); 1450 return builder.getIndexAttr(product.getLimitedValue()); 1451 } 1452 1453 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( 1454 MLIRContext *context, Optional<Location> location, ValueRange operands, 1455 DictionaryAttr attributes, RegionRange regions, 1456 SmallVectorImpl<Type> &inferredReturnTypes) { 1457 if (operands[0].getType().isa<ShapeType>()) 1458 inferredReturnTypes.assign({SizeType::get(context)}); 1459 else 1460 inferredReturnTypes.assign({IndexType::get(context)}); 1461 return success(); 1462 } 1463 1464 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, 1465 TypeRange r) { 1466 // SizeType is compatible with IndexType. 1467 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1468 } 1469 1470 LogicalResult shape::NumElementsOp::verify() { 1471 return verifySizeOrIndexOp(*this); 1472 } 1473 1474 //===----------------------------------------------------------------------===// 1475 // MaxOp 1476 //===----------------------------------------------------------------------===// 1477 1478 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1479 // If operands are equal, just propagate one. 1480 if (getLhs() == getRhs()) 1481 return getLhs(); 1482 return nullptr; 1483 } 1484 1485 LogicalResult mlir::shape::MaxOp::inferReturnTypes( 1486 MLIRContext *context, Optional<Location> location, ValueRange operands, 1487 DictionaryAttr attributes, RegionRange regions, 1488 SmallVectorImpl<Type> &inferredReturnTypes) { 1489 if (operands[0].getType() == operands[1].getType()) 1490 inferredReturnTypes.assign({operands[0].getType()}); 1491 else 1492 inferredReturnTypes.assign({SizeType::get(context)}); 1493 return success(); 1494 } 1495 1496 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1497 if (l.size() != 1 || r.size() != 1) 1498 return false; 1499 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1500 return true; 1501 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1502 return true; 1503 return false; 1504 } 1505 1506 //===----------------------------------------------------------------------===// 1507 // MinOp 1508 //===----------------------------------------------------------------------===// 1509 1510 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { 1511 // If operands are equal, just propagate one. 1512 if (getLhs() == getRhs()) 1513 return getLhs(); 1514 return nullptr; 1515 } 1516 1517 LogicalResult mlir::shape::MinOp::inferReturnTypes( 1518 MLIRContext *context, Optional<Location> location, ValueRange operands, 1519 DictionaryAttr attributes, RegionRange regions, 1520 SmallVectorImpl<Type> &inferredReturnTypes) { 1521 if (operands[0].getType() == operands[1].getType()) 1522 inferredReturnTypes.assign({operands[0].getType()}); 1523 else 1524 inferredReturnTypes.assign({SizeType::get(context)}); 1525 return success(); 1526 } 1527 1528 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1529 if (l.size() != 1 || r.size() != 1) 1530 return false; 1531 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>()) 1532 return true; 1533 if (l.front().isa<SizeType>() && r.front().isa<SizeType>()) 1534 return true; 1535 return false; 1536 } 1537 1538 //===----------------------------------------------------------------------===// 1539 // MulOp 1540 //===----------------------------------------------------------------------===// 1541 1542 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) { 1543 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); 1544 if (!lhs) 1545 return nullptr; 1546 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); 1547 if (!rhs) 1548 return nullptr; 1549 APInt folded = lhs.getValue() * rhs.getValue(); 1550 Type indexTy = IndexType::get(getContext()); 1551 return IntegerAttr::get(indexTy, folded); 1552 } 1553 1554 LogicalResult mlir::shape::MulOp::inferReturnTypes( 1555 MLIRContext *context, Optional<Location> location, ValueRange operands, 1556 DictionaryAttr attributes, RegionRange regions, 1557 SmallVectorImpl<Type> &inferredReturnTypes) { 1558 if (operands[0].getType().isa<SizeType>() || 1559 operands[1].getType().isa<SizeType>()) 1560 inferredReturnTypes.assign({SizeType::get(context)}); 1561 else 1562 inferredReturnTypes.assign({IndexType::get(context)}); 1563 return success(); 1564 } 1565 1566 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1567 // SizeType is compatible with IndexType. 1568 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); 1569 } 1570 1571 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } 1572 1573 //===----------------------------------------------------------------------===// 1574 // ShapeOfOp 1575 //===----------------------------------------------------------------------===// 1576 1577 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) { 1578 auto type = getOperand().getType().dyn_cast<ShapedType>(); 1579 if (!type || !type.hasStaticShape()) 1580 return nullptr; 1581 Builder builder(getContext()); 1582 return builder.getIndexTensorAttr(type.getShape()); 1583 } 1584 1585 namespace { 1586 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { 1587 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; 1588 1589 LogicalResult matchAndRewrite(shape::ShapeOfOp op, 1590 PatternRewriter &rewriter) const override { 1591 if (!op.getArg().getType().isa<ShapedType>()) 1592 return failure(); 1593 if (op.getType().isa<ShapedType>()) 1594 return failure(); 1595 1596 rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), 1597 op.getArg()); 1598 return success(); 1599 } 1600 }; 1601 1602 // Canonicalize 1603 // ``` 1604 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> 1605 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> 1606 // ``` 1607 // to 1608 // ``` 1609 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> 1610 // ``` 1611 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { 1612 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1613 1614 LogicalResult matchAndRewrite(tensor::CastOp op, 1615 PatternRewriter &rewriter) const override { 1616 auto ty = op.getType().dyn_cast<RankedTensorType>(); 1617 if (!ty || ty.getRank() != 1) 1618 return failure(); 1619 1620 auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>(); 1621 if (!shapeOfOp) 1622 return failure(); 1623 1624 // Argument type must be ranked and must not conflict. 1625 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>(); 1626 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) 1627 return failure(); 1628 1629 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); 1630 return success(); 1631 } 1632 }; 1633 } // namespace 1634 1635 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1636 MLIRContext *context) { 1637 patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor, 1638 ExtractFromShapeOfExtentTensor>(context); 1639 } 1640 1641 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( 1642 MLIRContext *context, Optional<Location> location, ValueRange operands, 1643 DictionaryAttr attributes, RegionRange regions, 1644 SmallVectorImpl<Type> &inferredReturnTypes) { 1645 if (operands[0].getType().isa<ValueShapeType>()) 1646 inferredReturnTypes.assign({ShapeType::get(context)}); 1647 else { 1648 auto shapedTy = operands[0].getType().cast<ShapedType>(); 1649 int64_t rank = 1650 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; 1651 Type indexTy = IndexType::get(context); 1652 Type extentTensorTy = RankedTensorType::get({rank}, indexTy); 1653 inferredReturnTypes.assign({extentTensorTy}); 1654 } 1655 return success(); 1656 } 1657 1658 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1659 if (l.size() != 1 || r.size() != 1) 1660 return false; 1661 if (l == r) 1662 return true; 1663 1664 Type lhs = l.front(); 1665 Type rhs = r.front(); 1666 1667 if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>()) 1668 return false; 1669 1670 if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>()) 1671 // Shape type is compatible with all other valid return types. 1672 return true; 1673 1674 if (succeeded(verifyCompatibleShapes({lhs, rhs}))) 1675 return true; 1676 return false; 1677 } 1678 1679 LogicalResult shape::ShapeOfOp::verify() { 1680 return verifyShapeOrExtentTensorOp(*this); 1681 } 1682 1683 //===----------------------------------------------------------------------===// 1684 // SizeToIndexOp 1685 //===----------------------------------------------------------------------===// 1686 1687 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) { 1688 // Constant values of both types, `shape.size` and `index`, are represented as 1689 // `IntegerAttr`s which makes constant folding simple. 1690 if (Attribute arg = operands[0]) 1691 return arg; 1692 return OpFoldResult(); 1693 } 1694 1695 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, 1696 MLIRContext *context) { 1697 patterns.add<IndexToSizeToIndexCanonicalization>(context); 1698 } 1699 1700 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1701 if (inputs.size() != 1 || outputs.size() != 1) 1702 return false; 1703 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>(); 1704 } 1705 1706 //===----------------------------------------------------------------------===// 1707 // YieldOp 1708 //===----------------------------------------------------------------------===// 1709 1710 LogicalResult shape::YieldOp::verify() { 1711 auto *parentOp = (*this)->getParentOp(); 1712 auto results = parentOp->getResults(); 1713 auto operands = getOperands(); 1714 1715 if (parentOp->getNumResults() != getNumOperands()) 1716 return emitOpError() << "number of operands does not match number of " 1717 "results of its parent"; 1718 for (auto e : llvm::zip(results, operands)) 1719 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 1720 return emitOpError() << "types mismatch between yield op and its parent"; 1721 1722 return success(); 1723 } 1724 1725 //===----------------------------------------------------------------------===// 1726 // SplitAtOp 1727 //===----------------------------------------------------------------------===// 1728 1729 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands, 1730 SmallVectorImpl<OpFoldResult> &results) { 1731 if (!operands[0] || !operands[1]) 1732 return failure(); 1733 auto shapeVec = llvm::to_vector<6>( 1734 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1735 auto shape = llvm::makeArrayRef(shapeVec); 1736 auto splitPoint = operands[1].cast<IntegerAttr>().getInt(); 1737 // Verify that the split point is in the correct range. 1738 // TODO: Constant fold to an "error". 1739 int64_t rank = shape.size(); 1740 if (!(-rank <= splitPoint && splitPoint <= rank)) 1741 return failure(); 1742 if (splitPoint < 0) 1743 splitPoint += shape.size(); 1744 Builder builder(operands[0].getContext()); 1745 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); 1746 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); 1747 return success(); 1748 } 1749 1750 //===----------------------------------------------------------------------===// 1751 // ToExtentTensorOp 1752 //===----------------------------------------------------------------------===// 1753 1754 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) { 1755 if (!operands[0]) 1756 return OpFoldResult(); 1757 Builder builder(getContext()); 1758 auto shape = llvm::to_vector<6>( 1759 operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>()); 1760 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, 1761 builder.getIndexType()); 1762 return DenseIntElementsAttr::get(type, shape); 1763 } 1764 1765 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 1766 if (inputs.size() != 1 || outputs.size() != 1) 1767 return false; 1768 if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) { 1769 if (!inputTensor.getElementType().isa<IndexType>() || 1770 inputTensor.getRank() != 1) 1771 return false; 1772 } else if (!inputs[0].isa<ShapeType>()) { 1773 return false; 1774 } 1775 1776 TensorType outputTensor = outputs[0].dyn_cast<TensorType>(); 1777 return outputTensor && outputTensor.getElementType().isa<IndexType>(); 1778 } 1779 1780 //===----------------------------------------------------------------------===// 1781 // ReduceOp 1782 //===----------------------------------------------------------------------===// 1783 1784 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, 1785 ValueRange initVals) { 1786 result.addOperands(shape); 1787 result.addOperands(initVals); 1788 1789 Region *bodyRegion = result.addRegion(); 1790 bodyRegion->push_back(new Block); 1791 Block &bodyBlock = bodyRegion->front(); 1792 bodyBlock.addArgument(builder.getIndexType(), result.location); 1793 1794 Type elementType; 1795 if (auto tensorType = shape.getType().dyn_cast<TensorType>()) 1796 elementType = tensorType.getElementType(); 1797 else 1798 elementType = SizeType::get(builder.getContext()); 1799 bodyBlock.addArgument(elementType, shape.getLoc()); 1800 1801 for (Value initVal : initVals) { 1802 bodyBlock.addArgument(initVal.getType(), initVal.getLoc()); 1803 result.addTypes(initVal.getType()); 1804 } 1805 } 1806 1807 LogicalResult ReduceOp::verify() { 1808 // Verify block arg types. 1809 Block &block = getRegion().front(); 1810 1811 // The block takes index, extent, and aggregated values as arguments. 1812 auto blockArgsCount = getInitVals().size() + 2; 1813 if (block.getNumArguments() != blockArgsCount) 1814 return emitOpError() << "ReduceOp body is expected to have " 1815 << blockArgsCount << " arguments"; 1816 1817 // The first block argument is the index and must always be of type `index`. 1818 if (!block.getArgument(0).getType().isa<IndexType>()) 1819 return emitOpError( 1820 "argument 0 of ReduceOp body is expected to be of IndexType"); 1821 1822 // The second block argument is the extent and must be of type `size` or 1823 // `index`, depending on whether the reduce operation is applied to a shape or 1824 // to an extent tensor. 1825 Type extentTy = block.getArgument(1).getType(); 1826 if (getShape().getType().isa<ShapeType>()) { 1827 if (!extentTy.isa<SizeType>()) 1828 return emitOpError("argument 1 of ReduceOp body is expected to be of " 1829 "SizeType if the ReduceOp operates on a ShapeType"); 1830 } else { 1831 if (!extentTy.isa<IndexType>()) 1832 return emitOpError( 1833 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1834 "ReduceOp operates on an extent tensor"); 1835 } 1836 1837 for (const auto &type : llvm::enumerate(getInitVals())) 1838 if (block.getArgument(type.index() + 2).getType() != type.value().getType()) 1839 return emitOpError() << "type mismatch between argument " 1840 << type.index() + 2 1841 << " of ReduceOp body and initial value " 1842 << type.index(); 1843 return success(); 1844 } 1845 1846 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { 1847 // Parse operands. 1848 SmallVector<OpAsmParser::OperandType, 3> operands; 1849 Type shapeOrExtentTensorType; 1850 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, 1851 OpAsmParser::Delimiter::Paren) || 1852 parser.parseColonType(shapeOrExtentTensorType) || 1853 parser.parseOptionalArrowTypeList(result.types)) 1854 return failure(); 1855 1856 // Resolve operands. 1857 auto initVals = llvm::makeArrayRef(operands).drop_front(); 1858 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, 1859 result.operands) || 1860 parser.resolveOperands(initVals, result.types, parser.getNameLoc(), 1861 result.operands)) 1862 return failure(); 1863 1864 // Parse the body. 1865 Region *body = result.addRegion(); 1866 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) 1867 return failure(); 1868 1869 // Parse attributes. 1870 if (parser.parseOptionalAttrDict(result.attributes)) 1871 return failure(); 1872 1873 return success(); 1874 } 1875 1876 void ReduceOp::print(OpAsmPrinter &p) { 1877 p << '(' << getShape() << ", " << getInitVals() 1878 << ") : " << getShape().getType(); 1879 p.printOptionalArrowTypeList(getResultTypes()); 1880 p << ' '; 1881 p.printRegion(getRegion()); 1882 p.printOptionalAttrDict((*this)->getAttrs()); 1883 } 1884 1885 #define GET_OP_CLASSES 1886 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1887