1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/SetOperations.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::shape;
31 
32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
33 
34 namespace {
35 #include "ShapeCanonicalization.inc"
36 } // namespace
37 
38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
39   return RankedTensorType::get({rank}, IndexType::get(ctx));
40 }
41 
42 bool shape::isExtentTensorType(Type type) {
43   auto ranked = type.dyn_cast<RankedTensorType>();
44   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
45 }
46 
47 LogicalResult shape::getShapeVec(Value input,
48                                  SmallVectorImpl<int64_t> &shapeValues) {
49   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
50     auto type = inputOp.getArg().getType().cast<ShapedType>();
51     if (!type.hasRank())
52       return failure();
53     llvm::append_range(shapeValues, type.getShape());
54     return success();
55   }
56   DenseIntElementsAttr attr;
57   if (matchPattern(input, m_Constant(&attr))) {
58     llvm::append_range(shapeValues, attr.getValues<int64_t>());
59     return success();
60   }
61   return failure();
62 }
63 
64 static bool isErrorPropagationPossible(TypeRange operandTypes) {
65   return llvm::any_of(operandTypes, [](Type ty) {
66     return ty.isa<SizeType, ShapeType, ValueShapeType>();
67   });
68 }
69 
70 static LogicalResult verifySizeOrIndexOp(Operation *op) {
71   assert(op != nullptr && op->getNumResults() == 1);
72   Type resultTy = op->getResultTypes().front();
73   if (isErrorPropagationPossible(op->getOperandTypes())) {
74     if (!resultTy.isa<SizeType>())
75       return op->emitOpError()
76              << "if at least one of the operands can hold error values then "
77                 "the result must be of type `size` to propagate them";
78   }
79   return success();
80 }
81 
82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
83   assert(op != nullptr && op->getNumResults() == 1);
84   Type resultTy = op->getResultTypes().front();
85   if (isErrorPropagationPossible(op->getOperandTypes())) {
86     if (!resultTy.isa<ShapeType>())
87       return op->emitOpError()
88              << "if at least one of the operands can hold error values then "
89                 "the result must be of type `shape` to propagate them";
90   }
91   return success();
92 }
93 
94 template <typename... Ty>
95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
96   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
97 }
98 
99 template <typename... Ty, typename... ranges>
100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
101   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // InlinerInterface
106 //===----------------------------------------------------------------------===//
107 
108 namespace {
109 /// This class defines the interface for inlining shape dialect ops.
110 struct ShapeInlinerInterface : public DialectInlinerInterface {
111   using DialectInlinerInterface::DialectInlinerInterface;
112 
113   // Returns true if the given region 'src' can be inlined into the region
114   // 'dest' that is attached to an operation registered to the current dialect.
115   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
116                        BlockAndValueMapping &) const final {
117     return true;
118   }
119 
120   // Returns true if the given operation 'op', that is registered to this
121   // dialect, can be inlined into the region 'dest' that is attached to an
122   // operation registered to the current dialect.
123   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
124                        BlockAndValueMapping &) const final {
125     return true;
126   }
127 };
128 } // namespace
129 
130 void ShapeDialect::initialize() {
131   addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
134       >();
135   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
136   addInterfaces<ShapeInlinerInterface>();
137   // Allow unknown operations during prototyping and testing. As the dialect is
138   // still evolving it makes it simple to start with an unregistered ops and
139   // try different variants before actually defining the op.
140   allowUnknownOperations();
141 }
142 
143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
144                                              Attribute value, Type type,
145                                              Location loc) {
146   if (type.isa<ShapeType>() || isExtentTensorType(type))
147     return builder.create<ConstShapeOp>(loc, type,
148                                         value.cast<DenseIntElementsAttr>());
149   if (type.isa<SizeType>())
150     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
151   if (type.isa<WitnessType>())
152     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
153   if (arith::ConstantOp::isBuildableWith(value, type))
154     return builder.create<arith::ConstantOp>(loc, type, value);
155   return nullptr;
156 }
157 
158 /// Parse a type registered to this dialect.
159 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
160   StringRef keyword;
161   if (parser.parseKeyword(&keyword))
162     return Type();
163 
164   if (keyword == "shape")
165     return ShapeType::get(getContext());
166   if (keyword == "size")
167     return SizeType::get(getContext());
168   if (keyword == "value_shape")
169     return ValueShapeType::get(getContext());
170   if (keyword == "witness")
171     return WitnessType::get(getContext());
172 
173   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
174   return Type();
175 }
176 
177 /// Print a type registered to this dialect.
178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
179   TypeSwitch<Type>(type)
180       .Case<ShapeType>([&](Type) { os << "shape"; })
181       .Case<SizeType>([&](Type) { os << "size"; })
182       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
183       .Case<WitnessType>([&](Type) { os << "witness"; })
184       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
185 }
186 
187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
188                                                      NamedAttribute attribute) {
189   // Verify shape.lib attribute.
190   if (attribute.getName() == "shape.lib") {
191     if (!op->hasTrait<OpTrait::SymbolTable>())
192       return op->emitError(
193           "shape.lib attribute may only be on op implementing SymbolTable");
194 
195     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
196       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
197       if (!symbol)
198         return op->emitError("shape function library ")
199                << symbolRef << " not found";
200       return isa<shape::FunctionLibraryOp>(symbol)
201                  ? success()
202                  : op->emitError()
203                        << symbolRef << " required to be shape function library";
204     }
205 
206     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
207       // Verify all entries are function libraries and mappings in libraries
208       // refer to unique ops.
209       DenseSet<StringAttr> key;
210       for (auto it : arr) {
211         if (!it.isa<SymbolRefAttr>())
212           return op->emitError(
213               "only SymbolRefAttr allowed in shape.lib attribute array");
214 
215         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
216             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
217         if (!shapeFnLib)
218           return op->emitError()
219                  << it << " does not refer to FunctionLibraryOp";
220         for (auto mapping : shapeFnLib.getMapping()) {
221           if (!key.insert(mapping.getName()).second) {
222             return op->emitError("only one op to shape mapping allowed, found "
223                                  "multiple for `")
224                    << mapping.getName() << "`";
225           }
226         }
227       }
228       return success();
229     }
230 
231     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
232                          "allowed as shape.lib attribute");
233   }
234   return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AnyOp
239 //===----------------------------------------------------------------------===//
240 
241 // TODO: Canonicalization should be implemented for shapes that can be
242 // determined through mixtures of the known dimensions of the inputs.
243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
244   // Only the last operand is checked because AnyOp is commutative.
245   if (operands.back())
246     return operands.back();
247 
248   return nullptr;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // AssumingOp
253 //===----------------------------------------------------------------------===//
254 
255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
256   result.regions.reserve(1);
257   Region *doRegion = result.addRegion();
258 
259   auto &builder = parser.getBuilder();
260   OpAsmParser::OperandType cond;
261   if (parser.parseOperand(cond) ||
262       parser.resolveOperand(cond, builder.getType<WitnessType>(),
263                             result.operands))
264     return failure();
265 
266   // Parse optional results type list.
267   if (parser.parseOptionalArrowTypeList(result.types))
268     return failure();
269 
270   // Parse the region and add a terminator if elided.
271   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
272     return failure();
273   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
274 
275   // Parse the optional attribute list.
276   if (parser.parseOptionalAttrDict(result.attributes))
277     return failure();
278   return success();
279 }
280 
281 void AssumingOp::print(OpAsmPrinter &p) {
282   bool yieldsResults = !getResults().empty();
283 
284   p << " " << getWitness();
285   if (yieldsResults)
286     p << " -> (" << getResultTypes() << ")";
287   p << ' ';
288   p.printRegion(getDoRegion(),
289                 /*printEntryBlockArgs=*/false,
290                 /*printBlockTerminators=*/yieldsResults);
291   p.printOptionalAttrDict((*this)->getAttrs());
292 }
293 
294 namespace {
295 // Removes AssumingOp with a passing witness and inlines the region.
296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
297   using OpRewritePattern<AssumingOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(AssumingOp op,
300                                 PatternRewriter &rewriter) const override {
301     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
302     if (!witness || !witness.getPassingAttr())
303       return failure();
304 
305     AssumingOp::inlineRegionIntoParent(op, rewriter);
306     return success();
307   }
308 };
309 
310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
311   using OpRewritePattern<AssumingOp>::OpRewritePattern;
312 
313   LogicalResult matchAndRewrite(AssumingOp op,
314                                 PatternRewriter &rewriter) const override {
315     Block *body = op.getBody();
316     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
317 
318     // Find used values.
319     SmallVector<Value, 4> newYieldOperands;
320     Value opResult, yieldOperand;
321     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
322       std::tie(opResult, yieldOperand) = it;
323       if (!opResult.getUses().empty()) {
324         newYieldOperands.push_back(yieldOperand);
325       }
326     }
327 
328     // Rewrite only if redundant results exist.
329     if (newYieldOperands.size() == yieldOp->getNumOperands())
330       return failure();
331 
332     // Replace yield op in the old assuming op's body and move the entire region
333     // to the new assuming op.
334     rewriter.setInsertionPointToEnd(body);
335     auto newYieldOp =
336         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
337     rewriter.setInsertionPoint(op);
338     auto newOp = rewriter.create<AssumingOp>(
339         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
340     newOp.getDoRegion().takeBody(op.getDoRegion());
341 
342     // Use the new results to replace the previously used ones.
343     SmallVector<Value, 4> replacementValues;
344     auto src = newOp.getResults().begin();
345     for (auto it : op.getResults()) {
346       if (it.getUses().empty())
347         replacementValues.push_back(nullptr);
348       else
349         replacementValues.push_back(*src++);
350     }
351     rewriter.replaceOp(op, replacementValues);
352     return success();
353   }
354 };
355 } // namespace
356 
357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
358                                              MLIRContext *context) {
359   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
360 }
361 
362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
363 void AssumingOp::getSuccessorRegions(
364     Optional<unsigned> index, ArrayRef<Attribute> operands,
365     SmallVectorImpl<RegionSuccessor> &regions) {
366   // AssumingOp has unconditional control flow into the region and back to the
367   // parent, so return the correct RegionSuccessor purely based on the index
368   // being None or 0.
369   if (index.hasValue()) {
370     regions.push_back(RegionSuccessor(getResults()));
371     return;
372   }
373 
374   regions.push_back(RegionSuccessor(&getDoRegion()));
375 }
376 
377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
378                                         PatternRewriter &rewriter) {
379   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
380   auto *assumingBlock = op.getBody();
381   auto initPosition = rewriter.getInsertionPoint();
382   auto *blockAfterAssuming =
383       rewriter.splitBlock(blockBeforeAssuming, initPosition);
384 
385   // Remove the AssumingOp and AssumingYieldOp.
386   auto &yieldOp = assumingBlock->back();
387   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
388   rewriter.replaceOp(op, yieldOp.getOperands());
389   rewriter.eraseOp(&yieldOp);
390 
391   // Merge blocks together as there was no branching behavior from the
392   // AssumingOp.
393   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
394   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
395 }
396 
397 void AssumingOp::build(
398     OpBuilder &builder, OperationState &result, Value witness,
399     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
400 
401   result.addOperands(witness);
402   Region *bodyRegion = result.addRegion();
403   bodyRegion->push_back(new Block);
404   Block &bodyBlock = bodyRegion->front();
405 
406   // Build body.
407   OpBuilder::InsertionGuard guard(builder);
408   builder.setInsertionPointToStart(&bodyBlock);
409   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
410   builder.create<AssumingYieldOp>(result.location, yieldValues);
411 
412   SmallVector<Type, 2> assumingTypes;
413   for (Value v : yieldValues)
414     assumingTypes.push_back(v.getType());
415   result.addTypes(assumingTypes);
416 }
417 
418 LogicalResult AssumingOp::verify() {
419   return RegionBranchOpInterface::verifyTypes(*this);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // AddOp
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult mlir::shape::AddOp::inferReturnTypes(
427     MLIRContext *context, Optional<Location> location, ValueRange operands,
428     DictionaryAttr attributes, RegionRange regions,
429     SmallVectorImpl<Type> &inferredReturnTypes) {
430   if (operands[0].getType().isa<SizeType>() ||
431       operands[1].getType().isa<SizeType>())
432     inferredReturnTypes.assign({SizeType::get(context)});
433   else
434     inferredReturnTypes.assign({IndexType::get(context)});
435   return success();
436 }
437 
438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
439   // SizeType is compatible with IndexType.
440   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
441 }
442 
443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
444   // add(x, 0) -> x
445   if (matchPattern(getRhs(), m_Zero()))
446     return getLhs();
447 
448   return constFoldBinaryOp<IntegerAttr>(
449       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
450 }
451 
452 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
453 
454 //===----------------------------------------------------------------------===//
455 // AssumingAllOp
456 //===----------------------------------------------------------------------===//
457 
458 namespace {
459 
460 // Merge multiple `shape.assuming_all` operations together.
461 //
462 //   %0 = shape.assuming_all %w0, %w1
463 //   %1 = shape.assuming_all %w2, %0
464 //
465 // to:
466 //
467 //   %0 = shape.assuming_all %w0, %w2, %w2
468 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
469   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
470 
471   LogicalResult matchAndRewrite(AssumingAllOp op,
472                                 PatternRewriter &rewriter) const override {
473     SmallVector<Value> operands;
474 
475     for (Value operand : op.getInputs()) {
476       if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
477         operands.append(assume_all.operand_begin(), assume_all->operand_end());
478       else
479         operands.push_back(operand);
480     }
481 
482     // We didn't find any other `assuming_all` ops to merge with.
483     if (operands.size() == op.getNumOperands())
484       return failure();
485 
486     // Replace with a new `assuming_all` operation with merged constraints.
487     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
488     return success();
489   }
490 };
491 
492 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
493 // are subsumed by others.
494 //
495 //   %0 = shape.cstr_broadcastable %shape0, %shape1
496 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
497 //
498 //   %2 = shape.cstr_broadcastable %shape3, %shape4
499 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
500 //
501 //   %4 = shape.assuming_all %0, %1, %2, %3
502 //
503 // to:
504 //
505 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
506 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
507 //   %2 = shape.assuming_all %0, %1
508 //
509 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
510 // shapes [0, 1] are broadcastable too, and can be removed from the list of
511 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
512 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
513 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
514   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
515 
516   LogicalResult matchAndRewrite(AssumingAllOp op,
517                                 PatternRewriter &rewriter) const override {
518     // Collect all `CstrBroadcastableOp` operands first.
519     SetVector<CstrBroadcastableOp> operands;
520     for (Value operand : op.getInputs()) {
521       // TODO: Apply this optimization if some of the witnesses are not
522       // produced by the `cstr_broadcastable`.
523       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
524       if (!broadcastable)
525         return failure();
526 
527       operands.insert(broadcastable);
528     }
529 
530     // Skip trivial `assuming_all` operations.
531     if (operands.size() <= 1)
532       return failure();
533 
534     // Collect shapes checked by `cstr_broadcastable` operands.
535     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
536     for (auto cstr : operands) {
537       DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
538       shapes.emplace_back(cstr, std::move(shapes_set));
539     }
540 
541     // Sort by the number of shape operands (larger to smaller).
542     llvm::sort(shapes, [](auto a, auto b) {
543       return a.first.getNumOperands() > b.first.getNumOperands();
544     });
545 
546     // We start from the `cst_broadcastable` operations with largest number of
547     // shape operands, and remove redundant `cst_broadcastable` operations. We
548     // do this until we find a set of `cst_broadcastable` operations with
549     // non-overlapping constraints.
550     SmallVector<CstrBroadcastableOp> marked_for_erase;
551 
552     for (unsigned i = 0; i < shapes.size(); ++i) {
553       auto isSubset = [&](auto pair) {
554         return llvm::set_is_subset(pair.second, shapes[i].second);
555       };
556 
557       // Keep redundant `cstr_broadcastable` operations to be erased.
558       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
559       for (auto *it0 = it; it0 < shapes.end(); ++it0)
560         marked_for_erase.push_back(it0->first);
561       shapes.erase(it, shapes.end());
562     }
563 
564     // We didn't find any operands that could be removed.
565     if (marked_for_erase.empty())
566       return failure();
567 
568     // Collect non-overlapping `cst_broadcastable` constraints.
569     SmallVector<Value> unique_constraints;
570     for (auto &shape : shapes)
571       unique_constraints.push_back(shape.first.getResult());
572 
573     // Replace with a new `assuming_all` operation ...
574     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
575 
576     // ... and maybe erase `cstr_broadcastable` ops without uses.
577     for (auto &op : marked_for_erase)
578       if (op->use_empty())
579         rewriter.eraseOp(op);
580 
581     return success();
582   }
583 };
584 
585 struct AssumingAllToCstrEqCanonicalization
586     : public OpRewritePattern<AssumingAllOp> {
587   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
588 
589   LogicalResult matchAndRewrite(AssumingAllOp op,
590                                 PatternRewriter &rewriter) const override {
591     SmallVector<Value, 8> shapes;
592     for (Value w : op.getInputs()) {
593       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
594       if (!cstrEqOp)
595         return failure();
596       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
597         return llvm::is_contained(shapes, s);
598       });
599       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
600         return failure();
601       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
602     }
603     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
604     return success();
605   }
606 };
607 
608 template <typename OpTy>
609 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
610   using OpRewritePattern<OpTy>::OpRewritePattern;
611 
612   LogicalResult matchAndRewrite(OpTy op,
613                                 PatternRewriter &rewriter) const override {
614     // Find unique operands.
615     SetVector<Value> unique(op.operand_begin(), op.operand_end());
616 
617     // Reduce op to equivalent with unique operands.
618     if (unique.size() < op.getNumOperands()) {
619       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
620                                         unique.takeVector(), op->getAttrs());
621       return success();
622     }
623 
624     return failure();
625   }
626 };
627 } // namespace
628 
629 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
630                                                 MLIRContext *context) {
631   patterns
632       .add<MergeAssumingAllOps, AssumingAllOneOp,
633            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
634            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
635 }
636 
637 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
638   // Iterate in reverse to first handle all constant operands. They are
639   // guaranteed to be the tail of the inputs because this is commutative.
640   for (int idx = operands.size() - 1; idx >= 0; idx--) {
641     Attribute a = operands[idx];
642     // Cannot fold if any inputs are not constant;
643     if (!a)
644       return nullptr;
645 
646     // We do not need to keep statically known values after handling them in
647     // this method.
648     getOperation()->eraseOperand(idx);
649 
650     // Always false if any input is statically known false
651     if (!a.cast<BoolAttr>().getValue())
652       return a;
653   }
654   // If this is reached, all inputs were statically known passing.
655   return BoolAttr::get(getContext(), true);
656 }
657 
658 LogicalResult AssumingAllOp::verify() {
659   // Ensure that AssumingAllOp contains at least one operand
660   if (getNumOperands() == 0)
661     return emitOpError("no operands specified");
662 
663   return success();
664 }
665 
666 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
667                           ValueRange inputs) {
668   build(b, state, b.getType<WitnessType>(), inputs);
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // BroadcastOp
673 //===----------------------------------------------------------------------===//
674 
675 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
676   if (getShapes().size() == 1) {
677     // Otherwise, we need a cast which would be a canonicalization, not folding.
678     if (getShapes().front().getType() != getType())
679       return nullptr;
680     return getShapes().front();
681   }
682 
683   // TODO: Support folding with more than 2 input shapes
684   if (getShapes().size() > 2)
685     return nullptr;
686 
687   if (!operands[0] || !operands[1])
688     return nullptr;
689   auto lhsShape = llvm::to_vector<6>(
690       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
691   auto rhsShape = llvm::to_vector<6>(
692       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
693   SmallVector<int64_t, 6> resultShape;
694 
695   // If the shapes are not compatible, we can't fold it.
696   // TODO: Fold to an "error".
697   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
698     return nullptr;
699 
700   Builder builder(getContext());
701   return builder.getIndexTensorAttr(resultShape);
702 }
703 
704 LogicalResult BroadcastOp::verify() {
705   return verifyShapeOrExtentTensorOp(*this);
706 }
707 
708 namespace {
709 template <typename OpTy>
710 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
711   using OpRewritePattern<OpTy>::OpRewritePattern;
712 
713   LogicalResult matchAndRewrite(OpTy op,
714                                 PatternRewriter &rewriter) const override {
715     auto isPotentiallyNonEmptyShape = [](Value shape) {
716       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
717         if (extentTensorTy.getDimSize(0) == 0)
718           return false;
719       }
720       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
721         if (constShape.getShape().empty())
722           return false;
723       }
724       return true;
725     };
726     auto newOperands = llvm::to_vector<8>(
727         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
728 
729     // Reduce op to equivalent without empty shape operands.
730     if (newOperands.size() < op.getNumOperands()) {
731       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
732                                         op->getAttrs());
733       return success();
734     }
735 
736     return failure();
737   }
738 };
739 
740 struct BroadcastForwardSingleOperandPattern
741     : public OpRewritePattern<BroadcastOp> {
742   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
743 
744   LogicalResult matchAndRewrite(BroadcastOp op,
745                                 PatternRewriter &rewriter) const override {
746     if (op.getNumOperands() != 1)
747       return failure();
748     Value replacement = op.getShapes().front();
749 
750     // Insert cast if needed.
751     if (replacement.getType() != op.getType()) {
752       auto loc = op.getLoc();
753       if (op.getType().isa<ShapeType>()) {
754         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
755       } else {
756         assert(!op.getType().isa<ShapeType>() &&
757                !replacement.getType().isa<ShapeType>() &&
758                "expect extent tensor cast");
759         replacement =
760             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
761       }
762     }
763 
764     rewriter.replaceOp(op, replacement);
765     return success();
766   }
767 };
768 
769 struct BroadcastFoldConstantOperandsPattern
770     : public OpRewritePattern<BroadcastOp> {
771   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
772 
773   LogicalResult matchAndRewrite(BroadcastOp op,
774                                 PatternRewriter &rewriter) const override {
775     SmallVector<int64_t, 8> foldedConstantShape;
776     SmallVector<Value, 8> newShapeOperands;
777     for (Value shape : op.getShapes()) {
778       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
779         SmallVector<int64_t, 8> newFoldedConstantShape;
780         if (OpTrait::util::getBroadcastedShape(
781                 foldedConstantShape,
782                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
783                 newFoldedConstantShape)) {
784           foldedConstantShape = newFoldedConstantShape;
785           continue;
786         }
787       }
788       newShapeOperands.push_back(shape);
789     }
790 
791     // Need at least two constant operands to fold anything.
792     if (op.getNumOperands() - newShapeOperands.size() < 2)
793       return failure();
794 
795     auto foldedConstantOperandsTy = RankedTensorType::get(
796         {static_cast<int64_t>(foldedConstantShape.size())},
797         rewriter.getIndexType());
798     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
799         op.getLoc(), foldedConstantOperandsTy,
800         rewriter.getIndexTensorAttr(foldedConstantShape)));
801     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
802                                              newShapeOperands);
803     return success();
804   }
805 };
806 
807 template <typename OpTy>
808 struct CanonicalizeCastExtentTensorOperandsPattern
809     : public OpRewritePattern<OpTy> {
810   using OpRewritePattern<OpTy>::OpRewritePattern;
811 
812   LogicalResult matchAndRewrite(OpTy op,
813                                 PatternRewriter &rewriter) const override {
814     // Canonicalize operands.
815     bool anyChange = false;
816     auto canonicalizeOperand = [&](Value operand) {
817       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
818         // Only eliminate the cast if it holds no shape information.
819         bool isInformationLoosingCast =
820             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
821         if (isInformationLoosingCast) {
822           anyChange = true;
823           return castOp.source();
824         }
825       }
826       return operand;
827     };
828     auto newOperands = llvm::to_vector<8>(
829         llvm::map_range(op.getOperands(), canonicalizeOperand));
830 
831     // Rewrite op if any change required.
832     if (!anyChange)
833       return failure();
834     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
835     return success();
836   }
837 };
838 
839 struct BroadcastConcretizeResultTypePattern
840     : public OpRewritePattern<BroadcastOp> {
841   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
842 
843   LogicalResult matchAndRewrite(BroadcastOp op,
844                                 PatternRewriter &rewriter) const override {
845     // Only concretize dynamic extent tensor result types.
846     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
847     if (!resultTy || !resultTy.isDynamicDim(0))
848       return failure();
849 
850     // Infer resulting shape rank if possible.
851     int64_t maxRank = 0;
852     for (Value shape : op.getShapes()) {
853       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
854         // Cannot infer resulting shape rank if any operand is dynamically
855         // ranked.
856         if (extentTensorTy.isDynamicDim(0))
857           return failure();
858         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
859       }
860     }
861 
862     auto newOp = rewriter.create<BroadcastOp>(
863         op.getLoc(), getExtentTensorType(getContext(), maxRank),
864         op.getShapes());
865     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
866     return success();
867   }
868 };
869 } // namespace
870 
871 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
872                                               MLIRContext *context) {
873   patterns.add<BroadcastConcretizeResultTypePattern,
874                BroadcastFoldConstantOperandsPattern,
875                BroadcastForwardSingleOperandPattern,
876                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
877                RemoveDuplicateOperandsPattern<BroadcastOp>,
878                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // ConcatOp
883 //===----------------------------------------------------------------------===//
884 
885 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
886   if (!operands[0] || !operands[1])
887     return nullptr;
888   auto lhsShape = llvm::to_vector<6>(
889       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
890   auto rhsShape = llvm::to_vector<6>(
891       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
892   SmallVector<int64_t, 6> resultShape;
893   resultShape.append(lhsShape.begin(), lhsShape.end());
894   resultShape.append(rhsShape.begin(), rhsShape.end());
895   Builder builder(getContext());
896   return builder.getIndexTensorAttr(resultShape);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // ConstShapeOp
901 //===----------------------------------------------------------------------===//
902 
903 void ConstShapeOp::print(OpAsmPrinter &p) {
904   p << " ";
905   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
906   p << "[";
907   interleaveComma(getShape().getValues<int64_t>(), p);
908   p << "] : ";
909   p.printType(getType());
910 }
911 
912 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
913   if (parser.parseOptionalAttrDict(result.attributes))
914     return failure();
915   // We piggy-back on ArrayAttr parsing, though we don't internally store the
916   // shape as an ArrayAttr.
917   // TODO: Implement custom parser and maybe make syntax a bit more concise.
918   Attribute extentsRaw;
919   NamedAttrList dummy;
920   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
921     return failure();
922   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
923   if (!extentsArray)
924     return failure();
925   SmallVector<int64_t, 6> ints;
926   for (Attribute extent : extentsArray) {
927     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
928     if (!attr)
929       return failure();
930     ints.push_back(attr.getInt());
931   }
932   Builder &builder = parser.getBuilder();
933   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
934   Type resultTy;
935   if (parser.parseColonType(resultTy))
936     return failure();
937   result.types.push_back(resultTy);
938   return success();
939 }
940 
941 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
942 
943 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
944                                                MLIRContext *context) {
945   patterns.add<TensorCastConstShape>(context);
946 }
947 
948 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
949     MLIRContext *context, Optional<Location> location, ValueRange operands,
950     DictionaryAttr attributes, RegionRange regions,
951     SmallVectorImpl<Type> &inferredReturnTypes) {
952   Builder b(context);
953   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
954   if (!shape)
955     return emitOptionalError(location, "missing shape attribute");
956   inferredReturnTypes.assign({RankedTensorType::get(
957       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
958   return success();
959 }
960 
961 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
962                                                         TypeRange r) {
963   if (l.size() != 1 || r.size() != 1)
964     return false;
965 
966   Type lhs = l.front();
967   Type rhs = r.front();
968 
969   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
970     // Shape type is compatible with all other valid return types.
971     return true;
972   return lhs == rhs;
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // CstrBroadcastableOp
977 //===----------------------------------------------------------------------===//
978 
979 void CstrBroadcastableOp::getCanonicalizationPatterns(
980     RewritePatternSet &patterns, MLIRContext *context) {
981   // Canonicalization patterns have overlap with the considerations during
982   // folding in case additional shape information is inferred at some point that
983   // does not result in folding.
984   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
985                CstrBroadcastableEqOps,
986                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
987                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
988 }
989 
990 // Return true if there is exactly one attribute not representing a scalar
991 // broadcast.
992 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
993   bool nonScalarSeen = false;
994   for (Attribute a : attributes) {
995     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
996       if (nonScalarSeen)
997         return false;
998       nonScalarSeen = true;
999     }
1000   }
1001   return true;
1002 }
1003 
1004 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1005   // No broadcasting is needed if all operands but one are scalar.
1006   if (hasAtMostSingleNonScalar(operands))
1007     return BoolAttr::get(getContext(), true);
1008 
1009   if ([&] {
1010         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1011         for (const auto &operand : operands) {
1012           if (!operand)
1013             return false;
1014           extents.push_back(llvm::to_vector<6>(
1015               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
1016         }
1017         return OpTrait::util::staticallyKnownBroadcastable(extents);
1018       }())
1019     return BoolAttr::get(getContext(), true);
1020 
1021   // Lastly, see if folding can be completed based on what constraints are known
1022   // on the input shapes.
1023   if ([&] {
1024         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1025         for (auto shapeValue : getShapes()) {
1026           extents.emplace_back();
1027           if (failed(getShapeVec(shapeValue, extents.back())))
1028             return false;
1029         }
1030         return OpTrait::util::staticallyKnownBroadcastable(extents);
1031       }())
1032     return BoolAttr::get(getContext(), true);
1033 
1034   // Because a failing witness result here represents an eventual assertion
1035   // failure, we do not replace it with a constant witness.
1036   return nullptr;
1037 }
1038 
1039 LogicalResult CstrBroadcastableOp::verify() {
1040   // Ensure that CstrBroadcastableOp contains at least two operands
1041   if (getNumOperands() < 2)
1042     return emitOpError("required at least 2 input shapes");
1043   return success();
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // CstrEqOp
1048 //===----------------------------------------------------------------------===//
1049 
1050 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1051                                            MLIRContext *context) {
1052   // If inputs are equal, return passing witness
1053   patterns.add<CstrEqEqOps>(context);
1054 }
1055 
1056 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1057   if (llvm::all_of(operands,
1058                    [&](Attribute a) { return a && a == operands[0]; }))
1059     return BoolAttr::get(getContext(), true);
1060 
1061   // Because a failing witness result here represents an eventual assertion
1062   // failure, we do not try to replace it with a constant witness. Similarly, we
1063   // cannot if there are any non-const inputs.
1064   return nullptr;
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // ConstSizeOp
1069 //===----------------------------------------------------------------------===//
1070 
1071 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1072                         int64_t value) {
1073   build(builder, result, builder.getIndexAttr(value));
1074 }
1075 
1076 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1077 
1078 void ConstSizeOp::getAsmResultNames(
1079     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1080   SmallString<4> buffer;
1081   llvm::raw_svector_ostream os(buffer);
1082   os << "c" << getValue();
1083   setNameFn(getResult(), os.str());
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // ConstWitnessOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1091   return getPassingAttr();
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // CstrRequireOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1099   return operands[0];
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // DivOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1107   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1108   if (!lhs)
1109     return nullptr;
1110   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1111   if (!rhs)
1112     return nullptr;
1113 
1114   // Division in APInt does not follow floor(lhs, rhs) when the result is
1115   // negative. Rather, APInt rounds toward zero.
1116   APInt quotient, remainder;
1117   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1118   if (quotient.isNegative() && !remainder.isNullValue()) {
1119     quotient -= 1;
1120   }
1121 
1122   Type indexTy = IndexType::get(getContext());
1123   return IntegerAttr::get(indexTy, quotient);
1124 }
1125 
1126 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1127     MLIRContext *context, Optional<Location> location, ValueRange operands,
1128     DictionaryAttr attributes, RegionRange regions,
1129     SmallVectorImpl<Type> &inferredReturnTypes) {
1130   if (operands[0].getType().isa<SizeType>() ||
1131       operands[1].getType().isa<SizeType>())
1132     inferredReturnTypes.assign({SizeType::get(context)});
1133   else
1134     inferredReturnTypes.assign({IndexType::get(context)});
1135   return success();
1136 }
1137 
1138 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1139   // SizeType is compatible with IndexType.
1140   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1141 }
1142 
1143 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // ShapeEqOp
1147 //===----------------------------------------------------------------------===//
1148 
1149 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1150   bool allSame = true;
1151   if (!operands.empty() && !operands[0])
1152     return {};
1153   for (Attribute operand : operands.drop_front(1)) {
1154     if (!operand)
1155       return {};
1156     allSame = allSame && operand == operands[0];
1157   }
1158   return BoolAttr::get(getContext(), allSame);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // IndexToSizeOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1166   // Constant values of both types, `shape.size` and `index`, are represented as
1167   // `IntegerAttr`s which makes constant folding simple.
1168   if (Attribute arg = operands[0])
1169     return arg;
1170   return {};
1171 }
1172 
1173 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1174                                                 MLIRContext *context) {
1175   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // FromExtentsOp
1180 //===----------------------------------------------------------------------===//
1181 
1182 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1183   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1184     return nullptr;
1185   SmallVector<int64_t, 6> extents;
1186   for (auto attr : operands)
1187     extents.push_back(attr.cast<IntegerAttr>().getInt());
1188   Builder builder(getContext());
1189   return builder.getIndexTensorAttr(extents);
1190 }
1191 
1192 //===----------------------------------------------------------------------===//
1193 // FunctionLibraryOp
1194 //===----------------------------------------------------------------------===//
1195 
1196 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1197                               StringRef name) {
1198   result.attributes.push_back(builder.getNamedAttr(
1199       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1200 }
1201 
1202 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1203   auto attr = getMapping()
1204                   .get(op->getName().getIdentifier())
1205                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1206   if (!attr)
1207     return nullptr;
1208   return lookupSymbol<FuncOp>(attr);
1209 }
1210 
1211 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1212                                      OperationState &result) {
1213   // Parse the op name.
1214   StringAttr nameAttr;
1215   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1216                              result.attributes))
1217     return failure();
1218 
1219   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1220     return failure();
1221 
1222   auto *bodyRegion = result.addRegion();
1223   if (parser.parseRegion(*bodyRegion))
1224     return failure();
1225 
1226   if (parser.parseKeyword("mapping"))
1227     return failure();
1228 
1229   DictionaryAttr mappingAttr;
1230   if (parser.parseAttribute(mappingAttr,
1231                             parser.getBuilder().getType<NoneType>(), "mapping",
1232                             result.attributes))
1233     return failure();
1234   return success();
1235 }
1236 
1237 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1238   p << ' ';
1239   p.printSymbolName(getName());
1240   p.printOptionalAttrDictWithKeyword(
1241       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1242   p << ' ';
1243   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1244                 /*printBlockTerminators=*/false);
1245   p << " mapping ";
1246   p.printAttributeWithoutType(getMappingAttr());
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // GetExtentOp
1251 //===----------------------------------------------------------------------===//
1252 
1253 Optional<int64_t> GetExtentOp::getConstantDim() {
1254   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1255     return constSizeOp.getValue().getLimitedValue();
1256   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1257     return constantOp.getValue().cast<IntegerAttr>().getInt();
1258   return llvm::None;
1259 }
1260 
1261 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1262   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1263   if (!elements)
1264     return nullptr;
1265   Optional<int64_t> dim = getConstantDim();
1266   if (!dim.hasValue())
1267     return nullptr;
1268   if (dim.getValue() >= elements.getNumElements())
1269     return nullptr;
1270   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1271 }
1272 
1273 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1274                         int64_t dim) {
1275   auto loc = result.location;
1276   auto dimAttr = builder.getIndexAttr(dim);
1277   if (shape.getType().isa<ShapeType>()) {
1278     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1279     build(builder, result, builder.getType<SizeType>(), shape, dim);
1280   } else {
1281     Value dim =
1282         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1283     build(builder, result, builder.getIndexType(), shape, dim);
1284   }
1285 }
1286 
1287 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1288     MLIRContext *context, Optional<Location> location, ValueRange operands,
1289     DictionaryAttr attributes, RegionRange regions,
1290     SmallVectorImpl<Type> &inferredReturnTypes) {
1291   inferredReturnTypes.assign({IndexType::get(context)});
1292   return success();
1293 }
1294 
1295 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1296                                                        TypeRange r) {
1297   // SizeType is compatible with IndexType.
1298   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1299 }
1300 
1301 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // IsBroadcastableOp
1305 //===----------------------------------------------------------------------===//
1306 
1307 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1308                                                     MLIRContext *context) {
1309   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1310 }
1311 
1312 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1313   // Can always broadcast fewer than two shapes.
1314   if (operands.size() < 2) {
1315     return BoolAttr::get(getContext(), true);
1316   }
1317 
1318   return nullptr;
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // MeetOp
1323 //===----------------------------------------------------------------------===//
1324 
1325 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1326     MLIRContext *context, Optional<Location> location, ValueRange operands,
1327     DictionaryAttr attributes, RegionRange regions,
1328     SmallVectorImpl<Type> &inferredReturnTypes) {
1329   inferredReturnTypes.assign({operands[0].getType()});
1330   return success();
1331 }
1332 
1333 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1334   if (l.size() != 1 || r.size() != 1)
1335     return false;
1336   if (l == r)
1337     return true;
1338 
1339   Type lhs = l.front();
1340   Type rhs = r.front();
1341 
1342   if (lhs != rhs)
1343     return false;
1344 
1345   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1346     return true;
1347 
1348   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1349     return true;
1350   return false;
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // RankOp
1355 //===----------------------------------------------------------------------===//
1356 
1357 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1358   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1359   if (!shape)
1360     return {};
1361   int64_t rank = shape.getNumElements();
1362   Builder builder(getContext());
1363   return builder.getIndexAttr(rank);
1364 }
1365 
1366 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1367 /// Constant folding fails in cases where only the rank is constant, not the
1368 /// shape itself.
1369 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1370 ///
1371 /// Example:
1372 ///
1373 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1374 /// %rank = shape.rank %shape
1375 ///
1376 /// becomes
1377 ///
1378 /// %rank = shape.const_size 3
1379 
1380 namespace {
1381 struct RankShapeOfCanonicalizationPattern
1382     : public OpRewritePattern<shape::RankOp> {
1383   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1384 
1385   LogicalResult matchAndRewrite(shape::RankOp op,
1386                                 PatternRewriter &rewriter) const override {
1387     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1388     if (!shapeOfOp)
1389       return failure();
1390     auto rankedTensorType =
1391         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1392     if (!rankedTensorType)
1393       return failure();
1394     int64_t rank = rankedTensorType.getRank();
1395     if (op.getType().isa<IndexType>()) {
1396       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1397                                                           rank);
1398     } else if (op.getType().isa<shape::SizeType>()) {
1399       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1400     } else {
1401       return failure();
1402     }
1403     return success();
1404   }
1405 };
1406 } // namespace
1407 
1408 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1409                                                 MLIRContext *context) {
1410   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1411 }
1412 
1413 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1414     MLIRContext *context, Optional<Location> location, ValueRange operands,
1415     DictionaryAttr attributes, RegionRange regions,
1416     SmallVectorImpl<Type> &inferredReturnTypes) {
1417   if (operands[0].getType().isa<ShapeType>())
1418     inferredReturnTypes.assign({SizeType::get(context)});
1419   else
1420     inferredReturnTypes.assign({IndexType::get(context)});
1421   return success();
1422 }
1423 
1424 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1425   // SizeType is compatible with IndexType.
1426   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1427 }
1428 
1429 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1430 
1431 //===----------------------------------------------------------------------===//
1432 // NumElementsOp
1433 //===----------------------------------------------------------------------===//
1434 
1435 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1436 
1437   // Fold only when argument constant.
1438   Attribute shape = operands[0];
1439   if (!shape)
1440     return {};
1441 
1442   APInt product(64, 1);
1443   for (auto value : shape.cast<DenseIntElementsAttr>())
1444     product *= value;
1445   Builder builder(getContext());
1446   return builder.getIndexAttr(product.getLimitedValue());
1447 }
1448 
1449 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1450     MLIRContext *context, Optional<Location> location, ValueRange operands,
1451     DictionaryAttr attributes, RegionRange regions,
1452     SmallVectorImpl<Type> &inferredReturnTypes) {
1453   if (operands[0].getType().isa<ShapeType>())
1454     inferredReturnTypes.assign({SizeType::get(context)});
1455   else
1456     inferredReturnTypes.assign({IndexType::get(context)});
1457   return success();
1458 }
1459 
1460 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1461                                                          TypeRange r) {
1462   // SizeType is compatible with IndexType.
1463   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1464 }
1465 
1466 LogicalResult shape::NumElementsOp::verify() {
1467   return verifySizeOrIndexOp(*this);
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // MaxOp
1472 //===----------------------------------------------------------------------===//
1473 
1474 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1475   // If operands are equal, just propagate one.
1476   if (getLhs() == getRhs())
1477     return getLhs();
1478   return nullptr;
1479 }
1480 
1481 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1482     MLIRContext *context, Optional<Location> location, ValueRange operands,
1483     DictionaryAttr attributes, RegionRange regions,
1484     SmallVectorImpl<Type> &inferredReturnTypes) {
1485   if (operands[0].getType() == operands[1].getType())
1486     inferredReturnTypes.assign({operands[0].getType()});
1487   else
1488     inferredReturnTypes.assign({SizeType::get(context)});
1489   return success();
1490 }
1491 
1492 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1493   if (l.size() != 1 || r.size() != 1)
1494     return false;
1495   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1496     return true;
1497   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1498     return true;
1499   return false;
1500 }
1501 
1502 //===----------------------------------------------------------------------===//
1503 // MinOp
1504 //===----------------------------------------------------------------------===//
1505 
1506 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1507   // If operands are equal, just propagate one.
1508   if (getLhs() == getRhs())
1509     return getLhs();
1510   return nullptr;
1511 }
1512 
1513 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1514     MLIRContext *context, Optional<Location> location, ValueRange operands,
1515     DictionaryAttr attributes, RegionRange regions,
1516     SmallVectorImpl<Type> &inferredReturnTypes) {
1517   if (operands[0].getType() == operands[1].getType())
1518     inferredReturnTypes.assign({operands[0].getType()});
1519   else
1520     inferredReturnTypes.assign({SizeType::get(context)});
1521   return success();
1522 }
1523 
1524 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1525   if (l.size() != 1 || r.size() != 1)
1526     return false;
1527   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1528     return true;
1529   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1530     return true;
1531   return false;
1532 }
1533 
1534 //===----------------------------------------------------------------------===//
1535 // MulOp
1536 //===----------------------------------------------------------------------===//
1537 
1538 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1539   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1540   if (!lhs)
1541     return nullptr;
1542   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1543   if (!rhs)
1544     return nullptr;
1545   APInt folded = lhs.getValue() * rhs.getValue();
1546   Type indexTy = IndexType::get(getContext());
1547   return IntegerAttr::get(indexTy, folded);
1548 }
1549 
1550 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1551     MLIRContext *context, Optional<Location> location, ValueRange operands,
1552     DictionaryAttr attributes, RegionRange regions,
1553     SmallVectorImpl<Type> &inferredReturnTypes) {
1554   if (operands[0].getType().isa<SizeType>() ||
1555       operands[1].getType().isa<SizeType>())
1556     inferredReturnTypes.assign({SizeType::get(context)});
1557   else
1558     inferredReturnTypes.assign({IndexType::get(context)});
1559   return success();
1560 }
1561 
1562 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1563   // SizeType is compatible with IndexType.
1564   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1565 }
1566 
1567 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1568 
1569 //===----------------------------------------------------------------------===//
1570 // ShapeOfOp
1571 //===----------------------------------------------------------------------===//
1572 
1573 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1574   auto type = getOperand().getType().dyn_cast<ShapedType>();
1575   if (!type || !type.hasStaticShape())
1576     return nullptr;
1577   Builder builder(getContext());
1578   return builder.getIndexTensorAttr(type.getShape());
1579 }
1580 
1581 namespace {
1582 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1583   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1584 
1585   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1586                                 PatternRewriter &rewriter) const override {
1587     if (!op.getArg().getType().isa<ShapedType>())
1588       return failure();
1589     if (op.getType().isa<ShapedType>())
1590       return failure();
1591 
1592     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1593                                                   op.getArg());
1594     return success();
1595   }
1596 };
1597 
1598 // Canonicalize
1599 // ```
1600 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1601 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1602 // ```
1603 // to
1604 // ```
1605 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1606 // ```
1607 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1608   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1609 
1610   LogicalResult matchAndRewrite(tensor::CastOp op,
1611                                 PatternRewriter &rewriter) const override {
1612     auto ty = op.getType().dyn_cast<RankedTensorType>();
1613     if (!ty || ty.getRank() != 1)
1614       return failure();
1615 
1616     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1617     if (!shapeOfOp)
1618       return failure();
1619 
1620     // Argument type must be ranked and must not conflict.
1621     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1622     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1623       return failure();
1624 
1625     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1626     return success();
1627   }
1628 };
1629 } // namespace
1630 
1631 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1632                                             MLIRContext *context) {
1633   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1634                ExtractFromShapeOfExtentTensor>(context);
1635 }
1636 
1637 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1638     MLIRContext *context, Optional<Location> location, ValueRange operands,
1639     DictionaryAttr attributes, RegionRange regions,
1640     SmallVectorImpl<Type> &inferredReturnTypes) {
1641   if (operands[0].getType().isa<ValueShapeType>())
1642     inferredReturnTypes.assign({ShapeType::get(context)});
1643   else {
1644     auto shapedTy = operands[0].getType().cast<ShapedType>();
1645     int64_t rank =
1646         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1647     Type indexTy = IndexType::get(context);
1648     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1649     inferredReturnTypes.assign({extentTensorTy});
1650   }
1651   return success();
1652 }
1653 
1654 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1655   if (l.size() != 1 || r.size() != 1)
1656     return false;
1657   if (l == r)
1658     return true;
1659 
1660   Type lhs = l.front();
1661   Type rhs = r.front();
1662 
1663   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1664     return false;
1665 
1666   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1667     // Shape type is compatible with all other valid return types.
1668     return true;
1669 
1670   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1671     return true;
1672   return false;
1673 }
1674 
1675 LogicalResult shape::ShapeOfOp::verify() {
1676   return verifyShapeOrExtentTensorOp(*this);
1677 }
1678 
1679 //===----------------------------------------------------------------------===//
1680 // SizeToIndexOp
1681 //===----------------------------------------------------------------------===//
1682 
1683 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1684   // Constant values of both types, `shape.size` and `index`, are represented as
1685   // `IntegerAttr`s which makes constant folding simple.
1686   if (Attribute arg = operands[0])
1687     return arg;
1688   return OpFoldResult();
1689 }
1690 
1691 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1692                                                 MLIRContext *context) {
1693   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1694 }
1695 
1696 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1697   if (inputs.size() != 1 || outputs.size() != 1)
1698     return false;
1699   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1700 }
1701 
1702 //===----------------------------------------------------------------------===//
1703 // YieldOp
1704 //===----------------------------------------------------------------------===//
1705 
1706 LogicalResult shape::YieldOp::verify() {
1707   auto *parentOp = (*this)->getParentOp();
1708   auto results = parentOp->getResults();
1709   auto operands = getOperands();
1710 
1711   if (parentOp->getNumResults() != getNumOperands())
1712     return emitOpError() << "number of operands does not match number of "
1713                             "results of its parent";
1714   for (auto e : llvm::zip(results, operands))
1715     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1716       return emitOpError() << "types mismatch between yield op and its parent";
1717 
1718   return success();
1719 }
1720 
1721 //===----------------------------------------------------------------------===//
1722 // SplitAtOp
1723 //===----------------------------------------------------------------------===//
1724 
1725 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1726                               SmallVectorImpl<OpFoldResult> &results) {
1727   if (!operands[0] || !operands[1])
1728     return failure();
1729   auto shapeVec = llvm::to_vector<6>(
1730       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1731   auto shape = llvm::makeArrayRef(shapeVec);
1732   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1733   // Verify that the split point is in the correct range.
1734   // TODO: Constant fold to an "error".
1735   int64_t rank = shape.size();
1736   if (!(-rank <= splitPoint && splitPoint <= rank))
1737     return failure();
1738   if (splitPoint < 0)
1739     splitPoint += shape.size();
1740   Builder builder(operands[0].getContext());
1741   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1742   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1743   return success();
1744 }
1745 
1746 //===----------------------------------------------------------------------===//
1747 // ToExtentTensorOp
1748 //===----------------------------------------------------------------------===//
1749 
1750 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1751   if (!operands[0])
1752     return OpFoldResult();
1753   Builder builder(getContext());
1754   auto shape = llvm::to_vector<6>(
1755       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1756   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1757                                     builder.getIndexType());
1758   return DenseIntElementsAttr::get(type, shape);
1759 }
1760 
1761 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1762   if (inputs.size() != 1 || outputs.size() != 1)
1763     return false;
1764   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1765     if (!inputTensor.getElementType().isa<IndexType>() ||
1766         inputTensor.getRank() != 1)
1767       return false;
1768   } else if (!inputs[0].isa<ShapeType>()) {
1769     return false;
1770   }
1771 
1772   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1773   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1774 }
1775 
1776 //===----------------------------------------------------------------------===//
1777 // ReduceOp
1778 //===----------------------------------------------------------------------===//
1779 
1780 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1781                      ValueRange initVals) {
1782   result.addOperands(shape);
1783   result.addOperands(initVals);
1784 
1785   Region *bodyRegion = result.addRegion();
1786   bodyRegion->push_back(new Block);
1787   Block &bodyBlock = bodyRegion->front();
1788   bodyBlock.addArgument(builder.getIndexType(), result.location);
1789 
1790   Type elementType;
1791   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1792     elementType = tensorType.getElementType();
1793   else
1794     elementType = SizeType::get(builder.getContext());
1795   bodyBlock.addArgument(elementType, shape.getLoc());
1796 
1797   for (Value initVal : initVals) {
1798     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1799     result.addTypes(initVal.getType());
1800   }
1801 }
1802 
1803 LogicalResult ReduceOp::verify() {
1804   // Verify block arg types.
1805   Block &block = getRegion().front();
1806 
1807   // The block takes index, extent, and aggregated values as arguments.
1808   auto blockArgsCount = getInitVals().size() + 2;
1809   if (block.getNumArguments() != blockArgsCount)
1810     return emitOpError() << "ReduceOp body is expected to have "
1811                          << blockArgsCount << " arguments";
1812 
1813   // The first block argument is the index and must always be of type `index`.
1814   if (!block.getArgument(0).getType().isa<IndexType>())
1815     return emitOpError(
1816         "argument 0 of ReduceOp body is expected to be of IndexType");
1817 
1818   // The second block argument is the extent and must be of type `size` or
1819   // `index`, depending on whether the reduce operation is applied to a shape or
1820   // to an extent tensor.
1821   Type extentTy = block.getArgument(1).getType();
1822   if (getShape().getType().isa<ShapeType>()) {
1823     if (!extentTy.isa<SizeType>())
1824       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1825                          "SizeType if the ReduceOp operates on a ShapeType");
1826   } else {
1827     if (!extentTy.isa<IndexType>())
1828       return emitOpError(
1829           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1830           "ReduceOp operates on an extent tensor");
1831   }
1832 
1833   for (const auto &type : llvm::enumerate(getInitVals()))
1834     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1835       return emitOpError() << "type mismatch between argument "
1836                            << type.index() + 2
1837                            << " of ReduceOp body and initial value "
1838                            << type.index();
1839   return success();
1840 }
1841 
1842 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1843   // Parse operands.
1844   SmallVector<OpAsmParser::OperandType, 3> operands;
1845   Type shapeOrExtentTensorType;
1846   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1847                               OpAsmParser::Delimiter::Paren) ||
1848       parser.parseColonType(shapeOrExtentTensorType) ||
1849       parser.parseOptionalArrowTypeList(result.types))
1850     return failure();
1851 
1852   // Resolve operands.
1853   auto initVals = llvm::makeArrayRef(operands).drop_front();
1854   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1855                             result.operands) ||
1856       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1857                              result.operands))
1858     return failure();
1859 
1860   // Parse the body.
1861   Region *body = result.addRegion();
1862   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1863     return failure();
1864 
1865   // Parse attributes.
1866   if (parser.parseOptionalAttrDict(result.attributes))
1867     return failure();
1868 
1869   return success();
1870 }
1871 
1872 void ReduceOp::print(OpAsmPrinter &p) {
1873   p << '(' << getShape() << ", " << getInitVals()
1874     << ") : " << getShape().getType();
1875   p.printOptionalArrowTypeList(getResultTypes());
1876   p << ' ';
1877   p.printRegion(getRegion());
1878   p.printOptionalAttrDict((*this)->getAttrs());
1879 }
1880 
1881 #define GET_OP_CLASSES
1882 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1883