1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 using namespace mlir::shape;
22 
23 namespace {
24 #include "ShapeCanonicalization.inc"
25 }
26 
27 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
28   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
29 }
30 
31 static bool isErrorPropagationPossible(TypeRange operandTypes) {
32   for (Type ty : operandTypes)
33     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
34       return true;
35   return false;
36 }
37 
38 static LogicalResult verifySizeOrIndexOp(Operation *op) {
39   assert(op != nullptr && op->getNumResults() == 1);
40   Type resultTy = op->getResultTypes().front();
41   if (isErrorPropagationPossible(op->getOperandTypes())) {
42     if (!resultTy.isa<SizeType>())
43       return op->emitOpError()
44              << "if at least one of the operands can hold error values then "
45                 "the result must be of type `size` to propagate them";
46   }
47   return success();
48 }
49 
50 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
51   assert(op != nullptr && op->getNumResults() == 1);
52   Type resultTy = op->getResultTypes().front();
53   if (isErrorPropagationPossible(op->getOperandTypes())) {
54     if (!resultTy.isa<ShapeType>())
55       return op->emitOpError()
56              << "if at least one of the operands can hold error values then "
57                 "the result must be of type `shape` to propagate them";
58   }
59   return success();
60 }
61 
62 void ShapeDialect::initialize() {
63   addOperations<
64 #define GET_OP_LIST
65 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
66       >();
67   addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
68            WitnessType>();
69   // Allow unknown operations during prototyping and testing. As the dialect is
70   // still evolving it makes it simple to start with an unregistered ops and
71   // try different variants before actually defining the op.
72   allowUnknownOperations();
73 }
74 
75 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
76                                              Attribute value, Type type,
77                                              Location loc) {
78   if (type.isa<ShapeType>() ||
79       type == getExtentTensorType(builder.getContext()))
80     return builder.create<ConstShapeOp>(loc, type,
81                                         value.cast<DenseIntElementsAttr>());
82   if (type.isa<SizeType>())
83     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
84   if (type.isa<WitnessType>())
85     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
86   if (type.isa<IndexType>())
87     return builder.create<ConstantOp>(loc, type, value);
88   return nullptr;
89 }
90 
91 /// Parse a type registered to this dialect.
92 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
93   StringRef keyword;
94   if (parser.parseKeyword(&keyword))
95     return Type();
96 
97   if (keyword == "component")
98     return ComponentType::get(getContext());
99   if (keyword == "element")
100     return ElementType::get(getContext());
101   if (keyword == "shape")
102     return ShapeType::get(getContext());
103   if (keyword == "size")
104     return SizeType::get(getContext());
105   if (keyword == "value_shape")
106     return ValueShapeType::get(getContext());
107   if (keyword == "witness")
108     return WitnessType::get(getContext());
109 
110   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
111   return Type();
112 }
113 
114 /// Print a type registered to this dialect.
115 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
116   switch (type.getKind()) {
117   case ShapeTypes::Component:
118     os << "component";
119     return;
120   case ShapeTypes::Element:
121     os << "element";
122     return;
123   case ShapeTypes::Size:
124     os << "size";
125     return;
126   case ShapeTypes::Shape:
127     os << "shape";
128     return;
129   case ShapeTypes::ValueShape:
130     os << "value_shape";
131     return;
132   case ShapeTypes::Witness:
133     os << "witness";
134     return;
135   default:
136     llvm_unreachable("unexpected 'shape' type kind");
137   }
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // AnyOp
142 //===----------------------------------------------------------------------===//
143 
144 // TODO: Canonicalization should be implemented for shapes that can be
145 // determined through mixtures of the known dimensions of the inputs.
146 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
147   // Only the last operand is checked because AnyOp is commutative.
148   if (operands.back())
149     return operands.back();
150 
151   return nullptr;
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // AssumingOp
156 //===----------------------------------------------------------------------===//
157 
158 static ParseResult parseAssumingOp(OpAsmParser &parser,
159                                    OperationState &result) {
160   result.regions.reserve(1);
161   Region *doRegion = result.addRegion();
162 
163   auto &builder = parser.getBuilder();
164   OpAsmParser::OperandType cond;
165   if (parser.parseOperand(cond) ||
166       parser.resolveOperand(cond, builder.getType<WitnessType>(),
167                             result.operands))
168     return failure();
169 
170   // Parse optional results type list.
171   if (parser.parseOptionalArrowTypeList(result.types))
172     return failure();
173 
174   // Parse the region and add a terminator if elided.
175   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
176     return failure();
177   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
178 
179   // Parse the optional attribute list.
180   if (parser.parseOptionalAttrDict(result.attributes))
181     return failure();
182   return success();
183 }
184 
185 static void print(OpAsmPrinter &p, AssumingOp op) {
186   bool yieldsResults = !op.results().empty();
187 
188   p << AssumingOp::getOperationName() << " " << op.witness();
189   if (yieldsResults) {
190     p << " -> (" << op.getResultTypes() << ")";
191   }
192   p.printRegion(op.doRegion(),
193                 /*printEntryBlockArgs=*/false,
194                 /*printBlockTerminators=*/yieldsResults);
195   p.printOptionalAttrDict(op.getAttrs());
196 }
197 
198 namespace {
199 // Removes AssumingOp with a passing witness and inlines the region.
200 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
201   using OpRewritePattern<AssumingOp>::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(AssumingOp op,
204                                 PatternRewriter &rewriter) const override {
205     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
206     if (!witness || !witness.passingAttr())
207       return failure();
208 
209     AssumingOp::inlineRegionIntoParent(op, rewriter);
210     return success();
211   }
212 };
213 } // namespace
214 
215 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
216                                              MLIRContext *context) {
217   // If taking a passing witness, inline region.
218   patterns.insert<AssumingWithTrue>(context);
219 }
220 
221 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
222                                         PatternRewriter &rewriter) {
223   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
224   auto *assumingBlock = op.getBody();
225   auto initPosition = rewriter.getInsertionPoint();
226   auto *blockAfterAssuming =
227       rewriter.splitBlock(blockBeforeAssuming, initPosition);
228 
229   // Remove the AssumingOp and AssumingYieldOp.
230   auto &yieldOp = assumingBlock->back();
231   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
232   rewriter.replaceOp(op, yieldOp.getOperands());
233   rewriter.eraseOp(&yieldOp);
234 
235   // Merge blocks together as there was no branching behavior from the
236   // AssumingOp.
237   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
238   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // AssumingAllOp
243 //===----------------------------------------------------------------------===//
244 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
245   // Iterate in reverse to first handle all constant operands. They are
246   // guaranteed to be the tail of the inputs because this is commutative.
247   for (int idx = operands.size() - 1; idx >= 0; idx--) {
248     Attribute a = operands[idx];
249     // Cannot fold if any inputs are not constant;
250     if (!a)
251       return nullptr;
252 
253     // We do not need to keep statically known values after handling them in
254     // this method.
255     getOperation()->eraseOperand(idx);
256 
257     // Always false if any input is statically known false
258     if (!a.cast<BoolAttr>().getValue())
259       return a;
260   }
261   // If this is reached, all inputs were statically known passing.
262   return BoolAttr::get(true, getContext());
263 }
264 
265 static LogicalResult verify(AssumingAllOp op) {
266   // Ensure that AssumingAllOp contains at least one operand
267   if (op.getNumOperands() == 0)
268     return op.emitOpError("no operands specified");
269 
270   return success();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // BroadcastOp
275 //===----------------------------------------------------------------------===//
276 
277 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
278   if (!operands[1])
279     return nullptr;
280 
281   auto rhsShape = llvm::to_vector<6>(
282       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
283   if (rhsShape.empty())
284     return lhs();
285 
286   if (!operands[0])
287     return nullptr;
288 
289   auto lhsShape = llvm::to_vector<6>(
290       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
291   if (lhsShape.empty())
292     return rhs();
293 
294   SmallVector<int64_t, 6> resultShape;
295   // If the shapes are not compatible, we can't fold it.
296   // TODO: Fold to an "error".
297   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
298     return nullptr;
299   Builder builder(getContext());
300   return builder.getIndexTensorAttr(resultShape);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ConcatOp
305 //===----------------------------------------------------------------------===//
306 
307 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
308   if (!operands[0] || !operands[1])
309     return nullptr;
310   auto lhsShape = llvm::to_vector<6>(
311       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
312   auto rhsShape = llvm::to_vector<6>(
313       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
314   SmallVector<int64_t, 6> resultShape;
315   resultShape.append(lhsShape.begin(), lhsShape.end());
316   resultShape.append(rhsShape.begin(), rhsShape.end());
317   Builder builder(getContext());
318   return builder.getIndexTensorAttr(resultShape);
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // ConstShapeOp
323 //===----------------------------------------------------------------------===//
324 
325 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
326   p << "shape.const_shape ";
327   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
328   p << "[";
329   interleaveComma(op.shape().getValues<int64_t>(), p,
330                   [&](int64_t i) { p << i; });
331   p << "] : ";
332   p.printType(op.getType());
333 }
334 
335 static ParseResult parseConstShapeOp(OpAsmParser &parser,
336                                      OperationState &result) {
337   if (parser.parseOptionalAttrDict(result.attributes))
338     return failure();
339   // We piggy-back on ArrayAttr parsing, though we don't internally store the
340   // shape as an ArrayAttr.
341   // TODO: Implement custom parser and maybe make syntax a bit more concise.
342   Attribute extentsRaw;
343   NamedAttrList dummy;
344   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
345     return failure();
346   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
347   if (!extentsArray)
348     return failure();
349   SmallVector<int64_t, 6> ints;
350   for (Attribute extent : extentsArray) {
351     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
352     if (!attr)
353       return failure();
354     ints.push_back(attr.getInt());
355   }
356   Builder &builder = parser.getBuilder();
357   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
358   Type resultTy;
359   if (parser.parseColonType(resultTy))
360     return failure();
361   result.types.push_back(resultTy);
362   return success();
363 }
364 
365 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
366 
367 //===----------------------------------------------------------------------===//
368 // CstrBroadcastableOp
369 //===----------------------------------------------------------------------===//
370 
371 namespace {
372 // Given an input shape Value, try to obtain the shape's values.
373 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
374   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
375     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
376     if (!type.hasRank())
377       return failure();
378     shapeValues = llvm::to_vector<6>(type.getShape());
379     return success();
380   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
381     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
382     return success();
383   } else {
384     return failure();
385   }
386 }
387 
388 // For shapes that were created by some operations, we can obtain partial
389 // information on the shapes and sometimes determine if they will be
390 // broadcastable with that.
391 struct CstrBroadcastablePartialInfo
392     : public OpRewritePattern<CstrBroadcastableOp> {
393   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
394 
395   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
396                                 PatternRewriter &rewriter) const override {
397     SmallVector<int64_t, 6> lhsShape, rhsShape;
398     if (failed(getShapeVec(op.lhs(), lhsShape)))
399       return failure();
400     if (failed(getShapeVec(op.rhs(), rhsShape)))
401       return failure();
402     if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
403       return failure();
404 
405     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
406     return success();
407   }
408 };
409 
410 // Scalars are always broadcastable.
411 struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
412   using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
413 
414   LogicalResult matchAndRewrite(CstrBroadcastableOp op,
415                                 PatternRewriter &rewriter) const override {
416     SmallVector<int64_t, 6> shape;
417     if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
418       return failure();
419     if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
420       return failure();
421 
422     rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
423     return success();
424   }
425 };
426 
427 } // namespace
428 
429 void CstrBroadcastableOp::getCanonicalizationPatterns(
430     OwningRewritePatternList &patterns, MLIRContext *context) {
431   // Canonicalization patterns have overlap with the considerations during
432   // folding in case additional shape information is inferred at some point that
433   // does not result in folding.
434   patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
435                   CstrBroadcastableScalar>(context);
436 }
437 
438 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
439   // Both operands are not needed if one is a scalar.
440   if (operands[0] &&
441       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
442     return BoolAttr::get(true, getContext());
443   if (operands[1] &&
444       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
445     return BoolAttr::get(true, getContext());
446 
447   if (operands[0] && operands[1]) {
448     auto lhsShape = llvm::to_vector<6>(
449         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
450     auto rhsShape = llvm::to_vector<6>(
451         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
452     SmallVector<int64_t, 6> resultShape;
453     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
454       return BoolAttr::get(true, getContext());
455   }
456 
457   // Lastly, see if folding can be completed based on what constraints are known
458   // on the input shapes.
459   SmallVector<int64_t, 6> lhsShape, rhsShape;
460   if (failed(getShapeVec(lhs(), lhsShape)))
461     return nullptr;
462   if (failed(getShapeVec(rhs(), rhsShape)))
463     return nullptr;
464 
465   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
466     return BoolAttr::get(true, getContext());
467 
468   // Because a failing witness result here represents an eventual assertion
469   // failure, we do not replace it with a constant witness.
470   return nullptr;
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // CstrEqOp
475 //===----------------------------------------------------------------------===//
476 
477 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
478                                            MLIRContext *context) {
479   // If inputs are equal, return passing witness
480   patterns.insert<CstrEqEqOps>(context);
481 }
482 
483 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
484   if (llvm::all_of(operands,
485                    [&](Attribute a) { return a && a == operands[0]; }))
486     return BoolAttr::get(true, getContext());
487 
488   // Because a failing witness result here represents an eventual assertion
489   // failure, we do not try to replace it with a constant witness. Similarly, we
490   // cannot if there are any non-const inputs.
491   return nullptr;
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // ConstSizeOp
496 //===----------------------------------------------------------------------===//
497 
498 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
499                         int64_t value) {
500   build(builder, result, builder.getIndexAttr(value));
501 }
502 
503 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
504 
505 void ConstSizeOp::getAsmResultNames(
506     llvm::function_ref<void(Value, StringRef)> setNameFn) {
507   SmallString<4> buffer;
508   llvm::raw_svector_ostream os(buffer);
509   os << "c" << value();
510   setNameFn(getResult(), os.str());
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // ConstWitnessOp
515 //===----------------------------------------------------------------------===//
516 
517 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
518 
519 //===----------------------------------------------------------------------===//
520 // ShapeEqOp
521 //===----------------------------------------------------------------------===//
522 
523 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
524   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
525   if (lhs == nullptr)
526     return {};
527   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
528   if (rhs == nullptr)
529     return {};
530   return BoolAttr::get(lhs == rhs, getContext());
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // IndexToSizeOp
535 //===----------------------------------------------------------------------===//
536 
537 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
538   // Constant values of both types, `shape.size` and `index`, are represented as
539   // `IntegerAttr`s which makes constant folding simple.
540   if (Attribute arg = operands[0])
541     return arg;
542   return {};
543 }
544 
545 void IndexToSizeOp::getCanonicalizationPatterns(
546     OwningRewritePatternList &patterns, MLIRContext *context) {
547   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // FromExtentsOp
552 //===----------------------------------------------------------------------===//
553 
554 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
555   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
556     return nullptr;
557   SmallVector<int64_t, 6> extents;
558   for (auto attr : operands)
559     extents.push_back(attr.cast<IntegerAttr>().getInt());
560   Builder builder(getContext());
561   return builder.getIndexTensorAttr(extents);
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // GetExtentOp
566 //===----------------------------------------------------------------------===//
567 
568 Optional<int64_t> GetExtentOp::getConstantDim() {
569   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
570     return constSizeOp.value().getLimitedValue();
571   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
572     return constantOp.value().cast<IntegerAttr>().getInt();
573   return llvm::None;
574 }
575 
576 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
577   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
578   if (!elements)
579     return nullptr;
580   Optional<int64_t> dim = getConstantDim();
581   if (!dim.hasValue())
582     return nullptr;
583   if (dim.getValue() >= elements.getNumElements())
584     return nullptr;
585   return elements.getValue({(uint64_t)dim.getValue()});
586 }
587 
588 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
589                         int64_t dim) {
590   auto loc = result.location;
591   auto dimAttr = builder.getIndexAttr(dim);
592   if (shape.getType().isa<ShapeType>()) {
593     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
594     build(builder, result, builder.getType<SizeType>(), shape, dim);
595   } else {
596     Value dim =
597         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
598     build(builder, result, builder.getIndexType(), shape, dim);
599   }
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // RankOp
604 //===----------------------------------------------------------------------===//
605 
606 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
607   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
608   if (!shape)
609     return {};
610   int64_t rank = shape.getNumElements();
611   Builder builder(getContext());
612   return builder.getIndexAttr(rank);
613 }
614 
615 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
616 /// Constant folding fails in cases where only the rank is constant, not the
617 /// shape itself.
618 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
619 ///
620 /// Example:
621 ///
622 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
623 /// %rank = shape.rank %shape
624 ///
625 /// becomes
626 ///
627 /// %rank = shape.const_size 3
628 
629 namespace {
630 struct RankShapeOfCanonicalizationPattern
631     : public OpRewritePattern<shape::RankOp> {
632   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
633 
634   LogicalResult matchAndRewrite(shape::RankOp op,
635                                 PatternRewriter &rewriter) const override {
636     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
637     if (!shapeOfOp)
638       return failure();
639     auto rankedTensorType =
640         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
641     if (!rankedTensorType)
642       return failure();
643     assert(op.getType().isa<IndexType>() &&
644            "expected `rank(shape_of( ... )]` based on a shaped argument to "
645            "yield an index type");
646     int64_t rank = rankedTensorType.getRank();
647     rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
648     return success();
649   }
650 };
651 } // namespace
652 
653 void shape::RankOp::getCanonicalizationPatterns(
654     OwningRewritePatternList &patterns, MLIRContext *context) {
655   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // NumElementsOp
660 //===----------------------------------------------------------------------===//
661 
662 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
663 
664   // Fold only when argument constant.
665   Attribute shape = operands[0];
666   if (!shape)
667     return {};
668 
669   APInt product(64, 1);
670   for (auto value : shape.cast<DenseIntElementsAttr>())
671     product *= value;
672   Builder builder(getContext());
673   return builder.getIndexAttr(product.getLimitedValue());
674 }
675 
676 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
677                           Value shape) {
678   if (shape.getType().isa<ShapedType>()) {
679     auto type = builder.getIndexType();
680     return build(builder, result, type, shape);
681   }
682   auto type = SizeType::get(builder.getContext());
683   return build(builder, result, type, shape);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // MulOp
688 //===----------------------------------------------------------------------===//
689 
690 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
691   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
692   if (!lhs)
693     return nullptr;
694   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
695   if (!rhs)
696     return nullptr;
697   APInt folded = lhs.getValue() * rhs.getValue();
698   Type indexTy = IndexType::get(getContext());
699   return IntegerAttr::get(indexTy, folded);
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // ShapeOfOp
704 //===----------------------------------------------------------------------===//
705 
706 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
707   auto type = getOperand().getType().dyn_cast<ShapedType>();
708   if (!type || !type.hasStaticShape())
709     return nullptr;
710   Builder builder(getContext());
711   return builder.getIndexTensorAttr(type.getShape());
712 }
713 
714 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
715   Type type = arg.getType().isa<ShapedType>()
716                   ? (Type)getExtentTensorType(builder.getContext())
717                   : (Type)builder.getType<ShapeType>();
718   return ShapeOfOp::build(builder, result, type, arg);
719 }
720 
721 namespace {
722 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
723   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
724 
725   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
726                                 PatternRewriter &rewriter) const override {
727     if (!op.arg().getType().isa<ShapedType>())
728       return failure();
729     if (op.getType().isa<ShapedType>())
730       return failure();
731 
732     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
733     return success();
734   }
735 };
736 } // namespace
737 
738 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
739                                             MLIRContext *context) {
740   patterns.insert<ShapeOfWithTensor>(context);
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // SizeToIndexOp
745 //===----------------------------------------------------------------------===//
746 
747 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
748   // Constant values of both types, `shape.size` and `index`, are represented as
749   // `IntegerAttr`s which makes constant folding simple.
750   if (Attribute arg = operands[0])
751     return arg;
752   return impl::foldCastOp(*this);
753 }
754 
755 void SizeToIndexOp::getCanonicalizationPatterns(
756     OwningRewritePatternList &patterns, MLIRContext *context) {
757   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // YieldOp
762 //===----------------------------------------------------------------------===//
763 
764 static LogicalResult verify(YieldOp op) {
765   auto *parentOp = op.getParentOp();
766   auto results = parentOp->getResults();
767   auto operands = op.getOperands();
768 
769   if (parentOp->getNumResults() != op.getNumOperands())
770     return op.emitOpError() << "number of operands does not match number of "
771                                "results of its parent";
772   for (auto e : llvm::zip(results, operands))
773     if (std::get<0>(e).getType() != std::get<1>(e).getType())
774       return op.emitOpError()
775              << "types mismatch between yield op and its parent";
776 
777   return success();
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // SplitAtOp
782 //===----------------------------------------------------------------------===//
783 
784 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
785                               SmallVectorImpl<OpFoldResult> &results) {
786   if (!operands[0] || !operands[1])
787     return failure();
788   auto shapeVec = llvm::to_vector<6>(
789       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
790   auto shape = llvm::makeArrayRef(shapeVec);
791   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
792   // Verify that the split point is in the correct range.
793   // TODO: Constant fold to an "error".
794   int64_t rank = shape.size();
795   if (!(-rank <= splitPoint && splitPoint <= rank))
796     return failure();
797   if (splitPoint < 0)
798     splitPoint += shape.size();
799   Builder builder(operands[0].getContext());
800   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
801   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
802   return success();
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // ToExtentTensorOp
807 //===----------------------------------------------------------------------===//
808 
809 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
810   if (!operands[0])
811     return impl::foldCastOp(*this);
812   Builder builder(getContext());
813   auto shape = llvm::to_vector<6>(
814       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
815   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
816                                     builder.getIndexType());
817   return DenseIntElementsAttr::get(type, shape);
818 }
819 
820 //===----------------------------------------------------------------------===//
821 // ReduceOp
822 //===----------------------------------------------------------------------===//
823 
824 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
825                      ValueRange initVals) {
826   result.addOperands(shape);
827   result.addOperands(initVals);
828 
829   Region *bodyRegion = result.addRegion();
830   bodyRegion->push_back(new Block);
831   Block &bodyBlock = bodyRegion->front();
832   bodyBlock.addArgument(builder.getIndexType());
833 
834   Type elementType;
835   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
836     elementType = tensorType.getElementType();
837   else
838     elementType = SizeType::get(builder.getContext());
839   bodyBlock.addArgument(elementType);
840 
841   for (Type initValType : initVals.getTypes()) {
842     bodyBlock.addArgument(initValType);
843     result.addTypes(initValType);
844   }
845 }
846 
847 static LogicalResult verify(ReduceOp op) {
848   // Verify block arg types.
849   Block &block = op.region().front();
850 
851   // The block takes index, extent, and aggregated values as arguments.
852   auto blockArgsCount = op.initVals().size() + 2;
853   if (block.getNumArguments() != blockArgsCount)
854     return op.emitOpError() << "ReduceOp body is expected to have "
855                             << blockArgsCount << " arguments";
856 
857   // The first block argument is the index and must always be of type `index`.
858   if (!block.getArgument(0).getType().isa<IndexType>())
859     return op.emitOpError(
860         "argument 0 of ReduceOp body is expected to be of IndexType");
861 
862   // The second block argument is the extent and must be of type `size` or
863   // `index`, depending on whether the reduce operation is applied to a shape or
864   // to an extent tensor.
865   Type extentTy = block.getArgument(1).getType();
866   if (op.shape().getType().isa<ShapeType>()) {
867     if (!extentTy.isa<SizeType>())
868       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
869                             "SizeType if the ReduceOp operates on a ShapeType");
870   } else {
871     if (!extentTy.isa<IndexType>())
872       return op.emitOpError(
873           "argument 1 of ReduceOp body is expected to be of IndexType if the "
874           "ReduceOp operates on an extent tensor");
875   }
876 
877   for (auto type : llvm::enumerate(op.initVals()))
878     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
879       return op.emitOpError()
880              << "type mismatch between argument " << type.index() + 2
881              << " of ReduceOp body and initial value " << type.index();
882   return success();
883 }
884 
885 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
886   // Parse operands.
887   SmallVector<OpAsmParser::OperandType, 3> operands;
888   Type shapeOrExtentTensorType;
889   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
890                               OpAsmParser::Delimiter::Paren) ||
891       parser.parseColonType(shapeOrExtentTensorType) ||
892       parser.parseOptionalArrowTypeList(result.types))
893     return failure();
894 
895   // Resolve operands.
896   auto initVals = llvm::makeArrayRef(operands).drop_front();
897   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
898                             result.operands) ||
899       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
900                              result.operands))
901     return failure();
902 
903   // Parse the body.
904   Region *body = result.addRegion();
905   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
906     return failure();
907 
908   // Parse attributes.
909   if (parser.parseOptionalAttrDict(result.attributes))
910     return failure();
911 
912   return success();
913 }
914 
915 static void print(OpAsmPrinter &p, ReduceOp op) {
916   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
917     << ") : " << op.shape().getType();
918   p.printOptionalArrowTypeList(op.getResultTypes());
919   p.printRegion(op.region());
920   p.printOptionalAttrDict(op.getAttrs());
921 }
922 
923 namespace mlir {
924 namespace shape {
925 
926 #define GET_OP_CLASSES
927 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
928 
929 } // namespace shape
930 } // namespace mlir
931