1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
31   return RankedTensorType::get({rank}, IndexType::get(ctx));
32 }
33 
34 bool shape::isExtentTensorType(Type type) {
35   auto ranked = type.dyn_cast<RankedTensorType>();
36   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
37 }
38 
39 LogicalResult shape::getShapeVec(Value input,
40                                  SmallVectorImpl<int64_t> &shapeValues) {
41   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
42     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
43     if (!type.hasRank())
44       return failure();
45     shapeValues = llvm::to_vector<6>(type.getShape());
46     return success();
47   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
48     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
49     return success();
50   } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
51     shapeValues = llvm::to_vector<6>(
52         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
53     return success();
54   } else {
55     return failure();
56   }
57 }
58 
59 static bool isErrorPropagationPossible(TypeRange operandTypes) {
60   return llvm::any_of(operandTypes, [](Type ty) {
61     return ty.isa<SizeType, ShapeType, ValueShapeType>();
62   });
63 }
64 
65 static LogicalResult verifySizeOrIndexOp(Operation *op) {
66   assert(op != nullptr && op->getNumResults() == 1);
67   Type resultTy = op->getResultTypes().front();
68   if (isErrorPropagationPossible(op->getOperandTypes())) {
69     if (!resultTy.isa<SizeType>())
70       return op->emitOpError()
71              << "if at least one of the operands can hold error values then "
72                 "the result must be of type `size` to propagate them";
73   }
74   return success();
75 }
76 
77 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
78   assert(op != nullptr && op->getNumResults() == 1);
79   Type resultTy = op->getResultTypes().front();
80   if (isErrorPropagationPossible(op->getOperandTypes())) {
81     if (!resultTy.isa<ShapeType>())
82       return op->emitOpError()
83              << "if at least one of the operands can hold error values then "
84                 "the result must be of type `shape` to propagate them";
85   }
86   return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // InlinerInterface
91 //===----------------------------------------------------------------------===//
92 
93 namespace {
94 /// This class defines the interface for inlining shape dialect ops.
95 struct ShapeInlinerInterface : public DialectInlinerInterface {
96   using DialectInlinerInterface::DialectInlinerInterface;
97 
98   // Returns true if the given region 'src' can be inlined into the region
99   // 'dest' that is attached to an operation registered to the current dialect.
100   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
101                        BlockAndValueMapping &) const final {
102     return true;
103   }
104 
105   // Returns true if the given operation 'op', that is registered to this
106   // dialect, can be inlined into the region 'dest' that is attached to an
107   // operation registered to the current dialect.
108   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
109                        BlockAndValueMapping &) const final {
110     return true;
111   }
112 };
113 } // namespace
114 
115 void ShapeDialect::initialize() {
116   addOperations<
117 #define GET_OP_LIST
118 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
119       >();
120   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
121   addInterfaces<ShapeInlinerInterface>();
122   // Allow unknown operations during prototyping and testing. As the dialect is
123   // still evolving it makes it simple to start with an unregistered ops and
124   // try different variants before actually defining the op.
125   allowUnknownOperations();
126 }
127 
128 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
129                                              Attribute value, Type type,
130                                              Location loc) {
131   if (type.isa<ShapeType>() || isExtentTensorType(type))
132     return builder.create<ConstShapeOp>(loc, type,
133                                         value.cast<DenseIntElementsAttr>());
134   if (type.isa<SizeType>())
135     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
136   if (type.isa<WitnessType>())
137     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
138   if (ConstantOp::isBuildableWith(value, type))
139     return builder.create<ConstantOp>(loc, type, value);
140   return nullptr;
141 }
142 
143 /// Parse a type registered to this dialect.
144 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
145   StringRef keyword;
146   if (parser.parseKeyword(&keyword))
147     return Type();
148 
149   if (keyword == "shape")
150     return ShapeType::get(getContext());
151   if (keyword == "size")
152     return SizeType::get(getContext());
153   if (keyword == "value_shape")
154     return ValueShapeType::get(getContext());
155   if (keyword == "witness")
156     return WitnessType::get(getContext());
157 
158   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
159   return Type();
160 }
161 
162 /// Print a type registered to this dialect.
163 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
164   TypeSwitch<Type>(type)
165       .Case<ShapeType>([&](Type) { os << "shape"; })
166       .Case<SizeType>([&](Type) { os << "size"; })
167       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
168       .Case<WitnessType>([&](Type) { os << "witness"; })
169       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
170 }
171 
172 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
173                                                      NamedAttribute attribute) {
174   // Verify shape.lib attribute.
175   if (attribute.first == "shape.lib") {
176     if (!op->hasTrait<OpTrait::SymbolTable>())
177       return op->emitError(
178           "shape.lib attribute may only be on op implementing SymbolTable");
179 
180     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
181       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
182       if (!symbol)
183         return op->emitError("shape function library ")
184                << symbolRef << " not found";
185       return isa<shape::FunctionLibraryOp>(symbol)
186                  ? success()
187                  : op->emitError()
188                        << symbolRef << " required to be shape function library";
189     }
190 
191     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
192       // Verify all entries are function libraries and mappings in libraries
193       // refer to unique ops.
194       DenseSet<Identifier> key;
195       for (auto it : arr) {
196         if (!it.isa<SymbolRefAttr>())
197           return op->emitError(
198               "only SymbolRefAttr allowed in shape.lib attribute array");
199 
200         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
201             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
202         if (!shapeFnLib)
203           return op->emitError()
204                  << it << " does not refer to FunctionLibraryOp";
205         for (auto mapping : shapeFnLib.mapping()) {
206           if (!key.insert(mapping.first).second) {
207             return op->emitError("only one op to shape mapping allowed, found "
208                                  "multiple for `")
209                    << mapping.first << "`";
210           }
211         }
212       }
213       return success();
214     }
215 
216     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
217                          "allowed as shape.lib attribute");
218   }
219   return success();
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // AnyOp
224 //===----------------------------------------------------------------------===//
225 
226 // TODO: Canonicalization should be implemented for shapes that can be
227 // determined through mixtures of the known dimensions of the inputs.
228 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
229   // Only the last operand is checked because AnyOp is commutative.
230   if (operands.back())
231     return operands.back();
232 
233   return nullptr;
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // AssumingOp
238 //===----------------------------------------------------------------------===//
239 
240 static ParseResult parseAssumingOp(OpAsmParser &parser,
241                                    OperationState &result) {
242   result.regions.reserve(1);
243   Region *doRegion = result.addRegion();
244 
245   auto &builder = parser.getBuilder();
246   OpAsmParser::OperandType cond;
247   if (parser.parseOperand(cond) ||
248       parser.resolveOperand(cond, builder.getType<WitnessType>(),
249                             result.operands))
250     return failure();
251 
252   // Parse optional results type list.
253   if (parser.parseOptionalArrowTypeList(result.types))
254     return failure();
255 
256   // Parse the region and add a terminator if elided.
257   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
258     return failure();
259   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
260 
261   // Parse the optional attribute list.
262   if (parser.parseOptionalAttrDict(result.attributes))
263     return failure();
264   return success();
265 }
266 
267 static void print(OpAsmPrinter &p, AssumingOp op) {
268   bool yieldsResults = !op.results().empty();
269 
270   p << AssumingOp::getOperationName() << " " << op.witness();
271   if (yieldsResults) {
272     p << " -> (" << op.getResultTypes() << ")";
273   }
274   p.printRegion(op.doRegion(),
275                 /*printEntryBlockArgs=*/false,
276                 /*printBlockTerminators=*/yieldsResults);
277   p.printOptionalAttrDict(op->getAttrs());
278 }
279 
280 namespace {
281 // Removes AssumingOp with a passing witness and inlines the region.
282 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
283   using OpRewritePattern<AssumingOp>::OpRewritePattern;
284 
285   LogicalResult matchAndRewrite(AssumingOp op,
286                                 PatternRewriter &rewriter) const override {
287     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
288     if (!witness || !witness.passingAttr())
289       return failure();
290 
291     AssumingOp::inlineRegionIntoParent(op, rewriter);
292     return success();
293   }
294 };
295 
296 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
297   using OpRewritePattern<AssumingOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(AssumingOp op,
300                                 PatternRewriter &rewriter) const override {
301     Block *body = op.getBody();
302     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
303 
304     // Find used values.
305     SmallVector<Value, 4> newYieldOperands;
306     Value opResult, yieldOperand;
307     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
308       std::tie(opResult, yieldOperand) = it;
309       if (!opResult.getUses().empty()) {
310         newYieldOperands.push_back(yieldOperand);
311       }
312     }
313 
314     // Rewrite only if redundant results exist.
315     if (newYieldOperands.size() == yieldOp->getNumOperands())
316       return failure();
317 
318     // Replace yield op in the old assuming op's body and move the entire region
319     // to the new assuming op.
320     rewriter.setInsertionPointToEnd(body);
321     auto newYieldOp =
322         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
323     rewriter.setInsertionPoint(op);
324     auto newOp = rewriter.create<AssumingOp>(
325         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
326     newOp.doRegion().takeBody(op.doRegion());
327 
328     // Use the new results to replace the previously used ones.
329     SmallVector<Value, 4> replacementValues;
330     auto src = newOp.getResults().begin();
331     for (auto it : op.getResults()) {
332       if (it.getUses().empty())
333         replacementValues.push_back(nullptr);
334       else
335         replacementValues.push_back(*src++);
336     }
337     rewriter.replaceOp(op, replacementValues);
338     return success();
339   }
340 };
341 } // namespace
342 
343 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
344                                              MLIRContext *context) {
345   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
346 }
347 
348 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
349 void AssumingOp::getSuccessorRegions(
350     Optional<unsigned> index, ArrayRef<Attribute> operands,
351     SmallVectorImpl<RegionSuccessor> &regions) {
352   // AssumingOp has unconditional control flow into the region and back to the
353   // parent, so return the correct RegionSuccessor purely based on the index
354   // being None or 0.
355   if (index.hasValue()) {
356     regions.push_back(RegionSuccessor(getResults()));
357     return;
358   }
359 
360   regions.push_back(RegionSuccessor(&doRegion()));
361 }
362 
363 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
364                                         PatternRewriter &rewriter) {
365   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
366   auto *assumingBlock = op.getBody();
367   auto initPosition = rewriter.getInsertionPoint();
368   auto *blockAfterAssuming =
369       rewriter.splitBlock(blockBeforeAssuming, initPosition);
370 
371   // Remove the AssumingOp and AssumingYieldOp.
372   auto &yieldOp = assumingBlock->back();
373   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
374   rewriter.replaceOp(op, yieldOp.getOperands());
375   rewriter.eraseOp(&yieldOp);
376 
377   // Merge blocks together as there was no branching behavior from the
378   // AssumingOp.
379   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
380   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
381 }
382 
383 void AssumingOp::build(
384     OpBuilder &builder, OperationState &result, Value witness,
385     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
386 
387   result.addOperands(witness);
388   Region *bodyRegion = result.addRegion();
389   bodyRegion->push_back(new Block);
390   Block &bodyBlock = bodyRegion->front();
391 
392   // Build body.
393   OpBuilder::InsertionGuard guard(builder);
394   builder.setInsertionPointToStart(&bodyBlock);
395   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
396   builder.create<AssumingYieldOp>(result.location, yieldValues);
397 
398   SmallVector<Type, 2> assumingTypes;
399   for (Value v : yieldValues)
400     assumingTypes.push_back(v.getType());
401   result.addTypes(assumingTypes);
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // AssumingAllOp
406 //===----------------------------------------------------------------------===//
407 
408 namespace {
409 struct AssumingAllToCstrEqCanonicalization
410     : public OpRewritePattern<AssumingAllOp> {
411   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
412 
413   LogicalResult matchAndRewrite(AssumingAllOp op,
414                                 PatternRewriter &rewriter) const override {
415     SmallVector<Value, 8> shapes;
416     for (Value w : op.inputs()) {
417       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
418       if (!cstrEqOp)
419         return failure();
420       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
421         return llvm::is_contained(shapes, s);
422       });
423       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
424         return failure();
425       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
426     }
427     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
428     return success();
429   }
430 };
431 } // namespace
432 
433 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
434                                                 MLIRContext *context) {
435   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
436 }
437 
438 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
439   // Iterate in reverse to first handle all constant operands. They are
440   // guaranteed to be the tail of the inputs because this is commutative.
441   for (int idx = operands.size() - 1; idx >= 0; idx--) {
442     Attribute a = operands[idx];
443     // Cannot fold if any inputs are not constant;
444     if (!a)
445       return nullptr;
446 
447     // We do not need to keep statically known values after handling them in
448     // this method.
449     getOperation()->eraseOperand(idx);
450 
451     // Always false if any input is statically known false
452     if (!a.cast<BoolAttr>().getValue())
453       return a;
454   }
455   // If this is reached, all inputs were statically known passing.
456   return BoolAttr::get(getContext(), true);
457 }
458 
459 static LogicalResult verify(AssumingAllOp op) {
460   // Ensure that AssumingAllOp contains at least one operand
461   if (op.getNumOperands() == 0)
462     return op.emitOpError("no operands specified");
463 
464   return success();
465 }
466 
467 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
468                           ValueRange inputs) {
469   build(b, state, b.getType<WitnessType>(), inputs);
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // BroadcastOp
474 //===----------------------------------------------------------------------===//
475 
476 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
477   if (shapes().size() == 1) {
478     // Otherwise, we need a cast which would be a canonicalization, not folding.
479     if (shapes().front().getType() != getType())
480       return nullptr;
481     return shapes().front();
482   }
483 
484   // TODO: Support folding with more than 2 input shapes
485   if (shapes().size() > 2)
486     return nullptr;
487 
488   if (!operands[0] || !operands[1])
489     return nullptr;
490   auto lhsShape = llvm::to_vector<6>(
491       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
492   auto rhsShape = llvm::to_vector<6>(
493       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
494   SmallVector<int64_t, 6> resultShape;
495 
496   // If the shapes are not compatible, we can't fold it.
497   // TODO: Fold to an "error".
498   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
499     return nullptr;
500 
501   Builder builder(getContext());
502   return builder.getIndexTensorAttr(resultShape);
503 }
504 
505 static LogicalResult verify(BroadcastOp op) {
506   return verifyShapeOrExtentTensorOp(op);
507 }
508 
509 namespace {
510 template <typename OpTy>
511 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
512   using OpRewritePattern<OpTy>::OpRewritePattern;
513 
514   LogicalResult matchAndRewrite(OpTy op,
515                                 PatternRewriter &rewriter) const override {
516     // Find unique operands.
517     SmallVector<Value, 2> unique;
518     for (Value v : op.getOperands()) {
519       if (!llvm::is_contained(unique, v))
520         unique.push_back(v);
521     }
522 
523     // Reduce op to equivalent with unique operands.
524     if (unique.size() < op.getNumOperands()) {
525       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
526                                         op->getAttrs());
527       return success();
528     }
529 
530     return failure();
531   }
532 };
533 
534 template <typename OpTy>
535 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
536   using OpRewritePattern<OpTy>::OpRewritePattern;
537 
538   LogicalResult matchAndRewrite(OpTy op,
539                                 PatternRewriter &rewriter) const override {
540     auto isPotentiallyNonEmptyShape = [](Value shape) {
541       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
542         if (extentTensorTy.getDimSize(0) == 0)
543           return false;
544       }
545       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
546         if (constShape.shape().empty())
547           return false;
548       }
549       return true;
550     };
551     auto newOperands = llvm::to_vector<8>(
552         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
553 
554     // Reduce op to equivalent without empty shape operands.
555     if (newOperands.size() < op.getNumOperands()) {
556       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
557                                         op->getAttrs());
558       return success();
559     }
560 
561     return failure();
562   }
563 };
564 
565 struct BroadcastForwardSingleOperandPattern
566     : public OpRewritePattern<BroadcastOp> {
567   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
568 
569   LogicalResult matchAndRewrite(BroadcastOp op,
570                                 PatternRewriter &rewriter) const override {
571     if (op.getNumOperands() != 1)
572       return failure();
573     Value replacement = op.shapes().front();
574 
575     // Insert cast if needed.
576     if (replacement.getType() != op.getType()) {
577       auto loc = op.getLoc();
578       if (op.getType().isa<ShapeType>()) {
579         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
580       } else {
581         assert(!op.getType().isa<ShapeType>() &&
582                !replacement.getType().isa<ShapeType>() &&
583                "expect extent tensor cast");
584         replacement =
585             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
586       }
587     }
588 
589     rewriter.replaceOp(op, replacement);
590     return success();
591   }
592 };
593 
594 struct BroadcastFoldConstantOperandsPattern
595     : public OpRewritePattern<BroadcastOp> {
596   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
597 
598   LogicalResult matchAndRewrite(BroadcastOp op,
599                                 PatternRewriter &rewriter) const override {
600     SmallVector<int64_t, 8> foldedConstantShape;
601     SmallVector<Value, 8> newShapeOperands;
602     for (Value shape : op.shapes()) {
603       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
604         SmallVector<int64_t, 8> newFoldedConstantShape;
605         if (OpTrait::util::getBroadcastedShape(
606                 foldedConstantShape,
607                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
608                 newFoldedConstantShape)) {
609           foldedConstantShape = newFoldedConstantShape;
610           continue;
611         }
612       }
613       newShapeOperands.push_back(shape);
614     }
615 
616     // Need at least two constant operands to fold anything.
617     if (op.getNumOperands() - newShapeOperands.size() < 2)
618       return failure();
619 
620     auto foldedConstantOperandsTy = RankedTensorType::get(
621         {static_cast<int64_t>(foldedConstantShape.size())},
622         rewriter.getIndexType());
623     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
624         op.getLoc(), foldedConstantOperandsTy,
625         rewriter.getIndexTensorAttr(foldedConstantShape)));
626     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
627                                              newShapeOperands);
628     return success();
629   }
630 };
631 
632 template <typename OpTy>
633 struct CanonicalizeCastExtentTensorOperandsPattern
634     : public OpRewritePattern<OpTy> {
635   using OpRewritePattern<OpTy>::OpRewritePattern;
636 
637   LogicalResult matchAndRewrite(OpTy op,
638                                 PatternRewriter &rewriter) const override {
639     // Canonicalize operands.
640     bool anyChange = false;
641     auto canonicalizeOperand = [&](Value operand) {
642       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
643         // Only eliminate the cast if it holds no shape information.
644         bool isInformationLoosingCast =
645             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
646         if (isInformationLoosingCast) {
647           anyChange = true;
648           return castOp.source();
649         }
650       }
651       return operand;
652     };
653     auto newOperands = llvm::to_vector<8>(
654         llvm::map_range(op.getOperands(), canonicalizeOperand));
655 
656     // Rewrite op if any change required.
657     if (!anyChange)
658       return failure();
659     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
660     return success();
661   }
662 };
663 
664 struct BroadcastConcretizeResultTypePattern
665     : public OpRewritePattern<BroadcastOp> {
666   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
667 
668   LogicalResult matchAndRewrite(BroadcastOp op,
669                                 PatternRewriter &rewriter) const override {
670     // Only concretize dynamic extent tensor result types.
671     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
672     if (!resultTy || !resultTy.isDynamicDim(0))
673       return failure();
674 
675     // Infer resulting shape rank if possible.
676     int64_t maxRank = 0;
677     for (Value shape : op.shapes()) {
678       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
679         // Cannot infer resulting shape rank if any operand is dynamically
680         // ranked.
681         if (extentTensorTy.isDynamicDim(0))
682           return failure();
683         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
684       }
685     }
686 
687     auto newOp = rewriter.create<BroadcastOp>(
688         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
689     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
690     return success();
691   }
692 };
693 } // namespace
694 
695 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
696                                               MLIRContext *context) {
697   patterns.add<BroadcastConcretizeResultTypePattern,
698                BroadcastFoldConstantOperandsPattern,
699                BroadcastForwardSingleOperandPattern,
700                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
701                RemoveDuplicateOperandsPattern<BroadcastOp>,
702                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // ConcatOp
707 //===----------------------------------------------------------------------===//
708 
709 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
710   if (!operands[0] || !operands[1])
711     return nullptr;
712   auto lhsShape = llvm::to_vector<6>(
713       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
714   auto rhsShape = llvm::to_vector<6>(
715       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
716   SmallVector<int64_t, 6> resultShape;
717   resultShape.append(lhsShape.begin(), lhsShape.end());
718   resultShape.append(rhsShape.begin(), rhsShape.end());
719   Builder builder(getContext());
720   return builder.getIndexTensorAttr(resultShape);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // ConstShapeOp
725 //===----------------------------------------------------------------------===//
726 
727 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
728   p << "shape.const_shape ";
729   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
730   p << "[";
731   interleaveComma(op.shape().getValues<int64_t>(), p,
732                   [&](int64_t i) { p << i; });
733   p << "] : ";
734   p.printType(op.getType());
735 }
736 
737 static ParseResult parseConstShapeOp(OpAsmParser &parser,
738                                      OperationState &result) {
739   if (parser.parseOptionalAttrDict(result.attributes))
740     return failure();
741   // We piggy-back on ArrayAttr parsing, though we don't internally store the
742   // shape as an ArrayAttr.
743   // TODO: Implement custom parser and maybe make syntax a bit more concise.
744   Attribute extentsRaw;
745   NamedAttrList dummy;
746   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
747     return failure();
748   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
749   if (!extentsArray)
750     return failure();
751   SmallVector<int64_t, 6> ints;
752   for (Attribute extent : extentsArray) {
753     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
754     if (!attr)
755       return failure();
756     ints.push_back(attr.getInt());
757   }
758   Builder &builder = parser.getBuilder();
759   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
760   Type resultTy;
761   if (parser.parseColonType(resultTy))
762     return failure();
763   result.types.push_back(resultTy);
764   return success();
765 }
766 
767 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
768 
769 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
770                                                MLIRContext *context) {
771   patterns.add<TensorCastConstShape>(context);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // CstrBroadcastableOp
776 //===----------------------------------------------------------------------===//
777 
778 void CstrBroadcastableOp::getCanonicalizationPatterns(
779     RewritePatternSet &patterns, MLIRContext *context) {
780   // Canonicalization patterns have overlap with the considerations during
781   // folding in case additional shape information is inferred at some point that
782   // does not result in folding.
783   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
784                CstrBroadcastableEqOps,
785                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
786                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
787 }
788 
789 // Return true if there is exactly one attribute not representing a scalar
790 // broadcast.
791 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
792   bool nonScalarSeen = false;
793   for (Attribute a : attributes) {
794     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
795       if (nonScalarSeen)
796         return false;
797       nonScalarSeen = true;
798     }
799   }
800   return true;
801 }
802 
803 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
804   // No broadcasting is needed if all operands but one are scalar.
805   if (hasAtMostSingleNonScalar(operands))
806     return BoolAttr::get(getContext(), true);
807 
808   if ([&] {
809         SmallVector<SmallVector<int64_t, 6>, 6> extents;
810         for (const auto &operand : operands) {
811           if (!operand)
812             return false;
813           extents.push_back(llvm::to_vector<6>(
814               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
815         }
816         return OpTrait::util::staticallyKnownBroadcastable(extents);
817       }())
818     return BoolAttr::get(getContext(), true);
819 
820   // Lastly, see if folding can be completed based on what constraints are known
821   // on the input shapes.
822   if ([&] {
823         SmallVector<SmallVector<int64_t, 6>, 6> extents;
824         for (auto shapeValue : shapes()) {
825           extents.emplace_back();
826           if (failed(getShapeVec(shapeValue, extents.back())))
827             return false;
828         }
829         return OpTrait::util::staticallyKnownBroadcastable(extents);
830       }())
831     return BoolAttr::get(getContext(), true);
832 
833   // Because a failing witness result here represents an eventual assertion
834   // failure, we do not replace it with a constant witness.
835   return nullptr;
836 }
837 
838 static LogicalResult verify(CstrBroadcastableOp op) {
839   // Ensure that AssumingAllOp contains at least one operand
840   if (op.getNumOperands() < 2)
841     return op.emitOpError("required at least 2 input shapes");
842   return success();
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // CstrEqOp
847 //===----------------------------------------------------------------------===//
848 
849 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
850                                            MLIRContext *context) {
851   // If inputs are equal, return passing witness
852   patterns.add<CstrEqEqOps>(context);
853 }
854 
855 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
856   if (llvm::all_of(operands,
857                    [&](Attribute a) { return a && a == operands[0]; }))
858     return BoolAttr::get(getContext(), true);
859 
860   // Because a failing witness result here represents an eventual assertion
861   // failure, we do not try to replace it with a constant witness. Similarly, we
862   // cannot if there are any non-const inputs.
863   return nullptr;
864 }
865 
866 //===----------------------------------------------------------------------===//
867 // ConstSizeOp
868 //===----------------------------------------------------------------------===//
869 
870 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
871                         int64_t value) {
872   build(builder, result, builder.getIndexAttr(value));
873 }
874 
875 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
876 
877 void ConstSizeOp::getAsmResultNames(
878     llvm::function_ref<void(Value, StringRef)> setNameFn) {
879   SmallString<4> buffer;
880   llvm::raw_svector_ostream os(buffer);
881   os << "c" << value();
882   setNameFn(getResult(), os.str());
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // ConstWitnessOp
887 //===----------------------------------------------------------------------===//
888 
889 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
890 
891 //===----------------------------------------------------------------------===//
892 // CstrRequireOp
893 //===----------------------------------------------------------------------===//
894 
895 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
896   return operands[0];
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // DivOp
901 //===----------------------------------------------------------------------===//
902 
903 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
904   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
905   if (!lhs)
906     return nullptr;
907   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
908   if (!rhs)
909     return nullptr;
910 
911   // Division in APInt does not follow floor(lhs, rhs) when the result is
912   // negative. Rather, APInt rounds toward zero.
913   APInt quotient, remainder;
914   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
915   if (quotient.isNegative() && !remainder.isNullValue()) {
916     quotient -= 1;
917   }
918 
919   Type indexTy = IndexType::get(getContext());
920   return IntegerAttr::get(indexTy, quotient);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // ShapeEqOp
925 //===----------------------------------------------------------------------===//
926 
927 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
928   bool allSame = true;
929   if (!operands.empty() && !operands[0])
930     return {};
931   for (Attribute operand : operands.drop_front(1)) {
932     if (!operand)
933       return {};
934     allSame = allSame && operand == operands[0];
935   }
936   return BoolAttr::get(getContext(), allSame);
937 }
938 
939 //===----------------------------------------------------------------------===//
940 // IndexToSizeOp
941 //===----------------------------------------------------------------------===//
942 
943 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
944   // Constant values of both types, `shape.size` and `index`, are represented as
945   // `IntegerAttr`s which makes constant folding simple.
946   if (Attribute arg = operands[0])
947     return arg;
948   return {};
949 }
950 
951 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
952                                                 MLIRContext *context) {
953   patterns.add<SizeToIndexToSizeCanonicalization>(context);
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // FromExtentsOp
958 //===----------------------------------------------------------------------===//
959 
960 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
961   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
962     return nullptr;
963   SmallVector<int64_t, 6> extents;
964   for (auto attr : operands)
965     extents.push_back(attr.cast<IntegerAttr>().getInt());
966   Builder builder(getContext());
967   return builder.getIndexTensorAttr(extents);
968 }
969 
970 //===----------------------------------------------------------------------===//
971 // FunctionLibraryOp
972 //===----------------------------------------------------------------------===//
973 
974 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
975                               StringRef name) {
976   result.attributes.push_back(builder.getNamedAttr(
977       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
978 }
979 
980 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
981   auto attr = mapping()
982                   .get(op->getName().getIdentifier())
983                   .dyn_cast_or_null<FlatSymbolRefAttr>();
984   if (!attr)
985     return nullptr;
986   return lookupSymbol<FuncOp>(attr);
987 }
988 
989 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
990                                    OperationState &result) {
991   // Parse the op name.
992   StringAttr nameAttr;
993   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
994                              result.attributes))
995     return failure();
996 
997   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
998     return failure();
999 
1000   auto *bodyRegion = result.addRegion();
1001   if (parser.parseRegion(*bodyRegion))
1002     return failure();
1003 
1004   if (parser.parseKeyword("mapping"))
1005     return failure();
1006 
1007   DictionaryAttr mappingAttr;
1008   if (parser.parseAttribute(mappingAttr,
1009                             parser.getBuilder().getType<NoneType>(), "mapping",
1010                             result.attributes))
1011     return failure();
1012   return success();
1013 }
1014 
1015 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1016   p << op.getOperationName() << ' ';
1017   p.printSymbolName(op.getName());
1018   p.printOptionalAttrDictWithKeyword(
1019       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1020   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1021                 /*printBlockTerminators=*/false);
1022   p << " mapping ";
1023   p.printAttributeWithoutType(op.mappingAttr());
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // GetExtentOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 Optional<int64_t> GetExtentOp::getConstantDim() {
1031   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1032     return constSizeOp.value().getLimitedValue();
1033   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1034     return constantOp.value().cast<IntegerAttr>().getInt();
1035   return llvm::None;
1036 }
1037 
1038 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1039   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1040   if (!elements)
1041     return nullptr;
1042   Optional<int64_t> dim = getConstantDim();
1043   if (!dim.hasValue())
1044     return nullptr;
1045   if (dim.getValue() >= elements.getNumElements())
1046     return nullptr;
1047   return elements.getValue({(uint64_t)dim.getValue()});
1048 }
1049 
1050 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1051                         int64_t dim) {
1052   auto loc = result.location;
1053   auto dimAttr = builder.getIndexAttr(dim);
1054   if (shape.getType().isa<ShapeType>()) {
1055     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1056     build(builder, result, builder.getType<SizeType>(), shape, dim);
1057   } else {
1058     Value dim =
1059         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1060     build(builder, result, builder.getIndexType(), shape, dim);
1061   }
1062 }
1063 
1064 //===----------------------------------------------------------------------===//
1065 // IsBroadcastableOp
1066 //===----------------------------------------------------------------------===//
1067 
1068 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1069                                                     MLIRContext *context) {
1070   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1071 }
1072 
1073 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1074   // Can always broadcast fewer than two shapes.
1075   if (operands.size() < 2) {
1076     return BoolAttr::get(getContext(), true);
1077   }
1078 
1079   return nullptr;
1080 }
1081 
1082 //===----------------------------------------------------------------------===//
1083 // RankOp
1084 //===----------------------------------------------------------------------===//
1085 
1086 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1087   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1088   if (!shape)
1089     return {};
1090   int64_t rank = shape.getNumElements();
1091   Builder builder(getContext());
1092   return builder.getIndexAttr(rank);
1093 }
1094 
1095 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1096 /// Constant folding fails in cases where only the rank is constant, not the
1097 /// shape itself.
1098 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1099 ///
1100 /// Example:
1101 ///
1102 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1103 /// %rank = shape.rank %shape
1104 ///
1105 /// becomes
1106 ///
1107 /// %rank = shape.const_size 3
1108 
1109 namespace {
1110 struct RankShapeOfCanonicalizationPattern
1111     : public OpRewritePattern<shape::RankOp> {
1112   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1113 
1114   LogicalResult matchAndRewrite(shape::RankOp op,
1115                                 PatternRewriter &rewriter) const override {
1116     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1117     if (!shapeOfOp)
1118       return failure();
1119     auto rankedTensorType =
1120         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1121     if (!rankedTensorType)
1122       return failure();
1123     int64_t rank = rankedTensorType.getRank();
1124     if (op.getType().isa<IndexType>()) {
1125       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1126     } else if (op.getType().isa<shape::SizeType>()) {
1127       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1128     } else {
1129       return failure();
1130     }
1131     return success();
1132   }
1133 };
1134 } // namespace
1135 
1136 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1137                                                 MLIRContext *context) {
1138   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1139 }
1140 
1141 //===----------------------------------------------------------------------===//
1142 // NumElementsOp
1143 //===----------------------------------------------------------------------===//
1144 
1145 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1146 
1147   // Fold only when argument constant.
1148   Attribute shape = operands[0];
1149   if (!shape)
1150     return {};
1151 
1152   APInt product(64, 1);
1153   for (auto value : shape.cast<DenseIntElementsAttr>())
1154     product *= value;
1155   Builder builder(getContext());
1156   return builder.getIndexAttr(product.getLimitedValue());
1157 }
1158 
1159 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
1160                           Value shape) {
1161   if (shape.getType().isa<ShapedType>()) {
1162     auto type = builder.getIndexType();
1163     return build(builder, result, type, shape);
1164   }
1165   auto type = SizeType::get(builder.getContext());
1166   return build(builder, result, type, shape);
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // MaxOp
1171 //===----------------------------------------------------------------------===//
1172 
1173 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1174   // If operands are equal, just propagate one.
1175   if (lhs() == rhs())
1176     return lhs();
1177   return nullptr;
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // MinOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1185   // If operands are equal, just propagate one.
1186   if (lhs() == rhs())
1187     return lhs();
1188   return nullptr;
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // MulOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1196   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1197   if (!lhs)
1198     return nullptr;
1199   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1200   if (!rhs)
1201     return nullptr;
1202   APInt folded = lhs.getValue() * rhs.getValue();
1203   Type indexTy = IndexType::get(getContext());
1204   return IntegerAttr::get(indexTy, folded);
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // ShapeOfOp
1209 //===----------------------------------------------------------------------===//
1210 
1211 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1212   auto type = getOperand().getType().dyn_cast<ShapedType>();
1213   if (!type || !type.hasStaticShape())
1214     return nullptr;
1215   Builder builder(getContext());
1216   return builder.getIndexTensorAttr(type.getShape());
1217 }
1218 
1219 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
1220   if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
1221     int64_t rank =
1222         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1223     Type indexTy = builder.getIndexType();
1224     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1225     return ShapeOfOp::build(builder, result, extentTensorTy, arg);
1226   }
1227   Type shapeTy = builder.getType<ShapeType>();
1228   return ShapeOfOp::build(builder, result, shapeTy, arg);
1229 }
1230 
1231 namespace {
1232 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1233   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1234 
1235   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1236                                 PatternRewriter &rewriter) const override {
1237     if (!op.arg().getType().isa<ShapedType>())
1238       return failure();
1239     if (op.getType().isa<ShapedType>())
1240       return failure();
1241 
1242     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1243     return success();
1244   }
1245 };
1246 
1247 // Canonicalize
1248 // ```
1249 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1250 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1251 // ```
1252 // to
1253 // ```
1254 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1255 // ```
1256 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1257   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1258 
1259   LogicalResult matchAndRewrite(tensor::CastOp op,
1260                                 PatternRewriter &rewriter) const override {
1261     auto ty = op.getType().dyn_cast<RankedTensorType>();
1262     if (!ty || ty.getRank() != 1)
1263       return failure();
1264 
1265     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1266     if (!shapeOfOp)
1267       return failure();
1268 
1269     // Argument type must be ranked and must not conflict.
1270     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1271     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1272       return failure();
1273 
1274     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1275     return success();
1276   }
1277 };
1278 } // namespace
1279 
1280 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1281                                             MLIRContext *context) {
1282   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // SizeToIndexOp
1287 //===----------------------------------------------------------------------===//
1288 
1289 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1290   // Constant values of both types, `shape.size` and `index`, are represented as
1291   // `IntegerAttr`s which makes constant folding simple.
1292   if (Attribute arg = operands[0])
1293     return arg;
1294   return impl::foldCastOp(*this);
1295 }
1296 
1297 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1298                                                 MLIRContext *context) {
1299   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // YieldOp
1304 //===----------------------------------------------------------------------===//
1305 
1306 static LogicalResult verify(shape::YieldOp op) {
1307   auto *parentOp = op->getParentOp();
1308   auto results = parentOp->getResults();
1309   auto operands = op.getOperands();
1310 
1311   if (parentOp->getNumResults() != op.getNumOperands())
1312     return op.emitOpError() << "number of operands does not match number of "
1313                                "results of its parent";
1314   for (auto e : llvm::zip(results, operands))
1315     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1316       return op.emitOpError()
1317              << "types mismatch between yield op and its parent";
1318 
1319   return success();
1320 }
1321 
1322 //===----------------------------------------------------------------------===//
1323 // SplitAtOp
1324 //===----------------------------------------------------------------------===//
1325 
1326 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1327                               SmallVectorImpl<OpFoldResult> &results) {
1328   if (!operands[0] || !operands[1])
1329     return failure();
1330   auto shapeVec = llvm::to_vector<6>(
1331       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1332   auto shape = llvm::makeArrayRef(shapeVec);
1333   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1334   // Verify that the split point is in the correct range.
1335   // TODO: Constant fold to an "error".
1336   int64_t rank = shape.size();
1337   if (!(-rank <= splitPoint && splitPoint <= rank))
1338     return failure();
1339   if (splitPoint < 0)
1340     splitPoint += shape.size();
1341   Builder builder(operands[0].getContext());
1342   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1343   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1344   return success();
1345 }
1346 
1347 //===----------------------------------------------------------------------===//
1348 // ToExtentTensorOp
1349 //===----------------------------------------------------------------------===//
1350 
1351 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1352   if (!operands[0])
1353     return impl::foldCastOp(*this);
1354   Builder builder(getContext());
1355   auto shape = llvm::to_vector<6>(
1356       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1357   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1358                                     builder.getIndexType());
1359   return DenseIntElementsAttr::get(type, shape);
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // ReduceOp
1364 //===----------------------------------------------------------------------===//
1365 
1366 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1367                      ValueRange initVals) {
1368   result.addOperands(shape);
1369   result.addOperands(initVals);
1370 
1371   Region *bodyRegion = result.addRegion();
1372   bodyRegion->push_back(new Block);
1373   Block &bodyBlock = bodyRegion->front();
1374   bodyBlock.addArgument(builder.getIndexType());
1375 
1376   Type elementType;
1377   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1378     elementType = tensorType.getElementType();
1379   else
1380     elementType = SizeType::get(builder.getContext());
1381   bodyBlock.addArgument(elementType);
1382 
1383   for (Type initValType : initVals.getTypes()) {
1384     bodyBlock.addArgument(initValType);
1385     result.addTypes(initValType);
1386   }
1387 }
1388 
1389 static LogicalResult verify(ReduceOp op) {
1390   // Verify block arg types.
1391   Block &block = op.region().front();
1392 
1393   // The block takes index, extent, and aggregated values as arguments.
1394   auto blockArgsCount = op.initVals().size() + 2;
1395   if (block.getNumArguments() != blockArgsCount)
1396     return op.emitOpError() << "ReduceOp body is expected to have "
1397                             << blockArgsCount << " arguments";
1398 
1399   // The first block argument is the index and must always be of type `index`.
1400   if (!block.getArgument(0).getType().isa<IndexType>())
1401     return op.emitOpError(
1402         "argument 0 of ReduceOp body is expected to be of IndexType");
1403 
1404   // The second block argument is the extent and must be of type `size` or
1405   // `index`, depending on whether the reduce operation is applied to a shape or
1406   // to an extent tensor.
1407   Type extentTy = block.getArgument(1).getType();
1408   if (op.shape().getType().isa<ShapeType>()) {
1409     if (!extentTy.isa<SizeType>())
1410       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1411                             "SizeType if the ReduceOp operates on a ShapeType");
1412   } else {
1413     if (!extentTy.isa<IndexType>())
1414       return op.emitOpError(
1415           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1416           "ReduceOp operates on an extent tensor");
1417   }
1418 
1419   for (auto type : llvm::enumerate(op.initVals()))
1420     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1421       return op.emitOpError()
1422              << "type mismatch between argument " << type.index() + 2
1423              << " of ReduceOp body and initial value " << type.index();
1424   return success();
1425 }
1426 
1427 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1428   // Parse operands.
1429   SmallVector<OpAsmParser::OperandType, 3> operands;
1430   Type shapeOrExtentTensorType;
1431   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1432                               OpAsmParser::Delimiter::Paren) ||
1433       parser.parseColonType(shapeOrExtentTensorType) ||
1434       parser.parseOptionalArrowTypeList(result.types))
1435     return failure();
1436 
1437   // Resolve operands.
1438   auto initVals = llvm::makeArrayRef(operands).drop_front();
1439   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1440                             result.operands) ||
1441       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1442                              result.operands))
1443     return failure();
1444 
1445   // Parse the body.
1446   Region *body = result.addRegion();
1447   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1448     return failure();
1449 
1450   // Parse attributes.
1451   if (parser.parseOptionalAttrDict(result.attributes))
1452     return failure();
1453 
1454   return success();
1455 }
1456 
1457 static void print(OpAsmPrinter &p, ReduceOp op) {
1458   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1459     << ") : " << op.shape().getType();
1460   p.printOptionalArrowTypeList(op.getResultTypes());
1461   p.printRegion(op.region());
1462   p.printOptionalAttrDict(op->getAttrs());
1463 }
1464 
1465 #define GET_OP_CLASSES
1466 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1467