1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallString.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 using namespace mlir::shape;
32 
33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
34 
35 namespace {
36 #include "ShapeCanonicalization.inc"
37 } // namespace
38 
39 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
40   return RankedTensorType::get({rank}, IndexType::get(ctx));
41 }
42 
43 bool shape::isExtentTensorType(Type type) {
44   auto ranked = type.dyn_cast<RankedTensorType>();
45   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
46 }
47 
48 LogicalResult shape::getShapeVec(Value input,
49                                  SmallVectorImpl<int64_t> &shapeValues) {
50   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
51     auto type = inputOp.getArg().getType().cast<ShapedType>();
52     if (!type.hasRank())
53       return failure();
54     llvm::append_range(shapeValues, type.getShape());
55     return success();
56   }
57   DenseIntElementsAttr attr;
58   if (matchPattern(input, m_Constant(&attr))) {
59     llvm::append_range(shapeValues, attr.getValues<int64_t>());
60     return success();
61   }
62   return failure();
63 }
64 
65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66   return llvm::any_of(operandTypes, [](Type ty) {
67     return ty.isa<SizeType, ShapeType, ValueShapeType>();
68   });
69 }
70 
71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72   assert(op != nullptr && op->getNumResults() == 1);
73   Type resultTy = op->getResultTypes().front();
74   if (isErrorPropagationPossible(op->getOperandTypes())) {
75     if (!resultTy.isa<SizeType>())
76       return op->emitOpError()
77              << "if at least one of the operands can hold error values then "
78                 "the result must be of type `size` to propagate them";
79   }
80   return success();
81 }
82 
83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84   assert(op != nullptr && op->getNumResults() == 1);
85   Type resultTy = op->getResultTypes().front();
86   if (isErrorPropagationPossible(op->getOperandTypes())) {
87     if (!resultTy.isa<ShapeType>())
88       return op->emitOpError()
89              << "if at least one of the operands can hold error values then "
90                 "the result must be of type `shape` to propagate them";
91   }
92   return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   // Returns true if the given region 'src' can be inlined into the region
115   // 'dest' that is attached to an operation registered to the current dialect.
116   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   // Returns true if the given operation 'op', that is registered to this
122   // dialect, can be inlined into the region 'dest' that is attached to an
123   // operation registered to the current dialect.
124   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125                        BlockAndValueMapping &) const final {
126     return true;
127   }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132   addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135       >();
136   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
137   addInterfaces<ShapeInlinerInterface>();
138   // Allow unknown operations during prototyping and testing. As the dialect is
139   // still evolving it makes it simple to start with an unregistered ops and
140   // try different variants before actually defining the op.
141   allowUnknownOperations();
142 }
143 
144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
145                                              Attribute value, Type type,
146                                              Location loc) {
147   if (type.isa<ShapeType>() || isExtentTensorType(type))
148     return builder.create<ConstShapeOp>(loc, type,
149                                         value.cast<DenseIntElementsAttr>());
150   if (type.isa<SizeType>())
151     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
152   if (type.isa<WitnessType>())
153     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
154   if (arith::ConstantOp::isBuildableWith(value, type))
155     return builder.create<arith::ConstantOp>(loc, type, value);
156   return nullptr;
157 }
158 
159 /// Parse a type registered to this dialect.
160 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
161   StringRef keyword;
162   if (parser.parseKeyword(&keyword))
163     return Type();
164 
165   if (keyword == "shape")
166     return ShapeType::get(getContext());
167   if (keyword == "size")
168     return SizeType::get(getContext());
169   if (keyword == "value_shape")
170     return ValueShapeType::get(getContext());
171   if (keyword == "witness")
172     return WitnessType::get(getContext());
173 
174   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
175   return Type();
176 }
177 
178 /// Print a type registered to this dialect.
179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
180   TypeSwitch<Type>(type)
181       .Case<ShapeType>([&](Type) { os << "shape"; })
182       .Case<SizeType>([&](Type) { os << "size"; })
183       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
184       .Case<WitnessType>([&](Type) { os << "witness"; })
185       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
186 }
187 
188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
189                                                      NamedAttribute attribute) {
190   // Verify shape.lib attribute.
191   if (attribute.getName() == "shape.lib") {
192     if (!op->hasTrait<OpTrait::SymbolTable>())
193       return op->emitError(
194           "shape.lib attribute may only be on op implementing SymbolTable");
195 
196     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
197       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
198       if (!symbol)
199         return op->emitError("shape function library ")
200                << symbolRef << " not found";
201       return isa<shape::FunctionLibraryOp>(symbol)
202                  ? success()
203                  : op->emitError()
204                        << symbolRef << " required to be shape function library";
205     }
206 
207     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
208       // Verify all entries are function libraries and mappings in libraries
209       // refer to unique ops.
210       DenseSet<StringAttr> key;
211       for (auto it : arr) {
212         if (!it.isa<SymbolRefAttr>())
213           return op->emitError(
214               "only SymbolRefAttr allowed in shape.lib attribute array");
215 
216         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
217             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
218         if (!shapeFnLib)
219           return op->emitError()
220                  << it << " does not refer to FunctionLibraryOp";
221         for (auto mapping : shapeFnLib.getMapping()) {
222           if (!key.insert(mapping.getName()).second) {
223             return op->emitError("only one op to shape mapping allowed, found "
224                                  "multiple for `")
225                    << mapping.getName() << "`";
226           }
227         }
228       }
229       return success();
230     }
231 
232     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
233                          "allowed as shape.lib attribute");
234   }
235   return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // AnyOp
240 //===----------------------------------------------------------------------===//
241 
242 // TODO: Canonicalization should be implemented for shapes that can be
243 // determined through mixtures of the known dimensions of the inputs.
244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
245   // Only the last operand is checked because AnyOp is commutative.
246   if (operands.back())
247     return operands.back();
248 
249   return nullptr;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // AssumingOp
254 //===----------------------------------------------------------------------===//
255 
256 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
257   result.regions.reserve(1);
258   Region *doRegion = result.addRegion();
259 
260   auto &builder = parser.getBuilder();
261   OpAsmParser::UnresolvedOperand cond;
262   if (parser.parseOperand(cond) ||
263       parser.resolveOperand(cond, builder.getType<WitnessType>(),
264                             result.operands))
265     return failure();
266 
267   // Parse optional results type list.
268   if (parser.parseOptionalArrowTypeList(result.types))
269     return failure();
270 
271   // Parse the region and add a terminator if elided.
272   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
273     return failure();
274   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
275 
276   // Parse the optional attribute list.
277   if (parser.parseOptionalAttrDict(result.attributes))
278     return failure();
279   return success();
280 }
281 
282 void AssumingOp::print(OpAsmPrinter &p) {
283   bool yieldsResults = !getResults().empty();
284 
285   p << " " << getWitness();
286   if (yieldsResults)
287     p << " -> (" << getResultTypes() << ")";
288   p << ' ';
289   p.printRegion(getDoRegion(),
290                 /*printEntryBlockArgs=*/false,
291                 /*printBlockTerminators=*/yieldsResults);
292   p.printOptionalAttrDict((*this)->getAttrs());
293 }
294 
295 namespace {
296 // Removes AssumingOp with a passing witness and inlines the region.
297 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
298   using OpRewritePattern<AssumingOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(AssumingOp op,
301                                 PatternRewriter &rewriter) const override {
302     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
303     if (!witness || !witness.getPassingAttr())
304       return failure();
305 
306     AssumingOp::inlineRegionIntoParent(op, rewriter);
307     return success();
308   }
309 };
310 
311 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
312   using OpRewritePattern<AssumingOp>::OpRewritePattern;
313 
314   LogicalResult matchAndRewrite(AssumingOp op,
315                                 PatternRewriter &rewriter) const override {
316     Block *body = op.getBody();
317     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
318 
319     // Find used values.
320     SmallVector<Value, 4> newYieldOperands;
321     Value opResult, yieldOperand;
322     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
323       std::tie(opResult, yieldOperand) = it;
324       if (!opResult.getUses().empty()) {
325         newYieldOperands.push_back(yieldOperand);
326       }
327     }
328 
329     // Rewrite only if redundant results exist.
330     if (newYieldOperands.size() == yieldOp->getNumOperands())
331       return failure();
332 
333     // Replace yield op in the old assuming op's body and move the entire region
334     // to the new assuming op.
335     rewriter.setInsertionPointToEnd(body);
336     auto newYieldOp =
337         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
338     rewriter.setInsertionPoint(op);
339     auto newOp = rewriter.create<AssumingOp>(
340         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
341     newOp.getDoRegion().takeBody(op.getDoRegion());
342 
343     // Use the new results to replace the previously used ones.
344     SmallVector<Value, 4> replacementValues;
345     auto src = newOp.getResults().begin();
346     for (auto it : op.getResults()) {
347       if (it.getUses().empty())
348         replacementValues.push_back(nullptr);
349       else
350         replacementValues.push_back(*src++);
351     }
352     rewriter.replaceOp(op, replacementValues);
353     return success();
354   }
355 };
356 } // namespace
357 
358 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
359                                              MLIRContext *context) {
360   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
361 }
362 
363 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
364 void AssumingOp::getSuccessorRegions(
365     Optional<unsigned> index, ArrayRef<Attribute> operands,
366     SmallVectorImpl<RegionSuccessor> &regions) {
367   // AssumingOp has unconditional control flow into the region and back to the
368   // parent, so return the correct RegionSuccessor purely based on the index
369   // being None or 0.
370   if (index.hasValue()) {
371     regions.push_back(RegionSuccessor(getResults()));
372     return;
373   }
374 
375   regions.push_back(RegionSuccessor(&getDoRegion()));
376 }
377 
378 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
379                                         PatternRewriter &rewriter) {
380   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
381   auto *assumingBlock = op.getBody();
382   auto initPosition = rewriter.getInsertionPoint();
383   auto *blockAfterAssuming =
384       rewriter.splitBlock(blockBeforeAssuming, initPosition);
385 
386   // Remove the AssumingOp and AssumingYieldOp.
387   auto &yieldOp = assumingBlock->back();
388   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
389   rewriter.replaceOp(op, yieldOp.getOperands());
390   rewriter.eraseOp(&yieldOp);
391 
392   // Merge blocks together as there was no branching behavior from the
393   // AssumingOp.
394   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
395   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
396 }
397 
398 void AssumingOp::build(
399     OpBuilder &builder, OperationState &result, Value witness,
400     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
401 
402   result.addOperands(witness);
403   Region *bodyRegion = result.addRegion();
404   bodyRegion->push_back(new Block);
405   Block &bodyBlock = bodyRegion->front();
406 
407   // Build body.
408   OpBuilder::InsertionGuard guard(builder);
409   builder.setInsertionPointToStart(&bodyBlock);
410   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
411   builder.create<AssumingYieldOp>(result.location, yieldValues);
412 
413   SmallVector<Type, 2> assumingTypes;
414   for (Value v : yieldValues)
415     assumingTypes.push_back(v.getType());
416   result.addTypes(assumingTypes);
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // AddOp
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult mlir::shape::AddOp::inferReturnTypes(
424     MLIRContext *context, Optional<Location> location, ValueRange operands,
425     DictionaryAttr attributes, RegionRange regions,
426     SmallVectorImpl<Type> &inferredReturnTypes) {
427   if (operands[0].getType().isa<SizeType>() ||
428       operands[1].getType().isa<SizeType>())
429     inferredReturnTypes.assign({SizeType::get(context)});
430   else
431     inferredReturnTypes.assign({IndexType::get(context)});
432   return success();
433 }
434 
435 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
436   // SizeType is compatible with IndexType.
437   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
438 }
439 
440 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
441   // add(x, 0) -> x
442   if (matchPattern(getRhs(), m_Zero()))
443     return getLhs();
444 
445   return constFoldBinaryOp<IntegerAttr>(
446       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
447 }
448 
449 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
450 
451 //===----------------------------------------------------------------------===//
452 // AssumingAllOp
453 //===----------------------------------------------------------------------===//
454 
455 namespace {
456 
457 // Merge multiple `shape.assuming_all` operations together.
458 //
459 //   %0 = shape.assuming_all %w0, %w1
460 //   %1 = shape.assuming_all %w2, %0
461 //
462 // to:
463 //
464 //   %0 = shape.assuming_all %w0, %w2, %w2
465 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
466   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
467 
468   LogicalResult matchAndRewrite(AssumingAllOp op,
469                                 PatternRewriter &rewriter) const override {
470     SmallVector<Value> operands;
471 
472     for (Value operand : op.getInputs()) {
473       if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
474         operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
475       else
476         operands.push_back(operand);
477     }
478 
479     // We didn't find any other `assuming_all` ops to merge with.
480     if (operands.size() == op.getNumOperands())
481       return failure();
482 
483     // Replace with a new `assuming_all` operation with merged constraints.
484     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
485     return success();
486   }
487 };
488 
489 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
490 // are subsumed by others.
491 //
492 //   %0 = shape.cstr_broadcastable %shape0, %shape1
493 //   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
494 //
495 //   %2 = shape.cstr_broadcastable %shape3, %shape4
496 //   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
497 //
498 //   %4 = shape.assuming_all %0, %1, %2, %3
499 //
500 // to:
501 //
502 //   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
503 //   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
504 //   %2 = shape.assuming_all %0, %1
505 //
506 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
507 // shapes [0, 1] are broadcastable too, and can be removed from the list of
508 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
509 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
510 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
511   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
512 
513   LogicalResult matchAndRewrite(AssumingAllOp op,
514                                 PatternRewriter &rewriter) const override {
515     // Collect all `CstrBroadcastableOp` operands first.
516     SetVector<CstrBroadcastableOp> operands;
517     for (Value operand : op.getInputs()) {
518       // TODO: Apply this optimization if some of the witnesses are not
519       // produced by the `cstr_broadcastable`.
520       auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
521       if (!broadcastable)
522         return failure();
523 
524       operands.insert(broadcastable);
525     }
526 
527     // Skip trivial `assuming_all` operations.
528     if (operands.size() <= 1)
529       return failure();
530 
531     // Collect shapes checked by `cstr_broadcastable` operands.
532     SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
533     for (auto cstr : operands) {
534       DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
535       shapes.emplace_back(cstr, std::move(shapesSet));
536     }
537 
538     // Sort by the number of shape operands (larger to smaller).
539     llvm::sort(shapes, [](auto a, auto b) {
540       return a.first.getNumOperands() > b.first.getNumOperands();
541     });
542 
543     // We start from the `cst_broadcastable` operations with largest number of
544     // shape operands, and remove redundant `cst_broadcastable` operations. We
545     // do this until we find a set of `cst_broadcastable` operations with
546     // non-overlapping constraints.
547     SmallVector<CstrBroadcastableOp> markedForErase;
548 
549     for (unsigned i = 0; i < shapes.size(); ++i) {
550       auto isSubset = [&](auto pair) {
551         return llvm::set_is_subset(pair.second, shapes[i].second);
552       };
553 
554       // Keep redundant `cstr_broadcastable` operations to be erased.
555       auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
556       for (auto *it0 = it; it0 < shapes.end(); ++it0)
557         markedForErase.push_back(it0->first);
558       shapes.erase(it, shapes.end());
559     }
560 
561     // We didn't find any operands that could be removed.
562     if (markedForErase.empty())
563       return failure();
564 
565     // Collect non-overlapping `cst_broadcastable` constraints.
566     SmallVector<Value> uniqueConstraints;
567     for (auto &shape : shapes)
568       uniqueConstraints.push_back(shape.first.getResult());
569 
570     // Replace with a new `assuming_all` operation ...
571     rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
572 
573     // ... and maybe erase `cstr_broadcastable` ops without uses.
574     for (auto &op : markedForErase)
575       if (op->use_empty())
576         rewriter.eraseOp(op);
577 
578     return success();
579   }
580 };
581 
582 struct AssumingAllToCstrEqCanonicalization
583     : public OpRewritePattern<AssumingAllOp> {
584   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
585 
586   LogicalResult matchAndRewrite(AssumingAllOp op,
587                                 PatternRewriter &rewriter) const override {
588     SmallVector<Value, 8> shapes;
589     for (Value w : op.getInputs()) {
590       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
591       if (!cstrEqOp)
592         return failure();
593       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
594         return llvm::is_contained(shapes, s);
595       });
596       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
597         return failure();
598       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
599     }
600     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
601     return success();
602   }
603 };
604 
605 template <typename OpTy>
606 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
607   using OpRewritePattern<OpTy>::OpRewritePattern;
608 
609   LogicalResult matchAndRewrite(OpTy op,
610                                 PatternRewriter &rewriter) const override {
611     // Find unique operands.
612     SetVector<Value> unique(op.operand_begin(), op.operand_end());
613 
614     // Reduce op to equivalent with unique operands.
615     if (unique.size() < op.getNumOperands()) {
616       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
617                                         unique.takeVector(), op->getAttrs());
618       return success();
619     }
620 
621     return failure();
622   }
623 };
624 } // namespace
625 
626 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
627                                                 MLIRContext *context) {
628   patterns
629       .add<MergeAssumingAllOps, AssumingAllOneOp,
630            AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
631            RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
632 }
633 
634 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
635   // Iterate in reverse to first handle all constant operands. They are
636   // guaranteed to be the tail of the inputs because this is commutative.
637   for (int idx = operands.size() - 1; idx >= 0; idx--) {
638     Attribute a = operands[idx];
639     // Cannot fold if any inputs are not constant;
640     if (!a)
641       return nullptr;
642 
643     // We do not need to keep statically known values after handling them in
644     // this method.
645     getOperation()->eraseOperand(idx);
646 
647     // Always false if any input is statically known false
648     if (!a.cast<BoolAttr>().getValue())
649       return a;
650   }
651   // If this is reached, all inputs were statically known passing.
652   return BoolAttr::get(getContext(), true);
653 }
654 
655 LogicalResult AssumingAllOp::verify() {
656   // Ensure that AssumingAllOp contains at least one operand
657   if (getNumOperands() == 0)
658     return emitOpError("no operands specified");
659 
660   return success();
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // BroadcastOp
665 //===----------------------------------------------------------------------===//
666 
667 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
668   if (getShapes().size() == 1) {
669     // Otherwise, we need a cast which would be a canonicalization, not folding.
670     if (getShapes().front().getType() != getType())
671       return nullptr;
672     return getShapes().front();
673   }
674 
675   // TODO: Support folding with more than 2 input shapes
676   if (getShapes().size() > 2)
677     return nullptr;
678 
679   if (!operands[0] || !operands[1])
680     return nullptr;
681   auto lhsShape = llvm::to_vector<6>(
682       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
683   auto rhsShape = llvm::to_vector<6>(
684       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
685   SmallVector<int64_t, 6> resultShape;
686 
687   // If the shapes are not compatible, we can't fold it.
688   // TODO: Fold to an "error".
689   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
690     return nullptr;
691 
692   Builder builder(getContext());
693   return builder.getIndexTensorAttr(resultShape);
694 }
695 
696 LogicalResult BroadcastOp::verify() {
697   return verifyShapeOrExtentTensorOp(*this);
698 }
699 
700 namespace {
701 template <typename OpTy>
702 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
703   using OpRewritePattern<OpTy>::OpRewritePattern;
704 
705   LogicalResult matchAndRewrite(OpTy op,
706                                 PatternRewriter &rewriter) const override {
707     auto isPotentiallyNonEmptyShape = [](Value shape) {
708       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
709         if (extentTensorTy.getDimSize(0) == 0)
710           return false;
711       }
712       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
713         if (constShape.getShape().empty())
714           return false;
715       }
716       return true;
717     };
718     auto newOperands = llvm::to_vector<8>(
719         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
720 
721     // Reduce op to equivalent without empty shape operands.
722     if (newOperands.size() < op.getNumOperands()) {
723       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
724                                         op->getAttrs());
725       return success();
726     }
727 
728     return failure();
729   }
730 };
731 
732 struct BroadcastForwardSingleOperandPattern
733     : public OpRewritePattern<BroadcastOp> {
734   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
735 
736   LogicalResult matchAndRewrite(BroadcastOp op,
737                                 PatternRewriter &rewriter) const override {
738     if (op.getNumOperands() != 1)
739       return failure();
740     Value replacement = op.getShapes().front();
741 
742     // Insert cast if needed.
743     if (replacement.getType() != op.getType()) {
744       auto loc = op.getLoc();
745       if (op.getType().isa<ShapeType>()) {
746         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
747       } else {
748         assert(!op.getType().isa<ShapeType>() &&
749                !replacement.getType().isa<ShapeType>() &&
750                "expect extent tensor cast");
751         replacement =
752             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
753       }
754     }
755 
756     rewriter.replaceOp(op, replacement);
757     return success();
758   }
759 };
760 
761 struct BroadcastFoldConstantOperandsPattern
762     : public OpRewritePattern<BroadcastOp> {
763   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
764 
765   LogicalResult matchAndRewrite(BroadcastOp op,
766                                 PatternRewriter &rewriter) const override {
767     SmallVector<int64_t, 8> foldedConstantShape;
768     SmallVector<Value, 8> newShapeOperands;
769     for (Value shape : op.getShapes()) {
770       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
771         SmallVector<int64_t, 8> newFoldedConstantShape;
772         if (OpTrait::util::getBroadcastedShape(
773                 foldedConstantShape,
774                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
775                 newFoldedConstantShape)) {
776           foldedConstantShape = newFoldedConstantShape;
777           continue;
778         }
779       }
780       newShapeOperands.push_back(shape);
781     }
782 
783     // Need at least two constant operands to fold anything.
784     if (op.getNumOperands() - newShapeOperands.size() < 2)
785       return failure();
786 
787     auto foldedConstantOperandsTy = RankedTensorType::get(
788         {static_cast<int64_t>(foldedConstantShape.size())},
789         rewriter.getIndexType());
790     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
791         op.getLoc(), foldedConstantOperandsTy,
792         rewriter.getIndexTensorAttr(foldedConstantShape)));
793     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
794                                              newShapeOperands);
795     return success();
796   }
797 };
798 
799 template <typename OpTy>
800 struct CanonicalizeCastExtentTensorOperandsPattern
801     : public OpRewritePattern<OpTy> {
802   using OpRewritePattern<OpTy>::OpRewritePattern;
803 
804   LogicalResult matchAndRewrite(OpTy op,
805                                 PatternRewriter &rewriter) const override {
806     // Canonicalize operands.
807     bool anyChange = false;
808     auto canonicalizeOperand = [&](Value operand) {
809       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
810         // Only eliminate the cast if it holds no shape information.
811         bool isInformationLoosingCast =
812             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
813         if (isInformationLoosingCast) {
814           anyChange = true;
815           return castOp.source();
816         }
817       }
818       return operand;
819     };
820     auto newOperands = llvm::to_vector<8>(
821         llvm::map_range(op.getOperands(), canonicalizeOperand));
822 
823     // Rewrite op if any change required.
824     if (!anyChange)
825       return failure();
826     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
827     return success();
828   }
829 };
830 
831 struct BroadcastConcretizeResultTypePattern
832     : public OpRewritePattern<BroadcastOp> {
833   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
834 
835   LogicalResult matchAndRewrite(BroadcastOp op,
836                                 PatternRewriter &rewriter) const override {
837     // Only concretize dynamic extent tensor result types.
838     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
839     if (!resultTy || !resultTy.isDynamicDim(0))
840       return failure();
841 
842     // Infer resulting shape rank if possible.
843     int64_t maxRank = 0;
844     for (Value shape : op.getShapes()) {
845       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
846         // Cannot infer resulting shape rank if any operand is dynamically
847         // ranked.
848         if (extentTensorTy.isDynamicDim(0))
849           return failure();
850         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
851       }
852     }
853 
854     auto newOp = rewriter.create<BroadcastOp>(
855         op.getLoc(), getExtentTensorType(getContext(), maxRank),
856         op.getShapes());
857     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
858     return success();
859   }
860 };
861 } // namespace
862 
863 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
864                                               MLIRContext *context) {
865   patterns.add<BroadcastConcretizeResultTypePattern,
866                BroadcastFoldConstantOperandsPattern,
867                BroadcastForwardSingleOperandPattern,
868                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
869                RemoveDuplicateOperandsPattern<BroadcastOp>,
870                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // ConcatOp
875 //===----------------------------------------------------------------------===//
876 
877 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
878   if (!operands[0] || !operands[1])
879     return nullptr;
880   auto lhsShape = llvm::to_vector<6>(
881       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
882   auto rhsShape = llvm::to_vector<6>(
883       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
884   SmallVector<int64_t, 6> resultShape;
885   resultShape.append(lhsShape.begin(), lhsShape.end());
886   resultShape.append(rhsShape.begin(), rhsShape.end());
887   Builder builder(getContext());
888   return builder.getIndexTensorAttr(resultShape);
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // ConstShapeOp
893 //===----------------------------------------------------------------------===//
894 
895 void ConstShapeOp::print(OpAsmPrinter &p) {
896   p << " ";
897   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
898   p << "[";
899   interleaveComma(getShape().getValues<int64_t>(), p);
900   p << "] : ";
901   p.printType(getType());
902 }
903 
904 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
905   if (parser.parseOptionalAttrDict(result.attributes))
906     return failure();
907   // We piggy-back on ArrayAttr parsing, though we don't internally store the
908   // shape as an ArrayAttr.
909   // TODO: Implement custom parser and maybe make syntax a bit more concise.
910   Attribute extentsRaw;
911   NamedAttrList dummy;
912   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
913     return failure();
914   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
915   if (!extentsArray)
916     return failure();
917   SmallVector<int64_t, 6> ints;
918   for (Attribute extent : extentsArray) {
919     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
920     if (!attr)
921       return failure();
922     ints.push_back(attr.getInt());
923   }
924   Builder &builder = parser.getBuilder();
925   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
926   Type resultTy;
927   if (parser.parseColonType(resultTy))
928     return failure();
929   result.types.push_back(resultTy);
930   return success();
931 }
932 
933 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
934 
935 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
936                                                MLIRContext *context) {
937   patterns.add<TensorCastConstShape>(context);
938 }
939 
940 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
941     MLIRContext *context, Optional<Location> location, ValueRange operands,
942     DictionaryAttr attributes, RegionRange regions,
943     SmallVectorImpl<Type> &inferredReturnTypes) {
944   Builder b(context);
945   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
946   if (!shape)
947     return emitOptionalError(location, "missing shape attribute");
948   inferredReturnTypes.assign({RankedTensorType::get(
949       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
950   return success();
951 }
952 
953 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
954                                                         TypeRange r) {
955   if (l.size() != 1 || r.size() != 1)
956     return false;
957 
958   Type lhs = l.front();
959   Type rhs = r.front();
960 
961   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
962     // Shape type is compatible with all other valid return types.
963     return true;
964   return lhs == rhs;
965 }
966 
967 //===----------------------------------------------------------------------===//
968 // CstrBroadcastableOp
969 //===----------------------------------------------------------------------===//
970 
971 void CstrBroadcastableOp::getCanonicalizationPatterns(
972     RewritePatternSet &patterns, MLIRContext *context) {
973   // Canonicalization patterns have overlap with the considerations during
974   // folding in case additional shape information is inferred at some point that
975   // does not result in folding.
976   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
977                CstrBroadcastableEqOps,
978                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
979                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
980 }
981 
982 // Return true if there is exactly one attribute not representing a scalar
983 // broadcast.
984 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
985   bool nonScalarSeen = false;
986   for (Attribute a : attributes) {
987     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
988       if (nonScalarSeen)
989         return false;
990       nonScalarSeen = true;
991     }
992   }
993   return true;
994 }
995 
996 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
997   // No broadcasting is needed if all operands but one are scalar.
998   if (hasAtMostSingleNonScalar(operands))
999     return BoolAttr::get(getContext(), true);
1000 
1001   if ([&] {
1002         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1003         for (const auto &operand : operands) {
1004           if (!operand)
1005             return false;
1006           extents.push_back(llvm::to_vector<6>(
1007               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
1008         }
1009         return OpTrait::util::staticallyKnownBroadcastable(extents);
1010       }())
1011     return BoolAttr::get(getContext(), true);
1012 
1013   // Lastly, see if folding can be completed based on what constraints are known
1014   // on the input shapes.
1015   if ([&] {
1016         SmallVector<SmallVector<int64_t, 6>, 6> extents;
1017         for (auto shapeValue : getShapes()) {
1018           extents.emplace_back();
1019           if (failed(getShapeVec(shapeValue, extents.back())))
1020             return false;
1021         }
1022         return OpTrait::util::staticallyKnownBroadcastable(extents);
1023       }())
1024     return BoolAttr::get(getContext(), true);
1025 
1026   // Because a failing witness result here represents an eventual assertion
1027   // failure, we do not replace it with a constant witness.
1028   return nullptr;
1029 }
1030 
1031 LogicalResult CstrBroadcastableOp::verify() {
1032   // Ensure that CstrBroadcastableOp contains at least two operands
1033   if (getNumOperands() < 2)
1034     return emitOpError("required at least 2 input shapes");
1035   return success();
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // CstrEqOp
1040 //===----------------------------------------------------------------------===//
1041 
1042 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1043                                            MLIRContext *context) {
1044   // If inputs are equal, return passing witness
1045   patterns.add<CstrEqEqOps>(context);
1046 }
1047 
1048 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
1049   if (llvm::all_of(operands,
1050                    [&](Attribute a) { return a && a == operands[0]; }))
1051     return BoolAttr::get(getContext(), true);
1052 
1053   // Because a failing witness result here represents an eventual assertion
1054   // failure, we do not try to replace it with a constant witness. Similarly, we
1055   // cannot if there are any non-const inputs.
1056   return nullptr;
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // ConstSizeOp
1061 //===----------------------------------------------------------------------===//
1062 
1063 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1064                         int64_t value) {
1065   build(builder, result, builder.getIndexAttr(value));
1066 }
1067 
1068 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
1069 
1070 void ConstSizeOp::getAsmResultNames(
1071     llvm::function_ref<void(Value, StringRef)> setNameFn) {
1072   SmallString<4> buffer;
1073   llvm::raw_svector_ostream os(buffer);
1074   os << "c" << getValue();
1075   setNameFn(getResult(), os.str());
1076 }
1077 
1078 //===----------------------------------------------------------------------===//
1079 // ConstWitnessOp
1080 //===----------------------------------------------------------------------===//
1081 
1082 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
1083   return getPassingAttr();
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // CstrRequireOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
1091   return operands[0];
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // DivOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1099   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1100   if (!lhs)
1101     return nullptr;
1102   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1103   if (!rhs)
1104     return nullptr;
1105 
1106   // Division in APInt does not follow floor(lhs, rhs) when the result is
1107   // negative. Rather, APInt rounds toward zero.
1108   APInt quotient, remainder;
1109   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1110   if (quotient.isNegative() && !remainder.isNullValue()) {
1111     quotient -= 1;
1112   }
1113 
1114   Type indexTy = IndexType::get(getContext());
1115   return IntegerAttr::get(indexTy, quotient);
1116 }
1117 
1118 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1119     MLIRContext *context, Optional<Location> location, ValueRange operands,
1120     DictionaryAttr attributes, RegionRange regions,
1121     SmallVectorImpl<Type> &inferredReturnTypes) {
1122   if (operands[0].getType().isa<SizeType>() ||
1123       operands[1].getType().isa<SizeType>())
1124     inferredReturnTypes.assign({SizeType::get(context)});
1125   else
1126     inferredReturnTypes.assign({IndexType::get(context)});
1127   return success();
1128 }
1129 
1130 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1131   // SizeType is compatible with IndexType.
1132   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1133 }
1134 
1135 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // ShapeEqOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1142   bool allSame = true;
1143   if (!operands.empty() && !operands[0])
1144     return {};
1145   for (Attribute operand : operands.drop_front(1)) {
1146     if (!operand)
1147       return {};
1148     allSame = allSame && operand == operands[0];
1149   }
1150   return BoolAttr::get(getContext(), allSame);
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // IndexToSizeOp
1155 //===----------------------------------------------------------------------===//
1156 
1157 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1158   // Constant values of both types, `shape.size` and `index`, are represented as
1159   // `IntegerAttr`s which makes constant folding simple.
1160   if (Attribute arg = operands[0])
1161     return arg;
1162   return {};
1163 }
1164 
1165 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1166                                                 MLIRContext *context) {
1167   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // FromExtentsOp
1172 //===----------------------------------------------------------------------===//
1173 
1174 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1175   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1176     return nullptr;
1177   SmallVector<int64_t, 6> extents;
1178   for (auto attr : operands)
1179     extents.push_back(attr.cast<IntegerAttr>().getInt());
1180   Builder builder(getContext());
1181   return builder.getIndexTensorAttr(extents);
1182 }
1183 
1184 //===----------------------------------------------------------------------===//
1185 // FunctionLibraryOp
1186 //===----------------------------------------------------------------------===//
1187 
1188 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1189                               StringRef name) {
1190   result.attributes.push_back(builder.getNamedAttr(
1191       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1192 }
1193 
1194 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1195   auto attr = getMapping()
1196                   .get(op->getName().getIdentifier())
1197                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1198   if (!attr)
1199     return nullptr;
1200   return lookupSymbol<FuncOp>(attr);
1201 }
1202 
1203 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1204                                      OperationState &result) {
1205   // Parse the op name.
1206   StringAttr nameAttr;
1207   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1208                              result.attributes))
1209     return failure();
1210 
1211   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1212     return failure();
1213 
1214   auto *bodyRegion = result.addRegion();
1215   if (parser.parseRegion(*bodyRegion))
1216     return failure();
1217 
1218   if (parser.parseKeyword("mapping"))
1219     return failure();
1220 
1221   DictionaryAttr mappingAttr;
1222   if (parser.parseAttribute(mappingAttr,
1223                             parser.getBuilder().getType<NoneType>(), "mapping",
1224                             result.attributes))
1225     return failure();
1226   return success();
1227 }
1228 
1229 void FunctionLibraryOp::print(OpAsmPrinter &p) {
1230   p << ' ';
1231   p.printSymbolName(getName());
1232   p.printOptionalAttrDictWithKeyword(
1233       (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1234   p << ' ';
1235   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1236                 /*printBlockTerminators=*/false);
1237   p << " mapping ";
1238   p.printAttributeWithoutType(getMappingAttr());
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // FuncOp
1243 //===----------------------------------------------------------------------===//
1244 
1245 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1246   auto buildFuncType =
1247       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1248          function_interface_impl::VariadicFlag,
1249          std::string &) { return builder.getFunctionType(argTypes, results); };
1250 
1251   return function_interface_impl::parseFunctionOp(
1252       parser, result, /*allowVariadic=*/false, buildFuncType);
1253 }
1254 
1255 void FuncOp::print(OpAsmPrinter &p) {
1256   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // GetExtentOp
1261 //===----------------------------------------------------------------------===//
1262 
1263 Optional<int64_t> GetExtentOp::getConstantDim() {
1264   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1265     return constSizeOp.getValue().getLimitedValue();
1266   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1267     return constantOp.getValue().cast<IntegerAttr>().getInt();
1268   return llvm::None;
1269 }
1270 
1271 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1272   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1273   if (!elements)
1274     return nullptr;
1275   Optional<int64_t> dim = getConstantDim();
1276   if (!dim.hasValue())
1277     return nullptr;
1278   if (dim.getValue() >= elements.getNumElements())
1279     return nullptr;
1280   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1281 }
1282 
1283 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1284                         int64_t dim) {
1285   auto loc = result.location;
1286   auto dimAttr = builder.getIndexAttr(dim);
1287   if (shape.getType().isa<ShapeType>()) {
1288     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1289     build(builder, result, builder.getType<SizeType>(), shape, dim);
1290   } else {
1291     Value dim =
1292         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1293     build(builder, result, builder.getIndexType(), shape, dim);
1294   }
1295 }
1296 
1297 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1298     MLIRContext *context, Optional<Location> location, ValueRange operands,
1299     DictionaryAttr attributes, RegionRange regions,
1300     SmallVectorImpl<Type> &inferredReturnTypes) {
1301   inferredReturnTypes.assign({IndexType::get(context)});
1302   return success();
1303 }
1304 
1305 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1306                                                        TypeRange r) {
1307   // SizeType is compatible with IndexType.
1308   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1309 }
1310 
1311 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1312 
1313 //===----------------------------------------------------------------------===//
1314 // IsBroadcastableOp
1315 //===----------------------------------------------------------------------===//
1316 
1317 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1318                                                     MLIRContext *context) {
1319   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1320 }
1321 
1322 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1323   // Can always broadcast fewer than two shapes.
1324   if (operands.size() < 2) {
1325     return BoolAttr::get(getContext(), true);
1326   }
1327 
1328   return nullptr;
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // MeetOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1336     MLIRContext *context, Optional<Location> location, ValueRange operands,
1337     DictionaryAttr attributes, RegionRange regions,
1338     SmallVectorImpl<Type> &inferredReturnTypes) {
1339   inferredReturnTypes.assign({operands[0].getType()});
1340   return success();
1341 }
1342 
1343 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1344   if (l.size() != 1 || r.size() != 1)
1345     return false;
1346   if (l == r)
1347     return true;
1348 
1349   Type lhs = l.front();
1350   Type rhs = r.front();
1351 
1352   if (lhs != rhs)
1353     return false;
1354 
1355   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1356     return true;
1357 
1358   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1359     return true;
1360   return false;
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // RankOp
1365 //===----------------------------------------------------------------------===//
1366 
1367 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1368   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1369   if (!shape)
1370     return {};
1371   int64_t rank = shape.getNumElements();
1372   Builder builder(getContext());
1373   return builder.getIndexAttr(rank);
1374 }
1375 
1376 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1377 /// Constant folding fails in cases where only the rank is constant, not the
1378 /// shape itself.
1379 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1380 ///
1381 /// Example:
1382 ///
1383 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1384 /// %rank = shape.rank %shape
1385 ///
1386 /// becomes
1387 ///
1388 /// %rank = shape.const_size 3
1389 
1390 namespace {
1391 struct RankShapeOfCanonicalizationPattern
1392     : public OpRewritePattern<shape::RankOp> {
1393   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1394 
1395   LogicalResult matchAndRewrite(shape::RankOp op,
1396                                 PatternRewriter &rewriter) const override {
1397     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1398     if (!shapeOfOp)
1399       return failure();
1400     auto rankedTensorType =
1401         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1402     if (!rankedTensorType)
1403       return failure();
1404     int64_t rank = rankedTensorType.getRank();
1405     if (op.getType().isa<IndexType>()) {
1406       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1407                                                           rank);
1408     } else if (op.getType().isa<shape::SizeType>()) {
1409       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1410     } else {
1411       return failure();
1412     }
1413     return success();
1414   }
1415 };
1416 } // namespace
1417 
1418 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1419                                                 MLIRContext *context) {
1420   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1421 }
1422 
1423 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1424     MLIRContext *context, Optional<Location> location, ValueRange operands,
1425     DictionaryAttr attributes, RegionRange regions,
1426     SmallVectorImpl<Type> &inferredReturnTypes) {
1427   if (operands[0].getType().isa<ShapeType>())
1428     inferredReturnTypes.assign({SizeType::get(context)});
1429   else
1430     inferredReturnTypes.assign({IndexType::get(context)});
1431   return success();
1432 }
1433 
1434 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1435   // SizeType is compatible with IndexType.
1436   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1437 }
1438 
1439 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1440 
1441 //===----------------------------------------------------------------------===//
1442 // NumElementsOp
1443 //===----------------------------------------------------------------------===//
1444 
1445 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1446 
1447   // Fold only when argument constant.
1448   Attribute shape = operands[0];
1449   if (!shape)
1450     return {};
1451 
1452   APInt product(64, 1);
1453   for (auto value : shape.cast<DenseIntElementsAttr>())
1454     product *= value;
1455   Builder builder(getContext());
1456   return builder.getIndexAttr(product.getLimitedValue());
1457 }
1458 
1459 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1460     MLIRContext *context, Optional<Location> location, ValueRange operands,
1461     DictionaryAttr attributes, RegionRange regions,
1462     SmallVectorImpl<Type> &inferredReturnTypes) {
1463   if (operands[0].getType().isa<ShapeType>())
1464     inferredReturnTypes.assign({SizeType::get(context)});
1465   else
1466     inferredReturnTypes.assign({IndexType::get(context)});
1467   return success();
1468 }
1469 
1470 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1471                                                          TypeRange r) {
1472   // SizeType is compatible with IndexType.
1473   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1474 }
1475 
1476 LogicalResult shape::NumElementsOp::verify() {
1477   return verifySizeOrIndexOp(*this);
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // MaxOp
1482 //===----------------------------------------------------------------------===//
1483 
1484 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1485   // If operands are equal, just propagate one.
1486   if (getLhs() == getRhs())
1487     return getLhs();
1488   return nullptr;
1489 }
1490 
1491 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1492     MLIRContext *context, Optional<Location> location, ValueRange operands,
1493     DictionaryAttr attributes, RegionRange regions,
1494     SmallVectorImpl<Type> &inferredReturnTypes) {
1495   if (operands[0].getType() == operands[1].getType())
1496     inferredReturnTypes.assign({operands[0].getType()});
1497   else
1498     inferredReturnTypes.assign({SizeType::get(context)});
1499   return success();
1500 }
1501 
1502 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1503   if (l.size() != 1 || r.size() != 1)
1504     return false;
1505   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1506     return true;
1507   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1508     return true;
1509   return false;
1510 }
1511 
1512 //===----------------------------------------------------------------------===//
1513 // MinOp
1514 //===----------------------------------------------------------------------===//
1515 
1516 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1517   // If operands are equal, just propagate one.
1518   if (getLhs() == getRhs())
1519     return getLhs();
1520   return nullptr;
1521 }
1522 
1523 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1524     MLIRContext *context, Optional<Location> location, ValueRange operands,
1525     DictionaryAttr attributes, RegionRange regions,
1526     SmallVectorImpl<Type> &inferredReturnTypes) {
1527   if (operands[0].getType() == operands[1].getType())
1528     inferredReturnTypes.assign({operands[0].getType()});
1529   else
1530     inferredReturnTypes.assign({SizeType::get(context)});
1531   return success();
1532 }
1533 
1534 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1535   if (l.size() != 1 || r.size() != 1)
1536     return false;
1537   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1538     return true;
1539   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1540     return true;
1541   return false;
1542 }
1543 
1544 //===----------------------------------------------------------------------===//
1545 // MulOp
1546 //===----------------------------------------------------------------------===//
1547 
1548 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1549   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1550   if (!lhs)
1551     return nullptr;
1552   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1553   if (!rhs)
1554     return nullptr;
1555   APInt folded = lhs.getValue() * rhs.getValue();
1556   Type indexTy = IndexType::get(getContext());
1557   return IntegerAttr::get(indexTy, folded);
1558 }
1559 
1560 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1561     MLIRContext *context, Optional<Location> location, ValueRange operands,
1562     DictionaryAttr attributes, RegionRange regions,
1563     SmallVectorImpl<Type> &inferredReturnTypes) {
1564   if (operands[0].getType().isa<SizeType>() ||
1565       operands[1].getType().isa<SizeType>())
1566     inferredReturnTypes.assign({SizeType::get(context)});
1567   else
1568     inferredReturnTypes.assign({IndexType::get(context)});
1569   return success();
1570 }
1571 
1572 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1573   // SizeType is compatible with IndexType.
1574   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1575 }
1576 
1577 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1578 
1579 //===----------------------------------------------------------------------===//
1580 // ShapeOfOp
1581 //===----------------------------------------------------------------------===//
1582 
1583 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1584   auto type = getOperand().getType().dyn_cast<ShapedType>();
1585   if (!type || !type.hasStaticShape())
1586     return nullptr;
1587   Builder builder(getContext());
1588   return builder.getIndexTensorAttr(type.getShape());
1589 }
1590 
1591 namespace {
1592 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1593   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1594 
1595   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1596                                 PatternRewriter &rewriter) const override {
1597     if (!op.getArg().getType().isa<ShapedType>())
1598       return failure();
1599     if (op.getType().isa<ShapedType>())
1600       return failure();
1601 
1602     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1603                                                   op.getArg());
1604     return success();
1605   }
1606 };
1607 
1608 // Canonicalize
1609 // ```
1610 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1611 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1612 // ```
1613 // to
1614 // ```
1615 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1616 // ```
1617 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1618   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1619 
1620   LogicalResult matchAndRewrite(tensor::CastOp op,
1621                                 PatternRewriter &rewriter) const override {
1622     auto ty = op.getType().dyn_cast<RankedTensorType>();
1623     if (!ty || ty.getRank() != 1)
1624       return failure();
1625 
1626     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1627     if (!shapeOfOp)
1628       return failure();
1629 
1630     // Argument type must be ranked and must not conflict.
1631     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1632     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1633       return failure();
1634 
1635     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1636     return success();
1637   }
1638 };
1639 } // namespace
1640 
1641 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1642                                             MLIRContext *context) {
1643   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1644                ExtractFromShapeOfExtentTensor>(context);
1645 }
1646 
1647 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1648     MLIRContext *context, Optional<Location> location, ValueRange operands,
1649     DictionaryAttr attributes, RegionRange regions,
1650     SmallVectorImpl<Type> &inferredReturnTypes) {
1651   if (operands[0].getType().isa<ValueShapeType>())
1652     inferredReturnTypes.assign({ShapeType::get(context)});
1653   else {
1654     auto shapedTy = operands[0].getType().cast<ShapedType>();
1655     int64_t rank =
1656         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1657     Type indexTy = IndexType::get(context);
1658     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1659     inferredReturnTypes.assign({extentTensorTy});
1660   }
1661   return success();
1662 }
1663 
1664 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1665   if (l.size() != 1 || r.size() != 1)
1666     return false;
1667   if (l == r)
1668     return true;
1669 
1670   Type lhs = l.front();
1671   Type rhs = r.front();
1672 
1673   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1674     return false;
1675 
1676   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1677     // Shape type is compatible with all other valid return types.
1678     return true;
1679 
1680   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1681     return true;
1682   return false;
1683 }
1684 
1685 LogicalResult shape::ShapeOfOp::verify() {
1686   return verifyShapeOrExtentTensorOp(*this);
1687 }
1688 
1689 //===----------------------------------------------------------------------===//
1690 // SizeToIndexOp
1691 //===----------------------------------------------------------------------===//
1692 
1693 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1694   // Constant values of both types, `shape.size` and `index`, are represented as
1695   // `IntegerAttr`s which makes constant folding simple.
1696   if (Attribute arg = operands[0])
1697     return arg;
1698   return OpFoldResult();
1699 }
1700 
1701 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1702                                                 MLIRContext *context) {
1703   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1704 }
1705 
1706 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1707   if (inputs.size() != 1 || outputs.size() != 1)
1708     return false;
1709   return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // YieldOp
1714 //===----------------------------------------------------------------------===//
1715 
1716 LogicalResult shape::YieldOp::verify() {
1717   auto *parentOp = (*this)->getParentOp();
1718   auto results = parentOp->getResults();
1719   auto operands = getOperands();
1720 
1721   if (parentOp->getNumResults() != getNumOperands())
1722     return emitOpError() << "number of operands does not match number of "
1723                             "results of its parent";
1724   for (auto e : llvm::zip(results, operands))
1725     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1726       return emitOpError() << "types mismatch between yield op and its parent";
1727 
1728   return success();
1729 }
1730 
1731 //===----------------------------------------------------------------------===//
1732 // SplitAtOp
1733 //===----------------------------------------------------------------------===//
1734 
1735 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1736                               SmallVectorImpl<OpFoldResult> &results) {
1737   if (!operands[0] || !operands[1])
1738     return failure();
1739   auto shapeVec = llvm::to_vector<6>(
1740       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1741   auto shape = llvm::makeArrayRef(shapeVec);
1742   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1743   // Verify that the split point is in the correct range.
1744   // TODO: Constant fold to an "error".
1745   int64_t rank = shape.size();
1746   if (!(-rank <= splitPoint && splitPoint <= rank))
1747     return failure();
1748   if (splitPoint < 0)
1749     splitPoint += shape.size();
1750   Builder builder(operands[0].getContext());
1751   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1752   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1753   return success();
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // ToExtentTensorOp
1758 //===----------------------------------------------------------------------===//
1759 
1760 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1761   if (!operands[0])
1762     return OpFoldResult();
1763   Builder builder(getContext());
1764   auto shape = llvm::to_vector<6>(
1765       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1766   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1767                                     builder.getIndexType());
1768   return DenseIntElementsAttr::get(type, shape);
1769 }
1770 
1771 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1772   if (inputs.size() != 1 || outputs.size() != 1)
1773     return false;
1774   if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1775     if (!inputTensor.getElementType().isa<IndexType>() ||
1776         inputTensor.getRank() != 1)
1777       return false;
1778   } else if (!inputs[0].isa<ShapeType>()) {
1779     return false;
1780   }
1781 
1782   TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
1783   return outputTensor && outputTensor.getElementType().isa<IndexType>();
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // ReduceOp
1788 //===----------------------------------------------------------------------===//
1789 
1790 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1791                      ValueRange initVals) {
1792   result.addOperands(shape);
1793   result.addOperands(initVals);
1794 
1795   Region *bodyRegion = result.addRegion();
1796   bodyRegion->push_back(new Block);
1797   Block &bodyBlock = bodyRegion->front();
1798   bodyBlock.addArgument(builder.getIndexType(), result.location);
1799 
1800   Type elementType;
1801   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1802     elementType = tensorType.getElementType();
1803   else
1804     elementType = SizeType::get(builder.getContext());
1805   bodyBlock.addArgument(elementType, shape.getLoc());
1806 
1807   for (Value initVal : initVals) {
1808     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1809     result.addTypes(initVal.getType());
1810   }
1811 }
1812 
1813 LogicalResult ReduceOp::verify() {
1814   // Verify block arg types.
1815   Block &block = getRegion().front();
1816 
1817   // The block takes index, extent, and aggregated values as arguments.
1818   auto blockArgsCount = getInitVals().size() + 2;
1819   if (block.getNumArguments() != blockArgsCount)
1820     return emitOpError() << "ReduceOp body is expected to have "
1821                          << blockArgsCount << " arguments";
1822 
1823   // The first block argument is the index and must always be of type `index`.
1824   if (!block.getArgument(0).getType().isa<IndexType>())
1825     return emitOpError(
1826         "argument 0 of ReduceOp body is expected to be of IndexType");
1827 
1828   // The second block argument is the extent and must be of type `size` or
1829   // `index`, depending on whether the reduce operation is applied to a shape or
1830   // to an extent tensor.
1831   Type extentTy = block.getArgument(1).getType();
1832   if (getShape().getType().isa<ShapeType>()) {
1833     if (!extentTy.isa<SizeType>())
1834       return emitOpError("argument 1 of ReduceOp body is expected to be of "
1835                          "SizeType if the ReduceOp operates on a ShapeType");
1836   } else {
1837     if (!extentTy.isa<IndexType>())
1838       return emitOpError(
1839           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1840           "ReduceOp operates on an extent tensor");
1841   }
1842 
1843   for (const auto &type : llvm::enumerate(getInitVals()))
1844     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1845       return emitOpError() << "type mismatch between argument "
1846                            << type.index() + 2
1847                            << " of ReduceOp body and initial value "
1848                            << type.index();
1849   return success();
1850 }
1851 
1852 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1853   // Parse operands.
1854   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1855   Type shapeOrExtentTensorType;
1856   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1857                               OpAsmParser::Delimiter::Paren) ||
1858       parser.parseColonType(shapeOrExtentTensorType) ||
1859       parser.parseOptionalArrowTypeList(result.types))
1860     return failure();
1861 
1862   // Resolve operands.
1863   auto initVals = llvm::makeArrayRef(operands).drop_front();
1864   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1865                             result.operands) ||
1866       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1867                              result.operands))
1868     return failure();
1869 
1870   // Parse the body.
1871   Region *body = result.addRegion();
1872   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1873     return failure();
1874 
1875   // Parse attributes.
1876   if (parser.parseOptionalAttrDict(result.attributes))
1877     return failure();
1878 
1879   return success();
1880 }
1881 
1882 void ReduceOp::print(OpAsmPrinter &p) {
1883   p << '(' << getShape() << ", " << getInitVals()
1884     << ") : " << getShape().getType();
1885   p.printOptionalArrowTypeList(getResultTypes());
1886   p << ' ';
1887   p.printRegion(getRegion());
1888   p.printOptionalAttrDict((*this)->getAttrs());
1889 }
1890 
1891 #define GET_OP_CLASSES
1892 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1893