1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 using namespace mlir;
25 using namespace mlir::shape;
26 
27 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
28 
29 namespace {
30 #include "ShapeCanonicalization.inc"
31 }
32 
33 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
34   return RankedTensorType::get({rank}, IndexType::get(ctx));
35 }
36 
37 bool shape::isExtentTensorType(Type type) {
38   auto ranked = type.dyn_cast<RankedTensorType>();
39   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
40 }
41 
42 LogicalResult shape::getShapeVec(Value input,
43                                  SmallVectorImpl<int64_t> &shapeValues) {
44   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
45     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
46     if (!type.hasRank())
47       return failure();
48     shapeValues = llvm::to_vector<6>(type.getShape());
49     return success();
50   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
51     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
52     return success();
53   } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
54     shapeValues = llvm::to_vector<6>(
55         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
56     return success();
57   } else {
58     return failure();
59   }
60 }
61 
62 static bool isErrorPropagationPossible(TypeRange operandTypes) {
63   return llvm::any_of(operandTypes, [](Type ty) {
64     return ty.isa<SizeType, ShapeType, ValueShapeType>();
65   });
66 }
67 
68 static LogicalResult verifySizeOrIndexOp(Operation *op) {
69   assert(op != nullptr && op->getNumResults() == 1);
70   Type resultTy = op->getResultTypes().front();
71   if (isErrorPropagationPossible(op->getOperandTypes())) {
72     if (!resultTy.isa<SizeType>())
73       return op->emitOpError()
74              << "if at least one of the operands can hold error values then "
75                 "the result must be of type `size` to propagate them";
76   }
77   return success();
78 }
79 
80 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
81   assert(op != nullptr && op->getNumResults() == 1);
82   Type resultTy = op->getResultTypes().front();
83   if (isErrorPropagationPossible(op->getOperandTypes())) {
84     if (!resultTy.isa<ShapeType>())
85       return op->emitOpError()
86              << "if at least one of the operands can hold error values then "
87                 "the result must be of type `shape` to propagate them";
88   }
89   return success();
90 }
91 
92 template <typename... Ty>
93 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
94   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
95 }
96 
97 template <typename... Ty, typename... ranges>
98 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
99   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // InlinerInterface
104 //===----------------------------------------------------------------------===//
105 
106 namespace {
107 /// This class defines the interface for inlining shape dialect ops.
108 struct ShapeInlinerInterface : public DialectInlinerInterface {
109   using DialectInlinerInterface::DialectInlinerInterface;
110 
111   // Returns true if the given region 'src' can be inlined into the region
112   // 'dest' that is attached to an operation registered to the current dialect.
113   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
114                        BlockAndValueMapping &) const final {
115     return true;
116   }
117 
118   // Returns true if the given operation 'op', that is registered to this
119   // dialect, can be inlined into the region 'dest' that is attached to an
120   // operation registered to the current dialect.
121   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
122                        BlockAndValueMapping &) const final {
123     return true;
124   }
125 };
126 } // namespace
127 
128 void ShapeDialect::initialize() {
129   addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
132       >();
133   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
134   addInterfaces<ShapeInlinerInterface>();
135   // Allow unknown operations during prototyping and testing. As the dialect is
136   // still evolving it makes it simple to start with an unregistered ops and
137   // try different variants before actually defining the op.
138   allowUnknownOperations();
139 }
140 
141 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
142                                              Attribute value, Type type,
143                                              Location loc) {
144   if (type.isa<ShapeType>() || isExtentTensorType(type))
145     return builder.create<ConstShapeOp>(loc, type,
146                                         value.cast<DenseIntElementsAttr>());
147   if (type.isa<SizeType>())
148     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
149   if (type.isa<WitnessType>())
150     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
151   if (ConstantOp::isBuildableWith(value, type))
152     return builder.create<ConstantOp>(loc, type, value);
153   return nullptr;
154 }
155 
156 /// Parse a type registered to this dialect.
157 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
158   StringRef keyword;
159   if (parser.parseKeyword(&keyword))
160     return Type();
161 
162   if (keyword == "shape")
163     return ShapeType::get(getContext());
164   if (keyword == "size")
165     return SizeType::get(getContext());
166   if (keyword == "value_shape")
167     return ValueShapeType::get(getContext());
168   if (keyword == "witness")
169     return WitnessType::get(getContext());
170 
171   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
172   return Type();
173 }
174 
175 /// Print a type registered to this dialect.
176 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
177   TypeSwitch<Type>(type)
178       .Case<ShapeType>([&](Type) { os << "shape"; })
179       .Case<SizeType>([&](Type) { os << "size"; })
180       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
181       .Case<WitnessType>([&](Type) { os << "witness"; })
182       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
183 }
184 
185 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
186                                                      NamedAttribute attribute) {
187   // Verify shape.lib attribute.
188   if (attribute.first == "shape.lib") {
189     if (!op->hasTrait<OpTrait::SymbolTable>())
190       return op->emitError(
191           "shape.lib attribute may only be on op implementing SymbolTable");
192 
193     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
194       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
195       if (!symbol)
196         return op->emitError("shape function library ")
197                << symbolRef << " not found";
198       return isa<shape::FunctionLibraryOp>(symbol)
199                  ? success()
200                  : op->emitError()
201                        << symbolRef << " required to be shape function library";
202     }
203 
204     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
205       // Verify all entries are function libraries and mappings in libraries
206       // refer to unique ops.
207       DenseSet<Identifier> key;
208       for (auto it : arr) {
209         if (!it.isa<SymbolRefAttr>())
210           return op->emitError(
211               "only SymbolRefAttr allowed in shape.lib attribute array");
212 
213         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
214             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
215         if (!shapeFnLib)
216           return op->emitError()
217                  << it << " does not refer to FunctionLibraryOp";
218         for (auto mapping : shapeFnLib.mapping()) {
219           if (!key.insert(mapping.first).second) {
220             return op->emitError("only one op to shape mapping allowed, found "
221                                  "multiple for `")
222                    << mapping.first << "`";
223           }
224         }
225       }
226       return success();
227     }
228 
229     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
230                          "allowed as shape.lib attribute");
231   }
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // AnyOp
237 //===----------------------------------------------------------------------===//
238 
239 // TODO: Canonicalization should be implemented for shapes that can be
240 // determined through mixtures of the known dimensions of the inputs.
241 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
242   // Only the last operand is checked because AnyOp is commutative.
243   if (operands.back())
244     return operands.back();
245 
246   return nullptr;
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // AssumingOp
251 //===----------------------------------------------------------------------===//
252 
253 static ParseResult parseAssumingOp(OpAsmParser &parser,
254                                    OperationState &result) {
255   result.regions.reserve(1);
256   Region *doRegion = result.addRegion();
257 
258   auto &builder = parser.getBuilder();
259   OpAsmParser::OperandType cond;
260   if (parser.parseOperand(cond) ||
261       parser.resolveOperand(cond, builder.getType<WitnessType>(),
262                             result.operands))
263     return failure();
264 
265   // Parse optional results type list.
266   if (parser.parseOptionalArrowTypeList(result.types))
267     return failure();
268 
269   // Parse the region and add a terminator if elided.
270   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
271     return failure();
272   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
273 
274   // Parse the optional attribute list.
275   if (parser.parseOptionalAttrDict(result.attributes))
276     return failure();
277   return success();
278 }
279 
280 static void print(OpAsmPrinter &p, AssumingOp op) {
281   bool yieldsResults = !op.results().empty();
282 
283   p << " " << op.witness();
284   if (yieldsResults) {
285     p << " -> (" << op.getResultTypes() << ")";
286   }
287   p.printRegion(op.doRegion(),
288                 /*printEntryBlockArgs=*/false,
289                 /*printBlockTerminators=*/yieldsResults);
290   p.printOptionalAttrDict(op->getAttrs());
291 }
292 
293 namespace {
294 // Removes AssumingOp with a passing witness and inlines the region.
295 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
296   using OpRewritePattern<AssumingOp>::OpRewritePattern;
297 
298   LogicalResult matchAndRewrite(AssumingOp op,
299                                 PatternRewriter &rewriter) const override {
300     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
301     if (!witness || !witness.passingAttr())
302       return failure();
303 
304     AssumingOp::inlineRegionIntoParent(op, rewriter);
305     return success();
306   }
307 };
308 
309 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
310   using OpRewritePattern<AssumingOp>::OpRewritePattern;
311 
312   LogicalResult matchAndRewrite(AssumingOp op,
313                                 PatternRewriter &rewriter) const override {
314     Block *body = op.getBody();
315     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
316 
317     // Find used values.
318     SmallVector<Value, 4> newYieldOperands;
319     Value opResult, yieldOperand;
320     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
321       std::tie(opResult, yieldOperand) = it;
322       if (!opResult.getUses().empty()) {
323         newYieldOperands.push_back(yieldOperand);
324       }
325     }
326 
327     // Rewrite only if redundant results exist.
328     if (newYieldOperands.size() == yieldOp->getNumOperands())
329       return failure();
330 
331     // Replace yield op in the old assuming op's body and move the entire region
332     // to the new assuming op.
333     rewriter.setInsertionPointToEnd(body);
334     auto newYieldOp =
335         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
336     rewriter.setInsertionPoint(op);
337     auto newOp = rewriter.create<AssumingOp>(
338         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
339     newOp.doRegion().takeBody(op.doRegion());
340 
341     // Use the new results to replace the previously used ones.
342     SmallVector<Value, 4> replacementValues;
343     auto src = newOp.getResults().begin();
344     for (auto it : op.getResults()) {
345       if (it.getUses().empty())
346         replacementValues.push_back(nullptr);
347       else
348         replacementValues.push_back(*src++);
349     }
350     rewriter.replaceOp(op, replacementValues);
351     return success();
352   }
353 };
354 } // namespace
355 
356 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
357                                              MLIRContext *context) {
358   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
359 }
360 
361 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
362 void AssumingOp::getSuccessorRegions(
363     Optional<unsigned> index, ArrayRef<Attribute> operands,
364     SmallVectorImpl<RegionSuccessor> &regions) {
365   // AssumingOp has unconditional control flow into the region and back to the
366   // parent, so return the correct RegionSuccessor purely based on the index
367   // being None or 0.
368   if (index.hasValue()) {
369     regions.push_back(RegionSuccessor(getResults()));
370     return;
371   }
372 
373   regions.push_back(RegionSuccessor(&doRegion()));
374 }
375 
376 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
377                                         PatternRewriter &rewriter) {
378   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
379   auto *assumingBlock = op.getBody();
380   auto initPosition = rewriter.getInsertionPoint();
381   auto *blockAfterAssuming =
382       rewriter.splitBlock(blockBeforeAssuming, initPosition);
383 
384   // Remove the AssumingOp and AssumingYieldOp.
385   auto &yieldOp = assumingBlock->back();
386   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
387   rewriter.replaceOp(op, yieldOp.getOperands());
388   rewriter.eraseOp(&yieldOp);
389 
390   // Merge blocks together as there was no branching behavior from the
391   // AssumingOp.
392   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
393   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
394 }
395 
396 void AssumingOp::build(
397     OpBuilder &builder, OperationState &result, Value witness,
398     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
399 
400   result.addOperands(witness);
401   Region *bodyRegion = result.addRegion();
402   bodyRegion->push_back(new Block);
403   Block &bodyBlock = bodyRegion->front();
404 
405   // Build body.
406   OpBuilder::InsertionGuard guard(builder);
407   builder.setInsertionPointToStart(&bodyBlock);
408   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
409   builder.create<AssumingYieldOp>(result.location, yieldValues);
410 
411   SmallVector<Type, 2> assumingTypes;
412   for (Value v : yieldValues)
413     assumingTypes.push_back(v.getType());
414   result.addTypes(assumingTypes);
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // AddOp
419 //===----------------------------------------------------------------------===//
420 
421 LogicalResult mlir::shape::AddOp::inferReturnTypes(
422     MLIRContext *context, Optional<Location> location, ValueRange operands,
423     DictionaryAttr attributes, RegionRange regions,
424     SmallVectorImpl<Type> &inferredReturnTypes) {
425   if (operands[0].getType().isa<SizeType>() ||
426       operands[1].getType().isa<SizeType>())
427     inferredReturnTypes.assign({SizeType::get(context)});
428   else
429     inferredReturnTypes.assign({IndexType::get(context)});
430   return success();
431 }
432 
433 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
434   // SizeType is compatible with IndexType.
435   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // AssumingAllOp
440 //===----------------------------------------------------------------------===//
441 
442 namespace {
443 struct AssumingAllToCstrEqCanonicalization
444     : public OpRewritePattern<AssumingAllOp> {
445   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
446 
447   LogicalResult matchAndRewrite(AssumingAllOp op,
448                                 PatternRewriter &rewriter) const override {
449     SmallVector<Value, 8> shapes;
450     for (Value w : op.inputs()) {
451       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
452       if (!cstrEqOp)
453         return failure();
454       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
455         return llvm::is_contained(shapes, s);
456       });
457       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
458         return failure();
459       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
460     }
461     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
462     return success();
463   }
464 };
465 
466 template <typename OpTy>
467 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
468   using OpRewritePattern<OpTy>::OpRewritePattern;
469 
470   LogicalResult matchAndRewrite(OpTy op,
471                                 PatternRewriter &rewriter) const override {
472     // Find unique operands.
473     SmallVector<Value, 2> unique;
474     for (Value v : op.getOperands()) {
475       if (!llvm::is_contained(unique, v))
476         unique.push_back(v);
477     }
478 
479     // Reduce op to equivalent with unique operands.
480     if (unique.size() < op.getNumOperands()) {
481       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
482                                         op->getAttrs());
483       return success();
484     }
485 
486     return failure();
487   }
488 };
489 } // namespace
490 
491 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
492                                                 MLIRContext *context) {
493   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
494                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
495 }
496 
497 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
498   // Iterate in reverse to first handle all constant operands. They are
499   // guaranteed to be the tail of the inputs because this is commutative.
500   for (int idx = operands.size() - 1; idx >= 0; idx--) {
501     Attribute a = operands[idx];
502     // Cannot fold if any inputs are not constant;
503     if (!a)
504       return nullptr;
505 
506     // We do not need to keep statically known values after handling them in
507     // this method.
508     getOperation()->eraseOperand(idx);
509 
510     // Always false if any input is statically known false
511     if (!a.cast<BoolAttr>().getValue())
512       return a;
513   }
514   // If this is reached, all inputs were statically known passing.
515   return BoolAttr::get(getContext(), true);
516 }
517 
518 static LogicalResult verify(AssumingAllOp op) {
519   // Ensure that AssumingAllOp contains at least one operand
520   if (op.getNumOperands() == 0)
521     return op.emitOpError("no operands specified");
522 
523   return success();
524 }
525 
526 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
527                           ValueRange inputs) {
528   build(b, state, b.getType<WitnessType>(), inputs);
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // BroadcastOp
533 //===----------------------------------------------------------------------===//
534 
535 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
536   if (shapes().size() == 1) {
537     // Otherwise, we need a cast which would be a canonicalization, not folding.
538     if (shapes().front().getType() != getType())
539       return nullptr;
540     return shapes().front();
541   }
542 
543   // TODO: Support folding with more than 2 input shapes
544   if (shapes().size() > 2)
545     return nullptr;
546 
547   if (!operands[0] || !operands[1])
548     return nullptr;
549   auto lhsShape = llvm::to_vector<6>(
550       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
551   auto rhsShape = llvm::to_vector<6>(
552       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
553   SmallVector<int64_t, 6> resultShape;
554 
555   // If the shapes are not compatible, we can't fold it.
556   // TODO: Fold to an "error".
557   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
558     return nullptr;
559 
560   Builder builder(getContext());
561   return builder.getIndexTensorAttr(resultShape);
562 }
563 
564 static LogicalResult verify(BroadcastOp op) {
565   return verifyShapeOrExtentTensorOp(op);
566 }
567 
568 namespace {
569 template <typename OpTy>
570 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
571   using OpRewritePattern<OpTy>::OpRewritePattern;
572 
573   LogicalResult matchAndRewrite(OpTy op,
574                                 PatternRewriter &rewriter) const override {
575     auto isPotentiallyNonEmptyShape = [](Value shape) {
576       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
577         if (extentTensorTy.getDimSize(0) == 0)
578           return false;
579       }
580       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
581         if (constShape.shape().empty())
582           return false;
583       }
584       return true;
585     };
586     auto newOperands = llvm::to_vector<8>(
587         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
588 
589     // Reduce op to equivalent without empty shape operands.
590     if (newOperands.size() < op.getNumOperands()) {
591       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
592                                         op->getAttrs());
593       return success();
594     }
595 
596     return failure();
597   }
598 };
599 
600 struct BroadcastForwardSingleOperandPattern
601     : public OpRewritePattern<BroadcastOp> {
602   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
603 
604   LogicalResult matchAndRewrite(BroadcastOp op,
605                                 PatternRewriter &rewriter) const override {
606     if (op.getNumOperands() != 1)
607       return failure();
608     Value replacement = op.shapes().front();
609 
610     // Insert cast if needed.
611     if (replacement.getType() != op.getType()) {
612       auto loc = op.getLoc();
613       if (op.getType().isa<ShapeType>()) {
614         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
615       } else {
616         assert(!op.getType().isa<ShapeType>() &&
617                !replacement.getType().isa<ShapeType>() &&
618                "expect extent tensor cast");
619         replacement =
620             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
621       }
622     }
623 
624     rewriter.replaceOp(op, replacement);
625     return success();
626   }
627 };
628 
629 struct BroadcastFoldConstantOperandsPattern
630     : public OpRewritePattern<BroadcastOp> {
631   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
632 
633   LogicalResult matchAndRewrite(BroadcastOp op,
634                                 PatternRewriter &rewriter) const override {
635     SmallVector<int64_t, 8> foldedConstantShape;
636     SmallVector<Value, 8> newShapeOperands;
637     for (Value shape : op.shapes()) {
638       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
639         SmallVector<int64_t, 8> newFoldedConstantShape;
640         if (OpTrait::util::getBroadcastedShape(
641                 foldedConstantShape,
642                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
643                 newFoldedConstantShape)) {
644           foldedConstantShape = newFoldedConstantShape;
645           continue;
646         }
647       }
648       newShapeOperands.push_back(shape);
649     }
650 
651     // Need at least two constant operands to fold anything.
652     if (op.getNumOperands() - newShapeOperands.size() < 2)
653       return failure();
654 
655     auto foldedConstantOperandsTy = RankedTensorType::get(
656         {static_cast<int64_t>(foldedConstantShape.size())},
657         rewriter.getIndexType());
658     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
659         op.getLoc(), foldedConstantOperandsTy,
660         rewriter.getIndexTensorAttr(foldedConstantShape)));
661     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
662                                              newShapeOperands);
663     return success();
664   }
665 };
666 
667 template <typename OpTy>
668 struct CanonicalizeCastExtentTensorOperandsPattern
669     : public OpRewritePattern<OpTy> {
670   using OpRewritePattern<OpTy>::OpRewritePattern;
671 
672   LogicalResult matchAndRewrite(OpTy op,
673                                 PatternRewriter &rewriter) const override {
674     // Canonicalize operands.
675     bool anyChange = false;
676     auto canonicalizeOperand = [&](Value operand) {
677       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
678         // Only eliminate the cast if it holds no shape information.
679         bool isInformationLoosingCast =
680             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
681         if (isInformationLoosingCast) {
682           anyChange = true;
683           return castOp.source();
684         }
685       }
686       return operand;
687     };
688     auto newOperands = llvm::to_vector<8>(
689         llvm::map_range(op.getOperands(), canonicalizeOperand));
690 
691     // Rewrite op if any change required.
692     if (!anyChange)
693       return failure();
694     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
695     return success();
696   }
697 };
698 
699 struct BroadcastConcretizeResultTypePattern
700     : public OpRewritePattern<BroadcastOp> {
701   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
702 
703   LogicalResult matchAndRewrite(BroadcastOp op,
704                                 PatternRewriter &rewriter) const override {
705     // Only concretize dynamic extent tensor result types.
706     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
707     if (!resultTy || !resultTy.isDynamicDim(0))
708       return failure();
709 
710     // Infer resulting shape rank if possible.
711     int64_t maxRank = 0;
712     for (Value shape : op.shapes()) {
713       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
714         // Cannot infer resulting shape rank if any operand is dynamically
715         // ranked.
716         if (extentTensorTy.isDynamicDim(0))
717           return failure();
718         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
719       }
720     }
721 
722     auto newOp = rewriter.create<BroadcastOp>(
723         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
724     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
725     return success();
726   }
727 };
728 } // namespace
729 
730 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
731                                               MLIRContext *context) {
732   patterns.add<BroadcastConcretizeResultTypePattern,
733                BroadcastFoldConstantOperandsPattern,
734                BroadcastForwardSingleOperandPattern,
735                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
736                RemoveDuplicateOperandsPattern<BroadcastOp>,
737                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
738 }
739 
740 //===----------------------------------------------------------------------===//
741 // ConcatOp
742 //===----------------------------------------------------------------------===//
743 
744 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
745   if (!operands[0] || !operands[1])
746     return nullptr;
747   auto lhsShape = llvm::to_vector<6>(
748       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
749   auto rhsShape = llvm::to_vector<6>(
750       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
751   SmallVector<int64_t, 6> resultShape;
752   resultShape.append(lhsShape.begin(), lhsShape.end());
753   resultShape.append(rhsShape.begin(), rhsShape.end());
754   Builder builder(getContext());
755   return builder.getIndexTensorAttr(resultShape);
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // ConstShapeOp
760 //===----------------------------------------------------------------------===//
761 
762 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
763   p << " ";
764   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
765   p << "[";
766   interleaveComma(op.shape().getValues<int64_t>(), p,
767                   [&](int64_t i) { p << i; });
768   p << "] : ";
769   p.printType(op.getType());
770 }
771 
772 static ParseResult parseConstShapeOp(OpAsmParser &parser,
773                                      OperationState &result) {
774   if (parser.parseOptionalAttrDict(result.attributes))
775     return failure();
776   // We piggy-back on ArrayAttr parsing, though we don't internally store the
777   // shape as an ArrayAttr.
778   // TODO: Implement custom parser and maybe make syntax a bit more concise.
779   Attribute extentsRaw;
780   NamedAttrList dummy;
781   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
782     return failure();
783   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
784   if (!extentsArray)
785     return failure();
786   SmallVector<int64_t, 6> ints;
787   for (Attribute extent : extentsArray) {
788     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
789     if (!attr)
790       return failure();
791     ints.push_back(attr.getInt());
792   }
793   Builder &builder = parser.getBuilder();
794   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
795   Type resultTy;
796   if (parser.parseColonType(resultTy))
797     return failure();
798   result.types.push_back(resultTy);
799   return success();
800 }
801 
802 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
803 
804 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
805                                                MLIRContext *context) {
806   patterns.add<TensorCastConstShape>(context);
807 }
808 
809 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
810     MLIRContext *context, Optional<Location> location, ValueRange operands,
811     DictionaryAttr attributes, RegionRange regions,
812     SmallVectorImpl<Type> &inferredReturnTypes) {
813   Builder b(context);
814   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
815   if (!shape)
816     return emitOptionalError(location, "missing shape attribute");
817   inferredReturnTypes.assign({RankedTensorType::get(
818       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
819   return success();
820 }
821 
822 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
823                                                         TypeRange r) {
824   if (l.size() != 1 || r.size() != 1)
825     return false;
826 
827   Type lhs = l.front();
828   Type rhs = r.front();
829 
830   if (lhs == rhs)
831     return true;
832 
833   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
834     // Shape type is compatible with all other valid return types.
835     return true;
836 
837   return succeeded(verifyCompatibleShapes(lhs, rhs));
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // CstrBroadcastableOp
842 //===----------------------------------------------------------------------===//
843 
844 void CstrBroadcastableOp::getCanonicalizationPatterns(
845     RewritePatternSet &patterns, MLIRContext *context) {
846   // Canonicalization patterns have overlap with the considerations during
847   // folding in case additional shape information is inferred at some point that
848   // does not result in folding.
849   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
850                CstrBroadcastableEqOps,
851                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
852                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
853 }
854 
855 // Return true if there is exactly one attribute not representing a scalar
856 // broadcast.
857 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
858   bool nonScalarSeen = false;
859   for (Attribute a : attributes) {
860     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
861       if (nonScalarSeen)
862         return false;
863       nonScalarSeen = true;
864     }
865   }
866   return true;
867 }
868 
869 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
870   // No broadcasting is needed if all operands but one are scalar.
871   if (hasAtMostSingleNonScalar(operands))
872     return BoolAttr::get(getContext(), true);
873 
874   if ([&] {
875         SmallVector<SmallVector<int64_t, 6>, 6> extents;
876         for (const auto &operand : operands) {
877           if (!operand)
878             return false;
879           extents.push_back(llvm::to_vector<6>(
880               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
881         }
882         return OpTrait::util::staticallyKnownBroadcastable(extents);
883       }())
884     return BoolAttr::get(getContext(), true);
885 
886   // Lastly, see if folding can be completed based on what constraints are known
887   // on the input shapes.
888   if ([&] {
889         SmallVector<SmallVector<int64_t, 6>, 6> extents;
890         for (auto shapeValue : shapes()) {
891           extents.emplace_back();
892           if (failed(getShapeVec(shapeValue, extents.back())))
893             return false;
894         }
895         return OpTrait::util::staticallyKnownBroadcastable(extents);
896       }())
897     return BoolAttr::get(getContext(), true);
898 
899   // Because a failing witness result here represents an eventual assertion
900   // failure, we do not replace it with a constant witness.
901   return nullptr;
902 }
903 
904 static LogicalResult verify(CstrBroadcastableOp op) {
905   // Ensure that AssumingAllOp contains at least one operand
906   if (op.getNumOperands() < 2)
907     return op.emitOpError("required at least 2 input shapes");
908   return success();
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // CstrEqOp
913 //===----------------------------------------------------------------------===//
914 
915 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
916                                            MLIRContext *context) {
917   // If inputs are equal, return passing witness
918   patterns.add<CstrEqEqOps>(context);
919 }
920 
921 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
922   if (llvm::all_of(operands,
923                    [&](Attribute a) { return a && a == operands[0]; }))
924     return BoolAttr::get(getContext(), true);
925 
926   // Because a failing witness result here represents an eventual assertion
927   // failure, we do not try to replace it with a constant witness. Similarly, we
928   // cannot if there are any non-const inputs.
929   return nullptr;
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // ConstSizeOp
934 //===----------------------------------------------------------------------===//
935 
936 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
937                         int64_t value) {
938   build(builder, result, builder.getIndexAttr(value));
939 }
940 
941 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
942 
943 void ConstSizeOp::getAsmResultNames(
944     llvm::function_ref<void(Value, StringRef)> setNameFn) {
945   SmallString<4> buffer;
946   llvm::raw_svector_ostream os(buffer);
947   os << "c" << value();
948   setNameFn(getResult(), os.str());
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // ConstWitnessOp
953 //===----------------------------------------------------------------------===//
954 
955 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
956 
957 //===----------------------------------------------------------------------===//
958 // CstrRequireOp
959 //===----------------------------------------------------------------------===//
960 
961 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
962   return operands[0];
963 }
964 
965 //===----------------------------------------------------------------------===//
966 // DivOp
967 //===----------------------------------------------------------------------===//
968 
969 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
970   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
971   if (!lhs)
972     return nullptr;
973   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
974   if (!rhs)
975     return nullptr;
976 
977   // Division in APInt does not follow floor(lhs, rhs) when the result is
978   // negative. Rather, APInt rounds toward zero.
979   APInt quotient, remainder;
980   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
981   if (quotient.isNegative() && !remainder.isNullValue()) {
982     quotient -= 1;
983   }
984 
985   Type indexTy = IndexType::get(getContext());
986   return IntegerAttr::get(indexTy, quotient);
987 }
988 
989 LogicalResult mlir::shape::DivOp::inferReturnTypes(
990     MLIRContext *context, Optional<Location> location, ValueRange operands,
991     DictionaryAttr attributes, RegionRange regions,
992     SmallVectorImpl<Type> &inferredReturnTypes) {
993   if (operands[0].getType().isa<SizeType>() ||
994       operands[1].getType().isa<SizeType>())
995     inferredReturnTypes.assign({SizeType::get(context)});
996   else
997     inferredReturnTypes.assign({IndexType::get(context)});
998   return success();
999 }
1000 
1001 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1002   // SizeType is compatible with IndexType.
1003   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // ShapeEqOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1011   bool allSame = true;
1012   if (!operands.empty() && !operands[0])
1013     return {};
1014   for (Attribute operand : operands.drop_front(1)) {
1015     if (!operand)
1016       return {};
1017     allSame = allSame && operand == operands[0];
1018   }
1019   return BoolAttr::get(getContext(), allSame);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // IndexToSizeOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1027   // Constant values of both types, `shape.size` and `index`, are represented as
1028   // `IntegerAttr`s which makes constant folding simple.
1029   if (Attribute arg = operands[0])
1030     return arg;
1031   return {};
1032 }
1033 
1034 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1035                                                 MLIRContext *context) {
1036   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // FromExtentsOp
1041 //===----------------------------------------------------------------------===//
1042 
1043 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1044   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1045     return nullptr;
1046   SmallVector<int64_t, 6> extents;
1047   for (auto attr : operands)
1048     extents.push_back(attr.cast<IntegerAttr>().getInt());
1049   Builder builder(getContext());
1050   return builder.getIndexTensorAttr(extents);
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // FunctionLibraryOp
1055 //===----------------------------------------------------------------------===//
1056 
1057 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1058                               StringRef name) {
1059   result.attributes.push_back(builder.getNamedAttr(
1060       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1061 }
1062 
1063 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1064   auto attr = mapping()
1065                   .get(op->getName().getIdentifier())
1066                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1067   if (!attr)
1068     return nullptr;
1069   return lookupSymbol<FuncOp>(attr);
1070 }
1071 
1072 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1073                                    OperationState &result) {
1074   // Parse the op name.
1075   StringAttr nameAttr;
1076   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1077                              result.attributes))
1078     return failure();
1079 
1080   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1081     return failure();
1082 
1083   auto *bodyRegion = result.addRegion();
1084   if (parser.parseRegion(*bodyRegion))
1085     return failure();
1086 
1087   if (parser.parseKeyword("mapping"))
1088     return failure();
1089 
1090   DictionaryAttr mappingAttr;
1091   if (parser.parseAttribute(mappingAttr,
1092                             parser.getBuilder().getType<NoneType>(), "mapping",
1093                             result.attributes))
1094     return failure();
1095   return success();
1096 }
1097 
1098 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1099   p << ' ';
1100   p.printSymbolName(op.getName());
1101   p.printOptionalAttrDictWithKeyword(
1102       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1103   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1104                 /*printBlockTerminators=*/false);
1105   p << " mapping ";
1106   p.printAttributeWithoutType(op.mappingAttr());
1107 }
1108 
1109 //===----------------------------------------------------------------------===//
1110 // GetExtentOp
1111 //===----------------------------------------------------------------------===//
1112 
1113 Optional<int64_t> GetExtentOp::getConstantDim() {
1114   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1115     return constSizeOp.value().getLimitedValue();
1116   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1117     return constantOp.value().cast<IntegerAttr>().getInt();
1118   return llvm::None;
1119 }
1120 
1121 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1122   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1123   if (!elements)
1124     return nullptr;
1125   Optional<int64_t> dim = getConstantDim();
1126   if (!dim.hasValue())
1127     return nullptr;
1128   if (dim.getValue() >= elements.getNumElements())
1129     return nullptr;
1130   return elements.getValue({(uint64_t)dim.getValue()});
1131 }
1132 
1133 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1134                         int64_t dim) {
1135   auto loc = result.location;
1136   auto dimAttr = builder.getIndexAttr(dim);
1137   if (shape.getType().isa<ShapeType>()) {
1138     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1139     build(builder, result, builder.getType<SizeType>(), shape, dim);
1140   } else {
1141     Value dim =
1142         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1143     build(builder, result, builder.getIndexType(), shape, dim);
1144   }
1145 }
1146 
1147 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1148     MLIRContext *context, Optional<Location> location, ValueRange operands,
1149     DictionaryAttr attributes, RegionRange regions,
1150     SmallVectorImpl<Type> &inferredReturnTypes) {
1151   inferredReturnTypes.assign({IndexType::get(context)});
1152   return success();
1153 }
1154 
1155 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1156                                                        TypeRange r) {
1157   // SizeType is compatible with IndexType.
1158   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // IsBroadcastableOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1166                                                     MLIRContext *context) {
1167   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1168 }
1169 
1170 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1171   // Can always broadcast fewer than two shapes.
1172   if (operands.size() < 2) {
1173     return BoolAttr::get(getContext(), true);
1174   }
1175 
1176   return nullptr;
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // JoinOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 LogicalResult mlir::shape::JoinOp::inferReturnTypes(
1184     MLIRContext *context, Optional<Location> location, ValueRange operands,
1185     DictionaryAttr attributes, RegionRange regions,
1186     SmallVectorImpl<Type> &inferredReturnTypes) {
1187   inferredReturnTypes.assign({operands[0].getType()});
1188   return success();
1189 }
1190 
1191 bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1192   if (l.size() != 1 || r.size() != 1)
1193     return false;
1194   if (l == r)
1195     return true;
1196 
1197   Type lhs = l.front();
1198   Type rhs = r.front();
1199 
1200   if (lhs != rhs)
1201     return false;
1202 
1203   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1204     return true;
1205 
1206   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1207     return true;
1208   return false;
1209 }
1210 
1211 //===----------------------------------------------------------------------===//
1212 // RankOp
1213 //===----------------------------------------------------------------------===//
1214 
1215 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1216   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1217   if (!shape)
1218     return {};
1219   int64_t rank = shape.getNumElements();
1220   Builder builder(getContext());
1221   return builder.getIndexAttr(rank);
1222 }
1223 
1224 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1225 /// Constant folding fails in cases where only the rank is constant, not the
1226 /// shape itself.
1227 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1228 ///
1229 /// Example:
1230 ///
1231 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1232 /// %rank = shape.rank %shape
1233 ///
1234 /// becomes
1235 ///
1236 /// %rank = shape.const_size 3
1237 
1238 namespace {
1239 struct RankShapeOfCanonicalizationPattern
1240     : public OpRewritePattern<shape::RankOp> {
1241   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1242 
1243   LogicalResult matchAndRewrite(shape::RankOp op,
1244                                 PatternRewriter &rewriter) const override {
1245     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1246     if (!shapeOfOp)
1247       return failure();
1248     auto rankedTensorType =
1249         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1250     if (!rankedTensorType)
1251       return failure();
1252     int64_t rank = rankedTensorType.getRank();
1253     if (op.getType().isa<IndexType>()) {
1254       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1255     } else if (op.getType().isa<shape::SizeType>()) {
1256       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1257     } else {
1258       return failure();
1259     }
1260     return success();
1261   }
1262 };
1263 } // namespace
1264 
1265 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1266                                                 MLIRContext *context) {
1267   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1268 }
1269 
1270 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1271     MLIRContext *context, Optional<Location> location, ValueRange operands,
1272     DictionaryAttr attributes, RegionRange regions,
1273     SmallVectorImpl<Type> &inferredReturnTypes) {
1274   if (operands[0].getType().isa<ShapeType>())
1275     inferredReturnTypes.assign({SizeType::get(context)});
1276   else
1277     inferredReturnTypes.assign({IndexType::get(context)});
1278   return success();
1279 }
1280 
1281 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1282   // SizeType is compatible with IndexType.
1283   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // NumElementsOp
1288 //===----------------------------------------------------------------------===//
1289 
1290 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1291 
1292   // Fold only when argument constant.
1293   Attribute shape = operands[0];
1294   if (!shape)
1295     return {};
1296 
1297   APInt product(64, 1);
1298   for (auto value : shape.cast<DenseIntElementsAttr>())
1299     product *= value;
1300   Builder builder(getContext());
1301   return builder.getIndexAttr(product.getLimitedValue());
1302 }
1303 
1304 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1305     MLIRContext *context, Optional<Location> location, ValueRange operands,
1306     DictionaryAttr attributes, RegionRange regions,
1307     SmallVectorImpl<Type> &inferredReturnTypes) {
1308   if (operands[0].getType().isa<ShapeType>())
1309     inferredReturnTypes.assign({SizeType::get(context)});
1310   else
1311     inferredReturnTypes.assign({IndexType::get(context)});
1312   return success();
1313 }
1314 
1315 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1316                                                          TypeRange r) {
1317   // SizeType is compatible with IndexType.
1318   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // MaxOp
1323 //===----------------------------------------------------------------------===//
1324 
1325 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1326   // If operands are equal, just propagate one.
1327   if (lhs() == rhs())
1328     return lhs();
1329   return nullptr;
1330 }
1331 
1332 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1333     MLIRContext *context, Optional<Location> location, ValueRange operands,
1334     DictionaryAttr attributes, RegionRange regions,
1335     SmallVectorImpl<Type> &inferredReturnTypes) {
1336   if (operands[0].getType() == operands[1].getType())
1337     inferredReturnTypes.assign({operands[0].getType()});
1338   else
1339     inferredReturnTypes.assign({SizeType::get(context)});
1340   return success();
1341 }
1342 
1343 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1344   if (l.size() != 1 || r.size() != 1)
1345     return false;
1346   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1347     return true;
1348   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1349     return true;
1350   return false;
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // MinOp
1355 //===----------------------------------------------------------------------===//
1356 
1357 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1358   // If operands are equal, just propagate one.
1359   if (lhs() == rhs())
1360     return lhs();
1361   return nullptr;
1362 }
1363 
1364 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1365     MLIRContext *context, Optional<Location> location, ValueRange operands,
1366     DictionaryAttr attributes, RegionRange regions,
1367     SmallVectorImpl<Type> &inferredReturnTypes) {
1368   if (operands[0].getType() == operands[1].getType())
1369     inferredReturnTypes.assign({operands[0].getType()});
1370   else
1371     inferredReturnTypes.assign({SizeType::get(context)});
1372   return success();
1373 }
1374 
1375 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1376   if (l.size() != 1 || r.size() != 1)
1377     return false;
1378   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1379     return true;
1380   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1381     return true;
1382   return false;
1383 }
1384 
1385 //===----------------------------------------------------------------------===//
1386 // MulOp
1387 //===----------------------------------------------------------------------===//
1388 
1389 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1390   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1391   if (!lhs)
1392     return nullptr;
1393   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1394   if (!rhs)
1395     return nullptr;
1396   APInt folded = lhs.getValue() * rhs.getValue();
1397   Type indexTy = IndexType::get(getContext());
1398   return IntegerAttr::get(indexTy, folded);
1399 }
1400 
1401 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1402     MLIRContext *context, Optional<Location> location, ValueRange operands,
1403     DictionaryAttr attributes, RegionRange regions,
1404     SmallVectorImpl<Type> &inferredReturnTypes) {
1405   if (operands[0].getType().isa<SizeType>() ||
1406       operands[1].getType().isa<SizeType>())
1407     inferredReturnTypes.assign({SizeType::get(context)});
1408   else
1409     inferredReturnTypes.assign({IndexType::get(context)});
1410   return success();
1411 }
1412 
1413 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1414   // SizeType is compatible with IndexType.
1415   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1416 }
1417 //===----------------------------------------------------------------------===//
1418 // ShapeOfOp
1419 //===----------------------------------------------------------------------===//
1420 
1421 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1422   auto type = getOperand().getType().dyn_cast<ShapedType>();
1423   if (!type || !type.hasStaticShape())
1424     return nullptr;
1425   Builder builder(getContext());
1426   return builder.getIndexTensorAttr(type.getShape());
1427 }
1428 
1429 namespace {
1430 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1431   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1432 
1433   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1434                                 PatternRewriter &rewriter) const override {
1435     if (!op.arg().getType().isa<ShapedType>())
1436       return failure();
1437     if (op.getType().isa<ShapedType>())
1438       return failure();
1439 
1440     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1441     return success();
1442   }
1443 };
1444 
1445 // Canonicalize
1446 // ```
1447 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1448 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1449 // ```
1450 // to
1451 // ```
1452 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1453 // ```
1454 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1455   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1456 
1457   LogicalResult matchAndRewrite(tensor::CastOp op,
1458                                 PatternRewriter &rewriter) const override {
1459     auto ty = op.getType().dyn_cast<RankedTensorType>();
1460     if (!ty || ty.getRank() != 1)
1461       return failure();
1462 
1463     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1464     if (!shapeOfOp)
1465       return failure();
1466 
1467     // Argument type must be ranked and must not conflict.
1468     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1469     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1470       return failure();
1471 
1472     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1473     return success();
1474   }
1475 };
1476 } // namespace
1477 
1478 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1479                                             MLIRContext *context) {
1480   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1481 }
1482 
1483 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1484     MLIRContext *context, Optional<Location> location, ValueRange operands,
1485     DictionaryAttr attributes, RegionRange regions,
1486     SmallVectorImpl<Type> &inferredReturnTypes) {
1487   if (operands[0].getType().isa<ValueShapeType>())
1488     inferredReturnTypes.assign({ShapeType::get(context)});
1489   else {
1490     auto shapedTy = operands[0].getType().cast<ShapedType>();
1491     int64_t rank =
1492         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1493     Type indexTy = IndexType::get(context);
1494     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1495     inferredReturnTypes.assign({extentTensorTy});
1496   }
1497   return success();
1498 }
1499 
1500 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1501   if (l.size() != 1 || r.size() != 1)
1502     return false;
1503   if (l == r)
1504     return true;
1505 
1506   Type lhs = l.front();
1507   Type rhs = r.front();
1508 
1509   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1510     return false;
1511 
1512   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1513     // Shape type is compatible with all other valid return types.
1514     return true;
1515 
1516   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1517     return true;
1518   return false;
1519 }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // SizeToIndexOp
1523 //===----------------------------------------------------------------------===//
1524 
1525 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1526   // Constant values of both types, `shape.size` and `index`, are represented as
1527   // `IntegerAttr`s which makes constant folding simple.
1528   if (Attribute arg = operands[0])
1529     return arg;
1530   return impl::foldCastOp(*this);
1531 }
1532 
1533 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1534                                                 MLIRContext *context) {
1535   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // YieldOp
1540 //===----------------------------------------------------------------------===//
1541 
1542 static LogicalResult verify(shape::YieldOp op) {
1543   auto *parentOp = op->getParentOp();
1544   auto results = parentOp->getResults();
1545   auto operands = op.getOperands();
1546 
1547   if (parentOp->getNumResults() != op.getNumOperands())
1548     return op.emitOpError() << "number of operands does not match number of "
1549                                "results of its parent";
1550   for (auto e : llvm::zip(results, operands))
1551     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1552       return op.emitOpError()
1553              << "types mismatch between yield op and its parent";
1554 
1555   return success();
1556 }
1557 
1558 //===----------------------------------------------------------------------===//
1559 // SplitAtOp
1560 //===----------------------------------------------------------------------===//
1561 
1562 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1563                               SmallVectorImpl<OpFoldResult> &results) {
1564   if (!operands[0] || !operands[1])
1565     return failure();
1566   auto shapeVec = llvm::to_vector<6>(
1567       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1568   auto shape = llvm::makeArrayRef(shapeVec);
1569   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1570   // Verify that the split point is in the correct range.
1571   // TODO: Constant fold to an "error".
1572   int64_t rank = shape.size();
1573   if (!(-rank <= splitPoint && splitPoint <= rank))
1574     return failure();
1575   if (splitPoint < 0)
1576     splitPoint += shape.size();
1577   Builder builder(operands[0].getContext());
1578   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1579   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1580   return success();
1581 }
1582 
1583 //===----------------------------------------------------------------------===//
1584 // ToExtentTensorOp
1585 //===----------------------------------------------------------------------===//
1586 
1587 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1588   if (!operands[0])
1589     return impl::foldCastOp(*this);
1590   Builder builder(getContext());
1591   auto shape = llvm::to_vector<6>(
1592       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1593   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1594                                     builder.getIndexType());
1595   return DenseIntElementsAttr::get(type, shape);
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // ReduceOp
1600 //===----------------------------------------------------------------------===//
1601 
1602 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1603                      ValueRange initVals) {
1604   result.addOperands(shape);
1605   result.addOperands(initVals);
1606 
1607   Region *bodyRegion = result.addRegion();
1608   bodyRegion->push_back(new Block);
1609   Block &bodyBlock = bodyRegion->front();
1610   bodyBlock.addArgument(builder.getIndexType());
1611 
1612   Type elementType;
1613   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1614     elementType = tensorType.getElementType();
1615   else
1616     elementType = SizeType::get(builder.getContext());
1617   bodyBlock.addArgument(elementType);
1618 
1619   for (Type initValType : initVals.getTypes()) {
1620     bodyBlock.addArgument(initValType);
1621     result.addTypes(initValType);
1622   }
1623 }
1624 
1625 static LogicalResult verify(ReduceOp op) {
1626   // Verify block arg types.
1627   Block &block = op.region().front();
1628 
1629   // The block takes index, extent, and aggregated values as arguments.
1630   auto blockArgsCount = op.initVals().size() + 2;
1631   if (block.getNumArguments() != blockArgsCount)
1632     return op.emitOpError() << "ReduceOp body is expected to have "
1633                             << blockArgsCount << " arguments";
1634 
1635   // The first block argument is the index and must always be of type `index`.
1636   if (!block.getArgument(0).getType().isa<IndexType>())
1637     return op.emitOpError(
1638         "argument 0 of ReduceOp body is expected to be of IndexType");
1639 
1640   // The second block argument is the extent and must be of type `size` or
1641   // `index`, depending on whether the reduce operation is applied to a shape or
1642   // to an extent tensor.
1643   Type extentTy = block.getArgument(1).getType();
1644   if (op.shape().getType().isa<ShapeType>()) {
1645     if (!extentTy.isa<SizeType>())
1646       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1647                             "SizeType if the ReduceOp operates on a ShapeType");
1648   } else {
1649     if (!extentTy.isa<IndexType>())
1650       return op.emitOpError(
1651           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1652           "ReduceOp operates on an extent tensor");
1653   }
1654 
1655   for (auto type : llvm::enumerate(op.initVals()))
1656     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1657       return op.emitOpError()
1658              << "type mismatch between argument " << type.index() + 2
1659              << " of ReduceOp body and initial value " << type.index();
1660   return success();
1661 }
1662 
1663 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1664   // Parse operands.
1665   SmallVector<OpAsmParser::OperandType, 3> operands;
1666   Type shapeOrExtentTensorType;
1667   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1668                               OpAsmParser::Delimiter::Paren) ||
1669       parser.parseColonType(shapeOrExtentTensorType) ||
1670       parser.parseOptionalArrowTypeList(result.types))
1671     return failure();
1672 
1673   // Resolve operands.
1674   auto initVals = llvm::makeArrayRef(operands).drop_front();
1675   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1676                             result.operands) ||
1677       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1678                              result.operands))
1679     return failure();
1680 
1681   // Parse the body.
1682   Region *body = result.addRegion();
1683   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1684     return failure();
1685 
1686   // Parse attributes.
1687   if (parser.parseOptionalAttrDict(result.attributes))
1688     return failure();
1689 
1690   return success();
1691 }
1692 
1693 static void print(OpAsmPrinter &p, ReduceOp op) {
1694   p << '(' << op.shape() << ", " << op.initVals()
1695     << ") : " << op.shape().getType();
1696   p.printOptionalArrowTypeList(op.getResultTypes());
1697   p.printRegion(op.region());
1698   p.printOptionalAttrDict(op->getAttrs());
1699 }
1700 
1701 #define GET_OP_CLASSES
1702 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1703