1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 using namespace mlir::shape;
23 
24 namespace {
25 #include "ShapeCanonicalization.inc"
26 }
27 
28 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
29   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
30 }
31 
32 static bool isErrorPropagationPossible(TypeRange operandTypes) {
33   for (Type ty : operandTypes)
34     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
35       return true;
36   return false;
37 }
38 
39 static LogicalResult verifySizeOrIndexOp(Operation *op) {
40   assert(op != nullptr && op->getNumResults() == 1);
41   Type resultTy = op->getResultTypes().front();
42   if (isErrorPropagationPossible(op->getOperandTypes())) {
43     if (!resultTy.isa<SizeType>())
44       return op->emitOpError()
45              << "if at least one of the operands can hold error values then "
46                 "the result must be of type `size` to propagate them";
47   }
48   return success();
49 }
50 
51 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
52   assert(op != nullptr && op->getNumResults() == 1);
53   Type resultTy = op->getResultTypes().front();
54   if (isErrorPropagationPossible(op->getOperandTypes())) {
55     if (!resultTy.isa<ShapeType>())
56       return op->emitOpError()
57              << "if at least one of the operands can hold error values then "
58                 "the result must be of type `shape` to propagate them";
59   }
60   return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // InlinerInterface
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 /// This class defines the interface for inlining shape dialect ops.
69 struct ShapeInlinerInterface : public DialectInlinerInterface {
70   using DialectInlinerInterface::DialectInlinerInterface;
71 
72   // Returns true if the given region 'src' can be inlined into the region
73   // 'dest' that is attached to an operation registered to the current dialect.
74   bool isLegalToInline(Region *dest, Region *src,
75                        BlockAndValueMapping &) const final {
76     return true;
77   }
78 
79   // Returns true if the given operation 'op', that is registered to this
80   // dialect, can be inlined into the region 'dest' that is attached to an
81   // operation registered to the current dialect.
82   bool isLegalToInline(Operation *op, Region *dest,
83                        BlockAndValueMapping &) const final {
84     return true;
85   }
86 };
87 } // namespace
88 
89 void ShapeDialect::initialize() {
90   addOperations<
91 #define GET_OP_LIST
92 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
93       >();
94   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
95            WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (type.isa<IndexType>())
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "component")
126     return ComponentType::get(getContext());
127   if (keyword == "element")
128     return ElementType::get(getContext());
129   if (keyword == "shape")
130     return ShapeType::get(getContext());
131   if (keyword == "size")
132     return SizeType::get(getContext());
133   if (keyword == "value_shape")
134     return ValueShapeType::get(getContext());
135   if (keyword == "witness")
136     return WitnessType::get(getContext());
137 
138   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
139   return Type();
140 }
141 
142 /// Print a type registered to this dialect.
143 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
144   switch (type.getKind()) {
145   case ShapeTypes::Component:
146     os << "component";
147     return;
148   case ShapeTypes::Element:
149     os << "element";
150     return;
151   case ShapeTypes::Size:
152     os << "size";
153     return;
154   case ShapeTypes::Shape:
155     os << "shape";
156     return;
157   case ShapeTypes::ValueShape:
158     os << "value_shape";
159     return;
160   case ShapeTypes::Witness:
161     os << "witness";
162     return;
163   default:
164     llvm_unreachable("unexpected 'shape' type kind");
165   }
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // AnyOp
170 //===----------------------------------------------------------------------===//
171 
172 // TODO: Canonicalization should be implemented for shapes that can be
173 // determined through mixtures of the known dimensions of the inputs.
174 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
175   // Only the last operand is checked because AnyOp is commutative.
176   if (operands.back())
177     return operands.back();
178 
179   return nullptr;
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // AssumingOp
184 //===----------------------------------------------------------------------===//
185 
186 static ParseResult parseAssumingOp(OpAsmParser &parser,
187                                    OperationState &result) {
188   result.regions.reserve(1);
189   Region *doRegion = result.addRegion();
190 
191   auto &builder = parser.getBuilder();
192   OpAsmParser::OperandType cond;
193   if (parser.parseOperand(cond) ||
194       parser.resolveOperand(cond, builder.getType<WitnessType>(),
195                             result.operands))
196     return failure();
197 
198   // Parse optional results type list.
199   if (parser.parseOptionalArrowTypeList(result.types))
200     return failure();
201 
202   // Parse the region and add a terminator if elided.
203   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
204     return failure();
205   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
206 
207   // Parse the optional attribute list.
208   if (parser.parseOptionalAttrDict(result.attributes))
209     return failure();
210   return success();
211 }
212 
213 static void print(OpAsmPrinter &p, AssumingOp op) {
214   bool yieldsResults = !op.results().empty();
215 
216   p << AssumingOp::getOperationName() << " " << op.witness();
217   if (yieldsResults) {
218     p << " -> (" << op.getResultTypes() << ")";
219   }
220   p.printRegion(op.doRegion(),
221                 /*printEntryBlockArgs=*/false,
222                 /*printBlockTerminators=*/yieldsResults);
223   p.printOptionalAttrDict(op.getAttrs());
224 }
225 
226 namespace {
227 // Removes AssumingOp with a passing witness and inlines the region.
228 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
229   using OpRewritePattern<AssumingOp>::OpRewritePattern;
230 
231   LogicalResult matchAndRewrite(AssumingOp op,
232                                 PatternRewriter &rewriter) const override {
233     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
234     if (!witness || !witness.passingAttr())
235       return failure();
236 
237     AssumingOp::inlineRegionIntoParent(op, rewriter);
238     return success();
239   }
240 };
241 } // namespace
242 
243 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
244                                              MLIRContext *context) {
245   // If taking a passing witness, inline region.
246   patterns.insert<AssumingWithTrue>(context);
247 }
248 
249 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
250                                         PatternRewriter &rewriter) {
251   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
252   auto *assumingBlock = op.getBody();
253   auto initPosition = rewriter.getInsertionPoint();
254   auto *blockAfterAssuming =
255       rewriter.splitBlock(blockBeforeAssuming, initPosition);
256 
257   // Remove the AssumingOp and AssumingYieldOp.
258   auto &yieldOp = assumingBlock->back();
259   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
260   rewriter.replaceOp(op, yieldOp.getOperands());
261   rewriter.eraseOp(&yieldOp);
262 
263   // Merge blocks together as there was no branching behavior from the
264   // AssumingOp.
265   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
266   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // AssumingAllOp
271 //===----------------------------------------------------------------------===//
272 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
273   // Iterate in reverse to first handle all constant operands. They are
274   // guaranteed to be the tail of the inputs because this is commutative.
275   for (int idx = operands.size() - 1; idx >= 0; idx--) {
276     Attribute a = operands[idx];
277     // Cannot fold if any inputs are not constant;
278     if (!a)
279       return nullptr;
280 
281     // We do not need to keep statically known values after handling them in
282     // this method.
283     getOperation()->eraseOperand(idx);
284 
285     // Always false if any input is statically known false
286     if (!a.cast<BoolAttr>().getValue())
287       return a;
288   }
289   // If this is reached, all inputs were statically known passing.
290   return BoolAttr::get(true, getContext());
291 }
292 
293 static LogicalResult verify(AssumingAllOp op) {
294   // Ensure that AssumingAllOp contains at least one operand
295   if (op.getNumOperands() == 0)
296     return op.emitOpError("no operands specified");
297 
298   return success();
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // BroadcastOp
303 //===----------------------------------------------------------------------===//
304 
305 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
306   if (!operands[1])
307     return nullptr;
308 
309   auto rhsShape = llvm::to_vector<6>(
310       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
311   if (rhsShape.empty())
312     return lhs();
313 
314   if (!operands[0])
315     return nullptr;
316 
317   auto lhsShape = llvm::to_vector<6>(
318       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
319   if (lhsShape.empty())
320     return rhs();
321 
322   SmallVector<int64_t, 6> resultShape;
323   // If the shapes are not compatible, we can't fold it.
324   // TODO: Fold to an "error".
325   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
326     return nullptr;
327   Builder builder(getContext());
328   return builder.getIndexTensorAttr(resultShape);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ConcatOp
333 //===----------------------------------------------------------------------===//
334 
335 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
336   if (!operands[0] || !operands[1])
337     return nullptr;
338   auto lhsShape = llvm::to_vector<6>(
339       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
340   auto rhsShape = llvm::to_vector<6>(
341       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
342   SmallVector<int64_t, 6> resultShape;
343   resultShape.append(lhsShape.begin(), lhsShape.end());
344   resultShape.append(rhsShape.begin(), rhsShape.end());
345   Builder builder(getContext());
346   return builder.getIndexTensorAttr(resultShape);
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // ConstShapeOp
351 //===----------------------------------------------------------------------===//
352 
353 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
354   p << "shape.const_shape ";
355   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
356   p << "[";
357   interleaveComma(op.shape().getValues<int64_t>(), p,
358                   [&](int64_t i) { p << i; });
359   p << "] : ";
360   p.printType(op.getType());
361 }
362 
363 static ParseResult parseConstShapeOp(OpAsmParser &parser,
364                                      OperationState &result) {
365   if (parser.parseOptionalAttrDict(result.attributes))
366     return failure();
367   // We piggy-back on ArrayAttr parsing, though we don't internally store the
368   // shape as an ArrayAttr.
369   // TODO: Implement custom parser and maybe make syntax a bit more concise.
370   Attribute extentsRaw;
371   NamedAttrList dummy;
372   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
373     return failure();
374   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
375   if (!extentsArray)
376     return failure();
377   SmallVector<int64_t, 6> ints;
378   for (Attribute extent : extentsArray) {
379     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
380     if (!attr)
381       return failure();
382     ints.push_back(attr.getInt());
383   }
384   Builder &builder = parser.getBuilder();
385   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
386   Type resultTy;
387   if (parser.parseColonType(resultTy))
388     return failure();
389   result.types.push_back(resultTy);
390   return success();
391 }
392 
393 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
394 
395 //===----------------------------------------------------------------------===//
396 // CstrBroadcastableOp
397 //===----------------------------------------------------------------------===//
398 
399 namespace {
400 // Given an input shape Value, try to obtain the shape's values.
401 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
402   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
403     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
404     if (!type.hasRank())
405       return failure();
406     shapeValues = llvm::to_vector<6>(type.getShape());
407     return success();
408   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
409     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
410     return success();
411   } else {
412     return failure();
413   }
414 }
415 
416 // For shapes that were created by some operations, we can obtain partial
417 // information on the shapes and sometimes determine if they will be
418 // broadcastable with that.
419 struct CstrBroadcastablePartialInfo
420     : public OpRewritePattern<CstrBroadcastableOp> {
421   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
422 
423   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
424                                 PatternRewriter &rewriter) const override {
425     SmallVector<int64_t, 6> lhsShape, rhsShape;
426     if (failed(getShapeVec(op.lhs(), lhsShape)))
427       return failure();
428     if (failed(getShapeVec(op.rhs(), rhsShape)))
429       return failure();
430     if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
431       return failure();
432 
433     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
434     return success();
435   }
436 };
437 
438 // Scalars are always broadcastable.
439 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
440   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
441 
442   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
443                                 PatternRewriter &rewriter) const override {
444     SmallVector<int64_t, 6> shape;
445     if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
446       return failure();
447     if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
448       return failure();
449 
450     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
451     return success();
452   }
453 };
454 
455 } // namespace
456 
457 void CstrBroadcastableOp::getCanonicalizationPatterns(
458     OwningRewritePatternList &patterns, MLIRContext *context) {
459   // Canonicalization patterns have overlap with the considerations during
460   // folding in case additional shape information is inferred at some point that
461   // does not result in folding.
462   patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
463                   CstrBroadcastableScalar>(context);
464 }
465 
466 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
467   // Both operands are not needed if one is a scalar.
468   if (operands[0] &&
469       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
470     return BoolAttr::get(true, getContext());
471   if (operands[1] &&
472       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
473     return BoolAttr::get(true, getContext());
474 
475   if (operands[0] && operands[1]) {
476     auto lhsShape = llvm::to_vector<6>(
477         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
478     auto rhsShape = llvm::to_vector<6>(
479         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
480     SmallVector<int64_t, 6> resultShape;
481     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
482       return BoolAttr::get(true, getContext());
483   }
484 
485   // Lastly, see if folding can be completed based on what constraints are known
486   // on the input shapes.
487   SmallVector<int64_t, 6> lhsShape, rhsShape;
488   if (failed(getShapeVec(lhs(), lhsShape)))
489     return nullptr;
490   if (failed(getShapeVec(rhs(), rhsShape)))
491     return nullptr;
492 
493   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
494     return BoolAttr::get(true, getContext());
495 
496   // Because a failing witness result here represents an eventual assertion
497   // failure, we do not replace it with a constant witness.
498   return nullptr;
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // CstrEqOp
503 //===----------------------------------------------------------------------===//
504 
505 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
506                                            MLIRContext *context) {
507   // If inputs are equal, return passing witness
508   patterns.insert<CstrEqEqOps>(context);
509 }
510 
511 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
512   if (llvm::all_of(operands,
513                    [&](Attribute a) { return a && a == operands[0]; }))
514     return BoolAttr::get(true, getContext());
515 
516   // Because a failing witness result here represents an eventual assertion
517   // failure, we do not try to replace it with a constant witness. Similarly, we
518   // cannot if there are any non-const inputs.
519   return nullptr;
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // ConstSizeOp
524 //===----------------------------------------------------------------------===//
525 
526 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
527                         int64_t value) {
528   build(builder, result, builder.getIndexAttr(value));
529 }
530 
531 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
532 
533 void ConstSizeOp::getAsmResultNames(
534     llvm::function_ref<void(Value, StringRef)> setNameFn) {
535   SmallString<4> buffer;
536   llvm::raw_svector_ostream os(buffer);
537   os << "c" << value();
538   setNameFn(getResult(), os.str());
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // ConstWitnessOp
543 //===----------------------------------------------------------------------===//
544 
545 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
546 
547 //===----------------------------------------------------------------------===//
548 // ShapeEqOp
549 //===----------------------------------------------------------------------===//
550 
551 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
552   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
553   if (lhs == nullptr)
554     return {};
555   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
556   if (rhs == nullptr)
557     return {};
558   return BoolAttr::get(lhs == rhs, getContext());
559 }
560 
561 //===----------------------------------------------------------------------===//
562 // IndexToSizeOp
563 //===----------------------------------------------------------------------===//
564 
565 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
566   // Constant values of both types, `shape.size` and `index`, are represented as
567   // `IntegerAttr`s which makes constant folding simple.
568   if (Attribute arg = operands[0])
569     return arg;
570   return {};
571 }
572 
573 void IndexToSizeOp::getCanonicalizationPatterns(
574     OwningRewritePatternList &patterns, MLIRContext *context) {
575   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // FromExtentsOp
580 //===----------------------------------------------------------------------===//
581 
582 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
583   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
584     return nullptr;
585   SmallVector<int64_t, 6> extents;
586   for (auto attr : operands)
587     extents.push_back(attr.cast<IntegerAttr>().getInt());
588   Builder builder(getContext());
589   return builder.getIndexTensorAttr(extents);
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // GetExtentOp
594 //===----------------------------------------------------------------------===//
595 
596 Optional<int64_t> GetExtentOp::getConstantDim() {
597   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
598     return constSizeOp.value().getLimitedValue();
599   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
600     return constantOp.value().cast<IntegerAttr>().getInt();
601   return llvm::None;
602 }
603 
604 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
605   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
606   if (!elements)
607     return nullptr;
608   Optional<int64_t> dim = getConstantDim();
609   if (!dim.hasValue())
610     return nullptr;
611   if (dim.getValue() >= elements.getNumElements())
612     return nullptr;
613   return elements.getValue({(uint64_t)dim.getValue()});
614 }
615 
616 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
617                         int64_t dim) {
618   auto loc = result.location;
619   auto dimAttr = builder.getIndexAttr(dim);
620   if (shape.getType().isa<ShapeType>()) {
621     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
622     build(builder, result, builder.getType<SizeType>(), shape, dim);
623   } else {
624     Value dim =
625         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
626     build(builder, result, builder.getIndexType(), shape, dim);
627   }
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // RankOp
632 //===----------------------------------------------------------------------===//
633 
634 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
635   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
636   if (!shape)
637     return {};
638   int64_t rank = shape.getNumElements();
639   Builder builder(getContext());
640   return builder.getIndexAttr(rank);
641 }
642 
643 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
644 /// Constant folding fails in cases where only the rank is constant, not the
645 /// shape itself.
646 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
647 ///
648 /// Example:
649 ///
650 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
651 /// %rank = shape.rank %shape
652 ///
653 /// becomes
654 ///
655 /// %rank = shape.const_size 3
656 
657 namespace {
658 struct RankShapeOfCanonicalizationPattern
659     : public OpRewritePattern<shape::RankOp> {
660   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
661 
662   LogicalResult matchAndRewrite(shape::RankOp op,
663                                 PatternRewriter &rewriter) const override {
664     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
665     if (!shapeOfOp)
666       return failure();
667     auto rankedTensorType =
668         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
669     if (!rankedTensorType)
670       return failure();
671     int64_t rank = rankedTensorType.getRank();
672     if (op.getType().isa<IndexType>()) {
673       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
674     } else if (op.getType().isa<shape::SizeType>()) {
675       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
676     } else {
677       return failure();
678     }
679     return success();
680   }
681 };
682 } // namespace
683 
684 void shape::RankOp::getCanonicalizationPatterns(
685     OwningRewritePatternList &patterns, MLIRContext *context) {
686   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // NumElementsOp
691 //===----------------------------------------------------------------------===//
692 
693 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
694 
695   // Fold only when argument constant.
696   Attribute shape = operands[0];
697   if (!shape)
698     return {};
699 
700   APInt product(64, 1);
701   for (auto value : shape.cast<DenseIntElementsAttr>())
702     product *= value;
703   Builder builder(getContext());
704   return builder.getIndexAttr(product.getLimitedValue());
705 }
706 
707 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
708                           Value shape) {
709   if (shape.getType().isa<ShapedType>()) {
710     auto type = builder.getIndexType();
711     return build(builder, result, type, shape);
712   }
713   auto type = SizeType::get(builder.getContext());
714   return build(builder, result, type, shape);
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // MulOp
719 //===----------------------------------------------------------------------===//
720 
721 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
722   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
723   if (!lhs)
724     return nullptr;
725   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
726   if (!rhs)
727     return nullptr;
728   APInt folded = lhs.getValue() * rhs.getValue();
729   Type indexTy = IndexType::get(getContext());
730   return IntegerAttr::get(indexTy, folded);
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // ShapeOfOp
735 //===----------------------------------------------------------------------===//
736 
737 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
738   auto type = getOperand().getType().dyn_cast<ShapedType>();
739   if (!type || !type.hasStaticShape())
740     return nullptr;
741   Builder builder(getContext());
742   return builder.getIndexTensorAttr(type.getShape());
743 }
744 
745 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
746   Type type = arg.getType().isa<ShapedType>()
747                   ? (Type)getExtentTensorType(builder.getContext())
748                   : (Type)builder.getType<ShapeType>();
749   return ShapeOfOp::build(builder, result, type, arg);
750 }
751 
752 namespace {
753 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
754   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
755 
756   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
757                                 PatternRewriter &rewriter) const override {
758     if (!op.arg().getType().isa<ShapedType>())
759       return failure();
760     if (op.getType().isa<ShapedType>())
761       return failure();
762 
763     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
764     return success();
765   }
766 };
767 } // namespace
768 
769 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
770                                             MLIRContext *context) {
771   patterns.insert<ShapeOfWithTensor>(context);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // SizeToIndexOp
776 //===----------------------------------------------------------------------===//
777 
778 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
779   // Constant values of both types, `shape.size` and `index`, are represented as
780   // `IntegerAttr`s which makes constant folding simple.
781   if (Attribute arg = operands[0])
782     return arg;
783   return impl::foldCastOp(*this);
784 }
785 
786 void SizeToIndexOp::getCanonicalizationPatterns(
787     OwningRewritePatternList &patterns, MLIRContext *context) {
788   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
789 }
790 
791 //===----------------------------------------------------------------------===//
792 // YieldOp
793 //===----------------------------------------------------------------------===//
794 
795 static LogicalResult verify(YieldOp op) {
796   auto *parentOp = op.getParentOp();
797   auto results = parentOp->getResults();
798   auto operands = op.getOperands();
799 
800   if (parentOp->getNumResults() != op.getNumOperands())
801     return op.emitOpError() << "number of operands does not match number of "
802                                "results of its parent";
803   for (auto e : llvm::zip(results, operands))
804     if (std::get<0>(e).getType() != std::get<1>(e).getType())
805       return op.emitOpError()
806              << "types mismatch between yield op and its parent";
807 
808   return success();
809 }
810 
811 //===----------------------------------------------------------------------===//
812 // SplitAtOp
813 //===----------------------------------------------------------------------===//
814 
815 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
816                               SmallVectorImpl<OpFoldResult> &results) {
817   if (!operands[0] || !operands[1])
818     return failure();
819   auto shapeVec = llvm::to_vector<6>(
820       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
821   auto shape = llvm::makeArrayRef(shapeVec);
822   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
823   // Verify that the split point is in the correct range.
824   // TODO: Constant fold to an "error".
825   int64_t rank = shape.size();
826   if (!(-rank <= splitPoint && splitPoint <= rank))
827     return failure();
828   if (splitPoint < 0)
829     splitPoint += shape.size();
830   Builder builder(operands[0].getContext());
831   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
832   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
833   return success();
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // ToExtentTensorOp
838 //===----------------------------------------------------------------------===//
839 
840 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
841   if (!operands[0])
842     return impl::foldCastOp(*this);
843   Builder builder(getContext());
844   auto shape = llvm::to_vector<6>(
845       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
846   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
847                                     builder.getIndexType());
848   return DenseIntElementsAttr::get(type, shape);
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // ReduceOp
853 //===----------------------------------------------------------------------===//
854 
855 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
856                      ValueRange initVals) {
857   result.addOperands(shape);
858   result.addOperands(initVals);
859 
860   Region *bodyRegion = result.addRegion();
861   bodyRegion->push_back(new Block);
862   Block &bodyBlock = bodyRegion->front();
863   bodyBlock.addArgument(builder.getIndexType());
864 
865   Type elementType;
866   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
867     elementType = tensorType.getElementType();
868   else
869     elementType = SizeType::get(builder.getContext());
870   bodyBlock.addArgument(elementType);
871 
872   for (Type initValType : initVals.getTypes()) {
873     bodyBlock.addArgument(initValType);
874     result.addTypes(initValType);
875   }
876 }
877 
878 static LogicalResult verify(ReduceOp op) {
879   // Verify block arg types.
880   Block &block = op.region().front();
881 
882   // The block takes index, extent, and aggregated values as arguments.
883   auto blockArgsCount = op.initVals().size() + 2;
884   if (block.getNumArguments() != blockArgsCount)
885     return op.emitOpError() << "ReduceOp body is expected to have "
886                             << blockArgsCount << " arguments";
887 
888   // The first block argument is the index and must always be of type `index`.
889   if (!block.getArgument(0).getType().isa<IndexType>())
890     return op.emitOpError(
891         "argument 0 of ReduceOp body is expected to be of IndexType");
892 
893   // The second block argument is the extent and must be of type `size` or
894   // `index`, depending on whether the reduce operation is applied to a shape or
895   // to an extent tensor.
896   Type extentTy = block.getArgument(1).getType();
897   if (op.shape().getType().isa<ShapeType>()) {
898     if (!extentTy.isa<SizeType>())
899       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
900                             "SizeType if the ReduceOp operates on a ShapeType");
901   } else {
902     if (!extentTy.isa<IndexType>())
903       return op.emitOpError(
904           "argument 1 of ReduceOp body is expected to be of IndexType if the "
905           "ReduceOp operates on an extent tensor");
906   }
907 
908   for (auto type : llvm::enumerate(op.initVals()))
909     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
910       return op.emitOpError()
911              << "type mismatch between argument " << type.index() + 2
912              << " of ReduceOp body and initial value " << type.index();
913   return success();
914 }
915 
916 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
917   // Parse operands.
918   SmallVector<OpAsmParser::OperandType, 3> operands;
919   Type shapeOrExtentTensorType;
920   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
921                               OpAsmParser::Delimiter::Paren) ||
922       parser.parseColonType(shapeOrExtentTensorType) ||
923       parser.parseOptionalArrowTypeList(result.types))
924     return failure();
925 
926   // Resolve operands.
927   auto initVals = llvm::makeArrayRef(operands).drop_front();
928   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
929                             result.operands) ||
930       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
931                              result.operands))
932     return failure();
933 
934   // Parse the body.
935   Region *body = result.addRegion();
936   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
937     return failure();
938 
939   // Parse attributes.
940   if (parser.parseOptionalAttrDict(result.attributes))
941     return failure();
942 
943   return success();
944 }
945 
946 static void print(OpAsmPrinter &p, ReduceOp op) {
947   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
948     << ") : " << op.shape().getType();
949   p.printOptionalArrowTypeList(op.getResultTypes());
950   p.printRegion(op.region());
951   p.printOptionalAttrDict(op.getAttrs());
952 }
953 
954 namespace mlir {
955 namespace shape {
956 
957 #define GET_OP_CLASSES
958 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
959 
960 } // namespace shape
961 } // namespace mlir
962