1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/SetOperations.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::shape;
31 
32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
33 
34 namespace {
35 #include "ShapeCanonicalization.inc"
36 } // namespace
37 
38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
39   return RankedTensorType::get({rank}, IndexType::get(ctx));
40 }
41 
42 bool shape::isExtentTensorType(Type type) {
43   auto ranked = type.dyn_cast<RankedTensorType>();
44   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
45 }
46 
47 LogicalResult shape::getShapeVec(Value input,
48                                  SmallVectorImpl<int64_t> &shapeValues) {
49   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
50     auto type = inputOp.getArg().getType().cast<ShapedType>();
51     if (!type.hasRank())
52       return failure();
53     llvm::append_range(shapeValues, type.getShape());
54     return success();
55   }
56   DenseIntElementsAttr attr;
57   if (matchPattern(input, m_Constant(&attr))) {
58     llvm::append_range(shapeValues, attr.getValues<int64_t>());
59     return success();
60   }
61   return failure();
62 }
63 
64 static bool isErrorPropagationPossible(TypeRange operandTypes) {
65   return llvm::any_of(operandTypes, [](Type ty) {
66     return ty.isa<SizeType, ShapeType, ValueShapeType>();
67   });
68 }
69 
70 static LogicalResult verifySizeOrIndexOp(Operation *op) {
71   assert(op != nullptr && op->getNumResults() == 1);
72   Type resultTy = op->getResultTypes().front();
73   if (isErrorPropagationPossible(op->getOperandTypes())) {
74     if (!resultTy.isa<SizeType>())
75       return op->emitOpError()
76              << "if at least one of the operands can hold error values then "
77                 "the result must be of type `size` to propagate them";
78   }
79   return success();
80 }
81 
82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
83   assert(op != nullptr && op->getNumResults() == 1);
84   Type resultTy = op->getResultTypes().front();
85   if (isErrorPropagationPossible(op->getOperandTypes())) {
86     if (!resultTy.isa<ShapeType>())
87       return op->emitOpError()
88              << "if at least one of the operands can hold error values then "
89                 "the result must be of type `shape` to propagate them";
90   }
91   return success();
92 }
93 
94 template <typename... Ty>
95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
96   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
97 }
98 
99 template <typename... Ty, typename... ranges>
100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
101   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // InlinerInterface
106 //===----------------------------------------------------------------------===//
107 
108 namespace {
109 /// This class defines the interface for inlining shape dialect ops.
110 struct ShapeInlinerInterface : public DialectInlinerInterface {
111   using DialectInlinerInterface::DialectInlinerInterface;
112 
113   // Returns true if the given region 'src' can be inlined into the region
114   // 'dest' that is attached to an operation registered to the current dialect.
115   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
116                        BlockAndValueMapping &) const final {
117     return true;
118   }
119 
120   // Returns true if the given operation 'op', that is registered to this
121   // dialect, can be inlined into the region 'dest' that is attached to an
122   // operation registered to the current dialect.
123   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
124                        BlockAndValueMapping &) const final {
125     return true;
126   }
127 };
128 } // namespace
129 
130 void ShapeDialect::initialize() {
131   addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
134       >();
135   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
136   addInterfaces<ShapeInlinerInterface>();
137   // Allow unknown operations during prototyping and testing. As the dialect is
138   // still evolving it makes it simple to start with an unregistered ops and
139   // try different variants before actually defining the op.
140   allowUnknownOperations();
141 }
142 
143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
144                                              Attribute value, Type type,
145                                              Location loc) {
146   if (type.isa<ShapeType>() || isExtentTensorType(type))
147     return builder.create<ConstShapeOp>(loc, type,
148                                         value.cast<DenseIntElementsAttr>());
149   if (type.isa<SizeType>())
150     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
151   if (type.isa<WitnessType>())
152     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
153   if (arith::ConstantOp::isBuildableWith(value, type))
154     return builder.create<arith::ConstantOp>(loc, type, value);
155   return nullptr;
156 }
157 
158 /// Parse a type registered to this dialect.
159 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
160   StringRef keyword;
161   if (parser.parseKeyword(&keyword))
162     return Type();
163 
164   if (keyword == "shape")
165     return ShapeType::get(getContext());
166   if (keyword == "size")
167     return SizeType::get(getContext());
168   if (keyword == "value_shape")
169     return ValueShapeType::get(getContext());
170   if (keyword == "witness")
171     return WitnessType::get(getContext());
172 
173   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
174   return Type();
175 }
176 
177 /// Print a type registered to this dialect.
178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
179   TypeSwitch<Type>(type)
180       .Case<ShapeType>([&](Type) { os << "shape"; })
181       .Case<SizeType>([&](Type) { os << "size"; })
182       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
183       .Case<WitnessType>([&](Type) { os << "witness"; })
184       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
185 }
186 
187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
188                                                      NamedAttribute attribute) {
189   // Verify shape.lib attribute.
190   if (attribute.getName() == "shape.lib") {
191     if (!op->hasTrait<OpTrait::SymbolTable>())
192       return op->emitError(
193           "shape.lib attribute may only be on op implementing SymbolTable");
194 
195     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
196       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
197       if (!symbol)
198         return op->emitError("shape function library ")
199                << symbolRef << " not found";
200       return isa<shape::FunctionLibraryOp>(symbol)
201                  ? success()
202                  : op->emitError()
203                        << symbolRef << " required to be shape function library";
204     }
205 
206     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
207       // Verify all entries are function libraries and mappings in libraries
208       // refer to unique ops.
209       DenseSet<StringAttr> key;
210       for (auto it : arr) {
211         if (!it.isa<SymbolRefAttr>())
212           return op->emitError(
213               "only SymbolRefAttr allowed in shape.lib attribute array");
214 
215         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
216             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
217         if (!shapeFnLib)
218           return op->emitError()
219                  << it << " does not refer to FunctionLibraryOp";
220         for (auto mapping : shapeFnLib.getMapping()) {
221           if (!key.insert(mapping.getName()).second) {
222             return op->emitError("only one op to shape mapping allowed, found "
223                                  "multiple for `")
224                    << mapping.getName() << "`";
225           }
226         }
227       }
228       return success();
229     }
230 
231     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
232                          "allowed as shape.lib attribute");
233   }
234   return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AnyOp
239 //===----------------------------------------------------------------------===//
240 
241 // TODO: Canonicalization should be implemented for shapes that can be
242 // determined through mixtures of the known dimensions of the inputs.
243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
244   // Only the last operand is checked because AnyOp is commutative.
245   if (operands.back())
246     return operands.back();
247 
248   return nullptr;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // AssumingOp
253 //===----------------------------------------------------------------------===//
254 
255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
256   result.regions.reserve(1);
257   Region *doRegion = result.addRegion();
258 
259   auto &builder = parser.getBuilder();
260   OpAsmParser::OperandType cond;
261   if (parser.parseOperand(cond) ||
262       parser.resolveOperand(cond, builder.getType<WitnessType>(),
263                             result.operands))
264     return failure();
265 
266   // Parse optional results type list.
267   if (parser.parseOptionalArrowTypeList(result.types))
268     return failure();
269 
270   // Parse the region and add a terminator if elided.
271   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
272     return failure();
273   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
274 
275   // Parse the optional attribute list.
276   if (parser.parseOptionalAttrDict(result.attributes))
277     return failure();
278   return success();
279 }
280 
281 void AssumingOp::print(OpAsmPrinter &p) {
282   bool yieldsResults = !getResults().empty();
283 
284   p << " " << getWitness();
285   if (yieldsResults)
286     p << " -> (" << getResultTypes() << ")";
287   p << ' ';
288   p.printRegion(getDoRegion(),
289                 /*printEntryBlockArgs=*/false,
290                 /*printBlockTerminators=*/yieldsResults);
291   p.printOptionalAttrDict((*this)->getAttrs());
292 }
293 
294 namespace {
295 // Removes AssumingOp with a passing witness and inlines the region.
296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
297   using OpRewritePattern<AssumingOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(AssumingOp op,
300                                 PatternRewriter &rewriter) const override {
301     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
302     if (!witness || !witness.getPassingAttr())
303       return failure();
304 
305     AssumingOp::inlineRegionIntoParent(op, rewriter);
306     return success();
307   }
308 };
309 
310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
311   using OpRewritePattern<AssumingOp>::OpRewritePattern;
312 
313   LogicalResult matchAndRewrite(AssumingOp op,
314                                 PatternRewriter &rewriter) const override {
315     Block *body = op.getBody();
316     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
317 
318     // Find used values.
319     SmallVector<Value, 4> newYieldOperands;
320     Value opResult, yieldOperand;
321     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
322       std::tie(opResult, yieldOperand) = it;
323       if (!opResult.getUses().empty()) {
324         newYieldOperands.push_back(yieldOperand);
325       }
326     }
327 
328     // Rewrite only if redundant results exist.
329     if (newYieldOperands.size() == yieldOp->getNumOperands())
330       return failure();
331 
332     // Replace yield op in the old assuming op's body and move the entire region
333     // to the new assuming op.
334     rewriter.setInsertionPointToEnd(body);
335     auto newYieldOp =
336         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
337     rewriter.setInsertionPoint(op);
338     auto newOp = rewriter.create<AssumingOp>(
339         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
340     newOp.getDoRegion().takeBody(op.getDoRegion());
341 
342     // Use the new results to replace the previously used ones.
343     SmallVector<Value, 4> replacementValues;
344     auto src = newOp.getResults().begin();
345     for (auto it : op.getResults()) {
346       if (it.getUses().empty())
347         replacementValues.push_back(nullptr);
348       else
349         replacementValues.push_back(*src++);
350     }
351     rewriter.replaceOp(op, replacementValues);
352     return success();
353   }
354 };
355 } // namespace
356 
357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
358                                              MLIRContext *context) {
359   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
360 }
361 
362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
363 void AssumingOp::getSuccessorRegions(
364     Optional<unsigned> index, ArrayRef<Attribute> operands,
365     SmallVectorImpl<RegionSuccessor> &regions) {
366   // AssumingOp has unconditional control flow into the region and back to the
367   // parent, so return the correct RegionSuccessor purely based on the index
368   // being None or 0.
369   if (index.hasValue()) {
370     regions.push_back(RegionSuccessor(getResults()));
371     return;
372   }
373 
374   regions.push_back(RegionSuccessor(&getDoRegion()));
375 }
376 
377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
378                                         PatternRewriter &rewriter) {
379   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
380   auto *assumingBlock = op.getBody();
381   auto initPosition = rewriter.getInsertionPoint();
382   auto *blockAfterAssuming =
383       rewriter.splitBlock(blockBeforeAssuming, initPosition);
384 
385   // Remove the AssumingOp and AssumingYieldOp.
386   auto &yieldOp = assumingBlock->back();
387   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
388   rewriter.replaceOp(op, yieldOp.getOperands());
389   rewriter.eraseOp(&yieldOp);
390 
391   // Merge blocks together as there was no branching behavior from the
392   // AssumingOp.
393   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
394   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
395 }
396 
397 void AssumingOp::build(
398     OpBuilder &builder, OperationState &result, Value witness,
399     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
400 
401   result.addOperands(witness);
402   Region *bodyRegion = result.addRegion();
403   bodyRegion->push_back(new Block);
404   Block &bodyBlock = bodyRegion->front();
405 
406   // Build body.
407   OpBuilder::InsertionGuard guard(builder);
408   builder.setInsertionPointToStart(&bodyBlock);
409   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
410   builder.create<AssumingYieldOp>(result.location, yieldValues);
411 
412   SmallVector<Type, 2> assumingTypes;
413   for (Value v : yieldValues)
414     assumingTypes.push_back(v.getType());
415   result.addTypes(assumingTypes);
416 }
417 
418 LogicalResult AssumingOp::verify() {
419   return RegionBranchOpInterface::verifyTypes(*this);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // AddOp
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult mlir::shape::AddOp::inferReturnTypes(
427     MLIRContext *context, Optional<Location> location, ValueRange operands,
428     DictionaryAttr attributes, RegionRange regions,
429     SmallVectorImpl<Type> &inferredReturnTypes) {
430   if (operands[0].getType().isa<SizeType>() ||
431       operands[1].getType().isa<SizeType>())
432     inferredReturnTypes.assign({SizeType::get(context)});
433   else
434     inferredReturnTypes.assign({IndexType::get(context)});
435   return success();
436 }
437 
438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
439   // SizeType is compatible with IndexType.
440   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
441 }
442 
443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
444   // add(x, 0) -> x
445   if (matchPattern(getRhs(), m_Zero()))
446     return getLhs();
447 
448   return constFoldBinaryOp<IntegerAttr>(
449       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
450 }
451 
452 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
453 
454 //===----------------------------------------------------------------------===//
455 // AssumingAllOp
456 //===----------------------------------------------------------------------===//
457 
458 namespace {
459 
460 // Merge multiple `shape.assuming_all` operations together.
461 //
462 //   %0 = shape.assuming_all %w0, %w1
463 //   %1 = shape.assuming_all %w2, %0
464 //
465 // to:
466 //
467 //   %0 = shape.assuming_all %w0, %w2, %w2
468 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
469   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
470 
471   LogicalResult matchAndRewrite(AssumingAllOp op,
472                                 PatternRewriter &rewriter) const override {
473     SmallVector<Value> operands;
474 
475     for (Value operand : op.getInputs()) {
476       if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
477         operands.append(assume_all.operand_begin(), assume_all->operand_end());
478       else
479         operands.push_back(operand);
480     }
481 
482     // We didn't find any other `assuming_all` ops to merge with.
483     if (operands.size() == op.getNumOperands())
484       return failure();
485 
486     // Replace with a new `assuming_all` operation with merged constraints.
487     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
488     return success();
489   }
490 };
491 
492 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
493 // are subsumed by others.
494 //
495 //   %0 = shape.cstr_broadcastable %shape0, %shape1
496 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
497 //
498 //   %2 = shape.cstr_broadcastable %shape3, %shape4
499 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
500 //
501 //   %4 = shape.assuming_all %0, %1, %2, %3
502 //
503 // to:
504 //
505 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
506 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
507 //   %2 = shape.assuming_all %0, %1
508 //
509 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
510 // shapes [0, 1] are broadcastable too, and can be removed from the list of
511 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
512 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
513 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
514   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
515 
516   LogicalResult matchAndRewrite(AssumingAllOp op,
517                                 PatternRewriter &rewriter) const override {
518     // Collect all `CstrBroadcastableOp` operands first.
519     SetVector<CstrBroadcastableOp> operands;
520     for (Value operand : op.getInputs()) {
521       // TODO: Apply this optimization if some of the witnesses are not
522       // produced by the `cstr_broadcastable`.
523       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
524       if (!broadcastable)
525         return failure();
526 
527       operands.insert(broadcastable);
528     }
529 
530     // Skip trivial `assuming_all` operations.
531     if (operands.size() <= 1)
532       return failure();
533 
534     // Collect shapes checked by `cstr_broadcastable` operands.
535     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
536     for (auto cstr : operands) {
537       DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
538       shapes.emplace_back(cstr, std::move(shapes_set));
539     }
540 
541     // Sort by the number of shape operands (larger to smaller).
542     llvm::sort(shapes, [](auto a, auto b) {
543       return a.first.getNumOperands() > b.first.getNumOperands();
544     });
545 
546     // We start from the `cst_broadcastable` operations with largest number of
547     // shape operands, and remove redundant `cst_broadcastable` operations. We
548     // do this until we find a set of `cst_broadcastable` operations with
549     // non-overlapping constraints.
550     SmallVector<CstrBroadcastableOp> marked_for_erase;
551 
552     for (unsigned i = 0; i < shapes.size(); ++i) {
553       auto isSubset = [&](auto pair) {
554         return llvm::set_is_subset(pair.second, shapes[i].second);
555       };
556 
557       // Keep redundant `cstr_broadcastable` operations to be erased.
558       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
559       for (auto *it0 = it; it0 < shapes.end(); ++it0)
560         marked_for_erase.push_back(it0->first);
561       shapes.erase(it, shapes.end());
562     }
563 
564     // We didn't find any operands that could be removed.
565     if (marked_for_erase.empty())
566       return failure();
567 
568     // Collect non-overlapping `cst_broadcastable` constraints.
569     SmallVector<Value> unique_constraints;
570     for (auto &shape : shapes)
571       unique_constraints.push_back(shape.first.getResult());
572 
573     // Replace with a new `assuming_all` operation ...
574     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
575 
576     // ... and maybe erase `cstr_broadcastable` ops without uses.
577     for (auto &op : marked_for_erase)
578       if (op->use_empty())
579         rewriter.eraseOp(op);
580 
581     return success();
582   }
583 };
584 
585 struct AssumingAllToCstrEqCanonicalization
586     : public OpRewritePattern<AssumingAllOp> {
587   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
588 
589   LogicalResult matchAndRewrite(AssumingAllOp op,
590                                 PatternRewriter &rewriter) const override {
591     SmallVector<Value, 8> shapes;
592     for (Value w : op.getInputs()) {
593       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
594       if (!cstrEqOp)
595         return failure();
596       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
597         return llvm::is_contained(shapes, s);
598       });
599       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
600         return failure();
601       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
602     }
603     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
604     return success();
605   }
606 };
607 
608 template <typename OpTy>
609 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
610   using OpRewritePattern<OpTy>::OpRewritePattern;
611 
612   LogicalResult matchAndRewrite(OpTy op,
613                                 PatternRewriter &rewriter) const override {
614     // Find unique operands.
615     SetVector<Value> unique(op.operand_begin(), op.operand_end());
616 
617     // Reduce op to equivalent with unique operands.
618     if (unique.size() < op.getNumOperands()) {
619       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
620                                         unique.takeVector(), op->getAttrs());
621       return success();
622     }
623 
624     return failure();
625   }
626 };
627 } // namespace
628 
629 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
630                                                 MLIRContext *context) {
631   patterns
632       .add<MergeAssumingAllOps, AssumingAllOneOp,
633            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
634            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
635 }
636 
637 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
638   // Iterate in reverse to first handle all constant operands. They are
639   // guaranteed to be the tail of the inputs because this is commutative.
640   for (int idx = operands.size() - 1; idx >= 0; idx--) {
641     Attribute a = operands[idx];
642     // Cannot fold if any inputs are not constant;
643     if (!a)
644       return nullptr;
645 
646     // We do not need to keep statically known values after handling them in
647     // this method.
648     getOperation()->eraseOperand(idx);
649 
650     // Always false if any input is statically known false
651     if (!a.cast<BoolAttr>().getValue())
652       return a;
653   }
654   // If this is reached, all inputs were statically known passing.
655   return BoolAttr::get(getContext(), true);
656 }
657 
658 LogicalResult AssumingAllOp::verify() {
659   // Ensure that AssumingAllOp contains at least one operand
660   if (getNumOperands() == 0)
661     return emitOpError("no operands specified");
662 
663   return success();
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // BroadcastOp
668 //===----------------------------------------------------------------------===//
669 
670 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
671   if (getShapes().size() == 1) {
672     // Otherwise, we need a cast which would be a canonicalization, not folding.
673     if (getShapes().front().getType() != getType())
674       return nullptr;
675     return getShapes().front();
676   }
677 
678   // TODO: Support folding with more than 2 input shapes
679   if (getShapes().size() > 2)
680     return nullptr;
681 
682   if (!operands[0] || !operands[1])
683     return nullptr;
684   auto lhsShape = llvm::to_vector<6>(
685       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
686   auto rhsShape = llvm::to_vector<6>(
687       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
688   SmallVector<int64_t, 6> resultShape;
689 
690   // If the shapes are not compatible, we can't fold it.
691   // TODO: Fold to an "error".
692   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
693     return nullptr;
694 
695   Builder builder(getContext());
696   return builder.getIndexTensorAttr(resultShape);
697 }
698 
699 LogicalResult BroadcastOp::verify() {
700   return verifyShapeOrExtentTensorOp(*this);
701 }
702 
703 namespace {
704 template <typename OpTy>
705 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
706   using OpRewritePattern<OpTy>::OpRewritePattern;
707 
708   LogicalResult matchAndRewrite(OpTy op,
709                                 PatternRewriter &rewriter) const override {
710     auto isPotentiallyNonEmptyShape = [](Value shape) {
711       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
712         if (extentTensorTy.getDimSize(0) == 0)
713           return false;
714       }
715       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
716         if (constShape.getShape().empty())
717           return false;
718       }
719       return true;
720     };
721     auto newOperands = llvm::to_vector<8>(
722         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
723 
724     // Reduce op to equivalent without empty shape operands.
725     if (newOperands.size() < op.getNumOperands()) {
726       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
727                                         op->getAttrs());
728       return success();
729     }
730 
731     return failure();
732   }
733 };
734 
735 struct BroadcastForwardSingleOperandPattern
736     : public OpRewritePattern<BroadcastOp> {
737   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
738 
739   LogicalResult matchAndRewrite(BroadcastOp op,
740                                 PatternRewriter &rewriter) const override {
741     if (op.getNumOperands() != 1)
742       return failure();
743     Value replacement = op.getShapes().front();
744 
745     // Insert cast if needed.
746     if (replacement.getType() != op.getType()) {
747       auto loc = op.getLoc();
748       if (op.getType().isa<ShapeType>()) {
749         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
750       } else {
751         assert(!op.getType().isa<ShapeType>() &&
752                !replacement.getType().isa<ShapeType>() &&
753                "expect extent tensor cast");
754         replacement =
755             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
756       }
757     }
758 
759     rewriter.replaceOp(op, replacement);
760     return success();
761   }
762 };
763 
764 struct BroadcastFoldConstantOperandsPattern
765     : public OpRewritePattern<BroadcastOp> {
766   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
767 
768   LogicalResult matchAndRewrite(BroadcastOp op,
769                                 PatternRewriter &rewriter) const override {
770     SmallVector<int64_t, 8> foldedConstantShape;
771     SmallVector<Value, 8> newShapeOperands;
772     for (Value shape : op.getShapes()) {
773       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
774         SmallVector<int64_t, 8> newFoldedConstantShape;
775         if (OpTrait::util::getBroadcastedShape(
776                 foldedConstantShape,
777                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
778                 newFoldedConstantShape)) {
779           foldedConstantShape = newFoldedConstantShape;
780           continue;
781         }
782       }
783       newShapeOperands.push_back(shape);
784     }
785 
786     // Need at least two constant operands to fold anything.
787     if (op.getNumOperands() - newShapeOperands.size() < 2)
788       return failure();
789 
790     auto foldedConstantOperandsTy = RankedTensorType::get(
791         {static_cast<int64_t>(foldedConstantShape.size())},
792         rewriter.getIndexType());
793     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
794         op.getLoc(), foldedConstantOperandsTy,
795         rewriter.getIndexTensorAttr(foldedConstantShape)));
796     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
797                                              newShapeOperands);
798     return success();
799   }
800 };
801 
802 template <typename OpTy>
803 struct CanonicalizeCastExtentTensorOperandsPattern
804     : public OpRewritePattern<OpTy> {
805   using OpRewritePattern<OpTy>::OpRewritePattern;
806 
807   LogicalResult matchAndRewrite(OpTy op,
808                                 PatternRewriter &rewriter) const override {
809     // Canonicalize operands.
810     bool anyChange = false;
811     auto canonicalizeOperand = [&](Value operand) {
812       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
813         // Only eliminate the cast if it holds no shape information.
814         bool isInformationLoosingCast =
815             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
816         if (isInformationLoosingCast) {
817           anyChange = true;
818           return castOp.source();
819         }
820       }
821       return operand;
822     };
823     auto newOperands = llvm::to_vector<8>(
824         llvm::map_range(op.getOperands(), canonicalizeOperand));
825 
826     // Rewrite op if any change required.
827     if (!anyChange)
828       return failure();
829     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
830     return success();
831   }
832 };
833 
834 struct BroadcastConcretizeResultTypePattern
835     : public OpRewritePattern<BroadcastOp> {
836   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
837 
838   LogicalResult matchAndRewrite(BroadcastOp op,
839                                 PatternRewriter &rewriter) const override {
840     // Only concretize dynamic extent tensor result types.
841     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
842     if (!resultTy || !resultTy.isDynamicDim(0))
843       return failure();
844 
845     // Infer resulting shape rank if possible.
846     int64_t maxRank = 0;
847     for (Value shape : op.getShapes()) {
848       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
849         // Cannot infer resulting shape rank if any operand is dynamically
850         // ranked.
851         if (extentTensorTy.isDynamicDim(0))
852           return failure();
853         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
854       }
855     }
856 
857     auto newOp = rewriter.create<BroadcastOp>(
858         op.getLoc(), getExtentTensorType(getContext(), maxRank),
859         op.getShapes());
860     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
861     return success();
862   }
863 };
864 } // namespace
865 
866 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
867                                               MLIRContext *context) {
868   patterns.add<BroadcastConcretizeResultTypePattern,
869                BroadcastFoldConstantOperandsPattern,
870                BroadcastForwardSingleOperandPattern,
871                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
872                RemoveDuplicateOperandsPattern<BroadcastOp>,
873                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
874 }
875 
876 //===----------------------------------------------------------------------===//
877 // ConcatOp
878 //===----------------------------------------------------------------------===//
879 
880 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
881   if (!operands[0] || !operands[1])
882     return nullptr;
883   auto lhsShape = llvm::to_vector<6>(
884       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
885   auto rhsShape = llvm::to_vector<6>(
886       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
887   SmallVector<int64_t, 6> resultShape;
888   resultShape.append(lhsShape.begin(), lhsShape.end());
889   resultShape.append(rhsShape.begin(), rhsShape.end());
890   Builder builder(getContext());
891   return builder.getIndexTensorAttr(resultShape);
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // ConstShapeOp
896 //===----------------------------------------------------------------------===//
897 
898 void ConstShapeOp::print(OpAsmPrinter &p) {
899   p << " ";
900   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
901   p << "[";
902   interleaveComma(getShape().getValues<int64_t>(), p);
903   p << "] : ";
904   p.printType(getType());
905 }
906 
907 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
908   if (parser.parseOptionalAttrDict(result.attributes))
909     return failure();
910   // We piggy-back on ArrayAttr parsing, though we don't internally store the
911   // shape as an ArrayAttr.
912   // TODO: Implement custom parser and maybe make syntax a bit more concise.
913   Attribute extentsRaw;
914   NamedAttrList dummy;
915   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
916     return failure();
917   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
918   if (!extentsArray)
919     return failure();
920   SmallVector<int64_t, 6> ints;
921   for (Attribute extent : extentsArray) {
922     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
923     if (!attr)
924       return failure();
925     ints.push_back(attr.getInt());
926   }
927   Builder &builder = parser.getBuilder();
928   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
929   Type resultTy;
930   if (parser.parseColonType(resultTy))
931     return failure();
932   result.types.push_back(resultTy);
933   return success();
934 }
935 
936 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
937 
938 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
939                                                MLIRContext *context) {
940   patterns.add<TensorCastConstShape>(context);
941 }
942 
943 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
944     MLIRContext *context, Optional<Location> location, ValueRange operands,
945     DictionaryAttr attributes, RegionRange regions,
946     SmallVectorImpl<Type> &inferredReturnTypes) {
947   Builder b(context);
948   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
949   if (!shape)
950     return emitOptionalError(location, "missing shape attribute");
951   inferredReturnTypes.assign({RankedTensorType::get(
952       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
953   return success();
954 }
955 
956 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
957                                                         TypeRange r) {
958   if (l.size() != 1 || r.size() != 1)
959     return false;
960 
961   Type lhs = l.front();
962   Type rhs = r.front();
963 
964   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
965     // Shape type is compatible with all other valid return types.
966     return true;
967   return lhs == rhs;
968 }
969 
970 //===----------------------------------------------------------------------===//
971 // CstrBroadcastableOp
972 //===----------------------------------------------------------------------===//
973 
974 void CstrBroadcastableOp::getCanonicalizationPatterns(
975     RewritePatternSet &patterns, MLIRContext *context) {
976   // Canonicalization patterns have overlap with the considerations during
977   // folding in case additional shape information is inferred at some point that
978   // does not result in folding.
979   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
980                CstrBroadcastableEqOps,
981                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
982                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
983 }
984 
985 // Return true if there is exactly one attribute not representing a scalar
986 // broadcast.
987 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
988   bool nonScalarSeen = false;
989   for (Attribute a : attributes) {
990     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
991       if (nonScalarSeen)
992         return false;
993       nonScalarSeen = true;
994     }
995   }
996   return true;
997 }
998 
999 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1000   // No broadcasting is needed if all operands but one are scalar.
1001   if (hasAtMostSingleNonScalar(operands))
1002     return BoolAttr::get(getContext(), true);
1003 
1004   if ([&] {
1005         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1006         for (const auto &operand : operands) {
1007           if (!operand)
1008             return false;
1009           extents.push_back(llvm::to_vector<6>(
1010               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
1011         }
1012         return OpTrait::util::staticallyKnownBroadcastable(extents);
1013       }())
1014     return BoolAttr::get(getContext(), true);
1015 
1016   // Lastly, see if folding can be completed based on what constraints are known
1017   // on the input shapes.
1018   if ([&] {
1019         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1020         for (auto shapeValue : getShapes()) {
1021           extents.emplace_back();
1022           if (failed(getShapeVec(shapeValue, extents.back())))
1023             return false;
1024         }
1025         return OpTrait::util::staticallyKnownBroadcastable(extents);
1026       }())
1027     return BoolAttr::get(getContext(), true);
1028 
1029   // Because a failing witness result here represents an eventual assertion
1030   // failure, we do not replace it with a constant witness.
1031   return nullptr;
1032 }
1033 
1034 LogicalResult CstrBroadcastableOp::verify() {
1035   // Ensure that CstrBroadcastableOp contains at least two operands
1036   if (getNumOperands() < 2)
1037     return emitOpError("required at least 2 input shapes");
1038   return success();
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // CstrEqOp
1043 //===----------------------------------------------------------------------===//
1044 
1045 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1046                                            MLIRContext *context) {
1047   // If inputs are equal, return passing witness
1048   patterns.add<CstrEqEqOps>(context);
1049 }
1050 
1051 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1052   if (llvm::all_of(operands,
1053                    [&](Attribute a) { return a && a == operands[0]; }))
1054     return BoolAttr::get(getContext(), true);
1055 
1056   // Because a failing witness result here represents an eventual assertion
1057   // failure, we do not try to replace it with a constant witness. Similarly, we
1058   // cannot if there are any non-const inputs.
1059   return nullptr;
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // ConstSizeOp
1064 //===----------------------------------------------------------------------===//
1065 
1066 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1067                         int64_t value) {
1068   build(builder, result, builder.getIndexAttr(value));
1069 }
1070 
1071 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1072 
1073 void ConstSizeOp::getAsmResultNames(
1074     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1075   SmallString<4> buffer;
1076   llvm::raw_svector_ostream os(buffer);
1077   os << "c" << getValue();
1078   setNameFn(getResult(), os.str());
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // ConstWitnessOp
1083 //===----------------------------------------------------------------------===//
1084 
1085 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1086   return getPassingAttr();
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // CstrRequireOp
1091 //===----------------------------------------------------------------------===//
1092 
1093 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1094   return operands[0];
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // DivOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1102   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1103   if (!lhs)
1104     return nullptr;
1105   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1106   if (!rhs)
1107     return nullptr;
1108 
1109   // Division in APInt does not follow floor(lhs, rhs) when the result is
1110   // negative. Rather, APInt rounds toward zero.
1111   APInt quotient, remainder;
1112   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1113   if (quotient.isNegative() && !remainder.isNullValue()) {
1114     quotient -= 1;
1115   }
1116 
1117   Type indexTy = IndexType::get(getContext());
1118   return IntegerAttr::get(indexTy, quotient);
1119 }
1120 
1121 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1122     MLIRContext *context, Optional<Location> location, ValueRange operands,
1123     DictionaryAttr attributes, RegionRange regions,
1124     SmallVectorImpl<Type> &inferredReturnTypes) {
1125   if (operands[0].getType().isa<SizeType>() ||
1126       operands[1].getType().isa<SizeType>())
1127     inferredReturnTypes.assign({SizeType::get(context)});
1128   else
1129     inferredReturnTypes.assign({IndexType::get(context)});
1130   return success();
1131 }
1132 
1133 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1134   // SizeType is compatible with IndexType.
1135   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1136 }
1137 
1138 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // ShapeEqOp
1142 //===----------------------------------------------------------------------===//
1143 
1144 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1145   bool allSame = true;
1146   if (!operands.empty() && !operands[0])
1147     return {};
1148   for (Attribute operand : operands.drop_front(1)) {
1149     if (!operand)
1150       return {};
1151     allSame = allSame && operand == operands[0];
1152   }
1153   return BoolAttr::get(getContext(), allSame);
1154 }
1155 
1156 //===----------------------------------------------------------------------===//
1157 // IndexToSizeOp
1158 //===----------------------------------------------------------------------===//
1159 
1160 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1161   // Constant values of both types, `shape.size` and `index`, are represented as
1162   // `IntegerAttr`s which makes constant folding simple.
1163   if (Attribute arg = operands[0])
1164     return arg;
1165   return {};
1166 }
1167 
1168 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1169                                                 MLIRContext *context) {
1170   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // FromExtentsOp
1175 //===----------------------------------------------------------------------===//
1176 
1177 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1178   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1179     return nullptr;
1180   SmallVector<int64_t, 6> extents;
1181   for (auto attr : operands)
1182     extents.push_back(attr.cast<IntegerAttr>().getInt());
1183   Builder builder(getContext());
1184   return builder.getIndexTensorAttr(extents);
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // FunctionLibraryOp
1189 //===----------------------------------------------------------------------===//
1190 
1191 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1192                               StringRef name) {
1193   result.attributes.push_back(builder.getNamedAttr(
1194       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1195 }
1196 
1197 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1198   auto attr = getMapping()
1199                   .get(op->getName().getIdentifier())
1200                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1201   if (!attr)
1202     return nullptr;
1203   return lookupSymbol<FuncOp>(attr);
1204 }
1205 
1206 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1207                                      OperationState &result) {
1208   // Parse the op name.
1209   StringAttr nameAttr;
1210   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1211                              result.attributes))
1212     return failure();
1213 
1214   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1215     return failure();
1216 
1217   auto *bodyRegion = result.addRegion();
1218   if (parser.parseRegion(*bodyRegion))
1219     return failure();
1220 
1221   if (parser.parseKeyword("mapping"))
1222     return failure();
1223 
1224   DictionaryAttr mappingAttr;
1225   if (parser.parseAttribute(mappingAttr,
1226                             parser.getBuilder().getType<NoneType>(), "mapping",
1227                             result.attributes))
1228     return failure();
1229   return success();
1230 }
1231 
1232 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1233   p << ' ';
1234   p.printSymbolName(getName());
1235   p.printOptionalAttrDictWithKeyword(
1236       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1237   p << ' ';
1238   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1239                 /*printBlockTerminators=*/false);
1240   p << " mapping ";
1241   p.printAttributeWithoutType(getMappingAttr());
1242 }
1243 
1244 //===----------------------------------------------------------------------===//
1245 // GetExtentOp
1246 //===----------------------------------------------------------------------===//
1247 
1248 Optional<int64_t> GetExtentOp::getConstantDim() {
1249   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1250     return constSizeOp.getValue().getLimitedValue();
1251   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1252     return constantOp.getValue().cast<IntegerAttr>().getInt();
1253   return llvm::None;
1254 }
1255 
1256 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1257   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1258   if (!elements)
1259     return nullptr;
1260   Optional<int64_t> dim = getConstantDim();
1261   if (!dim.hasValue())
1262     return nullptr;
1263   if (dim.getValue() >= elements.getNumElements())
1264     return nullptr;
1265   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1266 }
1267 
1268 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1269                         int64_t dim) {
1270   auto loc = result.location;
1271   auto dimAttr = builder.getIndexAttr(dim);
1272   if (shape.getType().isa<ShapeType>()) {
1273     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1274     build(builder, result, builder.getType<SizeType>(), shape, dim);
1275   } else {
1276     Value dim =
1277         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1278     build(builder, result, builder.getIndexType(), shape, dim);
1279   }
1280 }
1281 
1282 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1283     MLIRContext *context, Optional<Location> location, ValueRange operands,
1284     DictionaryAttr attributes, RegionRange regions,
1285     SmallVectorImpl<Type> &inferredReturnTypes) {
1286   inferredReturnTypes.assign({IndexType::get(context)});
1287   return success();
1288 }
1289 
1290 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1291                                                        TypeRange r) {
1292   // SizeType is compatible with IndexType.
1293   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1294 }
1295 
1296 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // IsBroadcastableOp
1300 //===----------------------------------------------------------------------===//
1301 
1302 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1303                                                     MLIRContext *context) {
1304   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1305 }
1306 
1307 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1308   // Can always broadcast fewer than two shapes.
1309   if (operands.size() < 2) {
1310     return BoolAttr::get(getContext(), true);
1311   }
1312 
1313   return nullptr;
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // MeetOp
1318 //===----------------------------------------------------------------------===//
1319 
1320 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1321     MLIRContext *context, Optional<Location> location, ValueRange operands,
1322     DictionaryAttr attributes, RegionRange regions,
1323     SmallVectorImpl<Type> &inferredReturnTypes) {
1324   inferredReturnTypes.assign({operands[0].getType()});
1325   return success();
1326 }
1327 
1328 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1329   if (l.size() != 1 || r.size() != 1)
1330     return false;
1331   if (l == r)
1332     return true;
1333 
1334   Type lhs = l.front();
1335   Type rhs = r.front();
1336 
1337   if (lhs != rhs)
1338     return false;
1339 
1340   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1341     return true;
1342 
1343   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1344     return true;
1345   return false;
1346 }
1347 
1348 //===----------------------------------------------------------------------===//
1349 // RankOp
1350 //===----------------------------------------------------------------------===//
1351 
1352 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1353   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1354   if (!shape)
1355     return {};
1356   int64_t rank = shape.getNumElements();
1357   Builder builder(getContext());
1358   return builder.getIndexAttr(rank);
1359 }
1360 
1361 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1362 /// Constant folding fails in cases where only the rank is constant, not the
1363 /// shape itself.
1364 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1365 ///
1366 /// Example:
1367 ///
1368 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1369 /// %rank = shape.rank %shape
1370 ///
1371 /// becomes
1372 ///
1373 /// %rank = shape.const_size 3
1374 
1375 namespace {
1376 struct RankShapeOfCanonicalizationPattern
1377     : public OpRewritePattern<shape::RankOp> {
1378   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1379 
1380   LogicalResult matchAndRewrite(shape::RankOp op,
1381                                 PatternRewriter &rewriter) const override {
1382     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1383     if (!shapeOfOp)
1384       return failure();
1385     auto rankedTensorType =
1386         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1387     if (!rankedTensorType)
1388       return failure();
1389     int64_t rank = rankedTensorType.getRank();
1390     if (op.getType().isa<IndexType>()) {
1391       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1392                                                           rank);
1393     } else if (op.getType().isa<shape::SizeType>()) {
1394       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1395     } else {
1396       return failure();
1397     }
1398     return success();
1399   }
1400 };
1401 } // namespace
1402 
1403 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1404                                                 MLIRContext *context) {
1405   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1406 }
1407 
1408 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1409     MLIRContext *context, Optional<Location> location, ValueRange operands,
1410     DictionaryAttr attributes, RegionRange regions,
1411     SmallVectorImpl<Type> &inferredReturnTypes) {
1412   if (operands[0].getType().isa<ShapeType>())
1413     inferredReturnTypes.assign({SizeType::get(context)});
1414   else
1415     inferredReturnTypes.assign({IndexType::get(context)});
1416   return success();
1417 }
1418 
1419 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1420   // SizeType is compatible with IndexType.
1421   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1422 }
1423 
1424 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // NumElementsOp
1428 //===----------------------------------------------------------------------===//
1429 
1430 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1431 
1432   // Fold only when argument constant.
1433   Attribute shape = operands[0];
1434   if (!shape)
1435     return {};
1436 
1437   APInt product(64, 1);
1438   for (auto value : shape.cast<DenseIntElementsAttr>())
1439     product *= value;
1440   Builder builder(getContext());
1441   return builder.getIndexAttr(product.getLimitedValue());
1442 }
1443 
1444 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1445     MLIRContext *context, Optional<Location> location, ValueRange operands,
1446     DictionaryAttr attributes, RegionRange regions,
1447     SmallVectorImpl<Type> &inferredReturnTypes) {
1448   if (operands[0].getType().isa<ShapeType>())
1449     inferredReturnTypes.assign({SizeType::get(context)});
1450   else
1451     inferredReturnTypes.assign({IndexType::get(context)});
1452   return success();
1453 }
1454 
1455 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1456                                                          TypeRange r) {
1457   // SizeType is compatible with IndexType.
1458   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1459 }
1460 
1461 LogicalResult shape::NumElementsOp::verify() {
1462   return verifySizeOrIndexOp(*this);
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // MaxOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1470   // If operands are equal, just propagate one.
1471   if (getLhs() == getRhs())
1472     return getLhs();
1473   return nullptr;
1474 }
1475 
1476 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1477     MLIRContext *context, Optional<Location> location, ValueRange operands,
1478     DictionaryAttr attributes, RegionRange regions,
1479     SmallVectorImpl<Type> &inferredReturnTypes) {
1480   if (operands[0].getType() == operands[1].getType())
1481     inferredReturnTypes.assign({operands[0].getType()});
1482   else
1483     inferredReturnTypes.assign({SizeType::get(context)});
1484   return success();
1485 }
1486 
1487 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1488   if (l.size() != 1 || r.size() != 1)
1489     return false;
1490   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1491     return true;
1492   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1493     return true;
1494   return false;
1495 }
1496 
1497 //===----------------------------------------------------------------------===//
1498 // MinOp
1499 //===----------------------------------------------------------------------===//
1500 
1501 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1502   // If operands are equal, just propagate one.
1503   if (getLhs() == getRhs())
1504     return getLhs();
1505   return nullptr;
1506 }
1507 
1508 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1509     MLIRContext *context, Optional<Location> location, ValueRange operands,
1510     DictionaryAttr attributes, RegionRange regions,
1511     SmallVectorImpl<Type> &inferredReturnTypes) {
1512   if (operands[0].getType() == operands[1].getType())
1513     inferredReturnTypes.assign({operands[0].getType()});
1514   else
1515     inferredReturnTypes.assign({SizeType::get(context)});
1516   return success();
1517 }
1518 
1519 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1520   if (l.size() != 1 || r.size() != 1)
1521     return false;
1522   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1523     return true;
1524   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1525     return true;
1526   return false;
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // MulOp
1531 //===----------------------------------------------------------------------===//
1532 
1533 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1534   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1535   if (!lhs)
1536     return nullptr;
1537   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1538   if (!rhs)
1539     return nullptr;
1540   APInt folded = lhs.getValue() * rhs.getValue();
1541   Type indexTy = IndexType::get(getContext());
1542   return IntegerAttr::get(indexTy, folded);
1543 }
1544 
1545 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1546     MLIRContext *context, Optional<Location> location, ValueRange operands,
1547     DictionaryAttr attributes, RegionRange regions,
1548     SmallVectorImpl<Type> &inferredReturnTypes) {
1549   if (operands[0].getType().isa<SizeType>() ||
1550       operands[1].getType().isa<SizeType>())
1551     inferredReturnTypes.assign({SizeType::get(context)});
1552   else
1553     inferredReturnTypes.assign({IndexType::get(context)});
1554   return success();
1555 }
1556 
1557 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1558   // SizeType is compatible with IndexType.
1559   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1560 }
1561 
1562 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1563 
1564 //===----------------------------------------------------------------------===//
1565 // ShapeOfOp
1566 //===----------------------------------------------------------------------===//
1567 
1568 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1569   auto type = getOperand().getType().dyn_cast<ShapedType>();
1570   if (!type || !type.hasStaticShape())
1571     return nullptr;
1572   Builder builder(getContext());
1573   return builder.getIndexTensorAttr(type.getShape());
1574 }
1575 
1576 namespace {
1577 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1578   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1579 
1580   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1581                                 PatternRewriter &rewriter) const override {
1582     if (!op.getArg().getType().isa<ShapedType>())
1583       return failure();
1584     if (op.getType().isa<ShapedType>())
1585       return failure();
1586 
1587     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1588                                                   op.getArg());
1589     return success();
1590   }
1591 };
1592 
1593 // Canonicalize
1594 // ```
1595 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1596 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1597 // ```
1598 // to
1599 // ```
1600 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1601 // ```
1602 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1603   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1604 
1605   LogicalResult matchAndRewrite(tensor::CastOp op,
1606                                 PatternRewriter &rewriter) const override {
1607     auto ty = op.getType().dyn_cast<RankedTensorType>();
1608     if (!ty || ty.getRank() != 1)
1609       return failure();
1610 
1611     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1612     if (!shapeOfOp)
1613       return failure();
1614 
1615     // Argument type must be ranked and must not conflict.
1616     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1617     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1618       return failure();
1619 
1620     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1621     return success();
1622   }
1623 };
1624 } // namespace
1625 
1626 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1627                                             MLIRContext *context) {
1628   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1629                ExtractFromShapeOfExtentTensor>(context);
1630 }
1631 
1632 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1633     MLIRContext *context, Optional<Location> location, ValueRange operands,
1634     DictionaryAttr attributes, RegionRange regions,
1635     SmallVectorImpl<Type> &inferredReturnTypes) {
1636   if (operands[0].getType().isa<ValueShapeType>())
1637     inferredReturnTypes.assign({ShapeType::get(context)});
1638   else {
1639     auto shapedTy = operands[0].getType().cast<ShapedType>();
1640     int64_t rank =
1641         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1642     Type indexTy = IndexType::get(context);
1643     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1644     inferredReturnTypes.assign({extentTensorTy});
1645   }
1646   return success();
1647 }
1648 
1649 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1650   if (l.size() != 1 || r.size() != 1)
1651     return false;
1652   if (l == r)
1653     return true;
1654 
1655   Type lhs = l.front();
1656   Type rhs = r.front();
1657 
1658   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1659     return false;
1660 
1661   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1662     // Shape type is compatible with all other valid return types.
1663     return true;
1664 
1665   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1666     return true;
1667   return false;
1668 }
1669 
1670 LogicalResult shape::ShapeOfOp::verify() {
1671   return verifyShapeOrExtentTensorOp(*this);
1672 }
1673 
1674 //===----------------------------------------------------------------------===//
1675 // SizeToIndexOp
1676 //===----------------------------------------------------------------------===//
1677 
1678 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1679   // Constant values of both types, `shape.size` and `index`, are represented as
1680   // `IntegerAttr`s which makes constant folding simple.
1681   if (Attribute arg = operands[0])
1682     return arg;
1683   return OpFoldResult();
1684 }
1685 
1686 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1687                                                 MLIRContext *context) {
1688   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1689 }
1690 
1691 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1692   if (inputs.size() != 1 || outputs.size() != 1)
1693     return false;
1694   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // YieldOp
1699 //===----------------------------------------------------------------------===//
1700 
1701 LogicalResult shape::YieldOp::verify() {
1702   auto *parentOp = (*this)->getParentOp();
1703   auto results = parentOp->getResults();
1704   auto operands = getOperands();
1705 
1706   if (parentOp->getNumResults() != getNumOperands())
1707     return emitOpError() << "number of operands does not match number of "
1708                             "results of its parent";
1709   for (auto e : llvm::zip(results, operands))
1710     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1711       return emitOpError() << "types mismatch between yield op and its parent";
1712 
1713   return success();
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // SplitAtOp
1718 //===----------------------------------------------------------------------===//
1719 
1720 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1721                               SmallVectorImpl<OpFoldResult> &results) {
1722   if (!operands[0] || !operands[1])
1723     return failure();
1724   auto shapeVec = llvm::to_vector<6>(
1725       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1726   auto shape = llvm::makeArrayRef(shapeVec);
1727   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1728   // Verify that the split point is in the correct range.
1729   // TODO: Constant fold to an "error".
1730   int64_t rank = shape.size();
1731   if (!(-rank <= splitPoint && splitPoint <= rank))
1732     return failure();
1733   if (splitPoint < 0)
1734     splitPoint += shape.size();
1735   Builder builder(operands[0].getContext());
1736   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1737   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1738   return success();
1739 }
1740 
1741 //===----------------------------------------------------------------------===//
1742 // ToExtentTensorOp
1743 //===----------------------------------------------------------------------===//
1744 
1745 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1746   if (!operands[0])
1747     return OpFoldResult();
1748   Builder builder(getContext());
1749   auto shape = llvm::to_vector<6>(
1750       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1751   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1752                                     builder.getIndexType());
1753   return DenseIntElementsAttr::get(type, shape);
1754 }
1755 
1756 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1757   if (inputs.size() != 1 || outputs.size() != 1)
1758     return false;
1759   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1760     if (!inputTensor.getElementType().isa<IndexType>() ||
1761         inputTensor.getRank() != 1)
1762       return false;
1763   } else if (!inputs[0].isa<ShapeType>()) {
1764     return false;
1765   }
1766 
1767   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1768   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1769 }
1770 
1771 //===----------------------------------------------------------------------===//
1772 // ReduceOp
1773 //===----------------------------------------------------------------------===//
1774 
1775 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1776                      ValueRange initVals) {
1777   result.addOperands(shape);
1778   result.addOperands(initVals);
1779 
1780   Region *bodyRegion = result.addRegion();
1781   bodyRegion->push_back(new Block);
1782   Block &bodyBlock = bodyRegion->front();
1783   bodyBlock.addArgument(builder.getIndexType(), result.location);
1784 
1785   Type elementType;
1786   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1787     elementType = tensorType.getElementType();
1788   else
1789     elementType = SizeType::get(builder.getContext());
1790   bodyBlock.addArgument(elementType, shape.getLoc());
1791 
1792   for (Value initVal : initVals) {
1793     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1794     result.addTypes(initVal.getType());
1795   }
1796 }
1797 
1798 LogicalResult ReduceOp::verify() {
1799   // Verify block arg types.
1800   Block &block = getRegion().front();
1801 
1802   // The block takes index, extent, and aggregated values as arguments.
1803   auto blockArgsCount = getInitVals().size() + 2;
1804   if (block.getNumArguments() != blockArgsCount)
1805     return emitOpError() << "ReduceOp body is expected to have "
1806                          << blockArgsCount << " arguments";
1807 
1808   // The first block argument is the index and must always be of type `index`.
1809   if (!block.getArgument(0).getType().isa<IndexType>())
1810     return emitOpError(
1811         "argument 0 of ReduceOp body is expected to be of IndexType");
1812 
1813   // The second block argument is the extent and must be of type `size` or
1814   // `index`, depending on whether the reduce operation is applied to a shape or
1815   // to an extent tensor.
1816   Type extentTy = block.getArgument(1).getType();
1817   if (getShape().getType().isa<ShapeType>()) {
1818     if (!extentTy.isa<SizeType>())
1819       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1820                          "SizeType if the ReduceOp operates on a ShapeType");
1821   } else {
1822     if (!extentTy.isa<IndexType>())
1823       return emitOpError(
1824           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1825           "ReduceOp operates on an extent tensor");
1826   }
1827 
1828   for (const auto &type : llvm::enumerate(getInitVals()))
1829     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1830       return emitOpError() << "type mismatch between argument "
1831                            << type.index() + 2
1832                            << " of ReduceOp body and initial value "
1833                            << type.index();
1834   return success();
1835 }
1836 
1837 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1838   // Parse operands.
1839   SmallVector<OpAsmParser::OperandType, 3> operands;
1840   Type shapeOrExtentTensorType;
1841   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1842                               OpAsmParser::Delimiter::Paren) ||
1843       parser.parseColonType(shapeOrExtentTensorType) ||
1844       parser.parseOptionalArrowTypeList(result.types))
1845     return failure();
1846 
1847   // Resolve operands.
1848   auto initVals = llvm::makeArrayRef(operands).drop_front();
1849   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1850                             result.operands) ||
1851       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1852                              result.operands))
1853     return failure();
1854 
1855   // Parse the body.
1856   Region *body = result.addRegion();
1857   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1858     return failure();
1859 
1860   // Parse attributes.
1861   if (parser.parseOptionalAttrDict(result.attributes))
1862     return failure();
1863 
1864   return success();
1865 }
1866 
1867 void ReduceOp::print(OpAsmPrinter &p) {
1868   p << '(' << getShape() << ", " << getInitVals()
1869     << ") : " << getShape().getType();
1870   p.printOptionalArrowTypeList(getResultTypes());
1871   p << ' ';
1872   p.printRegion(getRegion());
1873   p.printOptionalAttrDict((*this)->getAttrs());
1874 }
1875 
1876 #define GET_OP_CLASSES
1877 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1878