1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 using namespace mlir;
23 using namespace mlir::shape;
24 
25 namespace {
26 #include "ShapeCanonicalization.inc"
27 }
28 
29 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
30   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
31 }
32 
33 static bool isErrorPropagationPossible(TypeRange operandTypes) {
34   return llvm::any_of(operandTypes, [](Type ty) {
35     return ty.isa<SizeType, ShapeType, ValueShapeType>();
36   });
37 }
38 
39 static LogicalResult verifySizeOrIndexOp(Operation *op) {
40   assert(op != nullptr && op->getNumResults() == 1);
41   Type resultTy = op->getResultTypes().front();
42   if (isErrorPropagationPossible(op->getOperandTypes())) {
43     if (!resultTy.isa<SizeType>())
44       return op->emitOpError()
45              << "if at least one of the operands can hold error values then "
46                 "the result must be of type `size` to propagate them";
47   }
48   return success();
49 }
50 
51 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
52   assert(op != nullptr && op->getNumResults() == 1);
53   Type resultTy = op->getResultTypes().front();
54   if (isErrorPropagationPossible(op->getOperandTypes())) {
55     if (!resultTy.isa<ShapeType>())
56       return op->emitOpError()
57              << "if at least one of the operands can hold error values then "
58                 "the result must be of type `shape` to propagate them";
59   }
60   return success();
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // InlinerInterface
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 /// This class defines the interface for inlining shape dialect ops.
69 struct ShapeInlinerInterface : public DialectInlinerInterface {
70   using DialectInlinerInterface::DialectInlinerInterface;
71 
72   // Returns true if the given region 'src' can be inlined into the region
73   // 'dest' that is attached to an operation registered to the current dialect.
74   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
75                        BlockAndValueMapping &) const final {
76     return true;
77   }
78 
79   // Returns true if the given operation 'op', that is registered to this
80   // dialect, can be inlined into the region 'dest' that is attached to an
81   // operation registered to the current dialect.
82   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
83                        BlockAndValueMapping &) const final {
84     return true;
85   }
86 };
87 } // namespace
88 
89 void ShapeDialect::initialize() {
90   addOperations<
91 #define GET_OP_LIST
92 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
93       >();
94   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
95   addInterfaces<ShapeInlinerInterface>();
96   // Allow unknown operations during prototyping and testing. As the dialect is
97   // still evolving it makes it simple to start with an unregistered ops and
98   // try different variants before actually defining the op.
99   allowUnknownOperations();
100 }
101 
102 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
103                                              Attribute value, Type type,
104                                              Location loc) {
105   if (type.isa<ShapeType>() ||
106       type == getExtentTensorType(builder.getContext()))
107     return builder.create<ConstShapeOp>(loc, type,
108                                         value.cast<DenseIntElementsAttr>());
109   if (type.isa<SizeType>())
110     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
111   if (type.isa<WitnessType>())
112     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
113   if (ConstantOp::isBuildableWith(value, type))
114     return builder.create<ConstantOp>(loc, type, value);
115   return nullptr;
116 }
117 
118 /// Parse a type registered to this dialect.
119 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
120   StringRef keyword;
121   if (parser.parseKeyword(&keyword))
122     return Type();
123 
124   if (keyword == "shape")
125     return ShapeType::get(getContext());
126   if (keyword == "size")
127     return SizeType::get(getContext());
128   if (keyword == "value_shape")
129     return ValueShapeType::get(getContext());
130   if (keyword == "witness")
131     return WitnessType::get(getContext());
132 
133   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
134   return Type();
135 }
136 
137 /// Print a type registered to this dialect.
138 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
139   TypeSwitch<Type>(type)
140       .Case<ShapeType>([&](Type) { os << "shape"; })
141       .Case<SizeType>([&](Type) { os << "size"; })
142       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
143       .Case<WitnessType>([&](Type) { os << "witness"; })
144       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AnyOp
149 //===----------------------------------------------------------------------===//
150 
151 // TODO: Canonicalization should be implemented for shapes that can be
152 // determined through mixtures of the known dimensions of the inputs.
153 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
154   // Only the last operand is checked because AnyOp is commutative.
155   if (operands.back())
156     return operands.back();
157 
158   return nullptr;
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // AssumingOp
163 //===----------------------------------------------------------------------===//
164 
165 static ParseResult parseAssumingOp(OpAsmParser &parser,
166                                    OperationState &result) {
167   result.regions.reserve(1);
168   Region *doRegion = result.addRegion();
169 
170   auto &builder = parser.getBuilder();
171   OpAsmParser::OperandType cond;
172   if (parser.parseOperand(cond) ||
173       parser.resolveOperand(cond, builder.getType<WitnessType>(),
174                             result.operands))
175     return failure();
176 
177   // Parse optional results type list.
178   if (parser.parseOptionalArrowTypeList(result.types))
179     return failure();
180 
181   // Parse the region and add a terminator if elided.
182   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
183     return failure();
184   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
185 
186   // Parse the optional attribute list.
187   if (parser.parseOptionalAttrDict(result.attributes))
188     return failure();
189   return success();
190 }
191 
192 static void print(OpAsmPrinter &p, AssumingOp op) {
193   bool yieldsResults = !op.results().empty();
194 
195   p << AssumingOp::getOperationName() << " " << op.witness();
196   if (yieldsResults) {
197     p << " -> (" << op.getResultTypes() << ")";
198   }
199   p.printRegion(op.doRegion(),
200                 /*printEntryBlockArgs=*/false,
201                 /*printBlockTerminators=*/yieldsResults);
202   p.printOptionalAttrDict(op.getAttrs());
203 }
204 
205 namespace {
206 // Removes AssumingOp with a passing witness and inlines the region.
207 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
208   using OpRewritePattern<AssumingOp>::OpRewritePattern;
209 
210   LogicalResult matchAndRewrite(AssumingOp op,
211                                 PatternRewriter &rewriter) const override {
212     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
213     if (!witness || !witness.passingAttr())
214       return failure();
215 
216     AssumingOp::inlineRegionIntoParent(op, rewriter);
217     return success();
218   }
219 };
220 } // namespace
221 
222 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
223                                              MLIRContext *context) {
224   // If taking a passing witness, inline region.
225   patterns.insert<AssumingWithTrue>(context);
226 }
227 
228 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
229 void AssumingOp::getSuccessorRegions(
230     Optional<unsigned> index, ArrayRef<Attribute> operands,
231     SmallVectorImpl<RegionSuccessor> &regions) {
232   // AssumingOp has unconditional control flow into the region and back to the
233   // parent, so return the correct RegionSuccessor purely based on the index
234   // being None or 0.
235   if (index.hasValue()) {
236     regions.push_back(RegionSuccessor(getResults()));
237     return;
238   }
239 
240   regions.push_back(RegionSuccessor(&doRegion()));
241 }
242 
243 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
244                                         PatternRewriter &rewriter) {
245   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
246   auto *assumingBlock = op.getBody();
247   auto initPosition = rewriter.getInsertionPoint();
248   auto *blockAfterAssuming =
249       rewriter.splitBlock(blockBeforeAssuming, initPosition);
250 
251   // Remove the AssumingOp and AssumingYieldOp.
252   auto &yieldOp = assumingBlock->back();
253   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
254   rewriter.replaceOp(op, yieldOp.getOperands());
255   rewriter.eraseOp(&yieldOp);
256 
257   // Merge blocks together as there was no branching behavior from the
258   // AssumingOp.
259   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
260   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // AssumingAllOp
265 //===----------------------------------------------------------------------===//
266 
267 void AssumingAllOp::getCanonicalizationPatterns(
268     OwningRewritePatternList &patterns, MLIRContext *context) {
269   patterns.insert<AssumingAllOneOp>(context);
270 }
271 
272 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
273   // Iterate in reverse to first handle all constant operands. They are
274   // guaranteed to be the tail of the inputs because this is commutative.
275   for (int idx = operands.size() - 1; idx >= 0; idx--) {
276     Attribute a = operands[idx];
277     // Cannot fold if any inputs are not constant;
278     if (!a)
279       return nullptr;
280 
281     // We do not need to keep statically known values after handling them in
282     // this method.
283     getOperation()->eraseOperand(idx);
284 
285     // Always false if any input is statically known false
286     if (!a.cast<BoolAttr>().getValue())
287       return a;
288   }
289   // If this is reached, all inputs were statically known passing.
290   return BoolAttr::get(true, getContext());
291 }
292 
293 static LogicalResult verify(AssumingAllOp op) {
294   // Ensure that AssumingAllOp contains at least one operand
295   if (op.getNumOperands() == 0)
296     return op.emitOpError("no operands specified");
297 
298   return success();
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // BroadcastOp
303 //===----------------------------------------------------------------------===//
304 
305 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
306   if (!operands[1])
307     return nullptr;
308 
309   auto rhsShape = llvm::to_vector<6>(
310       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
311   if (rhsShape.empty())
312     return lhs();
313 
314   if (!operands[0])
315     return nullptr;
316 
317   auto lhsShape = llvm::to_vector<6>(
318       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
319   if (lhsShape.empty())
320     return rhs();
321 
322   SmallVector<int64_t, 6> resultShape;
323   // If the shapes are not compatible, we can't fold it.
324   // TODO: Fold to an "error".
325   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
326     return nullptr;
327   Builder builder(getContext());
328   return builder.getIndexTensorAttr(resultShape);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // ConcatOp
333 //===----------------------------------------------------------------------===//
334 
335 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
336   if (!operands[0] || !operands[1])
337     return nullptr;
338   auto lhsShape = llvm::to_vector<6>(
339       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
340   auto rhsShape = llvm::to_vector<6>(
341       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
342   SmallVector<int64_t, 6> resultShape;
343   resultShape.append(lhsShape.begin(), lhsShape.end());
344   resultShape.append(rhsShape.begin(), rhsShape.end());
345   Builder builder(getContext());
346   return builder.getIndexTensorAttr(resultShape);
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // ConstShapeOp
351 //===----------------------------------------------------------------------===//
352 
353 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
354   p << "shape.const_shape ";
355   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
356   p << "[";
357   interleaveComma(op.shape().getValues<int64_t>(), p,
358                   [&](int64_t i) { p << i; });
359   p << "] : ";
360   p.printType(op.getType());
361 }
362 
363 static ParseResult parseConstShapeOp(OpAsmParser &parser,
364                                      OperationState &result) {
365   if (parser.parseOptionalAttrDict(result.attributes))
366     return failure();
367   // We piggy-back on ArrayAttr parsing, though we don't internally store the
368   // shape as an ArrayAttr.
369   // TODO: Implement custom parser and maybe make syntax a bit more concise.
370   Attribute extentsRaw;
371   NamedAttrList dummy;
372   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
373     return failure();
374   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
375   if (!extentsArray)
376     return failure();
377   SmallVector<int64_t, 6> ints;
378   for (Attribute extent : extentsArray) {
379     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
380     if (!attr)
381       return failure();
382     ints.push_back(attr.getInt());
383   }
384   Builder &builder = parser.getBuilder();
385   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
386   Type resultTy;
387   if (parser.parseColonType(resultTy))
388     return failure();
389   result.types.push_back(resultTy);
390   return success();
391 }
392 
393 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
394 
395 void ConstShapeOp::getCanonicalizationPatterns(
396     OwningRewritePatternList &patterns, MLIRContext *context) {
397   patterns.insert<TensorCastConstShape>(context);
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // CstrBroadcastableOp
402 //===----------------------------------------------------------------------===//
403 
404 namespace {
405 // Given an input shape Value, try to obtain the shape's values.
406 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
407   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
408     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
409     if (!type.hasRank())
410       return failure();
411     shapeValues = llvm::to_vector<6>(type.getShape());
412     return success();
413   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
414     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
415     return success();
416   } else {
417     return failure();
418   }
419 }
420 } // namespace
421 
422 void CstrBroadcastableOp::getCanonicalizationPatterns(
423     OwningRewritePatternList &patterns, MLIRContext *context) {
424   // Canonicalization patterns have overlap with the considerations during
425   // folding in case additional shape information is inferred at some point that
426   // does not result in folding.
427   patterns.insert<CstrBroadcastableEqOps>(context);
428 }
429 
430 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
431   // Both operands are not needed if one is a scalar.
432   if (operands[0] &&
433       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
434     return BoolAttr::get(true, getContext());
435   if (operands[1] &&
436       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
437     return BoolAttr::get(true, getContext());
438 
439   if (operands[0] && operands[1]) {
440     auto lhsShape = llvm::to_vector<6>(
441         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
442     auto rhsShape = llvm::to_vector<6>(
443         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
444     SmallVector<int64_t, 6> resultShape;
445     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
446       return BoolAttr::get(true, getContext());
447   }
448 
449   // Lastly, see if folding can be completed based on what constraints are known
450   // on the input shapes.
451   SmallVector<int64_t, 6> lhsShape, rhsShape;
452   if (failed(getShapeVec(lhs(), lhsShape)))
453     return nullptr;
454   if (failed(getShapeVec(rhs(), rhsShape)))
455     return nullptr;
456 
457   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
458     return BoolAttr::get(true, getContext());
459 
460   // Because a failing witness result here represents an eventual assertion
461   // failure, we do not replace it with a constant witness.
462   return nullptr;
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // CstrEqOp
467 //===----------------------------------------------------------------------===//
468 
469 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
470                                            MLIRContext *context) {
471   // If inputs are equal, return passing witness
472   patterns.insert<CstrEqEqOps>(context);
473 }
474 
475 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
476   if (llvm::all_of(operands,
477                    [&](Attribute a) { return a && a == operands[0]; }))
478     return BoolAttr::get(true, getContext());
479 
480   // Because a failing witness result here represents an eventual assertion
481   // failure, we do not try to replace it with a constant witness. Similarly, we
482   // cannot if there are any non-const inputs.
483   return nullptr;
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // ConstSizeOp
488 //===----------------------------------------------------------------------===//
489 
490 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
491                         int64_t value) {
492   build(builder, result, builder.getIndexAttr(value));
493 }
494 
495 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
496 
497 void ConstSizeOp::getAsmResultNames(
498     llvm::function_ref<void(Value, StringRef)> setNameFn) {
499   SmallString<4> buffer;
500   llvm::raw_svector_ostream os(buffer);
501   os << "c" << value();
502   setNameFn(getResult(), os.str());
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // ConstWitnessOp
507 //===----------------------------------------------------------------------===//
508 
509 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
510 
511 //===----------------------------------------------------------------------===//
512 // CstrRequireOp
513 //===----------------------------------------------------------------------===//
514 
515 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
516   return operands[0];
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // ShapeEqOp
521 //===----------------------------------------------------------------------===//
522 
523 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
524   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
525   if (lhs == nullptr)
526     return {};
527   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
528   if (rhs == nullptr)
529     return {};
530   return BoolAttr::get(lhs == rhs, getContext());
531 }
532 
533 //===----------------------------------------------------------------------===//
534 // IndexToSizeOp
535 //===----------------------------------------------------------------------===//
536 
537 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
538   // Constant values of both types, `shape.size` and `index`, are represented as
539   // `IntegerAttr`s which makes constant folding simple.
540   if (Attribute arg = operands[0])
541     return arg;
542   return {};
543 }
544 
545 void IndexToSizeOp::getCanonicalizationPatterns(
546     OwningRewritePatternList &patterns, MLIRContext *context) {
547   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // FromExtentsOp
552 //===----------------------------------------------------------------------===//
553 
554 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
555   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
556     return nullptr;
557   SmallVector<int64_t, 6> extents;
558   for (auto attr : operands)
559     extents.push_back(attr.cast<IntegerAttr>().getInt());
560   Builder builder(getContext());
561   return builder.getIndexTensorAttr(extents);
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // FunctionLibraryOp
566 //===----------------------------------------------------------------------===//
567 
568 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
569                               StringRef name) {
570   ensureTerminator(*result.addRegion(), builder, result.location);
571   result.attributes.push_back(builder.getNamedAttr(
572       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
573 }
574 
575 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
576   auto attr = mapping()
577                   .get(op->getName().getIdentifier())
578                   .dyn_cast_or_null<FlatSymbolRefAttr>();
579   if (!attr)
580     return nullptr;
581   return lookupSymbol<FuncOp>(attr);
582 }
583 
584 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
585                                    OperationState &result) {
586   // Parse the op name.
587   StringAttr nameAttr;
588   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
589                              result.attributes))
590     return failure();
591 
592   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
593     return failure();
594 
595   auto *bodyRegion = result.addRegion();
596   if (parser.parseRegion(*bodyRegion))
597     return failure();
598 
599   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
600                                       result.location);
601   if (parser.parseKeyword("mapping"))
602     return failure();
603 
604   DictionaryAttr mappingAttr;
605   if (parser.parseAttribute(mappingAttr,
606                             parser.getBuilder().getType<NoneType>(), "mapping",
607                             result.attributes))
608     return failure();
609   return success();
610 }
611 
612 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
613   p << op.getOperationName() << ' ';
614   p.printSymbolName(op.getName());
615   p.printOptionalAttrDictWithKeyword(
616       op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
617   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
618                 /*printBlockTerminators=*/false);
619   p << " mapping ";
620   p.printAttributeWithoutType(op.mappingAttr());
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // GetExtentOp
625 //===----------------------------------------------------------------------===//
626 
627 Optional<int64_t> GetExtentOp::getConstantDim() {
628   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
629     return constSizeOp.value().getLimitedValue();
630   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
631     return constantOp.value().cast<IntegerAttr>().getInt();
632   return llvm::None;
633 }
634 
635 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
636   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
637   if (!elements)
638     return nullptr;
639   Optional<int64_t> dim = getConstantDim();
640   if (!dim.hasValue())
641     return nullptr;
642   if (dim.getValue() >= elements.getNumElements())
643     return nullptr;
644   return elements.getValue({(uint64_t)dim.getValue()});
645 }
646 
647 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
648                         int64_t dim) {
649   auto loc = result.location;
650   auto dimAttr = builder.getIndexAttr(dim);
651   if (shape.getType().isa<ShapeType>()) {
652     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
653     build(builder, result, builder.getType<SizeType>(), shape, dim);
654   } else {
655     Value dim =
656         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
657     build(builder, result, builder.getIndexType(), shape, dim);
658   }
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // RankOp
663 //===----------------------------------------------------------------------===//
664 
665 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
666   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
667   if (!shape)
668     return {};
669   int64_t rank = shape.getNumElements();
670   Builder builder(getContext());
671   return builder.getIndexAttr(rank);
672 }
673 
674 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
675 /// Constant folding fails in cases where only the rank is constant, not the
676 /// shape itself.
677 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
678 ///
679 /// Example:
680 ///
681 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
682 /// %rank = shape.rank %shape
683 ///
684 /// becomes
685 ///
686 /// %rank = shape.const_size 3
687 
688 namespace {
689 struct RankShapeOfCanonicalizationPattern
690     : public OpRewritePattern<shape::RankOp> {
691   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
692 
693   LogicalResult matchAndRewrite(shape::RankOp op,
694                                 PatternRewriter &rewriter) const override {
695     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
696     if (!shapeOfOp)
697       return failure();
698     auto rankedTensorType =
699         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
700     if (!rankedTensorType)
701       return failure();
702     int64_t rank = rankedTensorType.getRank();
703     if (op.getType().isa<IndexType>()) {
704       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
705     } else if (op.getType().isa<shape::SizeType>()) {
706       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
707     } else {
708       return failure();
709     }
710     return success();
711   }
712 };
713 } // namespace
714 
715 void shape::RankOp::getCanonicalizationPatterns(
716     OwningRewritePatternList &patterns, MLIRContext *context) {
717   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // NumElementsOp
722 //===----------------------------------------------------------------------===//
723 
724 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
725 
726   // Fold only when argument constant.
727   Attribute shape = operands[0];
728   if (!shape)
729     return {};
730 
731   APInt product(64, 1);
732   for (auto value : shape.cast<DenseIntElementsAttr>())
733     product *= value;
734   Builder builder(getContext());
735   return builder.getIndexAttr(product.getLimitedValue());
736 }
737 
738 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
739                           Value shape) {
740   if (shape.getType().isa<ShapedType>()) {
741     auto type = builder.getIndexType();
742     return build(builder, result, type, shape);
743   }
744   auto type = SizeType::get(builder.getContext());
745   return build(builder, result, type, shape);
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // MulOp
750 //===----------------------------------------------------------------------===//
751 
752 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
753   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
754   if (!lhs)
755     return nullptr;
756   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
757   if (!rhs)
758     return nullptr;
759   APInt folded = lhs.getValue() * rhs.getValue();
760   Type indexTy = IndexType::get(getContext());
761   return IntegerAttr::get(indexTy, folded);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // ShapeOfOp
766 //===----------------------------------------------------------------------===//
767 
768 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
769   auto type = getOperand().getType().dyn_cast<ShapedType>();
770   if (!type || !type.hasStaticShape())
771     return nullptr;
772   Builder builder(getContext());
773   return builder.getIndexTensorAttr(type.getShape());
774 }
775 
776 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
777   Type type = arg.getType().isa<ShapedType>()
778                   ? (Type)getExtentTensorType(builder.getContext())
779                   : (Type)builder.getType<ShapeType>();
780   return ShapeOfOp::build(builder, result, type, arg);
781 }
782 
783 namespace {
784 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
785   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
786 
787   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
788                                 PatternRewriter &rewriter) const override {
789     if (!op.arg().getType().isa<ShapedType>())
790       return failure();
791     if (op.getType().isa<ShapedType>())
792       return failure();
793 
794     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
795     return success();
796   }
797 };
798 } // namespace
799 
800 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
801                                             MLIRContext *context) {
802   patterns.insert<ShapeOfWithTensor>(context);
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // SizeToIndexOp
807 //===----------------------------------------------------------------------===//
808 
809 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
810   // Constant values of both types, `shape.size` and `index`, are represented as
811   // `IntegerAttr`s which makes constant folding simple.
812   if (Attribute arg = operands[0])
813     return arg;
814   return impl::foldCastOp(*this);
815 }
816 
817 void SizeToIndexOp::getCanonicalizationPatterns(
818     OwningRewritePatternList &patterns, MLIRContext *context) {
819   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // YieldOp
824 //===----------------------------------------------------------------------===//
825 
826 static LogicalResult verify(shape::YieldOp op) {
827   auto *parentOp = op->getParentOp();
828   auto results = parentOp->getResults();
829   auto operands = op.getOperands();
830 
831   if (parentOp->getNumResults() != op.getNumOperands())
832     return op.emitOpError() << "number of operands does not match number of "
833                                "results of its parent";
834   for (auto e : llvm::zip(results, operands))
835     if (std::get<0>(e).getType() != std::get<1>(e).getType())
836       return op.emitOpError()
837              << "types mismatch between yield op and its parent";
838 
839   return success();
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // SplitAtOp
844 //===----------------------------------------------------------------------===//
845 
846 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
847                               SmallVectorImpl<OpFoldResult> &results) {
848   if (!operands[0] || !operands[1])
849     return failure();
850   auto shapeVec = llvm::to_vector<6>(
851       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
852   auto shape = llvm::makeArrayRef(shapeVec);
853   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
854   // Verify that the split point is in the correct range.
855   // TODO: Constant fold to an "error".
856   int64_t rank = shape.size();
857   if (!(-rank <= splitPoint && splitPoint <= rank))
858     return failure();
859   if (splitPoint < 0)
860     splitPoint += shape.size();
861   Builder builder(operands[0].getContext());
862   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
863   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
864   return success();
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // ToExtentTensorOp
869 //===----------------------------------------------------------------------===//
870 
871 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
872   if (!operands[0])
873     return impl::foldCastOp(*this);
874   Builder builder(getContext());
875   auto shape = llvm::to_vector<6>(
876       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
877   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
878                                     builder.getIndexType());
879   return DenseIntElementsAttr::get(type, shape);
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // ReduceOp
884 //===----------------------------------------------------------------------===//
885 
886 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
887                      ValueRange initVals) {
888   result.addOperands(shape);
889   result.addOperands(initVals);
890 
891   Region *bodyRegion = result.addRegion();
892   bodyRegion->push_back(new Block);
893   Block &bodyBlock = bodyRegion->front();
894   bodyBlock.addArgument(builder.getIndexType());
895 
896   Type elementType;
897   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
898     elementType = tensorType.getElementType();
899   else
900     elementType = SizeType::get(builder.getContext());
901   bodyBlock.addArgument(elementType);
902 
903   for (Type initValType : initVals.getTypes()) {
904     bodyBlock.addArgument(initValType);
905     result.addTypes(initValType);
906   }
907 }
908 
909 static LogicalResult verify(ReduceOp op) {
910   // Verify block arg types.
911   Block &block = op.region().front();
912 
913   // The block takes index, extent, and aggregated values as arguments.
914   auto blockArgsCount = op.initVals().size() + 2;
915   if (block.getNumArguments() != blockArgsCount)
916     return op.emitOpError() << "ReduceOp body is expected to have "
917                             << blockArgsCount << " arguments";
918 
919   // The first block argument is the index and must always be of type `index`.
920   if (!block.getArgument(0).getType().isa<IndexType>())
921     return op.emitOpError(
922         "argument 0 of ReduceOp body is expected to be of IndexType");
923 
924   // The second block argument is the extent and must be of type `size` or
925   // `index`, depending on whether the reduce operation is applied to a shape or
926   // to an extent tensor.
927   Type extentTy = block.getArgument(1).getType();
928   if (op.shape().getType().isa<ShapeType>()) {
929     if (!extentTy.isa<SizeType>())
930       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
931                             "SizeType if the ReduceOp operates on a ShapeType");
932   } else {
933     if (!extentTy.isa<IndexType>())
934       return op.emitOpError(
935           "argument 1 of ReduceOp body is expected to be of IndexType if the "
936           "ReduceOp operates on an extent tensor");
937   }
938 
939   for (auto type : llvm::enumerate(op.initVals()))
940     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
941       return op.emitOpError()
942              << "type mismatch between argument " << type.index() + 2
943              << " of ReduceOp body and initial value " << type.index();
944   return success();
945 }
946 
947 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
948   // Parse operands.
949   SmallVector<OpAsmParser::OperandType, 3> operands;
950   Type shapeOrExtentTensorType;
951   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
952                               OpAsmParser::Delimiter::Paren) ||
953       parser.parseColonType(shapeOrExtentTensorType) ||
954       parser.parseOptionalArrowTypeList(result.types))
955     return failure();
956 
957   // Resolve operands.
958   auto initVals = llvm::makeArrayRef(operands).drop_front();
959   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
960                             result.operands) ||
961       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
962                              result.operands))
963     return failure();
964 
965   // Parse the body.
966   Region *body = result.addRegion();
967   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
968     return failure();
969 
970   // Parse attributes.
971   if (parser.parseOptionalAttrDict(result.attributes))
972     return failure();
973 
974   return success();
975 }
976 
977 static void print(OpAsmPrinter &p, ReduceOp op) {
978   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
979     << ") : " << op.shape().getType();
980   p.printOptionalArrowTypeList(op.getResultTypes());
981   p.printRegion(op.region());
982   p.printOptionalAttrDict(op.getAttrs());
983 }
984 
985 #define GET_OP_CLASSES
986 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
987