1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/SetOperations.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::shape;
31 
32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
33 
34 namespace {
35 #include "ShapeCanonicalization.inc"
36 } // namespace
37 
38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
39   return RankedTensorType::get({rank}, IndexType::get(ctx));
40 }
41 
42 bool shape::isExtentTensorType(Type type) {
43   auto ranked = type.dyn_cast<RankedTensorType>();
44   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
45 }
46 
47 LogicalResult shape::getShapeVec(Value input,
48                                  SmallVectorImpl<int64_t> &shapeValues) {
49   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
50     auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
51     if (!type.hasRank())
52       return failure();
53     shapeValues = llvm::to_vector<6>(type.getShape());
54     return success();
55   }
56   if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
57     shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
58     return success();
59   }
60   if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
61     shapeValues = llvm::to_vector<6>(
62         inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
63     return success();
64   }
65   return failure();
66 }
67 
68 static bool isErrorPropagationPossible(TypeRange operandTypes) {
69   return llvm::any_of(operandTypes, [](Type ty) {
70     return ty.isa<SizeType, ShapeType, ValueShapeType>();
71   });
72 }
73 
74 static LogicalResult verifySizeOrIndexOp(Operation *op) {
75   assert(op != nullptr && op->getNumResults() == 1);
76   Type resultTy = op->getResultTypes().front();
77   if (isErrorPropagationPossible(op->getOperandTypes())) {
78     if (!resultTy.isa<SizeType>())
79       return op->emitOpError()
80              << "if at least one of the operands can hold error values then "
81                 "the result must be of type `size` to propagate them";
82   }
83   return success();
84 }
85 
86 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
87   assert(op != nullptr && op->getNumResults() == 1);
88   Type resultTy = op->getResultTypes().front();
89   if (isErrorPropagationPossible(op->getOperandTypes())) {
90     if (!resultTy.isa<ShapeType>())
91       return op->emitOpError()
92              << "if at least one of the operands can hold error values then "
93                 "the result must be of type `shape` to propagate them";
94   }
95   return success();
96 }
97 
98 template <typename... Ty>
99 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
100   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
101 }
102 
103 template <typename... Ty, typename... ranges>
104 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
105   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // InlinerInterface
110 //===----------------------------------------------------------------------===//
111 
112 namespace {
113 /// This class defines the interface for inlining shape dialect ops.
114 struct ShapeInlinerInterface : public DialectInlinerInterface {
115   using DialectInlinerInterface::DialectInlinerInterface;
116 
117   // Returns true if the given region 'src' can be inlined into the region
118   // 'dest' that is attached to an operation registered to the current dialect.
119   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
120                        BlockAndValueMapping &) const final {
121     return true;
122   }
123 
124   // Returns true if the given operation 'op', that is registered to this
125   // dialect, can be inlined into the region 'dest' that is attached to an
126   // operation registered to the current dialect.
127   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
128                        BlockAndValueMapping &) const final {
129     return true;
130   }
131 };
132 } // namespace
133 
134 void ShapeDialect::initialize() {
135   addOperations<
136 #define GET_OP_LIST
137 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
138       >();
139   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
140   addInterfaces<ShapeInlinerInterface>();
141   // Allow unknown operations during prototyping and testing. As the dialect is
142   // still evolving it makes it simple to start with an unregistered ops and
143   // try different variants before actually defining the op.
144   allowUnknownOperations();
145 }
146 
147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
148                                              Attribute value, Type type,
149                                              Location loc) {
150   if (type.isa<ShapeType>() || isExtentTensorType(type))
151     return builder.create<ConstShapeOp>(loc, type,
152                                         value.cast<DenseIntElementsAttr>());
153   if (type.isa<SizeType>())
154     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155   if (type.isa<WitnessType>())
156     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157   if (arith::ConstantOp::isBuildableWith(value, type))
158     return builder.create<arith::ConstantOp>(loc, type, value);
159   return nullptr;
160 }
161 
162 /// Parse a type registered to this dialect.
163 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
164   StringRef keyword;
165   if (parser.parseKeyword(&keyword))
166     return Type();
167 
168   if (keyword == "shape")
169     return ShapeType::get(getContext());
170   if (keyword == "size")
171     return SizeType::get(getContext());
172   if (keyword == "value_shape")
173     return ValueShapeType::get(getContext());
174   if (keyword == "witness")
175     return WitnessType::get(getContext());
176 
177   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
178   return Type();
179 }
180 
181 /// Print a type registered to this dialect.
182 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
183   TypeSwitch<Type>(type)
184       .Case<ShapeType>([&](Type) { os << "shape"; })
185       .Case<SizeType>([&](Type) { os << "size"; })
186       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
187       .Case<WitnessType>([&](Type) { os << "witness"; })
188       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
189 }
190 
191 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
192                                                      NamedAttribute attribute) {
193   // Verify shape.lib attribute.
194   if (attribute.getName() == "shape.lib") {
195     if (!op->hasTrait<OpTrait::SymbolTable>())
196       return op->emitError(
197           "shape.lib attribute may only be on op implementing SymbolTable");
198 
199     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
200       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
201       if (!symbol)
202         return op->emitError("shape function library ")
203                << symbolRef << " not found";
204       return isa<shape::FunctionLibraryOp>(symbol)
205                  ? success()
206                  : op->emitError()
207                        << symbolRef << " required to be shape function library";
208     }
209 
210     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
211       // Verify all entries are function libraries and mappings in libraries
212       // refer to unique ops.
213       DenseSet<StringAttr> key;
214       for (auto it : arr) {
215         if (!it.isa<SymbolRefAttr>())
216           return op->emitError(
217               "only SymbolRefAttr allowed in shape.lib attribute array");
218 
219         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
220             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
221         if (!shapeFnLib)
222           return op->emitError()
223                  << it << " does not refer to FunctionLibraryOp";
224         for (auto mapping : shapeFnLib.getMapping()) {
225           if (!key.insert(mapping.getName()).second) {
226             return op->emitError("only one op to shape mapping allowed, found "
227                                  "multiple for `")
228                    << mapping.getName() << "`";
229           }
230         }
231       }
232       return success();
233     }
234 
235     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
236                          "allowed as shape.lib attribute");
237   }
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // AnyOp
243 //===----------------------------------------------------------------------===//
244 
245 // TODO: Canonicalization should be implemented for shapes that can be
246 // determined through mixtures of the known dimensions of the inputs.
247 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
248   // Only the last operand is checked because AnyOp is commutative.
249   if (operands.back())
250     return operands.back();
251 
252   return nullptr;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // AssumingOp
257 //===----------------------------------------------------------------------===//
258 
259 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
260   result.regions.reserve(1);
261   Region *doRegion = result.addRegion();
262 
263   auto &builder = parser.getBuilder();
264   OpAsmParser::OperandType cond;
265   if (parser.parseOperand(cond) ||
266       parser.resolveOperand(cond, builder.getType<WitnessType>(),
267                             result.operands))
268     return failure();
269 
270   // Parse optional results type list.
271   if (parser.parseOptionalArrowTypeList(result.types))
272     return failure();
273 
274   // Parse the region and add a terminator if elided.
275   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
276     return failure();
277   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
278 
279   // Parse the optional attribute list.
280   if (parser.parseOptionalAttrDict(result.attributes))
281     return failure();
282   return success();
283 }
284 
285 void AssumingOp::print(OpAsmPrinter &p) {
286   bool yieldsResults = !getResults().empty();
287 
288   p << " " << getWitness();
289   if (yieldsResults)
290     p << " -> (" << getResultTypes() << ")";
291   p << ' ';
292   p.printRegion(getDoRegion(),
293                 /*printEntryBlockArgs=*/false,
294                 /*printBlockTerminators=*/yieldsResults);
295   p.printOptionalAttrDict((*this)->getAttrs());
296 }
297 
298 namespace {
299 // Removes AssumingOp with a passing witness and inlines the region.
300 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
301   using OpRewritePattern<AssumingOp>::OpRewritePattern;
302 
303   LogicalResult matchAndRewrite(AssumingOp op,
304                                 PatternRewriter &rewriter) const override {
305     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
306     if (!witness || !witness.getPassingAttr())
307       return failure();
308 
309     AssumingOp::inlineRegionIntoParent(op, rewriter);
310     return success();
311   }
312 };
313 
314 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
315   using OpRewritePattern<AssumingOp>::OpRewritePattern;
316 
317   LogicalResult matchAndRewrite(AssumingOp op,
318                                 PatternRewriter &rewriter) const override {
319     Block *body = op.getBody();
320     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
321 
322     // Find used values.
323     SmallVector<Value, 4> newYieldOperands;
324     Value opResult, yieldOperand;
325     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
326       std::tie(opResult, yieldOperand) = it;
327       if (!opResult.getUses().empty()) {
328         newYieldOperands.push_back(yieldOperand);
329       }
330     }
331 
332     // Rewrite only if redundant results exist.
333     if (newYieldOperands.size() == yieldOp->getNumOperands())
334       return failure();
335 
336     // Replace yield op in the old assuming op's body and move the entire region
337     // to the new assuming op.
338     rewriter.setInsertionPointToEnd(body);
339     auto newYieldOp =
340         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
341     rewriter.setInsertionPoint(op);
342     auto newOp = rewriter.create<AssumingOp>(
343         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
344     newOp.getDoRegion().takeBody(op.getDoRegion());
345 
346     // Use the new results to replace the previously used ones.
347     SmallVector<Value, 4> replacementValues;
348     auto src = newOp.getResults().begin();
349     for (auto it : op.getResults()) {
350       if (it.getUses().empty())
351         replacementValues.push_back(nullptr);
352       else
353         replacementValues.push_back(*src++);
354     }
355     rewriter.replaceOp(op, replacementValues);
356     return success();
357   }
358 };
359 } // namespace
360 
361 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
362                                              MLIRContext *context) {
363   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
364 }
365 
366 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
367 void AssumingOp::getSuccessorRegions(
368     Optional<unsigned> index, ArrayRef<Attribute> operands,
369     SmallVectorImpl<RegionSuccessor> &regions) {
370   // AssumingOp has unconditional control flow into the region and back to the
371   // parent, so return the correct RegionSuccessor purely based on the index
372   // being None or 0.
373   if (index.hasValue()) {
374     regions.push_back(RegionSuccessor(getResults()));
375     return;
376   }
377 
378   regions.push_back(RegionSuccessor(&getDoRegion()));
379 }
380 
381 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
382                                         PatternRewriter &rewriter) {
383   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
384   auto *assumingBlock = op.getBody();
385   auto initPosition = rewriter.getInsertionPoint();
386   auto *blockAfterAssuming =
387       rewriter.splitBlock(blockBeforeAssuming, initPosition);
388 
389   // Remove the AssumingOp and AssumingYieldOp.
390   auto &yieldOp = assumingBlock->back();
391   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
392   rewriter.replaceOp(op, yieldOp.getOperands());
393   rewriter.eraseOp(&yieldOp);
394 
395   // Merge blocks together as there was no branching behavior from the
396   // AssumingOp.
397   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
398   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
399 }
400 
401 void AssumingOp::build(
402     OpBuilder &builder, OperationState &result, Value witness,
403     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
404 
405   result.addOperands(witness);
406   Region *bodyRegion = result.addRegion();
407   bodyRegion->push_back(new Block);
408   Block &bodyBlock = bodyRegion->front();
409 
410   // Build body.
411   OpBuilder::InsertionGuard guard(builder);
412   builder.setInsertionPointToStart(&bodyBlock);
413   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
414   builder.create<AssumingYieldOp>(result.location, yieldValues);
415 
416   SmallVector<Type, 2> assumingTypes;
417   for (Value v : yieldValues)
418     assumingTypes.push_back(v.getType());
419   result.addTypes(assumingTypes);
420 }
421 
422 LogicalResult AssumingOp::verify() {
423   return RegionBranchOpInterface::verifyTypes(*this);
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // AddOp
428 //===----------------------------------------------------------------------===//
429 
430 LogicalResult mlir::shape::AddOp::inferReturnTypes(
431     MLIRContext *context, Optional<Location> location, ValueRange operands,
432     DictionaryAttr attributes, RegionRange regions,
433     SmallVectorImpl<Type> &inferredReturnTypes) {
434   if (operands[0].getType().isa<SizeType>() ||
435       operands[1].getType().isa<SizeType>())
436     inferredReturnTypes.assign({SizeType::get(context)});
437   else
438     inferredReturnTypes.assign({IndexType::get(context)});
439   return success();
440 }
441 
442 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
443   // SizeType is compatible with IndexType.
444   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
445 }
446 
447 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
448   // add(x, 0) -> x
449   if (matchPattern(getRhs(), m_Zero()))
450     return getLhs();
451 
452   return constFoldBinaryOp<IntegerAttr>(
453       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
454 }
455 
456 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
457 
458 //===----------------------------------------------------------------------===//
459 // AssumingAllOp
460 //===----------------------------------------------------------------------===//
461 
462 namespace {
463 
464 // Merge multiple `shape.assuming_all` operations together.
465 //
466 //   %0 = shape.assuming_all %w0, %w1
467 //   %1 = shape.assuming_all %w2, %0
468 //
469 // to:
470 //
471 //   %0 = shape.assuming_all %w0, %w2, %w2
472 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
473   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
474 
475   LogicalResult matchAndRewrite(AssumingAllOp op,
476                                 PatternRewriter &rewriter) const override {
477     SmallVector<Value> operands;
478 
479     for (Value operand : op.getInputs()) {
480       if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
481         operands.append(assume_all.operand_begin(), assume_all->operand_end());
482       else
483         operands.push_back(operand);
484     }
485 
486     // We didn't find any other `assuming_all` ops to merge with.
487     if (operands.size() == op.getNumOperands())
488       return failure();
489 
490     // Replace with a new `assuming_all` operation with merged constraints.
491     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
492     return success();
493   }
494 };
495 
496 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
497 // are subsumed by others.
498 //
499 //   %0 = shape.cstr_broadcastable %shape0, %shape1
500 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
501 //
502 //   %2 = shape.cstr_broadcastable %shape3, %shape4
503 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
504 //
505 //   %4 = shape.assuming_all %0, %1, %2, %3
506 //
507 // to:
508 //
509 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
510 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
511 //   %2 = shape.assuming_all %0, %1
512 //
513 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
514 // shapes [0, 1] are broadcastable too, and can be removed from the list of
515 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
516 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
517 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
518   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
519 
520   LogicalResult matchAndRewrite(AssumingAllOp op,
521                                 PatternRewriter &rewriter) const override {
522     // Collect all `CstrBroadcastableOp` operands first.
523     SetVector<CstrBroadcastableOp> operands;
524     for (Value operand : op.getInputs()) {
525       // TODO: Apply this optimization if some of the witnesses are not
526       // produced by the `cstr_broadcastable`.
527       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
528       if (!broadcastable)
529         return failure();
530 
531       operands.insert(broadcastable);
532     }
533 
534     // Skip trivial `assuming_all` operations.
535     if (operands.size() <= 1)
536       return failure();
537 
538     // Collect shapes checked by `cstr_broadcastable` operands.
539     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
540     for (auto cstr : operands) {
541       DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
542       shapes.emplace_back(cstr, std::move(shapes_set));
543     }
544 
545     // Sort by the number of shape operands (larger to smaller).
546     llvm::sort(shapes, [](auto a, auto b) {
547       return a.first.getNumOperands() > b.first.getNumOperands();
548     });
549 
550     // We start from the `cst_broadcastable` operations with largest number of
551     // shape operands, and remove redundant `cst_broadcastable` operations. We
552     // do this until we find a set of `cst_broadcastable` operations with
553     // non-overlapping constraints.
554     SmallVector<CstrBroadcastableOp> marked_for_erase;
555 
556     for (unsigned i = 0; i < shapes.size(); ++i) {
557       auto isSubset = [&](auto pair) {
558         return llvm::set_is_subset(pair.second, shapes[i].second);
559       };
560 
561       // Keep redundant `cstr_broadcastable` operations to be erased.
562       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
563       for (auto *it0 = it; it0 < shapes.end(); ++it0)
564         marked_for_erase.push_back(it0->first);
565       shapes.erase(it, shapes.end());
566     }
567 
568     // We didn't find any operands that could be removed.
569     if (marked_for_erase.empty())
570       return failure();
571 
572     // Collect non-overlapping `cst_broadcastable` constraints.
573     SmallVector<Value> unique_constraints;
574     for (auto &shape : shapes)
575       unique_constraints.push_back(shape.first.getResult());
576 
577     // Replace with a new `assuming_all` operation ...
578     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
579 
580     // ... and maybe erase `cstr_broadcastable` ops without uses.
581     for (auto &op : marked_for_erase)
582       if (op->use_empty())
583         rewriter.eraseOp(op);
584 
585     return success();
586   }
587 };
588 
589 struct AssumingAllToCstrEqCanonicalization
590     : public OpRewritePattern<AssumingAllOp> {
591   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
592 
593   LogicalResult matchAndRewrite(AssumingAllOp op,
594                                 PatternRewriter &rewriter) const override {
595     SmallVector<Value, 8> shapes;
596     for (Value w : op.getInputs()) {
597       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
598       if (!cstrEqOp)
599         return failure();
600       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
601         return llvm::is_contained(shapes, s);
602       });
603       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
604         return failure();
605       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
606     }
607     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
608     return success();
609   }
610 };
611 
612 template <typename OpTy>
613 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
614   using OpRewritePattern<OpTy>::OpRewritePattern;
615 
616   LogicalResult matchAndRewrite(OpTy op,
617                                 PatternRewriter &rewriter) const override {
618     // Find unique operands.
619     SetVector<Value> unique(op.operand_begin(), op.operand_end());
620 
621     // Reduce op to equivalent with unique operands.
622     if (unique.size() < op.getNumOperands()) {
623       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
624                                         unique.takeVector(), op->getAttrs());
625       return success();
626     }
627 
628     return failure();
629   }
630 };
631 } // namespace
632 
633 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
634                                                 MLIRContext *context) {
635   patterns
636       .add<MergeAssumingAllOps, AssumingAllOneOp,
637            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
638            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
639 }
640 
641 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
642   // Iterate in reverse to first handle all constant operands. They are
643   // guaranteed to be the tail of the inputs because this is commutative.
644   for (int idx = operands.size() - 1; idx >= 0; idx--) {
645     Attribute a = operands[idx];
646     // Cannot fold if any inputs are not constant;
647     if (!a)
648       return nullptr;
649 
650     // We do not need to keep statically known values after handling them in
651     // this method.
652     getOperation()->eraseOperand(idx);
653 
654     // Always false if any input is statically known false
655     if (!a.cast<BoolAttr>().getValue())
656       return a;
657   }
658   // If this is reached, all inputs were statically known passing.
659   return BoolAttr::get(getContext(), true);
660 }
661 
662 LogicalResult AssumingAllOp::verify() {
663   // Ensure that AssumingAllOp contains at least one operand
664   if (getNumOperands() == 0)
665     return emitOpError("no operands specified");
666 
667   return success();
668 }
669 
670 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
671                           ValueRange inputs) {
672   build(b, state, b.getType<WitnessType>(), inputs);
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // BroadcastOp
677 //===----------------------------------------------------------------------===//
678 
679 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
680   if (getShapes().size() == 1) {
681     // Otherwise, we need a cast which would be a canonicalization, not folding.
682     if (getShapes().front().getType() != getType())
683       return nullptr;
684     return getShapes().front();
685   }
686 
687   // TODO: Support folding with more than 2 input shapes
688   if (getShapes().size() > 2)
689     return nullptr;
690 
691   if (!operands[0] || !operands[1])
692     return nullptr;
693   auto lhsShape = llvm::to_vector<6>(
694       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
695   auto rhsShape = llvm::to_vector<6>(
696       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
697   SmallVector<int64_t, 6> resultShape;
698 
699   // If the shapes are not compatible, we can't fold it.
700   // TODO: Fold to an "error".
701   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
702     return nullptr;
703 
704   Builder builder(getContext());
705   return builder.getIndexTensorAttr(resultShape);
706 }
707 
708 LogicalResult BroadcastOp::verify() {
709   return verifyShapeOrExtentTensorOp(*this);
710 }
711 
712 namespace {
713 template <typename OpTy>
714 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
715   using OpRewritePattern<OpTy>::OpRewritePattern;
716 
717   LogicalResult matchAndRewrite(OpTy op,
718                                 PatternRewriter &rewriter) const override {
719     auto isPotentiallyNonEmptyShape = [](Value shape) {
720       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
721         if (extentTensorTy.getDimSize(0) == 0)
722           return false;
723       }
724       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
725         if (constShape.getShape().empty())
726           return false;
727       }
728       return true;
729     };
730     auto newOperands = llvm::to_vector<8>(
731         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
732 
733     // Reduce op to equivalent without empty shape operands.
734     if (newOperands.size() < op.getNumOperands()) {
735       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
736                                         op->getAttrs());
737       return success();
738     }
739 
740     return failure();
741   }
742 };
743 
744 struct BroadcastForwardSingleOperandPattern
745     : public OpRewritePattern<BroadcastOp> {
746   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
747 
748   LogicalResult matchAndRewrite(BroadcastOp op,
749                                 PatternRewriter &rewriter) const override {
750     if (op.getNumOperands() != 1)
751       return failure();
752     Value replacement = op.getShapes().front();
753 
754     // Insert cast if needed.
755     if (replacement.getType() != op.getType()) {
756       auto loc = op.getLoc();
757       if (op.getType().isa<ShapeType>()) {
758         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
759       } else {
760         assert(!op.getType().isa<ShapeType>() &&
761                !replacement.getType().isa<ShapeType>() &&
762                "expect extent tensor cast");
763         replacement =
764             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
765       }
766     }
767 
768     rewriter.replaceOp(op, replacement);
769     return success();
770   }
771 };
772 
773 struct BroadcastFoldConstantOperandsPattern
774     : public OpRewritePattern<BroadcastOp> {
775   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
776 
777   LogicalResult matchAndRewrite(BroadcastOp op,
778                                 PatternRewriter &rewriter) const override {
779     SmallVector<int64_t, 8> foldedConstantShape;
780     SmallVector<Value, 8> newShapeOperands;
781     for (Value shape : op.getShapes()) {
782       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
783         SmallVector<int64_t, 8> newFoldedConstantShape;
784         if (OpTrait::util::getBroadcastedShape(
785                 foldedConstantShape,
786                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
787                 newFoldedConstantShape)) {
788           foldedConstantShape = newFoldedConstantShape;
789           continue;
790         }
791       }
792       newShapeOperands.push_back(shape);
793     }
794 
795     // Need at least two constant operands to fold anything.
796     if (op.getNumOperands() - newShapeOperands.size() < 2)
797       return failure();
798 
799     auto foldedConstantOperandsTy = RankedTensorType::get(
800         {static_cast<int64_t>(foldedConstantShape.size())},
801         rewriter.getIndexType());
802     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
803         op.getLoc(), foldedConstantOperandsTy,
804         rewriter.getIndexTensorAttr(foldedConstantShape)));
805     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
806                                              newShapeOperands);
807     return success();
808   }
809 };
810 
811 template <typename OpTy>
812 struct CanonicalizeCastExtentTensorOperandsPattern
813     : public OpRewritePattern<OpTy> {
814   using OpRewritePattern<OpTy>::OpRewritePattern;
815 
816   LogicalResult matchAndRewrite(OpTy op,
817                                 PatternRewriter &rewriter) const override {
818     // Canonicalize operands.
819     bool anyChange = false;
820     auto canonicalizeOperand = [&](Value operand) {
821       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
822         // Only eliminate the cast if it holds no shape information.
823         bool isInformationLoosingCast =
824             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
825         if (isInformationLoosingCast) {
826           anyChange = true;
827           return castOp.source();
828         }
829       }
830       return operand;
831     };
832     auto newOperands = llvm::to_vector<8>(
833         llvm::map_range(op.getOperands(), canonicalizeOperand));
834 
835     // Rewrite op if any change required.
836     if (!anyChange)
837       return failure();
838     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
839     return success();
840   }
841 };
842 
843 struct BroadcastConcretizeResultTypePattern
844     : public OpRewritePattern<BroadcastOp> {
845   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
846 
847   LogicalResult matchAndRewrite(BroadcastOp op,
848                                 PatternRewriter &rewriter) const override {
849     // Only concretize dynamic extent tensor result types.
850     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
851     if (!resultTy || !resultTy.isDynamicDim(0))
852       return failure();
853 
854     // Infer resulting shape rank if possible.
855     int64_t maxRank = 0;
856     for (Value shape : op.getShapes()) {
857       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
858         // Cannot infer resulting shape rank if any operand is dynamically
859         // ranked.
860         if (extentTensorTy.isDynamicDim(0))
861           return failure();
862         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
863       }
864     }
865 
866     auto newOp = rewriter.create<BroadcastOp>(
867         op.getLoc(), getExtentTensorType(getContext(), maxRank),
868         op.getShapes());
869     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
870     return success();
871   }
872 };
873 } // namespace
874 
875 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
876                                               MLIRContext *context) {
877   patterns.add<BroadcastConcretizeResultTypePattern,
878                BroadcastFoldConstantOperandsPattern,
879                BroadcastForwardSingleOperandPattern,
880                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
881                RemoveDuplicateOperandsPattern<BroadcastOp>,
882                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // ConcatOp
887 //===----------------------------------------------------------------------===//
888 
889 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
890   if (!operands[0] || !operands[1])
891     return nullptr;
892   auto lhsShape = llvm::to_vector<6>(
893       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
894   auto rhsShape = llvm::to_vector<6>(
895       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
896   SmallVector<int64_t, 6> resultShape;
897   resultShape.append(lhsShape.begin(), lhsShape.end());
898   resultShape.append(rhsShape.begin(), rhsShape.end());
899   Builder builder(getContext());
900   return builder.getIndexTensorAttr(resultShape);
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // ConstShapeOp
905 //===----------------------------------------------------------------------===//
906 
907 void ConstShapeOp::print(OpAsmPrinter &p) {
908   p << " ";
909   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
910   p << "[";
911   interleaveComma(getShape().getValues<int64_t>(), p);
912   p << "] : ";
913   p.printType(getType());
914 }
915 
916 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
917   if (parser.parseOptionalAttrDict(result.attributes))
918     return failure();
919   // We piggy-back on ArrayAttr parsing, though we don't internally store the
920   // shape as an ArrayAttr.
921   // TODO: Implement custom parser and maybe make syntax a bit more concise.
922   Attribute extentsRaw;
923   NamedAttrList dummy;
924   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
925     return failure();
926   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
927   if (!extentsArray)
928     return failure();
929   SmallVector<int64_t, 6> ints;
930   for (Attribute extent : extentsArray) {
931     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
932     if (!attr)
933       return failure();
934     ints.push_back(attr.getInt());
935   }
936   Builder &builder = parser.getBuilder();
937   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
938   Type resultTy;
939   if (parser.parseColonType(resultTy))
940     return failure();
941   result.types.push_back(resultTy);
942   return success();
943 }
944 
945 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
946 
947 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
948                                                MLIRContext *context) {
949   patterns.add<TensorCastConstShape>(context);
950 }
951 
952 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
953     MLIRContext *context, Optional<Location> location, ValueRange operands,
954     DictionaryAttr attributes, RegionRange regions,
955     SmallVectorImpl<Type> &inferredReturnTypes) {
956   Builder b(context);
957   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
958   if (!shape)
959     return emitOptionalError(location, "missing shape attribute");
960   inferredReturnTypes.assign({RankedTensorType::get(
961       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
962   return success();
963 }
964 
965 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
966                                                         TypeRange r) {
967   if (l.size() != 1 || r.size() != 1)
968     return false;
969 
970   Type lhs = l.front();
971   Type rhs = r.front();
972 
973   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
974     // Shape type is compatible with all other valid return types.
975     return true;
976   return lhs == rhs;
977 }
978 
979 //===----------------------------------------------------------------------===//
980 // CstrBroadcastableOp
981 //===----------------------------------------------------------------------===//
982 
983 void CstrBroadcastableOp::getCanonicalizationPatterns(
984     RewritePatternSet &patterns, MLIRContext *context) {
985   // Canonicalization patterns have overlap with the considerations during
986   // folding in case additional shape information is inferred at some point that
987   // does not result in folding.
988   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
989                CstrBroadcastableEqOps,
990                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
991                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
992 }
993 
994 // Return true if there is exactly one attribute not representing a scalar
995 // broadcast.
996 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
997   bool nonScalarSeen = false;
998   for (Attribute a : attributes) {
999     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
1000       if (nonScalarSeen)
1001         return false;
1002       nonScalarSeen = true;
1003     }
1004   }
1005   return true;
1006 }
1007 
1008 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1009   // No broadcasting is needed if all operands but one are scalar.
1010   if (hasAtMostSingleNonScalar(operands))
1011     return BoolAttr::get(getContext(), true);
1012 
1013   if ([&] {
1014         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1015         for (const auto &operand : operands) {
1016           if (!operand)
1017             return false;
1018           extents.push_back(llvm::to_vector<6>(
1019               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
1020         }
1021         return OpTrait::util::staticallyKnownBroadcastable(extents);
1022       }())
1023     return BoolAttr::get(getContext(), true);
1024 
1025   // Lastly, see if folding can be completed based on what constraints are known
1026   // on the input shapes.
1027   if ([&] {
1028         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1029         for (auto shapeValue : getShapes()) {
1030           extents.emplace_back();
1031           if (failed(getShapeVec(shapeValue, extents.back())))
1032             return false;
1033         }
1034         return OpTrait::util::staticallyKnownBroadcastable(extents);
1035       }())
1036     return BoolAttr::get(getContext(), true);
1037 
1038   // Because a failing witness result here represents an eventual assertion
1039   // failure, we do not replace it with a constant witness.
1040   return nullptr;
1041 }
1042 
1043 LogicalResult CstrBroadcastableOp::verify() {
1044   // Ensure that CstrBroadcastableOp contains at least two operands
1045   if (getNumOperands() < 2)
1046     return emitOpError("required at least 2 input shapes");
1047   return success();
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // CstrEqOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1055                                            MLIRContext *context) {
1056   // If inputs are equal, return passing witness
1057   patterns.add<CstrEqEqOps>(context);
1058 }
1059 
1060 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1061   if (llvm::all_of(operands,
1062                    [&](Attribute a) { return a && a == operands[0]; }))
1063     return BoolAttr::get(getContext(), true);
1064 
1065   // Because a failing witness result here represents an eventual assertion
1066   // failure, we do not try to replace it with a constant witness. Similarly, we
1067   // cannot if there are any non-const inputs.
1068   return nullptr;
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // ConstSizeOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1076                         int64_t value) {
1077   build(builder, result, builder.getIndexAttr(value));
1078 }
1079 
1080 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1081 
1082 void ConstSizeOp::getAsmResultNames(
1083     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1084   SmallString<4> buffer;
1085   llvm::raw_svector_ostream os(buffer);
1086   os << "c" << getValue();
1087   setNameFn(getResult(), os.str());
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // ConstWitnessOp
1092 //===----------------------------------------------------------------------===//
1093 
1094 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1095   return getPassingAttr();
1096 }
1097 
1098 //===----------------------------------------------------------------------===//
1099 // CstrRequireOp
1100 //===----------------------------------------------------------------------===//
1101 
1102 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1103   return operands[0];
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // DivOp
1108 //===----------------------------------------------------------------------===//
1109 
1110 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1111   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1112   if (!lhs)
1113     return nullptr;
1114   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1115   if (!rhs)
1116     return nullptr;
1117 
1118   // Division in APInt does not follow floor(lhs, rhs) when the result is
1119   // negative. Rather, APInt rounds toward zero.
1120   APInt quotient, remainder;
1121   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1122   if (quotient.isNegative() && !remainder.isNullValue()) {
1123     quotient -= 1;
1124   }
1125 
1126   Type indexTy = IndexType::get(getContext());
1127   return IntegerAttr::get(indexTy, quotient);
1128 }
1129 
1130 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1131     MLIRContext *context, Optional<Location> location, ValueRange operands,
1132     DictionaryAttr attributes, RegionRange regions,
1133     SmallVectorImpl<Type> &inferredReturnTypes) {
1134   if (operands[0].getType().isa<SizeType>() ||
1135       operands[1].getType().isa<SizeType>())
1136     inferredReturnTypes.assign({SizeType::get(context)});
1137   else
1138     inferredReturnTypes.assign({IndexType::get(context)});
1139   return success();
1140 }
1141 
1142 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1143   // SizeType is compatible with IndexType.
1144   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1145 }
1146 
1147 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // ShapeEqOp
1151 //===----------------------------------------------------------------------===//
1152 
1153 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1154   bool allSame = true;
1155   if (!operands.empty() && !operands[0])
1156     return {};
1157   for (Attribute operand : operands.drop_front(1)) {
1158     if (!operand)
1159       return {};
1160     allSame = allSame && operand == operands[0];
1161   }
1162   return BoolAttr::get(getContext(), allSame);
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // IndexToSizeOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1170   // Constant values of both types, `shape.size` and `index`, are represented as
1171   // `IntegerAttr`s which makes constant folding simple.
1172   if (Attribute arg = operands[0])
1173     return arg;
1174   return {};
1175 }
1176 
1177 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1178                                                 MLIRContext *context) {
1179   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1180 }
1181 
1182 //===----------------------------------------------------------------------===//
1183 // FromExtentsOp
1184 //===----------------------------------------------------------------------===//
1185 
1186 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1187   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1188     return nullptr;
1189   SmallVector<int64_t, 6> extents;
1190   for (auto attr : operands)
1191     extents.push_back(attr.cast<IntegerAttr>().getInt());
1192   Builder builder(getContext());
1193   return builder.getIndexTensorAttr(extents);
1194 }
1195 
1196 //===----------------------------------------------------------------------===//
1197 // FunctionLibraryOp
1198 //===----------------------------------------------------------------------===//
1199 
1200 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1201                               StringRef name) {
1202   result.attributes.push_back(builder.getNamedAttr(
1203       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1204 }
1205 
1206 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1207   auto attr = getMapping()
1208                   .get(op->getName().getIdentifier())
1209                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1210   if (!attr)
1211     return nullptr;
1212   return lookupSymbol<FuncOp>(attr);
1213 }
1214 
1215 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1216                                      OperationState &result) {
1217   // Parse the op name.
1218   StringAttr nameAttr;
1219   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1220                              result.attributes))
1221     return failure();
1222 
1223   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1224     return failure();
1225 
1226   auto *bodyRegion = result.addRegion();
1227   if (parser.parseRegion(*bodyRegion))
1228     return failure();
1229 
1230   if (parser.parseKeyword("mapping"))
1231     return failure();
1232 
1233   DictionaryAttr mappingAttr;
1234   if (parser.parseAttribute(mappingAttr,
1235                             parser.getBuilder().getType<NoneType>(), "mapping",
1236                             result.attributes))
1237     return failure();
1238   return success();
1239 }
1240 
1241 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1242   p << ' ';
1243   p.printSymbolName(getName());
1244   p.printOptionalAttrDictWithKeyword(
1245       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1246   p << ' ';
1247   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1248                 /*printBlockTerminators=*/false);
1249   p << " mapping ";
1250   p.printAttributeWithoutType(getMappingAttr());
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // GetExtentOp
1255 //===----------------------------------------------------------------------===//
1256 
1257 Optional<int64_t> GetExtentOp::getConstantDim() {
1258   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1259     return constSizeOp.getValue().getLimitedValue();
1260   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1261     return constantOp.getValue().cast<IntegerAttr>().getInt();
1262   return llvm::None;
1263 }
1264 
1265 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1266   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1267   if (!elements)
1268     return nullptr;
1269   Optional<int64_t> dim = getConstantDim();
1270   if (!dim.hasValue())
1271     return nullptr;
1272   if (dim.getValue() >= elements.getNumElements())
1273     return nullptr;
1274   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1275 }
1276 
1277 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1278                         int64_t dim) {
1279   auto loc = result.location;
1280   auto dimAttr = builder.getIndexAttr(dim);
1281   if (shape.getType().isa<ShapeType>()) {
1282     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1283     build(builder, result, builder.getType<SizeType>(), shape, dim);
1284   } else {
1285     Value dim =
1286         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1287     build(builder, result, builder.getIndexType(), shape, dim);
1288   }
1289 }
1290 
1291 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1292     MLIRContext *context, Optional<Location> location, ValueRange operands,
1293     DictionaryAttr attributes, RegionRange regions,
1294     SmallVectorImpl<Type> &inferredReturnTypes) {
1295   inferredReturnTypes.assign({IndexType::get(context)});
1296   return success();
1297 }
1298 
1299 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1300                                                        TypeRange r) {
1301   // SizeType is compatible with IndexType.
1302   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1303 }
1304 
1305 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // IsBroadcastableOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1312                                                     MLIRContext *context) {
1313   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1314 }
1315 
1316 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1317   // Can always broadcast fewer than two shapes.
1318   if (operands.size() < 2) {
1319     return BoolAttr::get(getContext(), true);
1320   }
1321 
1322   return nullptr;
1323 }
1324 
1325 //===----------------------------------------------------------------------===//
1326 // MeetOp
1327 //===----------------------------------------------------------------------===//
1328 
1329 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1330     MLIRContext *context, Optional<Location> location, ValueRange operands,
1331     DictionaryAttr attributes, RegionRange regions,
1332     SmallVectorImpl<Type> &inferredReturnTypes) {
1333   inferredReturnTypes.assign({operands[0].getType()});
1334   return success();
1335 }
1336 
1337 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1338   if (l.size() != 1 || r.size() != 1)
1339     return false;
1340   if (l == r)
1341     return true;
1342 
1343   Type lhs = l.front();
1344   Type rhs = r.front();
1345 
1346   if (lhs != rhs)
1347     return false;
1348 
1349   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1350     return true;
1351 
1352   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1353     return true;
1354   return false;
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // RankOp
1359 //===----------------------------------------------------------------------===//
1360 
1361 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1362   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1363   if (!shape)
1364     return {};
1365   int64_t rank = shape.getNumElements();
1366   Builder builder(getContext());
1367   return builder.getIndexAttr(rank);
1368 }
1369 
1370 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1371 /// Constant folding fails in cases where only the rank is constant, not the
1372 /// shape itself.
1373 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1374 ///
1375 /// Example:
1376 ///
1377 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1378 /// %rank = shape.rank %shape
1379 ///
1380 /// becomes
1381 ///
1382 /// %rank = shape.const_size 3
1383 
1384 namespace {
1385 struct RankShapeOfCanonicalizationPattern
1386     : public OpRewritePattern<shape::RankOp> {
1387   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1388 
1389   LogicalResult matchAndRewrite(shape::RankOp op,
1390                                 PatternRewriter &rewriter) const override {
1391     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1392     if (!shapeOfOp)
1393       return failure();
1394     auto rankedTensorType =
1395         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1396     if (!rankedTensorType)
1397       return failure();
1398     int64_t rank = rankedTensorType.getRank();
1399     if (op.getType().isa<IndexType>()) {
1400       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1401                                                           rank);
1402     } else if (op.getType().isa<shape::SizeType>()) {
1403       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1404     } else {
1405       return failure();
1406     }
1407     return success();
1408   }
1409 };
1410 } // namespace
1411 
1412 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1413                                                 MLIRContext *context) {
1414   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1415 }
1416 
1417 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1418     MLIRContext *context, Optional<Location> location, ValueRange operands,
1419     DictionaryAttr attributes, RegionRange regions,
1420     SmallVectorImpl<Type> &inferredReturnTypes) {
1421   if (operands[0].getType().isa<ShapeType>())
1422     inferredReturnTypes.assign({SizeType::get(context)});
1423   else
1424     inferredReturnTypes.assign({IndexType::get(context)});
1425   return success();
1426 }
1427 
1428 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1429   // SizeType is compatible with IndexType.
1430   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1431 }
1432 
1433 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // NumElementsOp
1437 //===----------------------------------------------------------------------===//
1438 
1439 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1440 
1441   // Fold only when argument constant.
1442   Attribute shape = operands[0];
1443   if (!shape)
1444     return {};
1445 
1446   APInt product(64, 1);
1447   for (auto value : shape.cast<DenseIntElementsAttr>())
1448     product *= value;
1449   Builder builder(getContext());
1450   return builder.getIndexAttr(product.getLimitedValue());
1451 }
1452 
1453 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1454     MLIRContext *context, Optional<Location> location, ValueRange operands,
1455     DictionaryAttr attributes, RegionRange regions,
1456     SmallVectorImpl<Type> &inferredReturnTypes) {
1457   if (operands[0].getType().isa<ShapeType>())
1458     inferredReturnTypes.assign({SizeType::get(context)});
1459   else
1460     inferredReturnTypes.assign({IndexType::get(context)});
1461   return success();
1462 }
1463 
1464 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1465                                                          TypeRange r) {
1466   // SizeType is compatible with IndexType.
1467   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1468 }
1469 
1470 LogicalResult shape::NumElementsOp::verify() {
1471   return verifySizeOrIndexOp(*this);
1472 }
1473 
1474 //===----------------------------------------------------------------------===//
1475 // MaxOp
1476 //===----------------------------------------------------------------------===//
1477 
1478 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1479   // If operands are equal, just propagate one.
1480   if (getLhs() == getRhs())
1481     return getLhs();
1482   return nullptr;
1483 }
1484 
1485 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1486     MLIRContext *context, Optional<Location> location, ValueRange operands,
1487     DictionaryAttr attributes, RegionRange regions,
1488     SmallVectorImpl<Type> &inferredReturnTypes) {
1489   if (operands[0].getType() == operands[1].getType())
1490     inferredReturnTypes.assign({operands[0].getType()});
1491   else
1492     inferredReturnTypes.assign({SizeType::get(context)});
1493   return success();
1494 }
1495 
1496 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1497   if (l.size() != 1 || r.size() != 1)
1498     return false;
1499   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1500     return true;
1501   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1502     return true;
1503   return false;
1504 }
1505 
1506 //===----------------------------------------------------------------------===//
1507 // MinOp
1508 //===----------------------------------------------------------------------===//
1509 
1510 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1511   // If operands are equal, just propagate one.
1512   if (getLhs() == getRhs())
1513     return getLhs();
1514   return nullptr;
1515 }
1516 
1517 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1518     MLIRContext *context, Optional<Location> location, ValueRange operands,
1519     DictionaryAttr attributes, RegionRange regions,
1520     SmallVectorImpl<Type> &inferredReturnTypes) {
1521   if (operands[0].getType() == operands[1].getType())
1522     inferredReturnTypes.assign({operands[0].getType()});
1523   else
1524     inferredReturnTypes.assign({SizeType::get(context)});
1525   return success();
1526 }
1527 
1528 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1529   if (l.size() != 1 || r.size() != 1)
1530     return false;
1531   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1532     return true;
1533   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1534     return true;
1535   return false;
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // MulOp
1540 //===----------------------------------------------------------------------===//
1541 
1542 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1543   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1544   if (!lhs)
1545     return nullptr;
1546   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1547   if (!rhs)
1548     return nullptr;
1549   APInt folded = lhs.getValue() * rhs.getValue();
1550   Type indexTy = IndexType::get(getContext());
1551   return IntegerAttr::get(indexTy, folded);
1552 }
1553 
1554 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1555     MLIRContext *context, Optional<Location> location, ValueRange operands,
1556     DictionaryAttr attributes, RegionRange regions,
1557     SmallVectorImpl<Type> &inferredReturnTypes) {
1558   if (operands[0].getType().isa<SizeType>() ||
1559       operands[1].getType().isa<SizeType>())
1560     inferredReturnTypes.assign({SizeType::get(context)});
1561   else
1562     inferredReturnTypes.assign({IndexType::get(context)});
1563   return success();
1564 }
1565 
1566 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1567   // SizeType is compatible with IndexType.
1568   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1569 }
1570 
1571 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1572 
1573 //===----------------------------------------------------------------------===//
1574 // ShapeOfOp
1575 //===----------------------------------------------------------------------===//
1576 
1577 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1578   auto type = getOperand().getType().dyn_cast<ShapedType>();
1579   if (!type || !type.hasStaticShape())
1580     return nullptr;
1581   Builder builder(getContext());
1582   return builder.getIndexTensorAttr(type.getShape());
1583 }
1584 
1585 namespace {
1586 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1587   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1588 
1589   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1590                                 PatternRewriter &rewriter) const override {
1591     if (!op.getArg().getType().isa<ShapedType>())
1592       return failure();
1593     if (op.getType().isa<ShapedType>())
1594       return failure();
1595 
1596     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1597                                                   op.getArg());
1598     return success();
1599   }
1600 };
1601 
1602 // Canonicalize
1603 // ```
1604 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1605 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1606 // ```
1607 // to
1608 // ```
1609 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1610 // ```
1611 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1612   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1613 
1614   LogicalResult matchAndRewrite(tensor::CastOp op,
1615                                 PatternRewriter &rewriter) const override {
1616     auto ty = op.getType().dyn_cast<RankedTensorType>();
1617     if (!ty || ty.getRank() != 1)
1618       return failure();
1619 
1620     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1621     if (!shapeOfOp)
1622       return failure();
1623 
1624     // Argument type must be ranked and must not conflict.
1625     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1626     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1627       return failure();
1628 
1629     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1630     return success();
1631   }
1632 };
1633 } // namespace
1634 
1635 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1636                                             MLIRContext *context) {
1637   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1638                ExtractFromShapeOfExtentTensor>(context);
1639 }
1640 
1641 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1642     MLIRContext *context, Optional<Location> location, ValueRange operands,
1643     DictionaryAttr attributes, RegionRange regions,
1644     SmallVectorImpl<Type> &inferredReturnTypes) {
1645   if (operands[0].getType().isa<ValueShapeType>())
1646     inferredReturnTypes.assign({ShapeType::get(context)});
1647   else {
1648     auto shapedTy = operands[0].getType().cast<ShapedType>();
1649     int64_t rank =
1650         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1651     Type indexTy = IndexType::get(context);
1652     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1653     inferredReturnTypes.assign({extentTensorTy});
1654   }
1655   return success();
1656 }
1657 
1658 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1659   if (l.size() != 1 || r.size() != 1)
1660     return false;
1661   if (l == r)
1662     return true;
1663 
1664   Type lhs = l.front();
1665   Type rhs = r.front();
1666 
1667   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1668     return false;
1669 
1670   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1671     // Shape type is compatible with all other valid return types.
1672     return true;
1673 
1674   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1675     return true;
1676   return false;
1677 }
1678 
1679 LogicalResult shape::ShapeOfOp::verify() {
1680   return verifyShapeOrExtentTensorOp(*this);
1681 }
1682 
1683 //===----------------------------------------------------------------------===//
1684 // SizeToIndexOp
1685 //===----------------------------------------------------------------------===//
1686 
1687 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1688   // Constant values of both types, `shape.size` and `index`, are represented as
1689   // `IntegerAttr`s which makes constant folding simple.
1690   if (Attribute arg = operands[0])
1691     return arg;
1692   return OpFoldResult();
1693 }
1694 
1695 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1696                                                 MLIRContext *context) {
1697   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1698 }
1699 
1700 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1701   if (inputs.size() != 1 || outputs.size() != 1)
1702     return false;
1703   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1704 }
1705 
1706 //===----------------------------------------------------------------------===//
1707 // YieldOp
1708 //===----------------------------------------------------------------------===//
1709 
1710 LogicalResult shape::YieldOp::verify() {
1711   auto *parentOp = (*this)->getParentOp();
1712   auto results = parentOp->getResults();
1713   auto operands = getOperands();
1714 
1715   if (parentOp->getNumResults() != getNumOperands())
1716     return emitOpError() << "number of operands does not match number of "
1717                             "results of its parent";
1718   for (auto e : llvm::zip(results, operands))
1719     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1720       return emitOpError() << "types mismatch between yield op and its parent";
1721 
1722   return success();
1723 }
1724 
1725 //===----------------------------------------------------------------------===//
1726 // SplitAtOp
1727 //===----------------------------------------------------------------------===//
1728 
1729 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1730                               SmallVectorImpl<OpFoldResult> &results) {
1731   if (!operands[0] || !operands[1])
1732     return failure();
1733   auto shapeVec = llvm::to_vector<6>(
1734       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1735   auto shape = llvm::makeArrayRef(shapeVec);
1736   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1737   // Verify that the split point is in the correct range.
1738   // TODO: Constant fold to an "error".
1739   int64_t rank = shape.size();
1740   if (!(-rank <= splitPoint && splitPoint <= rank))
1741     return failure();
1742   if (splitPoint < 0)
1743     splitPoint += shape.size();
1744   Builder builder(operands[0].getContext());
1745   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1746   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1747   return success();
1748 }
1749 
1750 //===----------------------------------------------------------------------===//
1751 // ToExtentTensorOp
1752 //===----------------------------------------------------------------------===//
1753 
1754 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1755   if (!operands[0])
1756     return OpFoldResult();
1757   Builder builder(getContext());
1758   auto shape = llvm::to_vector<6>(
1759       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1760   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1761                                     builder.getIndexType());
1762   return DenseIntElementsAttr::get(type, shape);
1763 }
1764 
1765 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1766   if (inputs.size() != 1 || outputs.size() != 1)
1767     return false;
1768   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1769     if (!inputTensor.getElementType().isa<IndexType>() ||
1770         inputTensor.getRank() != 1)
1771       return false;
1772   } else if (!inputs[0].isa<ShapeType>()) {
1773     return false;
1774   }
1775 
1776   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1777   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1778 }
1779 
1780 //===----------------------------------------------------------------------===//
1781 // ReduceOp
1782 //===----------------------------------------------------------------------===//
1783 
1784 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1785                      ValueRange initVals) {
1786   result.addOperands(shape);
1787   result.addOperands(initVals);
1788 
1789   Region *bodyRegion = result.addRegion();
1790   bodyRegion->push_back(new Block);
1791   Block &bodyBlock = bodyRegion->front();
1792   bodyBlock.addArgument(builder.getIndexType(), result.location);
1793 
1794   Type elementType;
1795   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1796     elementType = tensorType.getElementType();
1797   else
1798     elementType = SizeType::get(builder.getContext());
1799   bodyBlock.addArgument(elementType, shape.getLoc());
1800 
1801   for (Value initVal : initVals) {
1802     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1803     result.addTypes(initVal.getType());
1804   }
1805 }
1806 
1807 LogicalResult ReduceOp::verify() {
1808   // Verify block arg types.
1809   Block &block = getRegion().front();
1810 
1811   // The block takes index, extent, and aggregated values as arguments.
1812   auto blockArgsCount = getInitVals().size() + 2;
1813   if (block.getNumArguments() != blockArgsCount)
1814     return emitOpError() << "ReduceOp body is expected to have "
1815                          << blockArgsCount << " arguments";
1816 
1817   // The first block argument is the index and must always be of type `index`.
1818   if (!block.getArgument(0).getType().isa<IndexType>())
1819     return emitOpError(
1820         "argument 0 of ReduceOp body is expected to be of IndexType");
1821 
1822   // The second block argument is the extent and must be of type `size` or
1823   // `index`, depending on whether the reduce operation is applied to a shape or
1824   // to an extent tensor.
1825   Type extentTy = block.getArgument(1).getType();
1826   if (getShape().getType().isa<ShapeType>()) {
1827     if (!extentTy.isa<SizeType>())
1828       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1829                          "SizeType if the ReduceOp operates on a ShapeType");
1830   } else {
1831     if (!extentTy.isa<IndexType>())
1832       return emitOpError(
1833           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1834           "ReduceOp operates on an extent tensor");
1835   }
1836 
1837   for (const auto &type : llvm::enumerate(getInitVals()))
1838     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1839       return emitOpError() << "type mismatch between argument "
1840                            << type.index() + 2
1841                            << " of ReduceOp body and initial value "
1842                            << type.index();
1843   return success();
1844 }
1845 
1846 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1847   // Parse operands.
1848   SmallVector<OpAsmParser::OperandType, 3> operands;
1849   Type shapeOrExtentTensorType;
1850   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1851                               OpAsmParser::Delimiter::Paren) ||
1852       parser.parseColonType(shapeOrExtentTensorType) ||
1853       parser.parseOptionalArrowTypeList(result.types))
1854     return failure();
1855 
1856   // Resolve operands.
1857   auto initVals = llvm::makeArrayRef(operands).drop_front();
1858   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1859                             result.operands) ||
1860       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1861                              result.operands))
1862     return failure();
1863 
1864   // Parse the body.
1865   Region *body = result.addRegion();
1866   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1867     return failure();
1868 
1869   // Parse attributes.
1870   if (parser.parseOptionalAttrDict(result.attributes))
1871     return failure();
1872 
1873   return success();
1874 }
1875 
1876 void ReduceOp::print(OpAsmPrinter &p) {
1877   p << '(' << getShape() << ", " << getInitVals()
1878     << ") : " << getShape().getType();
1879   p.printOptionalArrowTypeList(getResultTypes());
1880   p << ' ';
1881   p.printRegion(getRegion());
1882   p.printOptionalAttrDict((*this)->getAttrs());
1883 }
1884 
1885 #define GET_OP_CLASSES
1886 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1887