1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 using namespace mlir::shape;
22 
23 namespace {
24 #include "ShapeCanonicalization.inc"
25 }
26 
27 static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
28   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
29 }
30 
31 static bool isErrorPropagationPossible(TypeRange operandTypes) {
32   for (Type ty : operandTypes)
33     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
34       return true;
35   return false;
36 }
37 
38 static LogicalResult verifySizeOrIndexOp(Operation *op) {
39   assert(op != nullptr && op->getNumResults() == 1);
40   Type resultTy = op->getResultTypes().front();
41   if (isErrorPropagationPossible(op->getOperandTypes())) {
42     if (!resultTy.isa<SizeType>())
43       return op->emitOpError()
44              << "if at least one of the operands can hold error values then "
45                 "the result must be of type `size` to propagate them";
46   }
47   return success();
48 }
49 
50 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
51   assert(op != nullptr && op->getNumResults() == 1);
52   Type resultTy = op->getResultTypes().front();
53   if (isErrorPropagationPossible(op->getOperandTypes())) {
54     if (!resultTy.isa<ShapeType>())
55       return op->emitOpError()
56              << "if at least one of the operands can hold error values then "
57                 "the result must be of type `shape` to propagate them";
58   }
59   return success();
60 }
61 
62 ShapeDialect::ShapeDialect(MLIRContext *context)
63     : Dialect(getDialectNamespace(), context) {
64   addOperations<
65 #define GET_OP_LIST
66 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
67       >();
68   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
69            WitnessType>();
70   // Allow unknown operations during prototyping and testing. As the dialect is
71   // still evolving it makes it simple to start with an unregistered ops and
72   // try different variants before actually defining the op.
73   allowUnknownOperations();
74 }
75 
76 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
77                                              Attribute value, Type type,
78                                              Location loc) {
79   if (type.isa<ShapeType>() ||
80       type == getExtentTensorType(builder.getContext()))
81     return builder.create<ConstShapeOp>(loc, type,
82                                         value.cast<DenseIntElementsAttr>());
83   if (type.isa<SizeType>())
84     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
85   if (type.isa<WitnessType>())
86     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
87   if (type.isa<IndexType>())
88     return builder.create<ConstantOp>(loc, type, value);
89   return nullptr;
90 }
91 
92 /// Parse a type registered to this dialect.
93 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
94   StringRef keyword;
95   if (parser.parseKeyword(&keyword))
96     return Type();
97 
98   if (keyword == "component")
99     return ComponentType::get(getContext());
100   if (keyword == "element")
101     return ElementType::get(getContext());
102   if (keyword == "shape")
103     return ShapeType::get(getContext());
104   if (keyword == "size")
105     return SizeType::get(getContext());
106   if (keyword == "value_shape")
107     return ValueShapeType::get(getContext());
108   if (keyword == "witness")
109     return WitnessType::get(getContext());
110 
111   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
112   return Type();
113 }
114 
115 /// Print a type registered to this dialect.
116 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
117   switch (type.getKind()) {
118   case ShapeTypes::Component:
119     os << "component";
120     return;
121   case ShapeTypes::Element:
122     os << "element";
123     return;
124   case ShapeTypes::Size:
125     os << "size";
126     return;
127   case ShapeTypes::Shape:
128     os << "shape";
129     return;
130   case ShapeTypes::ValueShape:
131     os << "value_shape";
132     return;
133   case ShapeTypes::Witness:
134     os << "witness";
135     return;
136   default:
137     llvm_unreachable("unexpected 'shape' type kind");
138   }
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // AnyOp
143 //===----------------------------------------------------------------------===//
144 
145 // TODO: Canonicalization should be implemented for shapes that can be
146 // determined through mixtures of the known dimensions of the inputs.
147 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
148   // Only the last operand is checked because AnyOp is commutative.
149   if (operands.back())
150     return operands.back();
151 
152   return nullptr;
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // AssumingOp
157 //===----------------------------------------------------------------------===//
158 
159 static ParseResult parseAssumingOp(OpAsmParser &parser,
160                                    OperationState &result) {
161   result.regions.reserve(1);
162   Region *doRegion = result.addRegion();
163 
164   auto &builder = parser.getBuilder();
165   OpAsmParser::OperandType cond;
166   if (parser.parseOperand(cond) ||
167       parser.resolveOperand(cond, builder.getType<WitnessType>(),
168                             result.operands))
169     return failure();
170 
171   // Parse optional results type list.
172   if (parser.parseOptionalArrowTypeList(result.types))
173     return failure();
174 
175   // Parse the region and add a terminator if elided.
176   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
177     return failure();
178   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
179 
180   // Parse the optional attribute list.
181   if (parser.parseOptionalAttrDict(result.attributes))
182     return failure();
183   return success();
184 }
185 
186 static void print(OpAsmPrinter &p, AssumingOp op) {
187   bool yieldsResults = !op.results().empty();
188 
189   p << AssumingOp::getOperationName() << " " << op.witness();
190   if (yieldsResults) {
191     p << " -> (" << op.getResultTypes() << ")";
192   }
193   p.printRegion(op.doRegion(),
194                 /*printEntryBlockArgs=*/false,
195                 /*printBlockTerminators=*/yieldsResults);
196   p.printOptionalAttrDict(op.getAttrs());
197 }
198 
199 namespace {
200 // Removes AssumingOp with a passing witness and inlines the region.
201 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
202   using OpRewritePattern<AssumingOp>::OpRewritePattern;
203 
204   LogicalResult matchAndRewrite(AssumingOp op,
205                                 PatternRewriter &rewriter) const override {
206     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
207     if (!witness || !witness.passingAttr())
208       return failure();
209 
210     AssumingOp::inlineRegionIntoParent(op, rewriter);
211     return success();
212   }
213 };
214 } // namespace
215 
216 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
217                                              MLIRContext *context) {
218   // If taking a passing witness, inline region.
219   patterns.insert<AssumingWithTrue>(context);
220 }
221 
222 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
223                                         PatternRewriter &rewriter) {
224   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
225   auto *assumingBlock = op.getBody();
226   auto initPosition = rewriter.getInsertionPoint();
227   auto *blockAfterAssuming =
228       rewriter.splitBlock(blockBeforeAssuming, initPosition);
229 
230   // Remove the AssumingOp and AssumingYieldOp.
231   auto &yieldOp = assumingBlock->back();
232   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
233   rewriter.replaceOp(op, yieldOp.getOperands());
234   rewriter.eraseOp(&yieldOp);
235 
236   // Merge blocks together as there was no branching behavior from the
237   // AssumingOp.
238   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
239   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // AssumingAllOp
244 //===----------------------------------------------------------------------===//
245 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
246   // Iterate in reverse to first handle all constant operands. They are
247   // guaranteed to be the tail of the inputs because this is commutative.
248   for (int idx = operands.size() - 1; idx >= 0; idx--) {
249     Attribute a = operands[idx];
250     // Cannot fold if any inputs are not constant;
251     if (!a)
252       return nullptr;
253 
254     // We do not need to keep statically known values after handling them in
255     // this method.
256     getOperation()->eraseOperand(idx);
257 
258     // Always false if any input is statically known false
259     if (!a.cast<BoolAttr>().getValue())
260       return a;
261   }
262   // If this is reached, all inputs were statically known passing.
263   return BoolAttr::get(true, getContext());
264 }
265 
266 static LogicalResult verify(AssumingAllOp op) {
267   // Ensure that AssumingAllOp contains at least one operand
268   if (op.getNumOperands() == 0)
269     return op.emitOpError("no operands specified");
270 
271   return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // BroadcastOp
276 //===----------------------------------------------------------------------===//
277 
278 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
279   if (!operands[1])
280     return nullptr;
281 
282   auto rhsShape = llvm::to_vector<6>(
283       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
284   if (rhsShape.empty())
285     return lhs();
286 
287   if (!operands[0])
288     return nullptr;
289 
290   auto lhsShape = llvm::to_vector<6>(
291       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
292   if (lhsShape.empty())
293     return rhs();
294 
295   SmallVector<int64_t, 6> resultShape;
296   // If the shapes are not compatible, we can't fold it.
297   // TODO: Fold to an "error".
298   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
299     return nullptr;
300   Builder builder(getContext());
301   return builder.getIndexTensorAttr(resultShape);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ConcatOp
306 //===----------------------------------------------------------------------===//
307 
308 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
309   if (!operands[0] || !operands[1])
310     return nullptr;
311   auto lhsShape = llvm::to_vector<6>(
312       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
313   auto rhsShape = llvm::to_vector<6>(
314       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
315   SmallVector<int64_t, 6> resultShape;
316   resultShape.append(lhsShape.begin(), lhsShape.end());
317   resultShape.append(rhsShape.begin(), rhsShape.end());
318   Builder builder(getContext());
319   return builder.getIndexTensorAttr(resultShape);
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // ConstShapeOp
324 //===----------------------------------------------------------------------===//
325 
326 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
327   p << "shape.const_shape ";
328   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
329   p << "[";
330   interleaveComma(op.shape().getValues<int64_t>(), p,
331                   [&](int64_t i) { p << i; });
332   p << "] : ";
333   p.printType(op.getType());
334 }
335 
336 static ParseResult parseConstShapeOp(OpAsmParser &parser,
337                                      OperationState &result) {
338   if (parser.parseOptionalAttrDict(result.attributes))
339     return failure();
340   // We piggy-back on ArrayAttr parsing, though we don't internally store the
341   // shape as an ArrayAttr.
342   // TODO: Implement custom parser and maybe make syntax a bit more concise.
343   Attribute extentsRaw;
344   NamedAttrList dummy;
345   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
346     return failure();
347   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
348   if (!extentsArray)
349     return failure();
350   SmallVector<int64_t, 6> ints;
351   for (Attribute extent : extentsArray) {
352     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
353     if (!attr)
354       return failure();
355     ints.push_back(attr.getInt());
356   }
357   Builder &builder = parser.getBuilder();
358   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
359   Type resultTy;
360   if (parser.parseColonType(resultTy))
361     return failure();
362   result.types.push_back(resultTy);
363   return success();
364 }
365 
366 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
367 
368 //===----------------------------------------------------------------------===//
369 // CstrBroadcastableOp
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 // Given an input shape Value, try to obtain the shape's values.
374 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
375   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
376     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
377     if (!type.hasRank())
378       return failure();
379     shapeValues = llvm::to_vector<6>(type.getShape());
380     return success();
381   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
382     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
383     return success();
384   } else {
385     return failure();
386   }
387 }
388 
389 // For shapes that were created by some operations, we can obtain partial
390 // information on the shapes and sometimes determine if they will be
391 // broadcastable with that.
392 struct CstrBroadcastablePartialInfo
393     : public OpRewritePattern<CstrBroadcastableOp> {
394   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
395 
396   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
397                                 PatternRewriter &rewriter) const override {
398     SmallVector<int64_t, 6> lhsShape, rhsShape;
399     if (failed(getShapeVec(op.lhs(), lhsShape)))
400       return failure();
401     if (failed(getShapeVec(op.rhs(), rhsShape)))
402       return failure();
403     if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
404       return failure();
405 
406     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
407     return success();
408   }
409 };
410 
411 // Scalars are always broadcastable.
412 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
413   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
414 
415   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
416                                 PatternRewriter &rewriter) const override {
417     SmallVector<int64_t, 6> shape;
418     if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
419       return failure();
420     if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
421       return failure();
422 
423     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
424     return success();
425   }
426 };
427 
428 } // namespace
429 
430 void CstrBroadcastableOp::getCanonicalizationPatterns(
431     OwningRewritePatternList &patterns, MLIRContext *context) {
432   // Canonicalization patterns have overlap with the considerations during
433   // folding in case additional shape information is inferred at some point that
434   // does not result in folding.
435   patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
436                   CstrBroadcastableScalar>(context);
437 }
438 
439 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
440   // Both operands are not needed if one is a scalar.
441   if (operands[0] &&
442       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
443     return BoolAttr::get(true, getContext());
444   if (operands[1] &&
445       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
446     return BoolAttr::get(true, getContext());
447 
448   if (operands[0] && operands[1]) {
449     auto lhsShape = llvm::to_vector<6>(
450         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
451     auto rhsShape = llvm::to_vector<6>(
452         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
453     SmallVector<int64_t, 6> resultShape;
454     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
455       return BoolAttr::get(true, getContext());
456   }
457 
458   // Lastly, see if folding can be completed based on what constraints are known
459   // on the input shapes.
460   SmallVector<int64_t, 6> lhsShape, rhsShape;
461   if (failed(getShapeVec(lhs(), lhsShape)))
462     return nullptr;
463   if (failed(getShapeVec(rhs(), rhsShape)))
464     return nullptr;
465 
466   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
467     return BoolAttr::get(true, getContext());
468 
469   // Because a failing witness result here represents an eventual assertion
470   // failure, we do not replace it with a constant witness.
471   return nullptr;
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // CstrEqOp
476 //===----------------------------------------------------------------------===//
477 
478 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
479                                            MLIRContext *context) {
480   // If inputs are equal, return passing witness
481   patterns.insert<CstrEqEqOps>(context);
482 }
483 
484 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
485   if (llvm::all_of(operands,
486                    [&](Attribute a) { return a && a == operands[0]; }))
487     return BoolAttr::get(true, getContext());
488 
489   // Because a failing witness result here represents an eventual assertion
490   // failure, we do not try to replace it with a constant witness. Similarly, we
491   // cannot if there are any non-const inputs.
492   return nullptr;
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // ConstSizeOp
497 //===----------------------------------------------------------------------===//
498 
499 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
500                         int64_t value) {
501   build(builder, result, builder.getIndexAttr(value));
502 }
503 
504 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
505 
506 void ConstSizeOp::getAsmResultNames(
507     llvm::function_ref<void(Value, StringRef)> setNameFn) {
508   SmallString<4> buffer;
509   llvm::raw_svector_ostream os(buffer);
510   os << "c" << value();
511   setNameFn(getResult(), os.str());
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // ConstWitnessOp
516 //===----------------------------------------------------------------------===//
517 
518 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
519 
520 //===----------------------------------------------------------------------===//
521 // ShapeEqOp
522 //===----------------------------------------------------------------------===//
523 
524 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
525   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
526   if (lhs == nullptr)
527     return {};
528   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
529   if (rhs == nullptr)
530     return {};
531   return BoolAttr::get(lhs == rhs, getContext());
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // IndexToSizeOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
539   // Constant values of both types, `shape.size` and `index`, are represented as
540   // `IntegerAttr`s which makes constant folding simple.
541   if (Attribute arg = operands[0])
542     return arg;
543   return {};
544 }
545 
546 void IndexToSizeOp::getCanonicalizationPatterns(
547     OwningRewritePatternList &patterns, MLIRContext *context) {
548   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // FromExtentsOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
556   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
557     return nullptr;
558   SmallVector<int64_t, 6> extents;
559   for (auto attr : operands)
560     extents.push_back(attr.cast<IntegerAttr>().getInt());
561   Builder builder(getContext());
562   return builder.getIndexTensorAttr(extents);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // GetExtentOp
567 //===----------------------------------------------------------------------===//
568 
569 Optional<int64_t> GetExtentOp::getConstantDim() {
570   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
571     return constSizeOp.value().getLimitedValue();
572   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
573     return constantOp.value().cast<IntegerAttr>().getInt();
574   return llvm::None;
575 }
576 
577 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
578   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
579   if (!elements)
580     return nullptr;
581   Optional<int64_t> dim = getConstantDim();
582   if (!dim.hasValue())
583     return nullptr;
584   if (dim.getValue() >= elements.getNumElements())
585     return nullptr;
586   return elements.getValue({(uint64_t)dim.getValue()});
587 }
588 
589 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
590                         int64_t dim) {
591   auto loc = result.location;
592   auto dimAttr = builder.getIndexAttr(dim);
593   if (shape.getType().isa<ShapeType>()) {
594     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
595     build(builder, result, builder.getType<SizeType>(), shape, dim);
596   } else {
597     Value dim =
598         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
599     build(builder, result, builder.getIndexType(), shape, dim);
600   }
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // RankOp
605 //===----------------------------------------------------------------------===//
606 
607 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
608   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
609   if (!shape)
610     return {};
611   int64_t rank = shape.getNumElements();
612   Builder builder(getContext());
613   return builder.getIndexAttr(rank);
614 }
615 
616 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
617 /// Constant folding fails in cases where only the rank is constant, not the
618 /// shape itself.
619 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
620 ///
621 /// Example:
622 ///
623 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
624 /// %rank = shape.rank %shape
625 ///
626 /// becomes
627 ///
628 /// %rank = shape.const_size 3
629 
630 namespace {
631 struct RankShapeOfCanonicalizationPattern
632     : public OpRewritePattern<shape::RankOp> {
633   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
634 
635   LogicalResult matchAndRewrite(shape::RankOp op,
636                                 PatternRewriter &rewriter) const override {
637     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
638     if (!shapeOfOp)
639       return failure();
640     auto rankedTensorType =
641         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
642     if (!rankedTensorType)
643       return failure();
644     assert(op.getType().isa<IndexType>() &&
645            "expected `rank(shape_of( ... )]` based on a shaped argument to "
646            "yield an index type");
647     int64_t rank = rankedTensorType.getRank();
648     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
649     return success();
650   }
651 };
652 } // namespace
653 
654 void shape::RankOp::getCanonicalizationPatterns(
655     OwningRewritePatternList &patterns, MLIRContext *context) {
656   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // NumElementsOp
661 //===----------------------------------------------------------------------===//
662 
663 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
664 
665   // Fold only when argument constant.
666   Attribute shape = operands[0];
667   if (!shape)
668     return {};
669 
670   APInt product(64, 1);
671   for (auto value : shape.cast<DenseIntElementsAttr>())
672     product *= value;
673   Builder builder(getContext());
674   return builder.getIndexAttr(product.getLimitedValue());
675 }
676 
677 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
678                           Value shape) {
679   if (shape.getType().isa<ShapedType>()) {
680     auto type = builder.getIndexType();
681     return build(builder, result, type, shape);
682   }
683   auto type = SizeType::get(builder.getContext());
684   return build(builder, result, type, shape);
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // MulOp
689 //===----------------------------------------------------------------------===//
690 
691 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
692   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
693   if (!lhs)
694     return nullptr;
695   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
696   if (!rhs)
697     return nullptr;
698   APInt folded = lhs.getValue() * rhs.getValue();
699   Type indexTy = IndexType::get(getContext());
700   return IntegerAttr::get(indexTy, folded);
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // ShapeOfOp
705 //===----------------------------------------------------------------------===//
706 
707 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
708   auto type = getOperand().getType().dyn_cast<ShapedType>();
709   if (!type || !type.hasStaticShape())
710     return nullptr;
711   Builder builder(getContext());
712   return builder.getIndexTensorAttr(type.getShape());
713 }
714 
715 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
716   if (arg.getType().isa<ShapedType>()) {
717     auto type = RankedTensorType::get({ShapedType::kDynamicSize},
718                                       builder.getIndexType());
719     return ShapeOfOp::build(builder, result, type, arg);
720   }
721   auto type = ShapeType::get(builder.getContext());
722   return ShapeOfOp::build(builder, result, type, arg);
723 }
724 
725 namespace {
726 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
727   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
728 
729   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
730                                 PatternRewriter &rewriter) const override {
731     if (!op.arg().getType().isa<ShapedType>())
732       return failure();
733     if (op.getType().isa<ShapedType>())
734       return failure();
735 
736     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
737     return success();
738   }
739 };
740 } // namespace
741 
742 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
743                                             MLIRContext *context) {
744   patterns.insert<ShapeOfWithTensor>(context);
745 }
746 
747 //===----------------------------------------------------------------------===//
748 // SizeToIndexOp
749 //===----------------------------------------------------------------------===//
750 
751 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
752   // Constant values of both types, `shape.size` and `index`, are represented as
753   // `IntegerAttr`s which makes constant folding simple.
754   if (Attribute arg = operands[0])
755     return arg;
756   return impl::foldCastOp(*this);
757 }
758 
759 void SizeToIndexOp::getCanonicalizationPatterns(
760     OwningRewritePatternList &patterns, MLIRContext *context) {
761   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // YieldOp
766 //===----------------------------------------------------------------------===//
767 
768 static LogicalResult verify(YieldOp op) {
769   auto *parentOp = op.getParentOp();
770   auto results = parentOp->getResults();
771   auto operands = op.getOperands();
772 
773   if (parentOp->getNumResults() != op.getNumOperands())
774     return op.emitOpError() << "number of operands does not match number of "
775                                "results of its parent";
776   for (auto e : llvm::zip(results, operands))
777     if (std::get<0>(e).getType() != std::get<1>(e).getType())
778       return op.emitOpError()
779              << "types mismatch between yield op and its parent";
780 
781   return success();
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // SplitAtOp
786 //===----------------------------------------------------------------------===//
787 
788 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
789                               SmallVectorImpl<OpFoldResult> &results) {
790   if (!operands[0] || !operands[1])
791     return failure();
792   auto shapeVec = llvm::to_vector<6>(
793       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
794   auto shape = llvm::makeArrayRef(shapeVec);
795   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
796   // Verify that the split point is in the correct range.
797   // TODO: Constant fold to an "error".
798   int64_t rank = shape.size();
799   if (!(-rank <= splitPoint && splitPoint <= rank))
800     return failure();
801   if (splitPoint < 0)
802     splitPoint += shape.size();
803   Builder builder(operands[0].getContext());
804   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
805   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
806   return success();
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // ToExtentTensorOp
811 //===----------------------------------------------------------------------===//
812 
813 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
814   if (!operands[0])
815     return impl::foldCastOp(*this);
816   Builder builder(getContext());
817   auto shape = llvm::to_vector<6>(
818       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
819   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
820                                     builder.getIndexType());
821   return DenseIntElementsAttr::get(type, shape);
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // ReduceOp
826 //===----------------------------------------------------------------------===//
827 
828 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
829                      ValueRange initVals) {
830   result.addOperands(shape);
831   result.addOperands(initVals);
832 
833   Region *bodyRegion = result.addRegion();
834   bodyRegion->push_back(new Block);
835   Block &bodyBlock = bodyRegion->front();
836   bodyBlock.addArgument(builder.getIndexType());
837 
838   Type elementType;
839   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
840     elementType = tensorType.getElementType();
841   else
842     elementType = SizeType::get(builder.getContext());
843   bodyBlock.addArgument(elementType);
844 
845   for (Type initValType : initVals.getTypes()) {
846     bodyBlock.addArgument(initValType);
847     result.addTypes(initValType);
848   }
849 }
850 
851 static LogicalResult verify(ReduceOp op) {
852   // Verify block arg types.
853   Block &block = op.region().front();
854 
855   // The block takes index, extent, and aggregated values as arguments.
856   auto blockArgsCount = op.initVals().size() + 2;
857   if (block.getNumArguments() != blockArgsCount)
858     return op.emitOpError() << "ReduceOp body is expected to have "
859                             << blockArgsCount << " arguments";
860 
861   // The first block argument is the index and must always be of type `index`.
862   if (!block.getArgument(0).getType().isa<IndexType>())
863     return op.emitOpError(
864         "argument 0 of ReduceOp body is expected to be of IndexType");
865 
866   // The second block argument is the extent and must be of type `size` or
867   // `index`, depending on whether the reduce operation is applied to a shape or
868   // to an extent tensor.
869   Type extentTy = block.getArgument(1).getType();
870   if (op.shape().getType().isa<ShapeType>()) {
871     if (!extentTy.isa<SizeType>())
872       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
873                             "SizeType if the ReduceOp operates on a ShapeType");
874   } else {
875     if (!extentTy.isa<IndexType>())
876       return op.emitOpError(
877           "argument 1 of ReduceOp body is expected to be of IndexType if the "
878           "ReduceOp operates on an extent tensor");
879   }
880 
881   for (auto type : llvm::enumerate(op.initVals()))
882     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
883       return op.emitOpError()
884              << "type mismatch between argument " << type.index() + 2
885              << " of ReduceOp body and initial value " << type.index();
886   return success();
887 }
888 
889 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
890   // Parse operands.
891   SmallVector<OpAsmParser::OperandType, 3> operands;
892   Type shapeOrExtentTensorType;
893   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
894                               OpAsmParser::Delimiter::Paren) ||
895       parser.parseColonType(shapeOrExtentTensorType) ||
896       parser.parseOptionalArrowTypeList(result.types))
897     return failure();
898 
899   // Resolve operands.
900   auto initVals = llvm::makeArrayRef(operands).drop_front();
901   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
902                             result.operands) ||
903       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
904                              result.operands))
905     return failure();
906 
907   // Parse the body.
908   Region *body = result.addRegion();
909   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
910     return failure();
911 
912   // Parse attributes.
913   if (parser.parseOptionalAttrDict(result.attributes))
914     return failure();
915 
916   return success();
917 }
918 
919 static void print(OpAsmPrinter &p, ReduceOp op) {
920   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
921     << ") : " << op.shape().getType();
922   p.printOptionalArrowTypeList(op.getResultTypes());
923   p.printRegion(op.region());
924   p.printOptionalAttrDict(op.getAttrs());
925 }
926 
927 namespace mlir {
928 namespace shape {
929 
930 #define GET_OP_CLASSES
931 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
932 
933 } // namespace shape
934 } // namespace mlir
935