1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AnyOp
150 //===----------------------------------------------------------------------===//
151 
152 // TODO: Canonicalization should be implemented for shapes that can be
153 // determined through mixtures of the known dimensions of the inputs.
154 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
155   // Only the last operand is checked because AnyOp is commutative.
156   if (operands.back())
157     return operands.back();
158 
159   return nullptr;
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // AssumingOp
164 //===----------------------------------------------------------------------===//
165 
166 static ParseResult parseAssumingOp(OpAsmParser &parser,
167                                    OperationState &result) {
168   result.regions.reserve(1);
169   Region *doRegion = result.addRegion();
170 
171   auto &builder = parser.getBuilder();
172   OpAsmParser::OperandType cond;
173   if (parser.parseOperand(cond) ||
174       parser.resolveOperand(cond, builder.getType<WitnessType>(),
175                             result.operands))
176     return failure();
177 
178   // Parse optional results type list.
179   if (parser.parseOptionalArrowTypeList(result.types))
180     return failure();
181 
182   // Parse the region and add a terminator if elided.
183   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
184     return failure();
185   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
186 
187   // Parse the optional attribute list.
188   if (parser.parseOptionalAttrDict(result.attributes))
189     return failure();
190   return success();
191 }
192 
193 static void print(OpAsmPrinter &p, AssumingOp op) {
194   bool yieldsResults = !op.results().empty();
195 
196   p << AssumingOp::getOperationName() << " " << op.witness();
197   if (yieldsResults) {
198     p << " -> (" << op.getResultTypes() << ")";
199   }
200   p.printRegion(op.doRegion(),
201                 /*printEntryBlockArgs=*/false,
202                 /*printBlockTerminators=*/yieldsResults);
203   p.printOptionalAttrDict(op.getAttrs());
204 }
205 
206 namespace {
207 // Removes AssumingOp with a passing witness and inlines the region.
208 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
209   using OpRewritePattern<AssumingOp>::OpRewritePattern;
210 
211   LogicalResult matchAndRewrite(AssumingOp op,
212                                 PatternRewriter &rewriter) const override {
213     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
214     if (!witness || !witness.passingAttr())
215       return failure();
216 
217     AssumingOp::inlineRegionIntoParent(op, rewriter);
218     return success();
219   }
220 };
221 } // namespace
222 
223 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
224                                              MLIRContext *context) {
225   // If taking a passing witness, inline region.
226   patterns.insert<AssumingWithTrue>(context);
227 }
228 
229 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
230 void AssumingOp::getSuccessorRegions(
231     Optional<unsigned> index, ArrayRef<Attribute> operands,
232     SmallVectorImpl<RegionSuccessor> &regions) {
233   // AssumingOp has unconditional control flow into the region and back to the
234   // parent, so return the correct RegionSuccessor purely based on the index
235   // being None or 0.
236   if (index.hasValue()) {
237     regions.push_back(RegionSuccessor(getResults()));
238     return;
239   }
240 
241   regions.push_back(RegionSuccessor(&doRegion()));
242 }
243 
244 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
245                                         PatternRewriter &rewriter) {
246   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
247   auto *assumingBlock = op.getBody();
248   auto initPosition = rewriter.getInsertionPoint();
249   auto *blockAfterAssuming =
250       rewriter.splitBlock(blockBeforeAssuming, initPosition);
251 
252   // Remove the AssumingOp and AssumingYieldOp.
253   auto &yieldOp = assumingBlock->back();
254   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
255   rewriter.replaceOp(op, yieldOp.getOperands());
256   rewriter.eraseOp(&yieldOp);
257 
258   // Merge blocks together as there was no branching behavior from the
259   // AssumingOp.
260   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
261   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // AssumingAllOp
266 //===----------------------------------------------------------------------===//
267 
268 void AssumingAllOp::getCanonicalizationPatterns(
269     OwningRewritePatternList &patterns, MLIRContext *context) {
270   patterns.insert<AssumingAllOneOp>(context);
271 }
272 
273 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
274   // Iterate in reverse to first handle all constant operands. They are
275   // guaranteed to be the tail of the inputs because this is commutative.
276   for (int idx = operands.size() - 1; idx >= 0; idx--) {
277     Attribute a = operands[idx];
278     // Cannot fold if any inputs are not constant;
279     if (!a)
280       return nullptr;
281 
282     // We do not need to keep statically known values after handling them in
283     // this method.
284     getOperation()->eraseOperand(idx);
285 
286     // Always false if any input is statically known false
287     if (!a.cast<BoolAttr>().getValue())
288       return a;
289   }
290   // If this is reached, all inputs were statically known passing.
291   return BoolAttr::get(true, getContext());
292 }
293 
294 static LogicalResult verify(AssumingAllOp op) {
295   // Ensure that AssumingAllOp contains at least one operand
296   if (op.getNumOperands() == 0)
297     return op.emitOpError("no operands specified");
298 
299   return success();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // BroadcastOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
307   if (!operands[1])
308     return nullptr;
309 
310   auto rhsShape = llvm::to_vector<6>(
311       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
312   if (rhsShape.empty())
313     return lhs();
314 
315   if (!operands[0])
316     return nullptr;
317 
318   auto lhsShape = llvm::to_vector<6>(
319       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
320   if (lhsShape.empty())
321     return rhs();
322 
323   SmallVector<int64_t, 6> resultShape;
324   // If the shapes are not compatible, we can't fold it.
325   // TODO: Fold to an "error".
326   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
327     return nullptr;
328   Builder builder(getContext());
329   return builder.getIndexTensorAttr(resultShape);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // ConcatOp
334 //===----------------------------------------------------------------------===//
335 
336 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
337   if (!operands[0] || !operands[1])
338     return nullptr;
339   auto lhsShape = llvm::to_vector<6>(
340       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
341   auto rhsShape = llvm::to_vector<6>(
342       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
343   SmallVector<int64_t, 6> resultShape;
344   resultShape.append(lhsShape.begin(), lhsShape.end());
345   resultShape.append(rhsShape.begin(), rhsShape.end());
346   Builder builder(getContext());
347   return builder.getIndexTensorAttr(resultShape);
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // ConstShapeOp
352 //===----------------------------------------------------------------------===//
353 
354 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
355   p << "shape.const_shape ";
356   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
357   p << "[";
358   interleaveComma(op.shape().getValues<int64_t>(), p,
359                   [&](int64_t i) { p << i; });
360   p << "] : ";
361   p.printType(op.getType());
362 }
363 
364 static ParseResult parseConstShapeOp(OpAsmParser &parser,
365                                      OperationState &result) {
366   if (parser.parseOptionalAttrDict(result.attributes))
367     return failure();
368   // We piggy-back on ArrayAttr parsing, though we don't internally store the
369   // shape as an ArrayAttr.
370   // TODO: Implement custom parser and maybe make syntax a bit more concise.
371   Attribute extentsRaw;
372   NamedAttrList dummy;
373   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
374     return failure();
375   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
376   if (!extentsArray)
377     return failure();
378   SmallVector<int64_t, 6> ints;
379   for (Attribute extent : extentsArray) {
380     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
381     if (!attr)
382       return failure();
383     ints.push_back(attr.getInt());
384   }
385   Builder &builder = parser.getBuilder();
386   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
387   Type resultTy;
388   if (parser.parseColonType(resultTy))
389     return failure();
390   result.types.push_back(resultTy);
391   return success();
392 }
393 
394 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
395 
396 void ConstShapeOp::getCanonicalizationPatterns(
397     OwningRewritePatternList &patterns, MLIRContext *context) {
398   patterns.insert<TensorCastConstShape>(context);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // CstrBroadcastableOp
403 //===----------------------------------------------------------------------===//
404 
405 namespace {
406 // Given an input shape Value, try to obtain the shape's values.
407 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
408   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
409     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
410     if (!type.hasRank())
411       return failure();
412     shapeValues = llvm::to_vector<6>(type.getShape());
413     return success();
414   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
415     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
416     return success();
417   } else {
418     return failure();
419   }
420 }
421 } // namespace
422 
423 void CstrBroadcastableOp::getCanonicalizationPatterns(
424     OwningRewritePatternList &patterns, MLIRContext *context) {
425   // Canonicalization patterns have overlap with the considerations during
426   // folding in case additional shape information is inferred at some point that
427   // does not result in folding.
428   patterns.insert<CstrBroadcastableEqOps>(context);
429 }
430 
431 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
432   // Both operands are not needed if one is a scalar.
433   if (operands[0] &&
434       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
435     return BoolAttr::get(true, getContext());
436   if (operands[1] &&
437       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
438     return BoolAttr::get(true, getContext());
439 
440   if (operands[0] && operands[1]) {
441     auto lhsShape = llvm::to_vector<6>(
442         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
443     auto rhsShape = llvm::to_vector<6>(
444         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
445     SmallVector<int64_t, 6> resultShape;
446     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
447       return BoolAttr::get(true, getContext());
448   }
449 
450   // Lastly, see if folding can be completed based on what constraints are known
451   // on the input shapes.
452   SmallVector<int64_t, 6> lhsShape, rhsShape;
453   if (failed(getShapeVec(lhs(), lhsShape)))
454     return nullptr;
455   if (failed(getShapeVec(rhs(), rhsShape)))
456     return nullptr;
457 
458   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
459     return BoolAttr::get(true, getContext());
460 
461   // Because a failing witness result here represents an eventual assertion
462   // failure, we do not replace it with a constant witness.
463   return nullptr;
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // CstrEqOp
468 //===----------------------------------------------------------------------===//
469 
470 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
471                                            MLIRContext *context) {
472   // If inputs are equal, return passing witness
473   patterns.insert<CstrEqEqOps>(context);
474 }
475 
476 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
477   if (llvm::all_of(operands,
478                    [&](Attribute a) { return a && a == operands[0]; }))
479     return BoolAttr::get(true, getContext());
480 
481   // Because a failing witness result here represents an eventual assertion
482   // failure, we do not try to replace it with a constant witness. Similarly, we
483   // cannot if there are any non-const inputs.
484   return nullptr;
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // ConstSizeOp
489 //===----------------------------------------------------------------------===//
490 
491 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
492                         int64_t value) {
493   build(builder, result, builder.getIndexAttr(value));
494 }
495 
496 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
497 
498 void ConstSizeOp::getAsmResultNames(
499     llvm::function_ref<void(Value, StringRef)> setNameFn) {
500   SmallString<4> buffer;
501   llvm::raw_svector_ostream os(buffer);
502   os << "c" << value();
503   setNameFn(getResult(), os.str());
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // ConstWitnessOp
508 //===----------------------------------------------------------------------===//
509 
510 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
511 
512 //===----------------------------------------------------------------------===//
513 // CstrRequireOp
514 //===----------------------------------------------------------------------===//
515 
516 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
517   return operands[0];
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // ShapeEqOp
522 //===----------------------------------------------------------------------===//
523 
524 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
525   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
526   if (lhs == nullptr)
527     return {};
528   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
529   if (rhs == nullptr)
530     return {};
531   return BoolAttr::get(lhs == rhs, getContext());
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // IndexToSizeOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
539   // Constant values of both types, `shape.size` and `index`, are represented as
540   // `IntegerAttr`s which makes constant folding simple.
541   if (Attribute arg = operands[0])
542     return arg;
543   return {};
544 }
545 
546 void IndexToSizeOp::getCanonicalizationPatterns(
547     OwningRewritePatternList &patterns, MLIRContext *context) {
548   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // FromExtentsOp
553 //===----------------------------------------------------------------------===//
554 
555 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
556   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
557     return nullptr;
558   SmallVector<int64_t, 6> extents;
559   for (auto attr : operands)
560     extents.push_back(attr.cast<IntegerAttr>().getInt());
561   Builder builder(getContext());
562   return builder.getIndexTensorAttr(extents);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // FunctionLibraryOp
567 //===----------------------------------------------------------------------===//
568 
569 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
570                               StringRef name) {
571   ensureTerminator(*result.addRegion(), builder, result.location);
572   result.attributes.push_back(builder.getNamedAttr(
573       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
574 }
575 
576 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
577   auto attr = mapping()
578                   .get(op->getName().getIdentifier())
579                   .dyn_cast_or_null<FlatSymbolRefAttr>();
580   if (!attr)
581     return nullptr;
582   return lookupSymbol<FuncOp>(attr);
583 }
584 
585 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
586                                    OperationState &result) {
587   // Parse the op name.
588   StringAttr nameAttr;
589   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
590                              result.attributes))
591     return failure();
592 
593   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
594     return failure();
595 
596   auto *bodyRegion = result.addRegion();
597   if (parser.parseRegion(*bodyRegion))
598     return failure();
599 
600   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
601                                       result.location);
602   if (parser.parseKeyword("mapping"))
603     return failure();
604 
605   DictionaryAttr mappingAttr;
606   if (parser.parseAttribute(mappingAttr,
607                             parser.getBuilder().getType<NoneType>(), "mapping",
608                             result.attributes))
609     return failure();
610   return success();
611 }
612 
613 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
614   p << op.getOperationName() << ' ';
615   p.printSymbolName(op.getName());
616   p.printOptionalAttrDictWithKeyword(
617       op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
618   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
619                 /*printBlockTerminators=*/false);
620   p << " mapping ";
621   p.printAttributeWithoutType(op.mappingAttr());
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // GetExtentOp
626 //===----------------------------------------------------------------------===//
627 
628 Optional<int64_t> GetExtentOp::getConstantDim() {
629   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
630     return constSizeOp.value().getLimitedValue();
631   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
632     return constantOp.value().cast<IntegerAttr>().getInt();
633   return llvm::None;
634 }
635 
636 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
637   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
638   if (!elements)
639     return nullptr;
640   Optional<int64_t> dim = getConstantDim();
641   if (!dim.hasValue())
642     return nullptr;
643   if (dim.getValue() >= elements.getNumElements())
644     return nullptr;
645   return elements.getValue({(uint64_t)dim.getValue()});
646 }
647 
648 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
649                         int64_t dim) {
650   auto loc = result.location;
651   auto dimAttr = builder.getIndexAttr(dim);
652   if (shape.getType().isa<ShapeType>()) {
653     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
654     build(builder, result, builder.getType<SizeType>(), shape, dim);
655   } else {
656     Value dim =
657         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
658     build(builder, result, builder.getIndexType(), shape, dim);
659   }
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // RankOp
664 //===----------------------------------------------------------------------===//
665 
666 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
667   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
668   if (!shape)
669     return {};
670   int64_t rank = shape.getNumElements();
671   Builder builder(getContext());
672   return builder.getIndexAttr(rank);
673 }
674 
675 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
676 /// Constant folding fails in cases where only the rank is constant, not the
677 /// shape itself.
678 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
679 ///
680 /// Example:
681 ///
682 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
683 /// %rank = shape.rank %shape
684 ///
685 /// becomes
686 ///
687 /// %rank = shape.const_size 3
688 
689 namespace {
690 struct RankShapeOfCanonicalizationPattern
691     : public OpRewritePattern<shape::RankOp> {
692   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
693 
694   LogicalResult matchAndRewrite(shape::RankOp op,
695                                 PatternRewriter &rewriter) const override {
696     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
697     if (!shapeOfOp)
698       return failure();
699     auto rankedTensorType =
700         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
701     if (!rankedTensorType)
702       return failure();
703     int64_t rank = rankedTensorType.getRank();
704     if (op.getType().isa<IndexType>()) {
705       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
706     } else if (op.getType().isa<shape::SizeType>()) {
707       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
708     } else {
709       return failure();
710     }
711     return success();
712   }
713 };
714 } // namespace
715 
716 void shape::RankOp::getCanonicalizationPatterns(
717     OwningRewritePatternList &patterns, MLIRContext *context) {
718   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // NumElementsOp
723 //===----------------------------------------------------------------------===//
724 
725 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
726 
727   // Fold only when argument constant.
728   Attribute shape = operands[0];
729   if (!shape)
730     return {};
731 
732   APInt product(64, 1);
733   for (auto value : shape.cast<DenseIntElementsAttr>())
734     product *= value;
735   Builder builder(getContext());
736   return builder.getIndexAttr(product.getLimitedValue());
737 }
738 
739 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
740                           Value shape) {
741   if (shape.getType().isa<ShapedType>()) {
742     auto type = builder.getIndexType();
743     return build(builder, result, type, shape);
744   }
745   auto type = SizeType::get(builder.getContext());
746   return build(builder, result, type, shape);
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // MulOp
751 //===----------------------------------------------------------------------===//
752 
753 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
754   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
755   if (!lhs)
756     return nullptr;
757   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
758   if (!rhs)
759     return nullptr;
760   APInt folded = lhs.getValue() * rhs.getValue();
761   Type indexTy = IndexType::get(getContext());
762   return IntegerAttr::get(indexTy, folded);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // ShapeOfOp
767 //===----------------------------------------------------------------------===//
768 
769 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
770   auto type = getOperand().getType().dyn_cast<ShapedType>();
771   if (!type || !type.hasStaticShape())
772     return nullptr;
773   Builder builder(getContext());
774   return builder.getIndexTensorAttr(type.getShape());
775 }
776 
777 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
778   Type type = arg.getType().isa<ShapedType>()
779                   ? (Type)getExtentTensorType(builder.getContext())
780                   : (Type)builder.getType<ShapeType>();
781   return ShapeOfOp::build(builder, result, type, arg);
782 }
783 
784 namespace {
785 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
786   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
787 
788   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
789                                 PatternRewriter &rewriter) const override {
790     if (!op.arg().getType().isa<ShapedType>())
791       return failure();
792     if (op.getType().isa<ShapedType>())
793       return failure();
794 
795     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
796     return success();
797   }
798 };
799 } // namespace
800 
801 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
802                                             MLIRContext *context) {
803   patterns.insert<ShapeOfWithTensor>(context);
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // SizeToIndexOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
811   // Constant values of both types, `shape.size` and `index`, are represented as
812   // `IntegerAttr`s which makes constant folding simple.
813   if (Attribute arg = operands[0])
814     return arg;
815   return impl::foldCastOp(*this);
816 }
817 
818 void SizeToIndexOp::getCanonicalizationPatterns(
819     OwningRewritePatternList &patterns, MLIRContext *context) {
820   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // YieldOp
825 //===----------------------------------------------------------------------===//
826 
827 static LogicalResult verify(shape::YieldOp op) {
828   auto *parentOp = op->getParentOp();
829   auto results = parentOp->getResults();
830   auto operands = op.getOperands();
831 
832   if (parentOp->getNumResults() != op.getNumOperands())
833     return op.emitOpError() << "number of operands does not match number of "
834                                "results of its parent";
835   for (auto e : llvm::zip(results, operands))
836     if (std::get<0>(e).getType() != std::get<1>(e).getType())
837       return op.emitOpError()
838              << "types mismatch between yield op and its parent";
839 
840   return success();
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // SplitAtOp
845 //===----------------------------------------------------------------------===//
846 
847 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
848                               SmallVectorImpl<OpFoldResult> &results) {
849   if (!operands[0] || !operands[1])
850     return failure();
851   auto shapeVec = llvm::to_vector<6>(
852       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
853   auto shape = llvm::makeArrayRef(shapeVec);
854   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
855   // Verify that the split point is in the correct range.
856   // TODO: Constant fold to an "error".
857   int64_t rank = shape.size();
858   if (!(-rank <= splitPoint && splitPoint <= rank))
859     return failure();
860   if (splitPoint < 0)
861     splitPoint += shape.size();
862   Builder builder(operands[0].getContext());
863   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
864   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
865   return success();
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // ToExtentTensorOp
870 //===----------------------------------------------------------------------===//
871 
872 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
873   if (!operands[0])
874     return impl::foldCastOp(*this);
875   Builder builder(getContext());
876   auto shape = llvm::to_vector<6>(
877       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
878   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
879                                     builder.getIndexType());
880   return DenseIntElementsAttr::get(type, shape);
881 }
882 
883 //===----------------------------------------------------------------------===//
884 // ReduceOp
885 //===----------------------------------------------------------------------===//
886 
887 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
888                      ValueRange initVals) {
889   result.addOperands(shape);
890   result.addOperands(initVals);
891 
892   Region *bodyRegion = result.addRegion();
893   bodyRegion->push_back(new Block);
894   Block &bodyBlock = bodyRegion->front();
895   bodyBlock.addArgument(builder.getIndexType());
896 
897   Type elementType;
898   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
899     elementType = tensorType.getElementType();
900   else
901     elementType = SizeType::get(builder.getContext());
902   bodyBlock.addArgument(elementType);
903 
904   for (Type initValType : initVals.getTypes()) {
905     bodyBlock.addArgument(initValType);
906     result.addTypes(initValType);
907   }
908 }
909 
910 static LogicalResult verify(ReduceOp op) {
911   // Verify block arg types.
912   Block &block = op.region().front();
913 
914   // The block takes index, extent, and aggregated values as arguments.
915   auto blockArgsCount = op.initVals().size() + 2;
916   if (block.getNumArguments() != blockArgsCount)
917     return op.emitOpError() << "ReduceOp body is expected to have "
918                             << blockArgsCount << " arguments";
919 
920   // The first block argument is the index and must always be of type `index`.
921   if (!block.getArgument(0).getType().isa<IndexType>())
922     return op.emitOpError(
923         "argument 0 of ReduceOp body is expected to be of IndexType");
924 
925   // The second block argument is the extent and must be of type `size` or
926   // `index`, depending on whether the reduce operation is applied to a shape or
927   // to an extent tensor.
928   Type extentTy = block.getArgument(1).getType();
929   if (op.shape().getType().isa<ShapeType>()) {
930     if (!extentTy.isa<SizeType>())
931       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
932                             "SizeType if the ReduceOp operates on a ShapeType");
933   } else {
934     if (!extentTy.isa<IndexType>())
935       return op.emitOpError(
936           "argument 1 of ReduceOp body is expected to be of IndexType if the "
937           "ReduceOp operates on an extent tensor");
938   }
939 
940   for (auto type : llvm::enumerate(op.initVals()))
941     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
942       return op.emitOpError()
943              << "type mismatch between argument " << type.index() + 2
944              << " of ReduceOp body and initial value " << type.index();
945   return success();
946 }
947 
948 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
949   // Parse operands.
950   SmallVector<OpAsmParser::OperandType, 3> operands;
951   Type shapeOrExtentTensorType;
952   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
953                               OpAsmParser::Delimiter::Paren) ||
954       parser.parseColonType(shapeOrExtentTensorType) ||
955       parser.parseOptionalArrowTypeList(result.types))
956     return failure();
957 
958   // Resolve operands.
959   auto initVals = llvm::makeArrayRef(operands).drop_front();
960   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
961                             result.operands) ||
962       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
963                              result.operands))
964     return failure();
965 
966   // Parse the body.
967   Region *body = result.addRegion();
968   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
969     return failure();
970 
971   // Parse attributes.
972   if (parser.parseOptionalAttrDict(result.attributes))
973     return failure();
974 
975   return success();
976 }
977 
978 static void print(OpAsmPrinter &p, ReduceOp op) {
979   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
980     << ") : " << op.shape().getType();
981   p.printOptionalArrowTypeList(op.getResultTypes());
982   p.printRegion(op.region());
983   p.printOptionalAttrDict(op.getAttrs());
984 }
985 
986 #define GET_OP_CLASSES
987 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
988