1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 
25 namespace {
26 #include "ShapeCanonicalization.inc"
27 }
28 
29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
30   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
31 }
32 
33 static bool isErrorPropagationPossible(TypeRange operandTypes) {
34   for (Type ty : operandTypes)
35     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
36       return true;
37   return false;
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
96            WitnessType>();
97   addInterfaces<ShapeInlinerInterface>();
98   // Allow unknown operations during prototyping and testing. As the dialect is
99   // still evolving it makes it simple to start with an unregistered ops and
100   // try different variants before actually defining the op.
101   allowUnknownOperations();
102 }
103 
104 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
105                                              Attribute value, Type type,
106                                              Location loc) {
107   if (type.isa<ShapeType>() ||
108       type == getExtentTensorType(builder.getContext()))
109     return builder.create<ConstShapeOp>(loc, type,
110                                         value.cast<DenseIntElementsAttr>());
111   if (type.isa<SizeType>())
112     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
113   if (type.isa<WitnessType>())
114     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
115   if (type.isa<IndexType>())
116     return builder.create<ConstantOp>(loc, type, value);
117   return nullptr;
118 }
119 
120 /// Parse a type registered to this dialect.
121 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
122   StringRef keyword;
123   if (parser.parseKeyword(&keyword))
124     return Type();
125 
126   if (keyword == "component")
127     return ComponentType::get(getContext());
128   if (keyword == "element")
129     return ElementType::get(getContext());
130   if (keyword == "shape")
131     return ShapeType::get(getContext());
132   if (keyword == "size")
133     return SizeType::get(getContext());
134   if (keyword == "value_shape")
135     return ValueShapeType::get(getContext());
136   if (keyword == "witness")
137     return WitnessType::get(getContext());
138 
139   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
140   return Type();
141 }
142 
143 /// Print a type registered to this dialect.
144 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
145   TypeSwitch<Type>(type)
146       .Case<ComponentType>([&](Type) { os << "component"; })
147       .Case<ElementType>([&](Type) { os << "element"; })
148       .Case<ShapeType>([&](Type) { os << "shape"; })
149       .Case<SizeType>([&](Type) { os << "size"; })
150       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
151       .Case<WitnessType>([&](Type) { os << "witness"; })
152       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // AnyOp
157 //===----------------------------------------------------------------------===//
158 
159 // TODO: Canonicalization should be implemented for shapes that can be
160 // determined through mixtures of the known dimensions of the inputs.
161 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
162   // Only the last operand is checked because AnyOp is commutative.
163   if (operands.back())
164     return operands.back();
165 
166   return nullptr;
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // AssumingOp
171 //===----------------------------------------------------------------------===//
172 
173 static ParseResult parseAssumingOp(OpAsmParser &parser,
174                                    OperationState &result) {
175   result.regions.reserve(1);
176   Region *doRegion = result.addRegion();
177 
178   auto &builder = parser.getBuilder();
179   OpAsmParser::OperandType cond;
180   if (parser.parseOperand(cond) ||
181       parser.resolveOperand(cond, builder.getType<WitnessType>(),
182                             result.operands))
183     return failure();
184 
185   // Parse optional results type list.
186   if (parser.parseOptionalArrowTypeList(result.types))
187     return failure();
188 
189   // Parse the region and add a terminator if elided.
190   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
191     return failure();
192   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
193 
194   // Parse the optional attribute list.
195   if (parser.parseOptionalAttrDict(result.attributes))
196     return failure();
197   return success();
198 }
199 
200 static void print(OpAsmPrinter &p, AssumingOp op) {
201   bool yieldsResults = !op.results().empty();
202 
203   p << AssumingOp::getOperationName() << " " << op.witness();
204   if (yieldsResults) {
205     p << " -> (" << op.getResultTypes() << ")";
206   }
207   p.printRegion(op.doRegion(),
208                 /*printEntryBlockArgs=*/false,
209                 /*printBlockTerminators=*/yieldsResults);
210   p.printOptionalAttrDict(op.getAttrs());
211 }
212 
213 namespace {
214 // Removes AssumingOp with a passing witness and inlines the region.
215 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
216   using OpRewritePattern<AssumingOp>::OpRewritePattern;
217 
218   LogicalResult matchAndRewrite(AssumingOp op,
219                                 PatternRewriter &rewriter) const override {
220     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
221     if (!witness || !witness.passingAttr())
222       return failure();
223 
224     AssumingOp::inlineRegionIntoParent(op, rewriter);
225     return success();
226   }
227 };
228 } // namespace
229 
230 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
231                                              MLIRContext *context) {
232   // If taking a passing witness, inline region.
233   patterns.insert<AssumingWithTrue>(context);
234 }
235 
236 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
237                                         PatternRewriter &rewriter) {
238   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
239   auto *assumingBlock = op.getBody();
240   auto initPosition = rewriter.getInsertionPoint();
241   auto *blockAfterAssuming =
242       rewriter.splitBlock(blockBeforeAssuming, initPosition);
243 
244   // Remove the AssumingOp and AssumingYieldOp.
245   auto &yieldOp = assumingBlock->back();
246   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
247   rewriter.replaceOp(op, yieldOp.getOperands());
248   rewriter.eraseOp(&yieldOp);
249 
250   // Merge blocks together as there was no branching behavior from the
251   // AssumingOp.
252   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
253   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // AssumingAllOp
258 //===----------------------------------------------------------------------===//
259 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
260   // Iterate in reverse to first handle all constant operands. They are
261   // guaranteed to be the tail of the inputs because this is commutative.
262   for (int idx = operands.size() - 1; idx >= 0; idx--) {
263     Attribute a = operands[idx];
264     // Cannot fold if any inputs are not constant;
265     if (!a)
266       return nullptr;
267 
268     // We do not need to keep statically known values after handling them in
269     // this method.
270     getOperation()->eraseOperand(idx);
271 
272     // Always false if any input is statically known false
273     if (!a.cast<BoolAttr>().getValue())
274       return a;
275   }
276   // If this is reached, all inputs were statically known passing.
277   return BoolAttr::get(true, getContext());
278 }
279 
280 static LogicalResult verify(AssumingAllOp op) {
281   // Ensure that AssumingAllOp contains at least one operand
282   if (op.getNumOperands() == 0)
283     return op.emitOpError("no operands specified");
284 
285   return success();
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // BroadcastOp
290 //===----------------------------------------------------------------------===//
291 
292 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
293   if (!operands[1])
294     return nullptr;
295 
296   auto rhsShape = llvm::to_vector<6>(
297       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
298   if (rhsShape.empty())
299     return lhs();
300 
301   if (!operands[0])
302     return nullptr;
303 
304   auto lhsShape = llvm::to_vector<6>(
305       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
306   if (lhsShape.empty())
307     return rhs();
308 
309   SmallVector<int64_t, 6> resultShape;
310   // If the shapes are not compatible, we can't fold it.
311   // TODO: Fold to an "error".
312   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
313     return nullptr;
314   Builder builder(getContext());
315   return builder.getIndexTensorAttr(resultShape);
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // ConcatOp
320 //===----------------------------------------------------------------------===//
321 
322 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
323   if (!operands[0] || !operands[1])
324     return nullptr;
325   auto lhsShape = llvm::to_vector<6>(
326       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
327   auto rhsShape = llvm::to_vector<6>(
328       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
329   SmallVector<int64_t, 6> resultShape;
330   resultShape.append(lhsShape.begin(), lhsShape.end());
331   resultShape.append(rhsShape.begin(), rhsShape.end());
332   Builder builder(getContext());
333   return builder.getIndexTensorAttr(resultShape);
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ConstShapeOp
338 //===----------------------------------------------------------------------===//
339 
340 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
341   p << "shape.const_shape ";
342   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
343   p << "[";
344   interleaveComma(op.shape().getValues<int64_t>(), p,
345                   [&](int64_t i) { p << i; });
346   p << "] : ";
347   p.printType(op.getType());
348 }
349 
350 static ParseResult parseConstShapeOp(OpAsmParser &parser,
351                                      OperationState &result) {
352   if (parser.parseOptionalAttrDict(result.attributes))
353     return failure();
354   // We piggy-back on ArrayAttr parsing, though we don't internally store the
355   // shape as an ArrayAttr.
356   // TODO: Implement custom parser and maybe make syntax a bit more concise.
357   Attribute extentsRaw;
358   NamedAttrList dummy;
359   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
360     return failure();
361   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
362   if (!extentsArray)
363     return failure();
364   SmallVector<int64_t, 6> ints;
365   for (Attribute extent : extentsArray) {
366     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
367     if (!attr)
368       return failure();
369     ints.push_back(attr.getInt());
370   }
371   Builder &builder = parser.getBuilder();
372   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
373   Type resultTy;
374   if (parser.parseColonType(resultTy))
375     return failure();
376   result.types.push_back(resultTy);
377   return success();
378 }
379 
380 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
381 
382 //===----------------------------------------------------------------------===//
383 // CstrBroadcastableOp
384 //===----------------------------------------------------------------------===//
385 
386 namespace {
387 // Given an input shape Value, try to obtain the shape's values.
388 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
389   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
390     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
391     if (!type.hasRank())
392       return failure();
393     shapeValues = llvm::to_vector<6>(type.getShape());
394     return success();
395   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
396     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
397     return success();
398   } else {
399     return failure();
400   }
401 }
402 
403 // For shapes that were created by some operations, we can obtain partial
404 // information on the shapes and sometimes determine if they will be
405 // broadcastable with that.
406 struct CstrBroadcastablePartialInfo
407     : public OpRewritePattern<CstrBroadcastableOp> {
408   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
409 
410   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
411                                 PatternRewriter &rewriter) const override {
412     SmallVector<int64_t, 6> lhsShape, rhsShape;
413     if (failed(getShapeVec(op.lhs(), lhsShape)))
414       return failure();
415     if (failed(getShapeVec(op.rhs(), rhsShape)))
416       return failure();
417     if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
418       return failure();
419 
420     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
421     return success();
422   }
423 };
424 
425 // Scalars are always broadcastable.
426 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
427   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
428 
429   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
430                                 PatternRewriter &rewriter) const override {
431     SmallVector<int64_t, 6> shape;
432     if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
433       return failure();
434     if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
435       return failure();
436 
437     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
438     return success();
439   }
440 };
441 
442 } // namespace
443 
444 void CstrBroadcastableOp::getCanonicalizationPatterns(
445     OwningRewritePatternList &patterns, MLIRContext *context) {
446   // Canonicalization patterns have overlap with the considerations during
447   // folding in case additional shape information is inferred at some point that
448   // does not result in folding.
449   patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
450                   CstrBroadcastableScalar>(context);
451 }
452 
453 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
454   // Both operands are not needed if one is a scalar.
455   if (operands[0] &&
456       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
457     return BoolAttr::get(true, getContext());
458   if (operands[1] &&
459       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
460     return BoolAttr::get(true, getContext());
461 
462   if (operands[0] && operands[1]) {
463     auto lhsShape = llvm::to_vector<6>(
464         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
465     auto rhsShape = llvm::to_vector<6>(
466         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
467     SmallVector<int64_t, 6> resultShape;
468     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
469       return BoolAttr::get(true, getContext());
470   }
471 
472   // Lastly, see if folding can be completed based on what constraints are known
473   // on the input shapes.
474   SmallVector<int64_t, 6> lhsShape, rhsShape;
475   if (failed(getShapeVec(lhs(), lhsShape)))
476     return nullptr;
477   if (failed(getShapeVec(rhs(), rhsShape)))
478     return nullptr;
479 
480   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
481     return BoolAttr::get(true, getContext());
482 
483   // Because a failing witness result here represents an eventual assertion
484   // failure, we do not replace it with a constant witness.
485   return nullptr;
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // CstrEqOp
490 //===----------------------------------------------------------------------===//
491 
492 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
493                                            MLIRContext *context) {
494   // If inputs are equal, return passing witness
495   patterns.insert<CstrEqEqOps>(context);
496 }
497 
498 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
499   if (llvm::all_of(operands,
500                    [&](Attribute a) { return a && a == operands[0]; }))
501     return BoolAttr::get(true, getContext());
502 
503   // Because a failing witness result here represents an eventual assertion
504   // failure, we do not try to replace it with a constant witness. Similarly, we
505   // cannot if there are any non-const inputs.
506   return nullptr;
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // ConstSizeOp
511 //===----------------------------------------------------------------------===//
512 
513 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
514                         int64_t value) {
515   build(builder, result, builder.getIndexAttr(value));
516 }
517 
518 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
519 
520 void ConstSizeOp::getAsmResultNames(
521     llvm::function_ref<void(Value, StringRef)> setNameFn) {
522   SmallString<4> buffer;
523   llvm::raw_svector_ostream os(buffer);
524   os << "c" << value();
525   setNameFn(getResult(), os.str());
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // ConstWitnessOp
530 //===----------------------------------------------------------------------===//
531 
532 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
533 
534 //===----------------------------------------------------------------------===//
535 // ShapeEqOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
539   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
540   if (lhs == nullptr)
541     return {};
542   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
543   if (rhs == nullptr)
544     return {};
545   return BoolAttr::get(lhs == rhs, getContext());
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // IndexToSizeOp
550 //===----------------------------------------------------------------------===//
551 
552 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
553   // Constant values of both types, `shape.size` and `index`, are represented as
554   // `IntegerAttr`s which makes constant folding simple.
555   if (Attribute arg = operands[0])
556     return arg;
557   return {};
558 }
559 
560 void IndexToSizeOp::getCanonicalizationPatterns(
561     OwningRewritePatternList &patterns, MLIRContext *context) {
562   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // FromExtentsOp
567 //===----------------------------------------------------------------------===//
568 
569 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
570   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
571     return nullptr;
572   SmallVector<int64_t, 6> extents;
573   for (auto attr : operands)
574     extents.push_back(attr.cast<IntegerAttr>().getInt());
575   Builder builder(getContext());
576   return builder.getIndexTensorAttr(extents);
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // GetExtentOp
581 //===----------------------------------------------------------------------===//
582 
583 Optional<int64_t> GetExtentOp::getConstantDim() {
584   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
585     return constSizeOp.value().getLimitedValue();
586   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
587     return constantOp.value().cast<IntegerAttr>().getInt();
588   return llvm::None;
589 }
590 
591 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
592   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
593   if (!elements)
594     return nullptr;
595   Optional<int64_t> dim = getConstantDim();
596   if (!dim.hasValue())
597     return nullptr;
598   if (dim.getValue() >= elements.getNumElements())
599     return nullptr;
600   return elements.getValue({(uint64_t)dim.getValue()});
601 }
602 
603 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
604                         int64_t dim) {
605   auto loc = result.location;
606   auto dimAttr = builder.getIndexAttr(dim);
607   if (shape.getType().isa<ShapeType>()) {
608     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
609     build(builder, result, builder.getType<SizeType>(), shape, dim);
610   } else {
611     Value dim =
612         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
613     build(builder, result, builder.getIndexType(), shape, dim);
614   }
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // RankOp
619 //===----------------------------------------------------------------------===//
620 
621 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
622   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
623   if (!shape)
624     return {};
625   int64_t rank = shape.getNumElements();
626   Builder builder(getContext());
627   return builder.getIndexAttr(rank);
628 }
629 
630 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
631 /// Constant folding fails in cases where only the rank is constant, not the
632 /// shape itself.
633 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
634 ///
635 /// Example:
636 ///
637 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
638 /// %rank = shape.rank %shape
639 ///
640 /// becomes
641 ///
642 /// %rank = shape.const_size 3
643 
644 namespace {
645 struct RankShapeOfCanonicalizationPattern
646     : public OpRewritePattern<shape::RankOp> {
647   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
648 
649   LogicalResult matchAndRewrite(shape::RankOp op,
650                                 PatternRewriter &rewriter) const override {
651     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
652     if (!shapeOfOp)
653       return failure();
654     auto rankedTensorType =
655         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
656     if (!rankedTensorType)
657       return failure();
658     int64_t rank = rankedTensorType.getRank();
659     if (op.getType().isa<IndexType>()) {
660       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
661     } else if (op.getType().isa<shape::SizeType>()) {
662       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
663     } else {
664       return failure();
665     }
666     return success();
667   }
668 };
669 } // namespace
670 
671 void shape::RankOp::getCanonicalizationPatterns(
672     OwningRewritePatternList &patterns, MLIRContext *context) {
673   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // NumElementsOp
678 //===----------------------------------------------------------------------===//
679 
680 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
681 
682   // Fold only when argument constant.
683   Attribute shape = operands[0];
684   if (!shape)
685     return {};
686 
687   APInt product(64, 1);
688   for (auto value : shape.cast<DenseIntElementsAttr>())
689     product *= value;
690   Builder builder(getContext());
691   return builder.getIndexAttr(product.getLimitedValue());
692 }
693 
694 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
695                           Value shape) {
696   if (shape.getType().isa<ShapedType>()) {
697     auto type = builder.getIndexType();
698     return build(builder, result, type, shape);
699   }
700   auto type = SizeType::get(builder.getContext());
701   return build(builder, result, type, shape);
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // MulOp
706 //===----------------------------------------------------------------------===//
707 
708 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
709   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
710   if (!lhs)
711     return nullptr;
712   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
713   if (!rhs)
714     return nullptr;
715   APInt folded = lhs.getValue() * rhs.getValue();
716   Type indexTy = IndexType::get(getContext());
717   return IntegerAttr::get(indexTy, folded);
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // ShapeOfOp
722 //===----------------------------------------------------------------------===//
723 
724 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
725   auto type = getOperand().getType().dyn_cast<ShapedType>();
726   if (!type || !type.hasStaticShape())
727     return nullptr;
728   Builder builder(getContext());
729   return builder.getIndexTensorAttr(type.getShape());
730 }
731 
732 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
733   Type type = arg.getType().isa<ShapedType>()
734                   ? (Type)getExtentTensorType(builder.getContext())
735                   : (Type)builder.getType<ShapeType>();
736   return ShapeOfOp::build(builder, result, type, arg);
737 }
738 
739 namespace {
740 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
741   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
742 
743   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
744                                 PatternRewriter &rewriter) const override {
745     if (!op.arg().getType().isa<ShapedType>())
746       return failure();
747     if (op.getType().isa<ShapedType>())
748       return failure();
749 
750     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
751     return success();
752   }
753 };
754 } // namespace
755 
756 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
757                                             MLIRContext *context) {
758   patterns.insert<ShapeOfWithTensor>(context);
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // SizeToIndexOp
763 //===----------------------------------------------------------------------===//
764 
765 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
766   // Constant values of both types, `shape.size` and `index`, are represented as
767   // `IntegerAttr`s which makes constant folding simple.
768   if (Attribute arg = operands[0])
769     return arg;
770   return impl::foldCastOp(*this);
771 }
772 
773 void SizeToIndexOp::getCanonicalizationPatterns(
774     OwningRewritePatternList &patterns, MLIRContext *context) {
775   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
776 }
777 
778 //===----------------------------------------------------------------------===//
779 // YieldOp
780 //===----------------------------------------------------------------------===//
781 
782 static LogicalResult verify(shape::YieldOp op) {
783   auto *parentOp = op.getParentOp();
784   auto results = parentOp->getResults();
785   auto operands = op.getOperands();
786 
787   if (parentOp->getNumResults() != op.getNumOperands())
788     return op.emitOpError() << "number of operands does not match number of "
789                                "results of its parent";
790   for (auto e : llvm::zip(results, operands))
791     if (std::get<0>(e).getType() != std::get<1>(e).getType())
792       return op.emitOpError()
793              << "types mismatch between yield op and its parent";
794 
795   return success();
796 }
797 
798 //===----------------------------------------------------------------------===//
799 // SplitAtOp
800 //===----------------------------------------------------------------------===//
801 
802 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
803                               SmallVectorImpl<OpFoldResult> &results) {
804   if (!operands[0] || !operands[1])
805     return failure();
806   auto shapeVec = llvm::to_vector<6>(
807       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
808   auto shape = llvm::makeArrayRef(shapeVec);
809   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
810   // Verify that the split point is in the correct range.
811   // TODO: Constant fold to an "error".
812   int64_t rank = shape.size();
813   if (!(-rank <= splitPoint && splitPoint <= rank))
814     return failure();
815   if (splitPoint < 0)
816     splitPoint += shape.size();
817   Builder builder(operands[0].getContext());
818   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
819   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
820   return success();
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // ToExtentTensorOp
825 //===----------------------------------------------------------------------===//
826 
827 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
828   if (!operands[0])
829     return impl::foldCastOp(*this);
830   Builder builder(getContext());
831   auto shape = llvm::to_vector<6>(
832       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
833   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
834                                     builder.getIndexType());
835   return DenseIntElementsAttr::get(type, shape);
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // ReduceOp
840 //===----------------------------------------------------------------------===//
841 
842 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
843                      ValueRange initVals) {
844   result.addOperands(shape);
845   result.addOperands(initVals);
846 
847   Region *bodyRegion = result.addRegion();
848   bodyRegion->push_back(new Block);
849   Block &bodyBlock = bodyRegion->front();
850   bodyBlock.addArgument(builder.getIndexType());
851 
852   Type elementType;
853   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
854     elementType = tensorType.getElementType();
855   else
856     elementType = SizeType::get(builder.getContext());
857   bodyBlock.addArgument(elementType);
858 
859   for (Type initValType : initVals.getTypes()) {
860     bodyBlock.addArgument(initValType);
861     result.addTypes(initValType);
862   }
863 }
864 
865 static LogicalResult verify(ReduceOp op) {
866   // Verify block arg types.
867   Block &block = op.region().front();
868 
869   // The block takes index, extent, and aggregated values as arguments.
870   auto blockArgsCount = op.initVals().size() + 2;
871   if (block.getNumArguments() != blockArgsCount)
872     return op.emitOpError() << "ReduceOp body is expected to have "
873                             << blockArgsCount << " arguments";
874 
875   // The first block argument is the index and must always be of type `index`.
876   if (!block.getArgument(0).getType().isa<IndexType>())
877     return op.emitOpError(
878         "argument 0 of ReduceOp body is expected to be of IndexType");
879 
880   // The second block argument is the extent and must be of type `size` or
881   // `index`, depending on whether the reduce operation is applied to a shape or
882   // to an extent tensor.
883   Type extentTy = block.getArgument(1).getType();
884   if (op.shape().getType().isa<ShapeType>()) {
885     if (!extentTy.isa<SizeType>())
886       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
887                             "SizeType if the ReduceOp operates on a ShapeType");
888   } else {
889     if (!extentTy.isa<IndexType>())
890       return op.emitOpError(
891           "argument 1 of ReduceOp body is expected to be of IndexType if the "
892           "ReduceOp operates on an extent tensor");
893   }
894 
895   for (auto type : llvm::enumerate(op.initVals()))
896     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
897       return op.emitOpError()
898              << "type mismatch between argument " << type.index() + 2
899              << " of ReduceOp body and initial value " << type.index();
900   return success();
901 }
902 
903 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
904   // Parse operands.
905   SmallVector<OpAsmParser::OperandType, 3> operands;
906   Type shapeOrExtentTensorType;
907   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
908                               OpAsmParser::Delimiter::Paren) ||
909       parser.parseColonType(shapeOrExtentTensorType) ||
910       parser.parseOptionalArrowTypeList(result.types))
911     return failure();
912 
913   // Resolve operands.
914   auto initVals = llvm::makeArrayRef(operands).drop_front();
915   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
916                             result.operands) ||
917       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
918                              result.operands))
919     return failure();
920 
921   // Parse the body.
922   Region *body = result.addRegion();
923   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
924     return failure();
925 
926   // Parse attributes.
927   if (parser.parseOptionalAttrDict(result.attributes))
928     return failure();
929 
930   return success();
931 }
932 
933 static void print(OpAsmPrinter &p, ReduceOp op) {
934   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
935     << ") : " << op.shape().getType();
936   p.printOptionalArrowTypeList(op.getResultTypes());
937   p.printRegion(op.region());
938   p.printOptionalAttrDict(op.getAttrs());
939 }
940 
941 namespace mlir {
942 namespace shape {
943 
944 #define GET_OP_CLASSES
945 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
946 
947 } // namespace shape
948 } // namespace mlir
949