1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/SetOperations.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::shape;
31 
32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
33 
34 namespace {
35 #include "ShapeCanonicalization.inc"
36 } // namespace
37 
38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
39   return RankedTensorType::get({rank}, IndexType::get(ctx));
40 }
41 
42 bool shape::isExtentTensorType(Type type) {
43   auto ranked = type.dyn_cast<RankedTensorType>();
44   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
45 }
46 
47 LogicalResult shape::getShapeVec(Value input,
48                                  SmallVectorImpl<int64_t> &shapeValues) {
49   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
50     auto type = inputOp.getArg().getType().cast<ShapedType>();
51     if (!type.hasRank())
52       return failure();
53     llvm::append_range(shapeValues, type.getShape());
54     return success();
55   }
56   DenseIntElementsAttr attr;
57   if (matchPattern(input, m_Constant(&attr))) {
58     llvm::append_range(shapeValues, attr.getValues<int64_t>());
59     return success();
60   }
61   return failure();
62 }
63 
64 static bool isErrorPropagationPossible(TypeRange operandTypes) {
65   return llvm::any_of(operandTypes, [](Type ty) {
66     return ty.isa<SizeType, ShapeType, ValueShapeType>();
67   });
68 }
69 
70 static LogicalResult verifySizeOrIndexOp(Operation *op) {
71   assert(op != nullptr && op->getNumResults() == 1);
72   Type resultTy = op->getResultTypes().front();
73   if (isErrorPropagationPossible(op->getOperandTypes())) {
74     if (!resultTy.isa<SizeType>())
75       return op->emitOpError()
76              << "if at least one of the operands can hold error values then "
77                 "the result must be of type `size` to propagate them";
78   }
79   return success();
80 }
81 
82 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
83   assert(op != nullptr && op->getNumResults() == 1);
84   Type resultTy = op->getResultTypes().front();
85   if (isErrorPropagationPossible(op->getOperandTypes())) {
86     if (!resultTy.isa<ShapeType>())
87       return op->emitOpError()
88              << "if at least one of the operands can hold error values then "
89                 "the result must be of type `shape` to propagate them";
90   }
91   return success();
92 }
93 
94 template <typename... Ty>
95 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
96   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
97 }
98 
99 template <typename... Ty, typename... ranges>
100 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
101   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // InlinerInterface
106 //===----------------------------------------------------------------------===//
107 
108 namespace {
109 /// This class defines the interface for inlining shape dialect ops.
110 struct ShapeInlinerInterface : public DialectInlinerInterface {
111   using DialectInlinerInterface::DialectInlinerInterface;
112 
113   // Returns true if the given region 'src' can be inlined into the region
114   // 'dest' that is attached to an operation registered to the current dialect.
115   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
116                        BlockAndValueMapping &) const final {
117     return true;
118   }
119 
120   // Returns true if the given operation 'op', that is registered to this
121   // dialect, can be inlined into the region 'dest' that is attached to an
122   // operation registered to the current dialect.
123   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
124                        BlockAndValueMapping &) const final {
125     return true;
126   }
127 };
128 } // namespace
129 
130 void ShapeDialect::initialize() {
131   addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
134       >();
135   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
136   addInterfaces<ShapeInlinerInterface>();
137   // Allow unknown operations during prototyping and testing. As the dialect is
138   // still evolving it makes it simple to start with an unregistered ops and
139   // try different variants before actually defining the op.
140   allowUnknownOperations();
141 }
142 
143 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
144                                              Attribute value, Type type,
145                                              Location loc) {
146   if (type.isa<ShapeType>() || isExtentTensorType(type))
147     return builder.create<ConstShapeOp>(loc, type,
148                                         value.cast<DenseIntElementsAttr>());
149   if (type.isa<SizeType>())
150     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
151   if (type.isa<WitnessType>())
152     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
153   if (arith::ConstantOp::isBuildableWith(value, type))
154     return builder.create<arith::ConstantOp>(loc, type, value);
155   return nullptr;
156 }
157 
158 /// Parse a type registered to this dialect.
159 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
160   StringRef keyword;
161   if (parser.parseKeyword(&keyword))
162     return Type();
163 
164   if (keyword == "shape")
165     return ShapeType::get(getContext());
166   if (keyword == "size")
167     return SizeType::get(getContext());
168   if (keyword == "value_shape")
169     return ValueShapeType::get(getContext());
170   if (keyword == "witness")
171     return WitnessType::get(getContext());
172 
173   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
174   return Type();
175 }
176 
177 /// Print a type registered to this dialect.
178 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
179   TypeSwitch<Type>(type)
180       .Case<ShapeType>([&](Type) { os << "shape"; })
181       .Case<SizeType>([&](Type) { os << "size"; })
182       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
183       .Case<WitnessType>([&](Type) { os << "witness"; })
184       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
185 }
186 
187 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
188                                                      NamedAttribute attribute) {
189   // Verify shape.lib attribute.
190   if (attribute.getName() == "shape.lib") {
191     if (!op->hasTrait<OpTrait::SymbolTable>())
192       return op->emitError(
193           "shape.lib attribute may only be on op implementing SymbolTable");
194 
195     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
196       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
197       if (!symbol)
198         return op->emitError("shape function library ")
199                << symbolRef << " not found";
200       return isa<shape::FunctionLibraryOp>(symbol)
201                  ? success()
202                  : op->emitError()
203                        << symbolRef << " required to be shape function library";
204     }
205 
206     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
207       // Verify all entries are function libraries and mappings in libraries
208       // refer to unique ops.
209       DenseSet<StringAttr> key;
210       for (auto it : arr) {
211         if (!it.isa<SymbolRefAttr>())
212           return op->emitError(
213               "only SymbolRefAttr allowed in shape.lib attribute array");
214 
215         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
216             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
217         if (!shapeFnLib)
218           return op->emitError()
219                  << it << " does not refer to FunctionLibraryOp";
220         for (auto mapping : shapeFnLib.getMapping()) {
221           if (!key.insert(mapping.getName()).second) {
222             return op->emitError("only one op to shape mapping allowed, found "
223                                  "multiple for `")
224                    << mapping.getName() << "`";
225           }
226         }
227       }
228       return success();
229     }
230 
231     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
232                          "allowed as shape.lib attribute");
233   }
234   return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AnyOp
239 //===----------------------------------------------------------------------===//
240 
241 // TODO: Canonicalization should be implemented for shapes that can be
242 // determined through mixtures of the known dimensions of the inputs.
243 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
244   // Only the last operand is checked because AnyOp is commutative.
245   if (operands.back())
246     return operands.back();
247 
248   return nullptr;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // AssumingOp
253 //===----------------------------------------------------------------------===//
254 
255 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
256   result.regions.reserve(1);
257   Region *doRegion = result.addRegion();
258 
259   auto &builder = parser.getBuilder();
260   OpAsmParser::OperandType cond;
261   if (parser.parseOperand(cond) ||
262       parser.resolveOperand(cond, builder.getType<WitnessType>(),
263                             result.operands))
264     return failure();
265 
266   // Parse optional results type list.
267   if (parser.parseOptionalArrowTypeList(result.types))
268     return failure();
269 
270   // Parse the region and add a terminator if elided.
271   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
272     return failure();
273   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
274 
275   // Parse the optional attribute list.
276   if (parser.parseOptionalAttrDict(result.attributes))
277     return failure();
278   return success();
279 }
280 
281 void AssumingOp::print(OpAsmPrinter &p) {
282   bool yieldsResults = !getResults().empty();
283 
284   p << " " << getWitness();
285   if (yieldsResults)
286     p << " -> (" << getResultTypes() << ")";
287   p << ' ';
288   p.printRegion(getDoRegion(),
289                 /*printEntryBlockArgs=*/false,
290                 /*printBlockTerminators=*/yieldsResults);
291   p.printOptionalAttrDict((*this)->getAttrs());
292 }
293 
294 namespace {
295 // Removes AssumingOp with a passing witness and inlines the region.
296 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
297   using OpRewritePattern<AssumingOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(AssumingOp op,
300                                 PatternRewriter &rewriter) const override {
301     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
302     if (!witness || !witness.getPassingAttr())
303       return failure();
304 
305     AssumingOp::inlineRegionIntoParent(op, rewriter);
306     return success();
307   }
308 };
309 
310 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
311   using OpRewritePattern<AssumingOp>::OpRewritePattern;
312 
313   LogicalResult matchAndRewrite(AssumingOp op,
314                                 PatternRewriter &rewriter) const override {
315     Block *body = op.getBody();
316     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
317 
318     // Find used values.
319     SmallVector<Value, 4> newYieldOperands;
320     Value opResult, yieldOperand;
321     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
322       std::tie(opResult, yieldOperand) = it;
323       if (!opResult.getUses().empty()) {
324         newYieldOperands.push_back(yieldOperand);
325       }
326     }
327 
328     // Rewrite only if redundant results exist.
329     if (newYieldOperands.size() == yieldOp->getNumOperands())
330       return failure();
331 
332     // Replace yield op in the old assuming op's body and move the entire region
333     // to the new assuming op.
334     rewriter.setInsertionPointToEnd(body);
335     auto newYieldOp =
336         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
337     rewriter.setInsertionPoint(op);
338     auto newOp = rewriter.create<AssumingOp>(
339         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
340     newOp.getDoRegion().takeBody(op.getDoRegion());
341 
342     // Use the new results to replace the previously used ones.
343     SmallVector<Value, 4> replacementValues;
344     auto src = newOp.getResults().begin();
345     for (auto it : op.getResults()) {
346       if (it.getUses().empty())
347         replacementValues.push_back(nullptr);
348       else
349         replacementValues.push_back(*src++);
350     }
351     rewriter.replaceOp(op, replacementValues);
352     return success();
353   }
354 };
355 } // namespace
356 
357 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
358                                              MLIRContext *context) {
359   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
360 }
361 
362 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
363 void AssumingOp::getSuccessorRegions(
364     Optional<unsigned> index, ArrayRef<Attribute> operands,
365     SmallVectorImpl<RegionSuccessor> &regions) {
366   // AssumingOp has unconditional control flow into the region and back to the
367   // parent, so return the correct RegionSuccessor purely based on the index
368   // being None or 0.
369   if (index.hasValue()) {
370     regions.push_back(RegionSuccessor(getResults()));
371     return;
372   }
373 
374   regions.push_back(RegionSuccessor(&getDoRegion()));
375 }
376 
377 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
378                                         PatternRewriter &rewriter) {
379   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
380   auto *assumingBlock = op.getBody();
381   auto initPosition = rewriter.getInsertionPoint();
382   auto *blockAfterAssuming =
383       rewriter.splitBlock(blockBeforeAssuming, initPosition);
384 
385   // Remove the AssumingOp and AssumingYieldOp.
386   auto &yieldOp = assumingBlock->back();
387   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
388   rewriter.replaceOp(op, yieldOp.getOperands());
389   rewriter.eraseOp(&yieldOp);
390 
391   // Merge blocks together as there was no branching behavior from the
392   // AssumingOp.
393   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
394   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
395 }
396 
397 void AssumingOp::build(
398     OpBuilder &builder, OperationState &result, Value witness,
399     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
400 
401   result.addOperands(witness);
402   Region *bodyRegion = result.addRegion();
403   bodyRegion->push_back(new Block);
404   Block &bodyBlock = bodyRegion->front();
405 
406   // Build body.
407   OpBuilder::InsertionGuard guard(builder);
408   builder.setInsertionPointToStart(&bodyBlock);
409   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
410   builder.create<AssumingYieldOp>(result.location, yieldValues);
411 
412   SmallVector<Type, 2> assumingTypes;
413   for (Value v : yieldValues)
414     assumingTypes.push_back(v.getType());
415   result.addTypes(assumingTypes);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // AddOp
420 //===----------------------------------------------------------------------===//
421 
422 LogicalResult mlir::shape::AddOp::inferReturnTypes(
423     MLIRContext *context, Optional<Location> location, ValueRange operands,
424     DictionaryAttr attributes, RegionRange regions,
425     SmallVectorImpl<Type> &inferredReturnTypes) {
426   if (operands[0].getType().isa<SizeType>() ||
427       operands[1].getType().isa<SizeType>())
428     inferredReturnTypes.assign({SizeType::get(context)});
429   else
430     inferredReturnTypes.assign({IndexType::get(context)});
431   return success();
432 }
433 
434 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
435   // SizeType is compatible with IndexType.
436   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
437 }
438 
439 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
440   // add(x, 0) -> x
441   if (matchPattern(getRhs(), m_Zero()))
442     return getLhs();
443 
444   return constFoldBinaryOp<IntegerAttr>(
445       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
446 }
447 
448 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
449 
450 //===----------------------------------------------------------------------===//
451 // AssumingAllOp
452 //===----------------------------------------------------------------------===//
453 
454 namespace {
455 
456 // Merge multiple `shape.assuming_all` operations together.
457 //
458 //   %0 = shape.assuming_all %w0, %w1
459 //   %1 = shape.assuming_all %w2, %0
460 //
461 // to:
462 //
463 //   %0 = shape.assuming_all %w0, %w2, %w2
464 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
465   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
466 
467   LogicalResult matchAndRewrite(AssumingAllOp op,
468                                 PatternRewriter &rewriter) const override {
469     SmallVector<Value> operands;
470 
471     for (Value operand : op.getInputs()) {
472       if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
473         operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
474       else
475         operands.push_back(operand);
476     }
477 
478     // We didn't find any other `assuming_all` ops to merge with.
479     if (operands.size() == op.getNumOperands())
480       return failure();
481 
482     // Replace with a new `assuming_all` operation with merged constraints.
483     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
484     return success();
485   }
486 };
487 
488 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
489 // are subsumed by others.
490 //
491 //   %0 = shape.cstr_broadcastable %shape0, %shape1
492 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
493 //
494 //   %2 = shape.cstr_broadcastable %shape3, %shape4
495 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
496 //
497 //   %4 = shape.assuming_all %0, %1, %2, %3
498 //
499 // to:
500 //
501 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
502 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
503 //   %2 = shape.assuming_all %0, %1
504 //
505 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
506 // shapes [0, 1] are broadcastable too, and can be removed from the list of
507 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
508 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
509 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
510   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
511 
512   LogicalResult matchAndRewrite(AssumingAllOp op,
513                                 PatternRewriter &rewriter) const override {
514     // Collect all `CstrBroadcastableOp` operands first.
515     SetVector<CstrBroadcastableOp> operands;
516     for (Value operand : op.getInputs()) {
517       // TODO: Apply this optimization if some of the witnesses are not
518       // produced by the `cstr_broadcastable`.
519       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
520       if (!broadcastable)
521         return failure();
522 
523       operands.insert(broadcastable);
524     }
525 
526     // Skip trivial `assuming_all` operations.
527     if (operands.size() <= 1)
528       return failure();
529 
530     // Collect shapes checked by `cstr_broadcastable` operands.
531     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
532     for (auto cstr : operands) {
533       DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
534       shapes.emplace_back(cstr, std::move(shapesSet));
535     }
536 
537     // Sort by the number of shape operands (larger to smaller).
538     llvm::sort(shapes, [](auto a, auto b) {
539       return a.first.getNumOperands() > b.first.getNumOperands();
540     });
541 
542     // We start from the `cst_broadcastable` operations with largest number of
543     // shape operands, and remove redundant `cst_broadcastable` operations. We
544     // do this until we find a set of `cst_broadcastable` operations with
545     // non-overlapping constraints.
546     SmallVector<CstrBroadcastableOp> markedForErase;
547 
548     for (unsigned i = 0; i < shapes.size(); ++i) {
549       auto isSubset = [&](auto pair) {
550         return llvm::set_is_subset(pair.second, shapes[i].second);
551       };
552 
553       // Keep redundant `cstr_broadcastable` operations to be erased.
554       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
555       for (auto *it0 = it; it0 < shapes.end(); ++it0)
556         markedForErase.push_back(it0->first);
557       shapes.erase(it, shapes.end());
558     }
559 
560     // We didn't find any operands that could be removed.
561     if (markedForErase.empty())
562       return failure();
563 
564     // Collect non-overlapping `cst_broadcastable` constraints.
565     SmallVector<Value> uniqueConstraints;
566     for (auto &shape : shapes)
567       uniqueConstraints.push_back(shape.first.getResult());
568 
569     // Replace with a new `assuming_all` operation ...
570     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
571 
572     // ... and maybe erase `cstr_broadcastable` ops without uses.
573     for (auto &op : markedForErase)
574       if (op->use_empty())
575         rewriter.eraseOp(op);
576 
577     return success();
578   }
579 };
580 
581 struct AssumingAllToCstrEqCanonicalization
582     : public OpRewritePattern<AssumingAllOp> {
583   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
584 
585   LogicalResult matchAndRewrite(AssumingAllOp op,
586                                 PatternRewriter &rewriter) const override {
587     SmallVector<Value, 8> shapes;
588     for (Value w : op.getInputs()) {
589       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
590       if (!cstrEqOp)
591         return failure();
592       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
593         return llvm::is_contained(shapes, s);
594       });
595       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
596         return failure();
597       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
598     }
599     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
600     return success();
601   }
602 };
603 
604 template <typename OpTy>
605 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
606   using OpRewritePattern<OpTy>::OpRewritePattern;
607 
608   LogicalResult matchAndRewrite(OpTy op,
609                                 PatternRewriter &rewriter) const override {
610     // Find unique operands.
611     SetVector<Value> unique(op.operand_begin(), op.operand_end());
612 
613     // Reduce op to equivalent with unique operands.
614     if (unique.size() < op.getNumOperands()) {
615       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
616                                         unique.takeVector(), op->getAttrs());
617       return success();
618     }
619 
620     return failure();
621   }
622 };
623 } // namespace
624 
625 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
626                                                 MLIRContext *context) {
627   patterns
628       .add<MergeAssumingAllOps, AssumingAllOneOp,
629            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
630            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
631 }
632 
633 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
634   // Iterate in reverse to first handle all constant operands. They are
635   // guaranteed to be the tail of the inputs because this is commutative.
636   for (int idx = operands.size() - 1; idx >= 0; idx--) {
637     Attribute a = operands[idx];
638     // Cannot fold if any inputs are not constant;
639     if (!a)
640       return nullptr;
641 
642     // We do not need to keep statically known values after handling them in
643     // this method.
644     getOperation()->eraseOperand(idx);
645 
646     // Always false if any input is statically known false
647     if (!a.cast<BoolAttr>().getValue())
648       return a;
649   }
650   // If this is reached, all inputs were statically known passing.
651   return BoolAttr::get(getContext(), true);
652 }
653 
654 LogicalResult AssumingAllOp::verify() {
655   // Ensure that AssumingAllOp contains at least one operand
656   if (getNumOperands() == 0)
657     return emitOpError("no operands specified");
658 
659   return success();
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // BroadcastOp
664 //===----------------------------------------------------------------------===//
665 
666 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
667   if (getShapes().size() == 1) {
668     // Otherwise, we need a cast which would be a canonicalization, not folding.
669     if (getShapes().front().getType() != getType())
670       return nullptr;
671     return getShapes().front();
672   }
673 
674   // TODO: Support folding with more than 2 input shapes
675   if (getShapes().size() > 2)
676     return nullptr;
677 
678   if (!operands[0] || !operands[1])
679     return nullptr;
680   auto lhsShape = llvm::to_vector<6>(
681       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
682   auto rhsShape = llvm::to_vector<6>(
683       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
684   SmallVector<int64_t, 6> resultShape;
685 
686   // If the shapes are not compatible, we can't fold it.
687   // TODO: Fold to an "error".
688   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
689     return nullptr;
690 
691   Builder builder(getContext());
692   return builder.getIndexTensorAttr(resultShape);
693 }
694 
695 LogicalResult BroadcastOp::verify() {
696   return verifyShapeOrExtentTensorOp(*this);
697 }
698 
699 namespace {
700 template <typename OpTy>
701 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
702   using OpRewritePattern<OpTy>::OpRewritePattern;
703 
704   LogicalResult matchAndRewrite(OpTy op,
705                                 PatternRewriter &rewriter) const override {
706     auto isPotentiallyNonEmptyShape = [](Value shape) {
707       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
708         if (extentTensorTy.getDimSize(0) == 0)
709           return false;
710       }
711       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
712         if (constShape.getShape().empty())
713           return false;
714       }
715       return true;
716     };
717     auto newOperands = llvm::to_vector<8>(
718         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
719 
720     // Reduce op to equivalent without empty shape operands.
721     if (newOperands.size() < op.getNumOperands()) {
722       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
723                                         op->getAttrs());
724       return success();
725     }
726 
727     return failure();
728   }
729 };
730 
731 struct BroadcastForwardSingleOperandPattern
732     : public OpRewritePattern<BroadcastOp> {
733   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
734 
735   LogicalResult matchAndRewrite(BroadcastOp op,
736                                 PatternRewriter &rewriter) const override {
737     if (op.getNumOperands() != 1)
738       return failure();
739     Value replacement = op.getShapes().front();
740 
741     // Insert cast if needed.
742     if (replacement.getType() != op.getType()) {
743       auto loc = op.getLoc();
744       if (op.getType().isa<ShapeType>()) {
745         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
746       } else {
747         assert(!op.getType().isa<ShapeType>() &&
748                !replacement.getType().isa<ShapeType>() &&
749                "expect extent tensor cast");
750         replacement =
751             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
752       }
753     }
754 
755     rewriter.replaceOp(op, replacement);
756     return success();
757   }
758 };
759 
760 struct BroadcastFoldConstantOperandsPattern
761     : public OpRewritePattern<BroadcastOp> {
762   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
763 
764   LogicalResult matchAndRewrite(BroadcastOp op,
765                                 PatternRewriter &rewriter) const override {
766     SmallVector<int64_t, 8> foldedConstantShape;
767     SmallVector<Value, 8> newShapeOperands;
768     for (Value shape : op.getShapes()) {
769       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
770         SmallVector<int64_t, 8> newFoldedConstantShape;
771         if (OpTrait::util::getBroadcastedShape(
772                 foldedConstantShape,
773                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
774                 newFoldedConstantShape)) {
775           foldedConstantShape = newFoldedConstantShape;
776           continue;
777         }
778       }
779       newShapeOperands.push_back(shape);
780     }
781 
782     // Need at least two constant operands to fold anything.
783     if (op.getNumOperands() - newShapeOperands.size() < 2)
784       return failure();
785 
786     auto foldedConstantOperandsTy = RankedTensorType::get(
787         {static_cast<int64_t>(foldedConstantShape.size())},
788         rewriter.getIndexType());
789     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
790         op.getLoc(), foldedConstantOperandsTy,
791         rewriter.getIndexTensorAttr(foldedConstantShape)));
792     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
793                                              newShapeOperands);
794     return success();
795   }
796 };
797 
798 template <typename OpTy>
799 struct CanonicalizeCastExtentTensorOperandsPattern
800     : public OpRewritePattern<OpTy> {
801   using OpRewritePattern<OpTy>::OpRewritePattern;
802 
803   LogicalResult matchAndRewrite(OpTy op,
804                                 PatternRewriter &rewriter) const override {
805     // Canonicalize operands.
806     bool anyChange = false;
807     auto canonicalizeOperand = [&](Value operand) {
808       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
809         // Only eliminate the cast if it holds no shape information.
810         bool isInformationLoosingCast =
811             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
812         if (isInformationLoosingCast) {
813           anyChange = true;
814           return castOp.source();
815         }
816       }
817       return operand;
818     };
819     auto newOperands = llvm::to_vector<8>(
820         llvm::map_range(op.getOperands(), canonicalizeOperand));
821 
822     // Rewrite op if any change required.
823     if (!anyChange)
824       return failure();
825     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
826     return success();
827   }
828 };
829 
830 struct BroadcastConcretizeResultTypePattern
831     : public OpRewritePattern<BroadcastOp> {
832   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
833 
834   LogicalResult matchAndRewrite(BroadcastOp op,
835                                 PatternRewriter &rewriter) const override {
836     // Only concretize dynamic extent tensor result types.
837     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
838     if (!resultTy || !resultTy.isDynamicDim(0))
839       return failure();
840 
841     // Infer resulting shape rank if possible.
842     int64_t maxRank = 0;
843     for (Value shape : op.getShapes()) {
844       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
845         // Cannot infer resulting shape rank if any operand is dynamically
846         // ranked.
847         if (extentTensorTy.isDynamicDim(0))
848           return failure();
849         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
850       }
851     }
852 
853     auto newOp = rewriter.create<BroadcastOp>(
854         op.getLoc(), getExtentTensorType(getContext(), maxRank),
855         op.getShapes());
856     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
857     return success();
858   }
859 };
860 } // namespace
861 
862 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
863                                               MLIRContext *context) {
864   patterns.add<BroadcastConcretizeResultTypePattern,
865                BroadcastFoldConstantOperandsPattern,
866                BroadcastForwardSingleOperandPattern,
867                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
868                RemoveDuplicateOperandsPattern<BroadcastOp>,
869                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
870 }
871 
872 //===----------------------------------------------------------------------===//
873 // ConcatOp
874 //===----------------------------------------------------------------------===//
875 
876 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
877   if (!operands[0] || !operands[1])
878     return nullptr;
879   auto lhsShape = llvm::to_vector<6>(
880       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
881   auto rhsShape = llvm::to_vector<6>(
882       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
883   SmallVector<int64_t, 6> resultShape;
884   resultShape.append(lhsShape.begin(), lhsShape.end());
885   resultShape.append(rhsShape.begin(), rhsShape.end());
886   Builder builder(getContext());
887   return builder.getIndexTensorAttr(resultShape);
888 }
889 
890 //===----------------------------------------------------------------------===//
891 // ConstShapeOp
892 //===----------------------------------------------------------------------===//
893 
894 void ConstShapeOp::print(OpAsmPrinter &p) {
895   p << " ";
896   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
897   p << "[";
898   interleaveComma(getShape().getValues<int64_t>(), p);
899   p << "] : ";
900   p.printType(getType());
901 }
902 
903 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
904   if (parser.parseOptionalAttrDict(result.attributes))
905     return failure();
906   // We piggy-back on ArrayAttr parsing, though we don't internally store the
907   // shape as an ArrayAttr.
908   // TODO: Implement custom parser and maybe make syntax a bit more concise.
909   Attribute extentsRaw;
910   NamedAttrList dummy;
911   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
912     return failure();
913   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
914   if (!extentsArray)
915     return failure();
916   SmallVector<int64_t, 6> ints;
917   for (Attribute extent : extentsArray) {
918     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
919     if (!attr)
920       return failure();
921     ints.push_back(attr.getInt());
922   }
923   Builder &builder = parser.getBuilder();
924   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
925   Type resultTy;
926   if (parser.parseColonType(resultTy))
927     return failure();
928   result.types.push_back(resultTy);
929   return success();
930 }
931 
932 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
933 
934 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
935                                                MLIRContext *context) {
936   patterns.add<TensorCastConstShape>(context);
937 }
938 
939 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
940     MLIRContext *context, Optional<Location> location, ValueRange operands,
941     DictionaryAttr attributes, RegionRange regions,
942     SmallVectorImpl<Type> &inferredReturnTypes) {
943   Builder b(context);
944   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
945   if (!shape)
946     return emitOptionalError(location, "missing shape attribute");
947   inferredReturnTypes.assign({RankedTensorType::get(
948       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
949   return success();
950 }
951 
952 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
953                                                         TypeRange r) {
954   if (l.size() != 1 || r.size() != 1)
955     return false;
956 
957   Type lhs = l.front();
958   Type rhs = r.front();
959 
960   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
961     // Shape type is compatible with all other valid return types.
962     return true;
963   return lhs == rhs;
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // CstrBroadcastableOp
968 //===----------------------------------------------------------------------===//
969 
970 void CstrBroadcastableOp::getCanonicalizationPatterns(
971     RewritePatternSet &patterns, MLIRContext *context) {
972   // Canonicalization patterns have overlap with the considerations during
973   // folding in case additional shape information is inferred at some point that
974   // does not result in folding.
975   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
976                CstrBroadcastableEqOps,
977                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
978                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
979 }
980 
981 // Return true if there is exactly one attribute not representing a scalar
982 // broadcast.
983 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
984   bool nonScalarSeen = false;
985   for (Attribute a : attributes) {
986     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
987       if (nonScalarSeen)
988         return false;
989       nonScalarSeen = true;
990     }
991   }
992   return true;
993 }
994 
995 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
996   // No broadcasting is needed if all operands but one are scalar.
997   if (hasAtMostSingleNonScalar(operands))
998     return BoolAttr::get(getContext(), true);
999 
1000   if ([&] {
1001         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1002         for (const auto &operand : operands) {
1003           if (!operand)
1004             return false;
1005           extents.push_back(llvm::to_vector<6>(
1006               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
1007         }
1008         return OpTrait::util::staticallyKnownBroadcastable(extents);
1009       }())
1010     return BoolAttr::get(getContext(), true);
1011 
1012   // Lastly, see if folding can be completed based on what constraints are known
1013   // on the input shapes.
1014   if ([&] {
1015         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1016         for (auto shapeValue : getShapes()) {
1017           extents.emplace_back();
1018           if (failed(getShapeVec(shapeValue, extents.back())))
1019             return false;
1020         }
1021         return OpTrait::util::staticallyKnownBroadcastable(extents);
1022       }())
1023     return BoolAttr::get(getContext(), true);
1024 
1025   // Because a failing witness result here represents an eventual assertion
1026   // failure, we do not replace it with a constant witness.
1027   return nullptr;
1028 }
1029 
1030 LogicalResult CstrBroadcastableOp::verify() {
1031   // Ensure that CstrBroadcastableOp contains at least two operands
1032   if (getNumOperands() < 2)
1033     return emitOpError("required at least 2 input shapes");
1034   return success();
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // CstrEqOp
1039 //===----------------------------------------------------------------------===//
1040 
1041 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1042                                            MLIRContext *context) {
1043   // If inputs are equal, return passing witness
1044   patterns.add<CstrEqEqOps>(context);
1045 }
1046 
1047 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1048   if (llvm::all_of(operands,
1049                    [&](Attribute a) { return a && a == operands[0]; }))
1050     return BoolAttr::get(getContext(), true);
1051 
1052   // Because a failing witness result here represents an eventual assertion
1053   // failure, we do not try to replace it with a constant witness. Similarly, we
1054   // cannot if there are any non-const inputs.
1055   return nullptr;
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // ConstSizeOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1063                         int64_t value) {
1064   build(builder, result, builder.getIndexAttr(value));
1065 }
1066 
1067 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1068 
1069 void ConstSizeOp::getAsmResultNames(
1070     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1071   SmallString<4> buffer;
1072   llvm::raw_svector_ostream os(buffer);
1073   os << "c" << getValue();
1074   setNameFn(getResult(), os.str());
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // ConstWitnessOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1082   return getPassingAttr();
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // CstrRequireOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1090   return operands[0];
1091 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // DivOp
1095 //===----------------------------------------------------------------------===//
1096 
1097 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1098   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1099   if (!lhs)
1100     return nullptr;
1101   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1102   if (!rhs)
1103     return nullptr;
1104 
1105   // Division in APInt does not follow floor(lhs, rhs) when the result is
1106   // negative. Rather, APInt rounds toward zero.
1107   APInt quotient, remainder;
1108   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1109   if (quotient.isNegative() && !remainder.isNullValue()) {
1110     quotient -= 1;
1111   }
1112 
1113   Type indexTy = IndexType::get(getContext());
1114   return IntegerAttr::get(indexTy, quotient);
1115 }
1116 
1117 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1118     MLIRContext *context, Optional<Location> location, ValueRange operands,
1119     DictionaryAttr attributes, RegionRange regions,
1120     SmallVectorImpl<Type> &inferredReturnTypes) {
1121   if (operands[0].getType().isa<SizeType>() ||
1122       operands[1].getType().isa<SizeType>())
1123     inferredReturnTypes.assign({SizeType::get(context)});
1124   else
1125     inferredReturnTypes.assign({IndexType::get(context)});
1126   return success();
1127 }
1128 
1129 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1130   // SizeType is compatible with IndexType.
1131   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1132 }
1133 
1134 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // ShapeEqOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1141   bool allSame = true;
1142   if (!operands.empty() && !operands[0])
1143     return {};
1144   for (Attribute operand : operands.drop_front(1)) {
1145     if (!operand)
1146       return {};
1147     allSame = allSame && operand == operands[0];
1148   }
1149   return BoolAttr::get(getContext(), allSame);
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // IndexToSizeOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1157   // Constant values of both types, `shape.size` and `index`, are represented as
1158   // `IntegerAttr`s which makes constant folding simple.
1159   if (Attribute arg = operands[0])
1160     return arg;
1161   return {};
1162 }
1163 
1164 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1165                                                 MLIRContext *context) {
1166   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // FromExtentsOp
1171 //===----------------------------------------------------------------------===//
1172 
1173 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1174   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1175     return nullptr;
1176   SmallVector<int64_t, 6> extents;
1177   for (auto attr : operands)
1178     extents.push_back(attr.cast<IntegerAttr>().getInt());
1179   Builder builder(getContext());
1180   return builder.getIndexTensorAttr(extents);
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // FunctionLibraryOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1188                               StringRef name) {
1189   result.attributes.push_back(builder.getNamedAttr(
1190       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1191 }
1192 
1193 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1194   auto attr = getMapping()
1195                   .get(op->getName().getIdentifier())
1196                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1197   if (!attr)
1198     return nullptr;
1199   return lookupSymbol<FuncOp>(attr);
1200 }
1201 
1202 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1203                                      OperationState &result) {
1204   // Parse the op name.
1205   StringAttr nameAttr;
1206   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1207                              result.attributes))
1208     return failure();
1209 
1210   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1211     return failure();
1212 
1213   auto *bodyRegion = result.addRegion();
1214   if (parser.parseRegion(*bodyRegion))
1215     return failure();
1216 
1217   if (parser.parseKeyword("mapping"))
1218     return failure();
1219 
1220   DictionaryAttr mappingAttr;
1221   if (parser.parseAttribute(mappingAttr,
1222                             parser.getBuilder().getType<NoneType>(), "mapping",
1223                             result.attributes))
1224     return failure();
1225   return success();
1226 }
1227 
1228 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1229   p << ' ';
1230   p.printSymbolName(getName());
1231   p.printOptionalAttrDictWithKeyword(
1232       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1233   p << ' ';
1234   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1235                 /*printBlockTerminators=*/false);
1236   p << " mapping ";
1237   p.printAttributeWithoutType(getMappingAttr());
1238 }
1239 
1240 //===----------------------------------------------------------------------===//
1241 // GetExtentOp
1242 //===----------------------------------------------------------------------===//
1243 
1244 Optional<int64_t> GetExtentOp::getConstantDim() {
1245   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1246     return constSizeOp.getValue().getLimitedValue();
1247   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1248     return constantOp.getValue().cast<IntegerAttr>().getInt();
1249   return llvm::None;
1250 }
1251 
1252 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1253   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1254   if (!elements)
1255     return nullptr;
1256   Optional<int64_t> dim = getConstantDim();
1257   if (!dim.hasValue())
1258     return nullptr;
1259   if (dim.getValue() >= elements.getNumElements())
1260     return nullptr;
1261   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1262 }
1263 
1264 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1265                         int64_t dim) {
1266   auto loc = result.location;
1267   auto dimAttr = builder.getIndexAttr(dim);
1268   if (shape.getType().isa<ShapeType>()) {
1269     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1270     build(builder, result, builder.getType<SizeType>(), shape, dim);
1271   } else {
1272     Value dim =
1273         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1274     build(builder, result, builder.getIndexType(), shape, dim);
1275   }
1276 }
1277 
1278 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1279     MLIRContext *context, Optional<Location> location, ValueRange operands,
1280     DictionaryAttr attributes, RegionRange regions,
1281     SmallVectorImpl<Type> &inferredReturnTypes) {
1282   inferredReturnTypes.assign({IndexType::get(context)});
1283   return success();
1284 }
1285 
1286 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1287                                                        TypeRange r) {
1288   // SizeType is compatible with IndexType.
1289   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1290 }
1291 
1292 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1293 
1294 //===----------------------------------------------------------------------===//
1295 // IsBroadcastableOp
1296 //===----------------------------------------------------------------------===//
1297 
1298 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1299                                                     MLIRContext *context) {
1300   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1301 }
1302 
1303 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1304   // Can always broadcast fewer than two shapes.
1305   if (operands.size() < 2) {
1306     return BoolAttr::get(getContext(), true);
1307   }
1308 
1309   return nullptr;
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // MeetOp
1314 //===----------------------------------------------------------------------===//
1315 
1316 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1317     MLIRContext *context, Optional<Location> location, ValueRange operands,
1318     DictionaryAttr attributes, RegionRange regions,
1319     SmallVectorImpl<Type> &inferredReturnTypes) {
1320   inferredReturnTypes.assign({operands[0].getType()});
1321   return success();
1322 }
1323 
1324 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1325   if (l.size() != 1 || r.size() != 1)
1326     return false;
1327   if (l == r)
1328     return true;
1329 
1330   Type lhs = l.front();
1331   Type rhs = r.front();
1332 
1333   if (lhs != rhs)
1334     return false;
1335 
1336   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1337     return true;
1338 
1339   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1340     return true;
1341   return false;
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // RankOp
1346 //===----------------------------------------------------------------------===//
1347 
1348 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1349   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1350   if (!shape)
1351     return {};
1352   int64_t rank = shape.getNumElements();
1353   Builder builder(getContext());
1354   return builder.getIndexAttr(rank);
1355 }
1356 
1357 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1358 /// Constant folding fails in cases where only the rank is constant, not the
1359 /// shape itself.
1360 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1361 ///
1362 /// Example:
1363 ///
1364 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1365 /// %rank = shape.rank %shape
1366 ///
1367 /// becomes
1368 ///
1369 /// %rank = shape.const_size 3
1370 
1371 namespace {
1372 struct RankShapeOfCanonicalizationPattern
1373     : public OpRewritePattern<shape::RankOp> {
1374   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1375 
1376   LogicalResult matchAndRewrite(shape::RankOp op,
1377                                 PatternRewriter &rewriter) const override {
1378     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1379     if (!shapeOfOp)
1380       return failure();
1381     auto rankedTensorType =
1382         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1383     if (!rankedTensorType)
1384       return failure();
1385     int64_t rank = rankedTensorType.getRank();
1386     if (op.getType().isa<IndexType>()) {
1387       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1388                                                           rank);
1389     } else if (op.getType().isa<shape::SizeType>()) {
1390       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1391     } else {
1392       return failure();
1393     }
1394     return success();
1395   }
1396 };
1397 } // namespace
1398 
1399 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1400                                                 MLIRContext *context) {
1401   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1402 }
1403 
1404 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1405     MLIRContext *context, Optional<Location> location, ValueRange operands,
1406     DictionaryAttr attributes, RegionRange regions,
1407     SmallVectorImpl<Type> &inferredReturnTypes) {
1408   if (operands[0].getType().isa<ShapeType>())
1409     inferredReturnTypes.assign({SizeType::get(context)});
1410   else
1411     inferredReturnTypes.assign({IndexType::get(context)});
1412   return success();
1413 }
1414 
1415 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1416   // SizeType is compatible with IndexType.
1417   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1418 }
1419 
1420 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // NumElementsOp
1424 //===----------------------------------------------------------------------===//
1425 
1426 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1427 
1428   // Fold only when argument constant.
1429   Attribute shape = operands[0];
1430   if (!shape)
1431     return {};
1432 
1433   APInt product(64, 1);
1434   for (auto value : shape.cast<DenseIntElementsAttr>())
1435     product *= value;
1436   Builder builder(getContext());
1437   return builder.getIndexAttr(product.getLimitedValue());
1438 }
1439 
1440 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1441     MLIRContext *context, Optional<Location> location, ValueRange operands,
1442     DictionaryAttr attributes, RegionRange regions,
1443     SmallVectorImpl<Type> &inferredReturnTypes) {
1444   if (operands[0].getType().isa<ShapeType>())
1445     inferredReturnTypes.assign({SizeType::get(context)});
1446   else
1447     inferredReturnTypes.assign({IndexType::get(context)});
1448   return success();
1449 }
1450 
1451 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1452                                                          TypeRange r) {
1453   // SizeType is compatible with IndexType.
1454   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1455 }
1456 
1457 LogicalResult shape::NumElementsOp::verify() {
1458   return verifySizeOrIndexOp(*this);
1459 }
1460 
1461 //===----------------------------------------------------------------------===//
1462 // MaxOp
1463 //===----------------------------------------------------------------------===//
1464 
1465 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1466   // If operands are equal, just propagate one.
1467   if (getLhs() == getRhs())
1468     return getLhs();
1469   return nullptr;
1470 }
1471 
1472 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1473     MLIRContext *context, Optional<Location> location, ValueRange operands,
1474     DictionaryAttr attributes, RegionRange regions,
1475     SmallVectorImpl<Type> &inferredReturnTypes) {
1476   if (operands[0].getType() == operands[1].getType())
1477     inferredReturnTypes.assign({operands[0].getType()});
1478   else
1479     inferredReturnTypes.assign({SizeType::get(context)});
1480   return success();
1481 }
1482 
1483 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1484   if (l.size() != 1 || r.size() != 1)
1485     return false;
1486   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1487     return true;
1488   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1489     return true;
1490   return false;
1491 }
1492 
1493 //===----------------------------------------------------------------------===//
1494 // MinOp
1495 //===----------------------------------------------------------------------===//
1496 
1497 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1498   // If operands are equal, just propagate one.
1499   if (getLhs() == getRhs())
1500     return getLhs();
1501   return nullptr;
1502 }
1503 
1504 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1505     MLIRContext *context, Optional<Location> location, ValueRange operands,
1506     DictionaryAttr attributes, RegionRange regions,
1507     SmallVectorImpl<Type> &inferredReturnTypes) {
1508   if (operands[0].getType() == operands[1].getType())
1509     inferredReturnTypes.assign({operands[0].getType()});
1510   else
1511     inferredReturnTypes.assign({SizeType::get(context)});
1512   return success();
1513 }
1514 
1515 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1516   if (l.size() != 1 || r.size() != 1)
1517     return false;
1518   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1519     return true;
1520   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1521     return true;
1522   return false;
1523 }
1524 
1525 //===----------------------------------------------------------------------===//
1526 // MulOp
1527 //===----------------------------------------------------------------------===//
1528 
1529 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1530   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1531   if (!lhs)
1532     return nullptr;
1533   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1534   if (!rhs)
1535     return nullptr;
1536   APInt folded = lhs.getValue() * rhs.getValue();
1537   Type indexTy = IndexType::get(getContext());
1538   return IntegerAttr::get(indexTy, folded);
1539 }
1540 
1541 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1542     MLIRContext *context, Optional<Location> location, ValueRange operands,
1543     DictionaryAttr attributes, RegionRange regions,
1544     SmallVectorImpl<Type> &inferredReturnTypes) {
1545   if (operands[0].getType().isa<SizeType>() ||
1546       operands[1].getType().isa<SizeType>())
1547     inferredReturnTypes.assign({SizeType::get(context)});
1548   else
1549     inferredReturnTypes.assign({IndexType::get(context)});
1550   return success();
1551 }
1552 
1553 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1554   // SizeType is compatible with IndexType.
1555   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1556 }
1557 
1558 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // ShapeOfOp
1562 //===----------------------------------------------------------------------===//
1563 
1564 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1565   auto type = getOperand().getType().dyn_cast<ShapedType>();
1566   if (!type || !type.hasStaticShape())
1567     return nullptr;
1568   Builder builder(getContext());
1569   return builder.getIndexTensorAttr(type.getShape());
1570 }
1571 
1572 namespace {
1573 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1574   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1575 
1576   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1577                                 PatternRewriter &rewriter) const override {
1578     if (!op.getArg().getType().isa<ShapedType>())
1579       return failure();
1580     if (op.getType().isa<ShapedType>())
1581       return failure();
1582 
1583     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1584                                                   op.getArg());
1585     return success();
1586   }
1587 };
1588 
1589 // Canonicalize
1590 // ```
1591 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1592 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1593 // ```
1594 // to
1595 // ```
1596 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1597 // ```
1598 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1599   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1600 
1601   LogicalResult matchAndRewrite(tensor::CastOp op,
1602                                 PatternRewriter &rewriter) const override {
1603     auto ty = op.getType().dyn_cast<RankedTensorType>();
1604     if (!ty || ty.getRank() != 1)
1605       return failure();
1606 
1607     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1608     if (!shapeOfOp)
1609       return failure();
1610 
1611     // Argument type must be ranked and must not conflict.
1612     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1613     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1614       return failure();
1615 
1616     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1617     return success();
1618   }
1619 };
1620 } // namespace
1621 
1622 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1623                                             MLIRContext *context) {
1624   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1625                ExtractFromShapeOfExtentTensor>(context);
1626 }
1627 
1628 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1629     MLIRContext *context, Optional<Location> location, ValueRange operands,
1630     DictionaryAttr attributes, RegionRange regions,
1631     SmallVectorImpl<Type> &inferredReturnTypes) {
1632   if (operands[0].getType().isa<ValueShapeType>())
1633     inferredReturnTypes.assign({ShapeType::get(context)});
1634   else {
1635     auto shapedTy = operands[0].getType().cast<ShapedType>();
1636     int64_t rank =
1637         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1638     Type indexTy = IndexType::get(context);
1639     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1640     inferredReturnTypes.assign({extentTensorTy});
1641   }
1642   return success();
1643 }
1644 
1645 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1646   if (l.size() != 1 || r.size() != 1)
1647     return false;
1648   if (l == r)
1649     return true;
1650 
1651   Type lhs = l.front();
1652   Type rhs = r.front();
1653 
1654   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1655     return false;
1656 
1657   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1658     // Shape type is compatible with all other valid return types.
1659     return true;
1660 
1661   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1662     return true;
1663   return false;
1664 }
1665 
1666 LogicalResult shape::ShapeOfOp::verify() {
1667   return verifyShapeOrExtentTensorOp(*this);
1668 }
1669 
1670 //===----------------------------------------------------------------------===//
1671 // SizeToIndexOp
1672 //===----------------------------------------------------------------------===//
1673 
1674 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1675   // Constant values of both types, `shape.size` and `index`, are represented as
1676   // `IntegerAttr`s which makes constant folding simple.
1677   if (Attribute arg = operands[0])
1678     return arg;
1679   return OpFoldResult();
1680 }
1681 
1682 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1683                                                 MLIRContext *context) {
1684   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1685 }
1686 
1687 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1688   if (inputs.size() != 1 || outputs.size() != 1)
1689     return false;
1690   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1691 }
1692 
1693 //===----------------------------------------------------------------------===//
1694 // YieldOp
1695 //===----------------------------------------------------------------------===//
1696 
1697 LogicalResult shape::YieldOp::verify() {
1698   auto *parentOp = (*this)->getParentOp();
1699   auto results = parentOp->getResults();
1700   auto operands = getOperands();
1701 
1702   if (parentOp->getNumResults() != getNumOperands())
1703     return emitOpError() << "number of operands does not match number of "
1704                             "results of its parent";
1705   for (auto e : llvm::zip(results, operands))
1706     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1707       return emitOpError() << "types mismatch between yield op and its parent";
1708 
1709   return success();
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // SplitAtOp
1714 //===----------------------------------------------------------------------===//
1715 
1716 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1717                               SmallVectorImpl<OpFoldResult> &results) {
1718   if (!operands[0] || !operands[1])
1719     return failure();
1720   auto shapeVec = llvm::to_vector<6>(
1721       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1722   auto shape = llvm::makeArrayRef(shapeVec);
1723   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1724   // Verify that the split point is in the correct range.
1725   // TODO: Constant fold to an "error".
1726   int64_t rank = shape.size();
1727   if (!(-rank <= splitPoint && splitPoint <= rank))
1728     return failure();
1729   if (splitPoint < 0)
1730     splitPoint += shape.size();
1731   Builder builder(operands[0].getContext());
1732   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1733   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1734   return success();
1735 }
1736 
1737 //===----------------------------------------------------------------------===//
1738 // ToExtentTensorOp
1739 //===----------------------------------------------------------------------===//
1740 
1741 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1742   if (!operands[0])
1743     return OpFoldResult();
1744   Builder builder(getContext());
1745   auto shape = llvm::to_vector<6>(
1746       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1747   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1748                                     builder.getIndexType());
1749   return DenseIntElementsAttr::get(type, shape);
1750 }
1751 
1752 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1753   if (inputs.size() != 1 || outputs.size() != 1)
1754     return false;
1755   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1756     if (!inputTensor.getElementType().isa<IndexType>() ||
1757         inputTensor.getRank() != 1)
1758       return false;
1759   } else if (!inputs[0].isa<ShapeType>()) {
1760     return false;
1761   }
1762 
1763   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1764   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1765 }
1766 
1767 //===----------------------------------------------------------------------===//
1768 // ReduceOp
1769 //===----------------------------------------------------------------------===//
1770 
1771 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1772                      ValueRange initVals) {
1773   result.addOperands(shape);
1774   result.addOperands(initVals);
1775 
1776   Region *bodyRegion = result.addRegion();
1777   bodyRegion->push_back(new Block);
1778   Block &bodyBlock = bodyRegion->front();
1779   bodyBlock.addArgument(builder.getIndexType(), result.location);
1780 
1781   Type elementType;
1782   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1783     elementType = tensorType.getElementType();
1784   else
1785     elementType = SizeType::get(builder.getContext());
1786   bodyBlock.addArgument(elementType, shape.getLoc());
1787 
1788   for (Value initVal : initVals) {
1789     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1790     result.addTypes(initVal.getType());
1791   }
1792 }
1793 
1794 LogicalResult ReduceOp::verify() {
1795   // Verify block arg types.
1796   Block &block = getRegion().front();
1797 
1798   // The block takes index, extent, and aggregated values as arguments.
1799   auto blockArgsCount = getInitVals().size() + 2;
1800   if (block.getNumArguments() != blockArgsCount)
1801     return emitOpError() << "ReduceOp body is expected to have "
1802                          << blockArgsCount << " arguments";
1803 
1804   // The first block argument is the index and must always be of type `index`.
1805   if (!block.getArgument(0).getType().isa<IndexType>())
1806     return emitOpError(
1807         "argument 0 of ReduceOp body is expected to be of IndexType");
1808 
1809   // The second block argument is the extent and must be of type `size` or
1810   // `index`, depending on whether the reduce operation is applied to a shape or
1811   // to an extent tensor.
1812   Type extentTy = block.getArgument(1).getType();
1813   if (getShape().getType().isa<ShapeType>()) {
1814     if (!extentTy.isa<SizeType>())
1815       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1816                          "SizeType if the ReduceOp operates on a ShapeType");
1817   } else {
1818     if (!extentTy.isa<IndexType>())
1819       return emitOpError(
1820           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1821           "ReduceOp operates on an extent tensor");
1822   }
1823 
1824   for (const auto &type : llvm::enumerate(getInitVals()))
1825     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1826       return emitOpError() << "type mismatch between argument "
1827                            << type.index() + 2
1828                            << " of ReduceOp body and initial value "
1829                            << type.index();
1830   return success();
1831 }
1832 
1833 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1834   // Parse operands.
1835   SmallVector<OpAsmParser::OperandType, 3> operands;
1836   Type shapeOrExtentTensorType;
1837   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1838                               OpAsmParser::Delimiter::Paren) ||
1839       parser.parseColonType(shapeOrExtentTensorType) ||
1840       parser.parseOptionalArrowTypeList(result.types))
1841     return failure();
1842 
1843   // Resolve operands.
1844   auto initVals = llvm::makeArrayRef(operands).drop_front();
1845   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1846                             result.operands) ||
1847       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1848                              result.operands))
1849     return failure();
1850 
1851   // Parse the body.
1852   Region *body = result.addRegion();
1853   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1854     return failure();
1855 
1856   // Parse attributes.
1857   if (parser.parseOptionalAttrDict(result.attributes))
1858     return failure();
1859 
1860   return success();
1861 }
1862 
1863 void ReduceOp::print(OpAsmPrinter &p) {
1864   p << '(' << getShape() << ", " << getInitVals()
1865     << ") : " << getShape().getType();
1866   p.printOptionalArrowTypeList(getResultTypes());
1867   p << ' ';
1868   p.printRegion(getRegion());
1869   p.printOptionalAttrDict((*this)->getAttrs());
1870 }
1871 
1872 #define GET_OP_CLASSES
1873 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1874