1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 #include "ShapeCanonicalization.inc"
28 }
29 
30 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
31   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
32 }
33 
34 static bool isErrorPropagationPossible(TypeRange operandTypes) {
35   return llvm::any_of(operandTypes, [](Type ty) {
36     return ty.isa<SizeType, ShapeType, ValueShapeType>();
37   });
38 }
39 
40 static LogicalResult verifySizeOrIndexOp(Operation *op) {
41   assert(op != nullptr && op->getNumResults() == 1);
42   Type resultTy = op->getResultTypes().front();
43   if (isErrorPropagationPossible(op->getOperandTypes())) {
44     if (!resultTy.isa<SizeType>())
45       return op->emitOpError()
46              << "if at least one of the operands can hold error values then "
47                 "the result must be of type `size` to propagate them";
48   }
49   return success();
50 }
51 
52 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
53   assert(op != nullptr && op->getNumResults() == 1);
54   Type resultTy = op->getResultTypes().front();
55   if (isErrorPropagationPossible(op->getOperandTypes())) {
56     if (!resultTy.isa<ShapeType>())
57       return op->emitOpError()
58              << "if at least one of the operands can hold error values then "
59                 "the result must be of type `shape` to propagate them";
60   }
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // InlinerInterface
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 /// This class defines the interface for inlining shape dialect ops.
70 struct ShapeInlinerInterface : public DialectInlinerInterface {
71   using DialectInlinerInterface::DialectInlinerInterface;
72 
73   // Returns true if the given region 'src' can be inlined into the region
74   // 'dest' that is attached to an operation registered to the current dialect.
75   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
76                        BlockAndValueMapping &) const final {
77     return true;
78   }
79 
80   // Returns true if the given operation 'op', that is registered to this
81   // dialect, can be inlined into the region 'dest' that is attached to an
82   // operation registered to the current dialect.
83   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
84                        BlockAndValueMapping &) const final {
85     return true;
86   }
87 };
88 } // namespace
89 
90 void ShapeDialect::initialize() {
91   addOperations<
92 #define GET_OP_LIST
93 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
94       >();
95   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
96   addInterfaces<ShapeInlinerInterface>();
97   // Allow unknown operations during prototyping and testing. As the dialect is
98   // still evolving it makes it simple to start with an unregistered ops and
99   // try different variants before actually defining the op.
100   allowUnknownOperations();
101 }
102 
103 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
104                                              Attribute value, Type type,
105                                              Location loc) {
106   if (type.isa<ShapeType>() ||
107       type == getExtentTensorType(builder.getContext()))
108     return builder.create<ConstShapeOp>(loc, type,
109                                         value.cast<DenseIntElementsAttr>());
110   if (type.isa<SizeType>())
111     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
112   if (type.isa<WitnessType>())
113     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
114   if (ConstantOp::isBuildableWith(value, type))
115     return builder.create<ConstantOp>(loc, type, value);
116   return nullptr;
117 }
118 
119 /// Parse a type registered to this dialect.
120 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
121   StringRef keyword;
122   if (parser.parseKeyword(&keyword))
123     return Type();
124 
125   if (keyword == "shape")
126     return ShapeType::get(getContext());
127   if (keyword == "size")
128     return SizeType::get(getContext());
129   if (keyword == "value_shape")
130     return ValueShapeType::get(getContext());
131   if (keyword == "witness")
132     return WitnessType::get(getContext());
133 
134   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
135   return Type();
136 }
137 
138 /// Print a type registered to this dialect.
139 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
140   TypeSwitch<Type>(type)
141       .Case<ShapeType>([&](Type) { os << "shape"; })
142       .Case<SizeType>([&](Type) { os << "size"; })
143       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
144       .Case<WitnessType>([&](Type) { os << "witness"; })
145       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
146 }
147 
148 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
149                                                      NamedAttribute attribute) {
150   // Verify shape.lib attribute.
151   if (attribute.first == "shape.lib") {
152     if (!op->hasTrait<OpTrait::SymbolTable>())
153       return op->emitError(
154           "shape.lib attribute may only be on op implementing SymbolTable");
155 
156     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
157       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
158       if (!symbol)
159         return op->emitError("shape function library ")
160                << symbolRef << " not found";
161       return isa<shape::FunctionLibraryOp>(symbol)
162                  ? success()
163                  : op->emitError()
164                        << symbolRef << " required to be shape function library";
165     }
166 
167     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
168       // Verify all entries are function libraries and mappings in libraries
169       // refer to unique ops.
170       DenseSet<Identifier> key;
171       for (auto it : arr) {
172         if (!it.isa<SymbolRefAttr>())
173           return op->emitError(
174               "only SymbolRefAttr allowed in shape.lib attribute array");
175 
176         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
177             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
178         if (!shapeFnLib)
179           return op->emitError()
180                  << it << " does not refer to FunctionLibraryOp";
181         for (auto mapping : shapeFnLib.mapping()) {
182           if (!key.insert(mapping.first).second) {
183             return op->emitError("only one op to shape mapping allowed, found "
184                                  "multiple for `")
185                    << mapping.first << "`";
186           }
187         }
188       }
189       return success();
190     }
191 
192     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
193                          "allowed as shape.lib attribute");
194   }
195   return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // AnyOp
200 //===----------------------------------------------------------------------===//
201 
202 // TODO: Canonicalization should be implemented for shapes that can be
203 // determined through mixtures of the known dimensions of the inputs.
204 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
205   // Only the last operand is checked because AnyOp is commutative.
206   if (operands.back())
207     return operands.back();
208 
209   return nullptr;
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // AssumingOp
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseAssumingOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   result.regions.reserve(1);
219   Region *doRegion = result.addRegion();
220 
221   auto &builder = parser.getBuilder();
222   OpAsmParser::OperandType cond;
223   if (parser.parseOperand(cond) ||
224       parser.resolveOperand(cond, builder.getType<WitnessType>(),
225                             result.operands))
226     return failure();
227 
228   // Parse optional results type list.
229   if (parser.parseOptionalArrowTypeList(result.types))
230     return failure();
231 
232   // Parse the region and add a terminator if elided.
233   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
234     return failure();
235   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
236 
237   // Parse the optional attribute list.
238   if (parser.parseOptionalAttrDict(result.attributes))
239     return failure();
240   return success();
241 }
242 
243 static void print(OpAsmPrinter &p, AssumingOp op) {
244   bool yieldsResults = !op.results().empty();
245 
246   p << AssumingOp::getOperationName() << " " << op.witness();
247   if (yieldsResults) {
248     p << " -> (" << op.getResultTypes() << ")";
249   }
250   p.printRegion(op.doRegion(),
251                 /*printEntryBlockArgs=*/false,
252                 /*printBlockTerminators=*/yieldsResults);
253   p.printOptionalAttrDict(op.getAttrs());
254 }
255 
256 namespace {
257 // Removes AssumingOp with a passing witness and inlines the region.
258 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
259   using OpRewritePattern<AssumingOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(AssumingOp op,
262                                 PatternRewriter &rewriter) const override {
263     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
264     if (!witness || !witness.passingAttr())
265       return failure();
266 
267     AssumingOp::inlineRegionIntoParent(op, rewriter);
268     return success();
269   }
270 };
271 } // namespace
272 
273 void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
274                                              MLIRContext *context) {
275   // If taking a passing witness, inline region.
276   patterns.insert<AssumingWithTrue>(context);
277 }
278 
279 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
280 void AssumingOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   // AssumingOp has unconditional control flow into the region and back to the
284   // parent, so return the correct RegionSuccessor purely based on the index
285   // being None or 0.
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&doRegion()));
292 }
293 
294 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
295                                         PatternRewriter &rewriter) {
296   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
297   auto *assumingBlock = op.getBody();
298   auto initPosition = rewriter.getInsertionPoint();
299   auto *blockAfterAssuming =
300       rewriter.splitBlock(blockBeforeAssuming, initPosition);
301 
302   // Remove the AssumingOp and AssumingYieldOp.
303   auto &yieldOp = assumingBlock->back();
304   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
305   rewriter.replaceOp(op, yieldOp.getOperands());
306   rewriter.eraseOp(&yieldOp);
307 
308   // Merge blocks together as there was no branching behavior from the
309   // AssumingOp.
310   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
311   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // AssumingAllOp
316 //===----------------------------------------------------------------------===//
317 
318 void AssumingAllOp::getCanonicalizationPatterns(
319     OwningRewritePatternList &patterns, MLIRContext *context) {
320   patterns.insert<AssumingAllOneOp>(context);
321 }
322 
323 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
324   // Iterate in reverse to first handle all constant operands. They are
325   // guaranteed to be the tail of the inputs because this is commutative.
326   for (int idx = operands.size() - 1; idx >= 0; idx--) {
327     Attribute a = operands[idx];
328     // Cannot fold if any inputs are not constant;
329     if (!a)
330       return nullptr;
331 
332     // We do not need to keep statically known values after handling them in
333     // this method.
334     getOperation()->eraseOperand(idx);
335 
336     // Always false if any input is statically known false
337     if (!a.cast<BoolAttr>().getValue())
338       return a;
339   }
340   // If this is reached, all inputs were statically known passing.
341   return BoolAttr::get(getContext(), true);
342 }
343 
344 static LogicalResult verify(AssumingAllOp op) {
345   // Ensure that AssumingAllOp contains at least one operand
346   if (op.getNumOperands() == 0)
347     return op.emitOpError("no operands specified");
348 
349   return success();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // BroadcastOp
354 //===----------------------------------------------------------------------===//
355 
356 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
357   if (!operands[1])
358     return nullptr;
359 
360   // TODO: Support folding with more than 2 input shapes
361   if (operands.size() > 2 && !operands[2].isa<StringAttr>())
362     return nullptr;
363 
364   auto rhsShape = llvm::to_vector<6>(
365       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
366   if (rhsShape.empty())
367     return shapes()[0];
368 
369   if (!operands[0])
370     return nullptr;
371 
372   auto lhsShape = llvm::to_vector<6>(
373       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
374   if (lhsShape.empty())
375     return shapes()[1];
376 
377   SmallVector<int64_t, 6> resultShape;
378   // If the shapes are not compatible, we can't fold it.
379   // TODO: Fold to an "error".
380   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
381     return nullptr;
382   Builder builder(getContext());
383   return builder.getIndexTensorAttr(resultShape);
384 }
385 
386 static LogicalResult verify(BroadcastOp op) {
387   // Ensure that AssumingAllOp contains at least one operand
388   if (op.getNumOperands() < 2)
389     return op.emitOpError("required at least 2 input shapes");
390 
391   return verifyShapeOrExtentTensorOp(op);
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // ConcatOp
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
399   if (!operands[0] || !operands[1])
400     return nullptr;
401   auto lhsShape = llvm::to_vector<6>(
402       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
403   auto rhsShape = llvm::to_vector<6>(
404       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
405   SmallVector<int64_t, 6> resultShape;
406   resultShape.append(lhsShape.begin(), lhsShape.end());
407   resultShape.append(rhsShape.begin(), rhsShape.end());
408   Builder builder(getContext());
409   return builder.getIndexTensorAttr(resultShape);
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // ConstShapeOp
414 //===----------------------------------------------------------------------===//
415 
416 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
417   p << "shape.const_shape ";
418   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
419   p << "[";
420   interleaveComma(op.shape().getValues<int64_t>(), p,
421                   [&](int64_t i) { p << i; });
422   p << "] : ";
423   p.printType(op.getType());
424 }
425 
426 static ParseResult parseConstShapeOp(OpAsmParser &parser,
427                                      OperationState &result) {
428   if (parser.parseOptionalAttrDict(result.attributes))
429     return failure();
430   // We piggy-back on ArrayAttr parsing, though we don't internally store the
431   // shape as an ArrayAttr.
432   // TODO: Implement custom parser and maybe make syntax a bit more concise.
433   Attribute extentsRaw;
434   NamedAttrList dummy;
435   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
436     return failure();
437   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
438   if (!extentsArray)
439     return failure();
440   SmallVector<int64_t, 6> ints;
441   for (Attribute extent : extentsArray) {
442     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
443     if (!attr)
444       return failure();
445     ints.push_back(attr.getInt());
446   }
447   Builder &builder = parser.getBuilder();
448   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
449   Type resultTy;
450   if (parser.parseColonType(resultTy))
451     return failure();
452   result.types.push_back(resultTy);
453   return success();
454 }
455 
456 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
457 
458 void ConstShapeOp::getCanonicalizationPatterns(
459     OwningRewritePatternList &patterns, MLIRContext *context) {
460   patterns.insert<TensorCastConstShape>(context);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // CstrBroadcastableOp
465 //===----------------------------------------------------------------------===//
466 
467 namespace {
468 // Given an input shape Value, try to obtain the shape's values.
469 LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
470   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
471     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
472     if (!type.hasRank())
473       return failure();
474     shapeValues = llvm::to_vector<6>(type.getShape());
475     return success();
476   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
477     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
478     return success();
479   } else {
480     return failure();
481   }
482 }
483 } // namespace
484 
485 void CstrBroadcastableOp::getCanonicalizationPatterns(
486     OwningRewritePatternList &patterns, MLIRContext *context) {
487   // Canonicalization patterns have overlap with the considerations during
488   // folding in case additional shape information is inferred at some point that
489   // does not result in folding.
490   patterns.insert<CstrBroadcastableEqOps>(context);
491 }
492 
493 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
494   // TODO: Add folding for the nary case
495   if (operands.size() != 2)
496     return nullptr;
497 
498   // Both operands are not needed if one is a scalar.
499   if (operands[0] &&
500       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
501     return BoolAttr::get(getContext(), true);
502   if (operands[1] &&
503       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
504     return BoolAttr::get(getContext(), true);
505 
506   if (operands[0] && operands[1]) {
507     auto lhsShape = llvm::to_vector<6>(
508         operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
509     auto rhsShape = llvm::to_vector<6>(
510         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
511     SmallVector<int64_t, 6> resultShape;
512     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
513       return BoolAttr::get(getContext(), true);
514   }
515 
516   // Lastly, see if folding can be completed based on what constraints are known
517   // on the input shapes.
518   SmallVector<int64_t, 6> lhsShape, rhsShape;
519   if (failed(getShapeVec(shapes()[0], lhsShape)))
520     return nullptr;
521   if (failed(getShapeVec(shapes()[1], rhsShape)))
522     return nullptr;
523 
524   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
525     return BoolAttr::get(getContext(), true);
526 
527   // Because a failing witness result here represents an eventual assertion
528   // failure, we do not replace it with a constant witness.
529   return nullptr;
530 }
531 
532 static LogicalResult verify(CstrBroadcastableOp op) {
533   // Ensure that AssumingAllOp contains at least one operand
534   if (op.getNumOperands() < 2)
535     return op.emitOpError("required at least 2 input shapes");
536   return success();
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // CstrEqOp
541 //===----------------------------------------------------------------------===//
542 
543 void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
544                                            MLIRContext *context) {
545   // If inputs are equal, return passing witness
546   patterns.insert<CstrEqEqOps>(context);
547 }
548 
549 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
550   if (llvm::all_of(operands,
551                    [&](Attribute a) { return a && a == operands[0]; }))
552     return BoolAttr::get(getContext(), true);
553 
554   // Because a failing witness result here represents an eventual assertion
555   // failure, we do not try to replace it with a constant witness. Similarly, we
556   // cannot if there are any non-const inputs.
557   return nullptr;
558 }
559 
560 //===----------------------------------------------------------------------===//
561 // ConstSizeOp
562 //===----------------------------------------------------------------------===//
563 
564 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
565                         int64_t value) {
566   build(builder, result, builder.getIndexAttr(value));
567 }
568 
569 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
570 
571 void ConstSizeOp::getAsmResultNames(
572     llvm::function_ref<void(Value, StringRef)> setNameFn) {
573   SmallString<4> buffer;
574   llvm::raw_svector_ostream os(buffer);
575   os << "c" << value();
576   setNameFn(getResult(), os.str());
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // ConstWitnessOp
581 //===----------------------------------------------------------------------===//
582 
583 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
584 
585 //===----------------------------------------------------------------------===//
586 // CstrRequireOp
587 //===----------------------------------------------------------------------===//
588 
589 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
590   return operands[0];
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // ShapeEqOp
595 //===----------------------------------------------------------------------===//
596 
597 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
598   if (lhs() == rhs())
599     return BoolAttr::get(getContext(), true);
600   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
601   if (lhs == nullptr)
602     return {};
603   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
604   if (rhs == nullptr)
605     return {};
606   return BoolAttr::get(getContext(), lhs == rhs);
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // IndexToSizeOp
611 //===----------------------------------------------------------------------===//
612 
613 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
614   // Constant values of both types, `shape.size` and `index`, are represented as
615   // `IntegerAttr`s which makes constant folding simple.
616   if (Attribute arg = operands[0])
617     return arg;
618   return {};
619 }
620 
621 void IndexToSizeOp::getCanonicalizationPatterns(
622     OwningRewritePatternList &patterns, MLIRContext *context) {
623   patterns.insert<SizeToIndexToSizeCanonicalization>(context);
624 }
625 
626 //===----------------------------------------------------------------------===//
627 // FromExtentsOp
628 //===----------------------------------------------------------------------===//
629 
630 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
631   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
632     return nullptr;
633   SmallVector<int64_t, 6> extents;
634   for (auto attr : operands)
635     extents.push_back(attr.cast<IntegerAttr>().getInt());
636   Builder builder(getContext());
637   return builder.getIndexTensorAttr(extents);
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // FunctionLibraryOp
642 //===----------------------------------------------------------------------===//
643 
644 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
645                               StringRef name) {
646   ensureTerminator(*result.addRegion(), builder, result.location);
647   result.attributes.push_back(builder.getNamedAttr(
648       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
649 }
650 
651 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
652   auto attr = mapping()
653                   .get(op->getName().getIdentifier())
654                   .dyn_cast_or_null<FlatSymbolRefAttr>();
655   if (!attr)
656     return nullptr;
657   return lookupSymbol<FuncOp>(attr);
658 }
659 
660 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
661                                    OperationState &result) {
662   // Parse the op name.
663   StringAttr nameAttr;
664   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
665                              result.attributes))
666     return failure();
667 
668   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
669     return failure();
670 
671   auto *bodyRegion = result.addRegion();
672   if (parser.parseRegion(*bodyRegion))
673     return failure();
674 
675   FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
676                                       result.location);
677   if (parser.parseKeyword("mapping"))
678     return failure();
679 
680   DictionaryAttr mappingAttr;
681   if (parser.parseAttribute(mappingAttr,
682                             parser.getBuilder().getType<NoneType>(), "mapping",
683                             result.attributes))
684     return failure();
685   return success();
686 }
687 
688 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
689   p << op.getOperationName() << ' ';
690   p.printSymbolName(op.getName());
691   p.printOptionalAttrDictWithKeyword(
692       op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
693   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
694                 /*printBlockTerminators=*/false);
695   p << " mapping ";
696   p.printAttributeWithoutType(op.mappingAttr());
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // GetExtentOp
701 //===----------------------------------------------------------------------===//
702 
703 Optional<int64_t> GetExtentOp::getConstantDim() {
704   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
705     return constSizeOp.value().getLimitedValue();
706   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
707     return constantOp.value().cast<IntegerAttr>().getInt();
708   return llvm::None;
709 }
710 
711 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
712   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
713   if (!elements)
714     return nullptr;
715   Optional<int64_t> dim = getConstantDim();
716   if (!dim.hasValue())
717     return nullptr;
718   if (dim.getValue() >= elements.getNumElements())
719     return nullptr;
720   return elements.getValue({(uint64_t)dim.getValue()});
721 }
722 
723 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
724                         int64_t dim) {
725   auto loc = result.location;
726   auto dimAttr = builder.getIndexAttr(dim);
727   if (shape.getType().isa<ShapeType>()) {
728     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
729     build(builder, result, builder.getType<SizeType>(), shape, dim);
730   } else {
731     Value dim =
732         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
733     build(builder, result, builder.getIndexType(), shape, dim);
734   }
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // IsBroadcastableOp
739 //===----------------------------------------------------------------------===//
740 
741 static LogicalResult verify(IsBroadcastableOp op) {
742   // Ensure that AssumingAllOp contains at least one operand
743   if (op.getNumOperands() < 2)
744     return op.emitOpError("required at least 2 input shapes");
745   return success();
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // RankOp
750 //===----------------------------------------------------------------------===//
751 
752 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
753   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
754   if (!shape)
755     return {};
756   int64_t rank = shape.getNumElements();
757   Builder builder(getContext());
758   return builder.getIndexAttr(rank);
759 }
760 
761 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
762 /// Constant folding fails in cases where only the rank is constant, not the
763 /// shape itself.
764 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
765 ///
766 /// Example:
767 ///
768 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
769 /// %rank = shape.rank %shape
770 ///
771 /// becomes
772 ///
773 /// %rank = shape.const_size 3
774 
775 namespace {
776 struct RankShapeOfCanonicalizationPattern
777     : public OpRewritePattern<shape::RankOp> {
778   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
779 
780   LogicalResult matchAndRewrite(shape::RankOp op,
781                                 PatternRewriter &rewriter) const override {
782     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
783     if (!shapeOfOp)
784       return failure();
785     auto rankedTensorType =
786         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
787     if (!rankedTensorType)
788       return failure();
789     int64_t rank = rankedTensorType.getRank();
790     if (op.getType().isa<IndexType>()) {
791       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
792     } else if (op.getType().isa<shape::SizeType>()) {
793       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
794     } else {
795       return failure();
796     }
797     return success();
798   }
799 };
800 } // namespace
801 
802 void shape::RankOp::getCanonicalizationPatterns(
803     OwningRewritePatternList &patterns, MLIRContext *context) {
804   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // NumElementsOp
809 //===----------------------------------------------------------------------===//
810 
811 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
812 
813   // Fold only when argument constant.
814   Attribute shape = operands[0];
815   if (!shape)
816     return {};
817 
818   APInt product(64, 1);
819   for (auto value : shape.cast<DenseIntElementsAttr>())
820     product *= value;
821   Builder builder(getContext());
822   return builder.getIndexAttr(product.getLimitedValue());
823 }
824 
825 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
826                           Value shape) {
827   if (shape.getType().isa<ShapedType>()) {
828     auto type = builder.getIndexType();
829     return build(builder, result, type, shape);
830   }
831   auto type = SizeType::get(builder.getContext());
832   return build(builder, result, type, shape);
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // MulOp
837 //===----------------------------------------------------------------------===//
838 
839 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
840   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
841   if (!lhs)
842     return nullptr;
843   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
844   if (!rhs)
845     return nullptr;
846   APInt folded = lhs.getValue() * rhs.getValue();
847   Type indexTy = IndexType::get(getContext());
848   return IntegerAttr::get(indexTy, folded);
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // ShapeOfOp
853 //===----------------------------------------------------------------------===//
854 
855 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
856   auto type = getOperand().getType().dyn_cast<ShapedType>();
857   if (!type || !type.hasStaticShape())
858     return nullptr;
859   Builder builder(getContext());
860   return builder.getIndexTensorAttr(type.getShape());
861 }
862 
863 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
864   Type type = arg.getType().isa<ShapedType>()
865                   ? (Type)getExtentTensorType(builder.getContext())
866                   : (Type)builder.getType<ShapeType>();
867   return ShapeOfOp::build(builder, result, type, arg);
868 }
869 
870 namespace {
871 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
872   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
873 
874   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
875                                 PatternRewriter &rewriter) const override {
876     if (!op.arg().getType().isa<ShapedType>())
877       return failure();
878     if (op.getType().isa<ShapedType>())
879       return failure();
880 
881     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
882     return success();
883   }
884 };
885 } // namespace
886 
887 void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
888                                             MLIRContext *context) {
889   patterns.insert<ShapeOfWithTensor>(context);
890 }
891 
892 //===----------------------------------------------------------------------===//
893 // SizeToIndexOp
894 //===----------------------------------------------------------------------===//
895 
896 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
897   // Constant values of both types, `shape.size` and `index`, are represented as
898   // `IntegerAttr`s which makes constant folding simple.
899   if (Attribute arg = operands[0])
900     return arg;
901   return impl::foldCastOp(*this);
902 }
903 
904 void SizeToIndexOp::getCanonicalizationPatterns(
905     OwningRewritePatternList &patterns, MLIRContext *context) {
906   patterns.insert<IndexToSizeToIndexCanonicalization>(context);
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // YieldOp
911 //===----------------------------------------------------------------------===//
912 
913 static LogicalResult verify(shape::YieldOp op) {
914   auto *parentOp = op->getParentOp();
915   auto results = parentOp->getResults();
916   auto operands = op.getOperands();
917 
918   if (parentOp->getNumResults() != op.getNumOperands())
919     return op.emitOpError() << "number of operands does not match number of "
920                                "results of its parent";
921   for (auto e : llvm::zip(results, operands))
922     if (std::get<0>(e).getType() != std::get<1>(e).getType())
923       return op.emitOpError()
924              << "types mismatch between yield op and its parent";
925 
926   return success();
927 }
928 
929 //===----------------------------------------------------------------------===//
930 // SplitAtOp
931 //===----------------------------------------------------------------------===//
932 
933 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
934                               SmallVectorImpl<OpFoldResult> &results) {
935   if (!operands[0] || !operands[1])
936     return failure();
937   auto shapeVec = llvm::to_vector<6>(
938       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
939   auto shape = llvm::makeArrayRef(shapeVec);
940   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
941   // Verify that the split point is in the correct range.
942   // TODO: Constant fold to an "error".
943   int64_t rank = shape.size();
944   if (!(-rank <= splitPoint && splitPoint <= rank))
945     return failure();
946   if (splitPoint < 0)
947     splitPoint += shape.size();
948   Builder builder(operands[0].getContext());
949   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
950   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
951   return success();
952 }
953 
954 //===----------------------------------------------------------------------===//
955 // ToExtentTensorOp
956 //===----------------------------------------------------------------------===//
957 
958 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
959   if (!operands[0])
960     return impl::foldCastOp(*this);
961   Builder builder(getContext());
962   auto shape = llvm::to_vector<6>(
963       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
964   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
965                                     builder.getIndexType());
966   return DenseIntElementsAttr::get(type, shape);
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // ReduceOp
971 //===----------------------------------------------------------------------===//
972 
973 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
974                      ValueRange initVals) {
975   result.addOperands(shape);
976   result.addOperands(initVals);
977 
978   Region *bodyRegion = result.addRegion();
979   bodyRegion->push_back(new Block);
980   Block &bodyBlock = bodyRegion->front();
981   bodyBlock.addArgument(builder.getIndexType());
982 
983   Type elementType;
984   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
985     elementType = tensorType.getElementType();
986   else
987     elementType = SizeType::get(builder.getContext());
988   bodyBlock.addArgument(elementType);
989 
990   for (Type initValType : initVals.getTypes()) {
991     bodyBlock.addArgument(initValType);
992     result.addTypes(initValType);
993   }
994 }
995 
996 static LogicalResult verify(ReduceOp op) {
997   // Verify block arg types.
998   Block &block = op.region().front();
999 
1000   // The block takes index, extent, and aggregated values as arguments.
1001   auto blockArgsCount = op.initVals().size() + 2;
1002   if (block.getNumArguments() != blockArgsCount)
1003     return op.emitOpError() << "ReduceOp body is expected to have "
1004                             << blockArgsCount << " arguments";
1005 
1006   // The first block argument is the index and must always be of type `index`.
1007   if (!block.getArgument(0).getType().isa<IndexType>())
1008     return op.emitOpError(
1009         "argument 0 of ReduceOp body is expected to be of IndexType");
1010 
1011   // The second block argument is the extent and must be of type `size` or
1012   // `index`, depending on whether the reduce operation is applied to a shape or
1013   // to an extent tensor.
1014   Type extentTy = block.getArgument(1).getType();
1015   if (op.shape().getType().isa<ShapeType>()) {
1016     if (!extentTy.isa<SizeType>())
1017       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1018                             "SizeType if the ReduceOp operates on a ShapeType");
1019   } else {
1020     if (!extentTy.isa<IndexType>())
1021       return op.emitOpError(
1022           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1023           "ReduceOp operates on an extent tensor");
1024   }
1025 
1026   for (auto type : llvm::enumerate(op.initVals()))
1027     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1028       return op.emitOpError()
1029              << "type mismatch between argument " << type.index() + 2
1030              << " of ReduceOp body and initial value " << type.index();
1031   return success();
1032 }
1033 
1034 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1035   // Parse operands.
1036   SmallVector<OpAsmParser::OperandType, 3> operands;
1037   Type shapeOrExtentTensorType;
1038   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1039                               OpAsmParser::Delimiter::Paren) ||
1040       parser.parseColonType(shapeOrExtentTensorType) ||
1041       parser.parseOptionalArrowTypeList(result.types))
1042     return failure();
1043 
1044   // Resolve operands.
1045   auto initVals = llvm::makeArrayRef(operands).drop_front();
1046   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1047                             result.operands) ||
1048       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1049                              result.operands))
1050     return failure();
1051 
1052   // Parse the body.
1053   Region *body = result.addRegion();
1054   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1055     return failure();
1056 
1057   // Parse attributes.
1058   if (parser.parseOptionalAttrDict(result.attributes))
1059     return failure();
1060 
1061   return success();
1062 }
1063 
1064 static void print(OpAsmPrinter &p, ReduceOp op) {
1065   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1066     << ") : " << op.shape().getType();
1067   p.printOptionalArrowTypeList(op.getResultTypes());
1068   p.printRegion(op.region());
1069   p.printOptionalAttrDict(op.getAttrs());
1070 }
1071 
1072 #define GET_OP_CLASSES
1073 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1074