1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Traits.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::shape;
31 
32 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
33 
34 namespace {
35 #include "ShapeCanonicalization.inc"
36 } // namespace
37 
38 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
39   return RankedTensorType::get({rank}, IndexType::get(ctx));
40 }
41 
42 bool shape::isExtentTensorType(Type type) {
43   auto ranked = type.dyn_cast<RankedTensorType>();
44   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
45 }
46 
47 LogicalResult shape::getShapeVec(Value input,
48                                  SmallVectorImpl<int64_t> &shapeValues) {
49   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
50     auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
51     if (!type.hasRank())
52       return failure();
53     shapeValues = llvm::to_vector<6>(type.getShape());
54     return success();
55   }
56   if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
57     shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
58     return success();
59   }
60   if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
61     shapeValues = llvm::to_vector<6>(
62         inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
63     return success();
64   }
65   return failure();
66 }
67 
68 static bool isErrorPropagationPossible(TypeRange operandTypes) {
69   return llvm::any_of(operandTypes, [](Type ty) {
70     return ty.isa<SizeType, ShapeType, ValueShapeType>();
71   });
72 }
73 
74 static LogicalResult verifySizeOrIndexOp(Operation *op) {
75   assert(op != nullptr && op->getNumResults() == 1);
76   Type resultTy = op->getResultTypes().front();
77   if (isErrorPropagationPossible(op->getOperandTypes())) {
78     if (!resultTy.isa<SizeType>())
79       return op->emitOpError()
80              << "if at least one of the operands can hold error values then "
81                 "the result must be of type `size` to propagate them";
82   }
83   return success();
84 }
85 
86 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
87   assert(op != nullptr && op->getNumResults() == 1);
88   Type resultTy = op->getResultTypes().front();
89   if (isErrorPropagationPossible(op->getOperandTypes())) {
90     if (!resultTy.isa<ShapeType>())
91       return op->emitOpError()
92              << "if at least one of the operands can hold error values then "
93                 "the result must be of type `shape` to propagate them";
94   }
95   return success();
96 }
97 
98 template <typename... Ty>
99 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
100   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
101 }
102 
103 template <typename... Ty, typename... ranges>
104 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
105   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // InlinerInterface
110 //===----------------------------------------------------------------------===//
111 
112 namespace {
113 /// This class defines the interface for inlining shape dialect ops.
114 struct ShapeInlinerInterface : public DialectInlinerInterface {
115   using DialectInlinerInterface::DialectInlinerInterface;
116 
117   // Returns true if the given region 'src' can be inlined into the region
118   // 'dest' that is attached to an operation registered to the current dialect.
119   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
120                        BlockAndValueMapping &) const final {
121     return true;
122   }
123 
124   // Returns true if the given operation 'op', that is registered to this
125   // dialect, can be inlined into the region 'dest' that is attached to an
126   // operation registered to the current dialect.
127   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
128                        BlockAndValueMapping &) const final {
129     return true;
130   }
131 };
132 } // namespace
133 
134 void ShapeDialect::initialize() {
135   addOperations<
136 #define GET_OP_LIST
137 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
138       >();
139   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
140   addInterfaces<ShapeInlinerInterface>();
141   // Allow unknown operations during prototyping and testing. As the dialect is
142   // still evolving it makes it simple to start with an unregistered ops and
143   // try different variants before actually defining the op.
144   allowUnknownOperations();
145 }
146 
147 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
148                                              Attribute value, Type type,
149                                              Location loc) {
150   if (type.isa<ShapeType>() || isExtentTensorType(type))
151     return builder.create<ConstShapeOp>(loc, type,
152                                         value.cast<DenseIntElementsAttr>());
153   if (type.isa<SizeType>())
154     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
155   if (type.isa<WitnessType>())
156     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
157   if (arith::ConstantOp::isBuildableWith(value, type))
158     return builder.create<arith::ConstantOp>(loc, type, value);
159   return nullptr;
160 }
161 
162 /// Parse a type registered to this dialect.
163 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
164   StringRef keyword;
165   if (parser.parseKeyword(&keyword))
166     return Type();
167 
168   if (keyword == "shape")
169     return ShapeType::get(getContext());
170   if (keyword == "size")
171     return SizeType::get(getContext());
172   if (keyword == "value_shape")
173     return ValueShapeType::get(getContext());
174   if (keyword == "witness")
175     return WitnessType::get(getContext());
176 
177   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
178   return Type();
179 }
180 
181 /// Print a type registered to this dialect.
182 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
183   TypeSwitch<Type>(type)
184       .Case<ShapeType>([&](Type) { os << "shape"; })
185       .Case<SizeType>([&](Type) { os << "size"; })
186       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
187       .Case<WitnessType>([&](Type) { os << "witness"; })
188       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
189 }
190 
191 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
192                                                      NamedAttribute attribute) {
193   // Verify shape.lib attribute.
194   if (attribute.getName() == "shape.lib") {
195     if (!op->hasTrait<OpTrait::SymbolTable>())
196       return op->emitError(
197           "shape.lib attribute may only be on op implementing SymbolTable");
198 
199     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
200       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
201       if (!symbol)
202         return op->emitError("shape function library ")
203                << symbolRef << " not found";
204       return isa<shape::FunctionLibraryOp>(symbol)
205                  ? success()
206                  : op->emitError()
207                        << symbolRef << " required to be shape function library";
208     }
209 
210     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
211       // Verify all entries are function libraries and mappings in libraries
212       // refer to unique ops.
213       DenseSet<StringAttr> key;
214       for (auto it : arr) {
215         if (!it.isa<SymbolRefAttr>())
216           return op->emitError(
217               "only SymbolRefAttr allowed in shape.lib attribute array");
218 
219         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
220             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
221         if (!shapeFnLib)
222           return op->emitError()
223                  << it << " does not refer to FunctionLibraryOp";
224         for (auto mapping : shapeFnLib.getMapping()) {
225           if (!key.insert(mapping.getName()).second) {
226             return op->emitError("only one op to shape mapping allowed, found "
227                                  "multiple for `")
228                    << mapping.getName() << "`";
229           }
230         }
231       }
232       return success();
233     }
234 
235     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
236                          "allowed as shape.lib attribute");
237   }
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // AnyOp
243 //===----------------------------------------------------------------------===//
244 
245 // TODO: Canonicalization should be implemented for shapes that can be
246 // determined through mixtures of the known dimensions of the inputs.
247 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
248   // Only the last operand is checked because AnyOp is commutative.
249   if (operands.back())
250     return operands.back();
251 
252   return nullptr;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // AssumingOp
257 //===----------------------------------------------------------------------===//
258 
259 static ParseResult parseAssumingOp(OpAsmParser &parser,
260                                    OperationState &result) {
261   result.regions.reserve(1);
262   Region *doRegion = result.addRegion();
263 
264   auto &builder = parser.getBuilder();
265   OpAsmParser::OperandType cond;
266   if (parser.parseOperand(cond) ||
267       parser.resolveOperand(cond, builder.getType<WitnessType>(),
268                             result.operands))
269     return failure();
270 
271   // Parse optional results type list.
272   if (parser.parseOptionalArrowTypeList(result.types))
273     return failure();
274 
275   // Parse the region and add a terminator if elided.
276   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
277     return failure();
278   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
279 
280   // Parse the optional attribute list.
281   if (parser.parseOptionalAttrDict(result.attributes))
282     return failure();
283   return success();
284 }
285 
286 static void print(OpAsmPrinter &p, AssumingOp op) {
287   bool yieldsResults = !op.getResults().empty();
288 
289   p << " " << op.getWitness();
290   if (yieldsResults)
291     p << " -> (" << op.getResultTypes() << ")";
292   p << ' ';
293   p.printRegion(op.getDoRegion(),
294                 /*printEntryBlockArgs=*/false,
295                 /*printBlockTerminators=*/yieldsResults);
296   p.printOptionalAttrDict(op->getAttrs());
297 }
298 
299 namespace {
300 // Removes AssumingOp with a passing witness and inlines the region.
301 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
302   using OpRewritePattern<AssumingOp>::OpRewritePattern;
303 
304   LogicalResult matchAndRewrite(AssumingOp op,
305                                 PatternRewriter &rewriter) const override {
306     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
307     if (!witness || !witness.getPassingAttr())
308       return failure();
309 
310     AssumingOp::inlineRegionIntoParent(op, rewriter);
311     return success();
312   }
313 };
314 
315 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
316   using OpRewritePattern<AssumingOp>::OpRewritePattern;
317 
318   LogicalResult matchAndRewrite(AssumingOp op,
319                                 PatternRewriter &rewriter) const override {
320     Block *body = op.getBody();
321     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
322 
323     // Find used values.
324     SmallVector<Value, 4> newYieldOperands;
325     Value opResult, yieldOperand;
326     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
327       std::tie(opResult, yieldOperand) = it;
328       if (!opResult.getUses().empty()) {
329         newYieldOperands.push_back(yieldOperand);
330       }
331     }
332 
333     // Rewrite only if redundant results exist.
334     if (newYieldOperands.size() == yieldOp->getNumOperands())
335       return failure();
336 
337     // Replace yield op in the old assuming op's body and move the entire region
338     // to the new assuming op.
339     rewriter.setInsertionPointToEnd(body);
340     auto newYieldOp =
341         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
342     rewriter.setInsertionPoint(op);
343     auto newOp = rewriter.create<AssumingOp>(
344         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
345     newOp.getDoRegion().takeBody(op.getDoRegion());
346 
347     // Use the new results to replace the previously used ones.
348     SmallVector<Value, 4> replacementValues;
349     auto src = newOp.getResults().begin();
350     for (auto it : op.getResults()) {
351       if (it.getUses().empty())
352         replacementValues.push_back(nullptr);
353       else
354         replacementValues.push_back(*src++);
355     }
356     rewriter.replaceOp(op, replacementValues);
357     return success();
358   }
359 };
360 } // namespace
361 
362 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
363                                              MLIRContext *context) {
364   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
365 }
366 
367 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
368 void AssumingOp::getSuccessorRegions(
369     Optional<unsigned> index, ArrayRef<Attribute> operands,
370     SmallVectorImpl<RegionSuccessor> &regions) {
371   // AssumingOp has unconditional control flow into the region and back to the
372   // parent, so return the correct RegionSuccessor purely based on the index
373   // being None or 0.
374   if (index.hasValue()) {
375     regions.push_back(RegionSuccessor(getResults()));
376     return;
377   }
378 
379   regions.push_back(RegionSuccessor(&getDoRegion()));
380 }
381 
382 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
383                                         PatternRewriter &rewriter) {
384   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
385   auto *assumingBlock = op.getBody();
386   auto initPosition = rewriter.getInsertionPoint();
387   auto *blockAfterAssuming =
388       rewriter.splitBlock(blockBeforeAssuming, initPosition);
389 
390   // Remove the AssumingOp and AssumingYieldOp.
391   auto &yieldOp = assumingBlock->back();
392   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
393   rewriter.replaceOp(op, yieldOp.getOperands());
394   rewriter.eraseOp(&yieldOp);
395 
396   // Merge blocks together as there was no branching behavior from the
397   // AssumingOp.
398   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
399   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
400 }
401 
402 void AssumingOp::build(
403     OpBuilder &builder, OperationState &result, Value witness,
404     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
405 
406   result.addOperands(witness);
407   Region *bodyRegion = result.addRegion();
408   bodyRegion->push_back(new Block);
409   Block &bodyBlock = bodyRegion->front();
410 
411   // Build body.
412   OpBuilder::InsertionGuard guard(builder);
413   builder.setInsertionPointToStart(&bodyBlock);
414   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
415   builder.create<AssumingYieldOp>(result.location, yieldValues);
416 
417   SmallVector<Type, 2> assumingTypes;
418   for (Value v : yieldValues)
419     assumingTypes.push_back(v.getType());
420   result.addTypes(assumingTypes);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // AddOp
425 //===----------------------------------------------------------------------===//
426 
427 LogicalResult mlir::shape::AddOp::inferReturnTypes(
428     MLIRContext *context, Optional<Location> location, ValueRange operands,
429     DictionaryAttr attributes, RegionRange regions,
430     SmallVectorImpl<Type> &inferredReturnTypes) {
431   if (operands[0].getType().isa<SizeType>() ||
432       operands[1].getType().isa<SizeType>())
433     inferredReturnTypes.assign({SizeType::get(context)});
434   else
435     inferredReturnTypes.assign({IndexType::get(context)});
436   return success();
437 }
438 
439 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
440   // SizeType is compatible with IndexType.
441   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
442 }
443 
444 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
445   // add(x, 0) -> x
446   if (matchPattern(getRhs(), m_Zero()))
447     return getLhs();
448 
449   return constFoldBinaryOp<IntegerAttr>(
450       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // AssumingAllOp
455 //===----------------------------------------------------------------------===//
456 
457 namespace {
458 struct AssumingAllToCstrEqCanonicalization
459     : public OpRewritePattern<AssumingAllOp> {
460   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
461 
462   LogicalResult matchAndRewrite(AssumingAllOp op,
463                                 PatternRewriter &rewriter) const override {
464     SmallVector<Value, 8> shapes;
465     for (Value w : op.getInputs()) {
466       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
467       if (!cstrEqOp)
468         return failure();
469       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
470         return llvm::is_contained(shapes, s);
471       });
472       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
473         return failure();
474       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
475     }
476     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
477     return success();
478   }
479 };
480 
481 template <typename OpTy>
482 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
483   using OpRewritePattern<OpTy>::OpRewritePattern;
484 
485   LogicalResult matchAndRewrite(OpTy op,
486                                 PatternRewriter &rewriter) const override {
487     // Find unique operands.
488     SmallVector<Value, 2> unique;
489     for (Value v : op.getOperands()) {
490       if (!llvm::is_contained(unique, v))
491         unique.push_back(v);
492     }
493 
494     // Reduce op to equivalent with unique operands.
495     if (unique.size() < op.getNumOperands()) {
496       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
497                                         op->getAttrs());
498       return success();
499     }
500 
501     return failure();
502   }
503 };
504 } // namespace
505 
506 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
507                                                 MLIRContext *context) {
508   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
509                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
510 }
511 
512 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
513   // Iterate in reverse to first handle all constant operands. They are
514   // guaranteed to be the tail of the inputs because this is commutative.
515   for (int idx = operands.size() - 1; idx >= 0; idx--) {
516     Attribute a = operands[idx];
517     // Cannot fold if any inputs are not constant;
518     if (!a)
519       return nullptr;
520 
521     // We do not need to keep statically known values after handling them in
522     // this method.
523     getOperation()->eraseOperand(idx);
524 
525     // Always false if any input is statically known false
526     if (!a.cast<BoolAttr>().getValue())
527       return a;
528   }
529   // If this is reached, all inputs were statically known passing.
530   return BoolAttr::get(getContext(), true);
531 }
532 
533 static LogicalResult verify(AssumingAllOp op) {
534   // Ensure that AssumingAllOp contains at least one operand
535   if (op.getNumOperands() == 0)
536     return op.emitOpError("no operands specified");
537 
538   return success();
539 }
540 
541 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
542                           ValueRange inputs) {
543   build(b, state, b.getType<WitnessType>(), inputs);
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // BroadcastOp
548 //===----------------------------------------------------------------------===//
549 
550 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
551   if (getShapes().size() == 1) {
552     // Otherwise, we need a cast which would be a canonicalization, not folding.
553     if (getShapes().front().getType() != getType())
554       return nullptr;
555     return getShapes().front();
556   }
557 
558   // TODO: Support folding with more than 2 input shapes
559   if (getShapes().size() > 2)
560     return nullptr;
561 
562   if (!operands[0] || !operands[1])
563     return nullptr;
564   auto lhsShape = llvm::to_vector<6>(
565       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
566   auto rhsShape = llvm::to_vector<6>(
567       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
568   SmallVector<int64_t, 6> resultShape;
569 
570   // If the shapes are not compatible, we can't fold it.
571   // TODO: Fold to an "error".
572   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
573     return nullptr;
574 
575   Builder builder(getContext());
576   return builder.getIndexTensorAttr(resultShape);
577 }
578 
579 static LogicalResult verify(BroadcastOp op) {
580   return verifyShapeOrExtentTensorOp(op);
581 }
582 
583 namespace {
584 template <typename OpTy>
585 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
586   using OpRewritePattern<OpTy>::OpRewritePattern;
587 
588   LogicalResult matchAndRewrite(OpTy op,
589                                 PatternRewriter &rewriter) const override {
590     auto isPotentiallyNonEmptyShape = [](Value shape) {
591       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
592         if (extentTensorTy.getDimSize(0) == 0)
593           return false;
594       }
595       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
596         if (constShape.getShape().empty())
597           return false;
598       }
599       return true;
600     };
601     auto newOperands = llvm::to_vector<8>(
602         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
603 
604     // Reduce op to equivalent without empty shape operands.
605     if (newOperands.size() < op.getNumOperands()) {
606       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
607                                         op->getAttrs());
608       return success();
609     }
610 
611     return failure();
612   }
613 };
614 
615 struct BroadcastForwardSingleOperandPattern
616     : public OpRewritePattern<BroadcastOp> {
617   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
618 
619   LogicalResult matchAndRewrite(BroadcastOp op,
620                                 PatternRewriter &rewriter) const override {
621     if (op.getNumOperands() != 1)
622       return failure();
623     Value replacement = op.getShapes().front();
624 
625     // Insert cast if needed.
626     if (replacement.getType() != op.getType()) {
627       auto loc = op.getLoc();
628       if (op.getType().isa<ShapeType>()) {
629         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
630       } else {
631         assert(!op.getType().isa<ShapeType>() &&
632                !replacement.getType().isa<ShapeType>() &&
633                "expect extent tensor cast");
634         replacement =
635             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
636       }
637     }
638 
639     rewriter.replaceOp(op, replacement);
640     return success();
641   }
642 };
643 
644 struct BroadcastFoldConstantOperandsPattern
645     : public OpRewritePattern<BroadcastOp> {
646   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
647 
648   LogicalResult matchAndRewrite(BroadcastOp op,
649                                 PatternRewriter &rewriter) const override {
650     SmallVector<int64_t, 8> foldedConstantShape;
651     SmallVector<Value, 8> newShapeOperands;
652     for (Value shape : op.getShapes()) {
653       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
654         SmallVector<int64_t, 8> newFoldedConstantShape;
655         if (OpTrait::util::getBroadcastedShape(
656                 foldedConstantShape,
657                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
658                 newFoldedConstantShape)) {
659           foldedConstantShape = newFoldedConstantShape;
660           continue;
661         }
662       }
663       newShapeOperands.push_back(shape);
664     }
665 
666     // Need at least two constant operands to fold anything.
667     if (op.getNumOperands() - newShapeOperands.size() < 2)
668       return failure();
669 
670     auto foldedConstantOperandsTy = RankedTensorType::get(
671         {static_cast<int64_t>(foldedConstantShape.size())},
672         rewriter.getIndexType());
673     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
674         op.getLoc(), foldedConstantOperandsTy,
675         rewriter.getIndexTensorAttr(foldedConstantShape)));
676     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
677                                              newShapeOperands);
678     return success();
679   }
680 };
681 
682 template <typename OpTy>
683 struct CanonicalizeCastExtentTensorOperandsPattern
684     : public OpRewritePattern<OpTy> {
685   using OpRewritePattern<OpTy>::OpRewritePattern;
686 
687   LogicalResult matchAndRewrite(OpTy op,
688                                 PatternRewriter &rewriter) const override {
689     // Canonicalize operands.
690     bool anyChange = false;
691     auto canonicalizeOperand = [&](Value operand) {
692       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
693         // Only eliminate the cast if it holds no shape information.
694         bool isInformationLoosingCast =
695             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
696         if (isInformationLoosingCast) {
697           anyChange = true;
698           return castOp.source();
699         }
700       }
701       return operand;
702     };
703     auto newOperands = llvm::to_vector<8>(
704         llvm::map_range(op.getOperands(), canonicalizeOperand));
705 
706     // Rewrite op if any change required.
707     if (!anyChange)
708       return failure();
709     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
710     return success();
711   }
712 };
713 
714 struct BroadcastConcretizeResultTypePattern
715     : public OpRewritePattern<BroadcastOp> {
716   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
717 
718   LogicalResult matchAndRewrite(BroadcastOp op,
719                                 PatternRewriter &rewriter) const override {
720     // Only concretize dynamic extent tensor result types.
721     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
722     if (!resultTy || !resultTy.isDynamicDim(0))
723       return failure();
724 
725     // Infer resulting shape rank if possible.
726     int64_t maxRank = 0;
727     for (Value shape : op.getShapes()) {
728       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
729         // Cannot infer resulting shape rank if any operand is dynamically
730         // ranked.
731         if (extentTensorTy.isDynamicDim(0))
732           return failure();
733         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
734       }
735     }
736 
737     auto newOp = rewriter.create<BroadcastOp>(
738         op.getLoc(), getExtentTensorType(getContext(), maxRank),
739         op.getShapes());
740     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
741     return success();
742   }
743 };
744 } // namespace
745 
746 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747                                               MLIRContext *context) {
748   patterns.add<BroadcastConcretizeResultTypePattern,
749                BroadcastFoldConstantOperandsPattern,
750                BroadcastForwardSingleOperandPattern,
751                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
752                RemoveDuplicateOperandsPattern<BroadcastOp>,
753                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // ConcatOp
758 //===----------------------------------------------------------------------===//
759 
760 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
761   if (!operands[0] || !operands[1])
762     return nullptr;
763   auto lhsShape = llvm::to_vector<6>(
764       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
765   auto rhsShape = llvm::to_vector<6>(
766       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
767   SmallVector<int64_t, 6> resultShape;
768   resultShape.append(lhsShape.begin(), lhsShape.end());
769   resultShape.append(rhsShape.begin(), rhsShape.end());
770   Builder builder(getContext());
771   return builder.getIndexTensorAttr(resultShape);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // ConstShapeOp
776 //===----------------------------------------------------------------------===//
777 
778 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
779   p << " ";
780   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
781   p << "[";
782   interleaveComma(op.getShape().getValues<int64_t>(), p,
783                   [&](int64_t i) { p << i; });
784   p << "] : ";
785   p.printType(op.getType());
786 }
787 
788 static ParseResult parseConstShapeOp(OpAsmParser &parser,
789                                      OperationState &result) {
790   if (parser.parseOptionalAttrDict(result.attributes))
791     return failure();
792   // We piggy-back on ArrayAttr parsing, though we don't internally store the
793   // shape as an ArrayAttr.
794   // TODO: Implement custom parser and maybe make syntax a bit more concise.
795   Attribute extentsRaw;
796   NamedAttrList dummy;
797   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
798     return failure();
799   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
800   if (!extentsArray)
801     return failure();
802   SmallVector<int64_t, 6> ints;
803   for (Attribute extent : extentsArray) {
804     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
805     if (!attr)
806       return failure();
807     ints.push_back(attr.getInt());
808   }
809   Builder &builder = parser.getBuilder();
810   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
811   Type resultTy;
812   if (parser.parseColonType(resultTy))
813     return failure();
814   result.types.push_back(resultTy);
815   return success();
816 }
817 
818 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
819 
820 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
821                                                MLIRContext *context) {
822   patterns.add<TensorCastConstShape>(context);
823 }
824 
825 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
826     MLIRContext *context, Optional<Location> location, ValueRange operands,
827     DictionaryAttr attributes, RegionRange regions,
828     SmallVectorImpl<Type> &inferredReturnTypes) {
829   Builder b(context);
830   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
831   if (!shape)
832     return emitOptionalError(location, "missing shape attribute");
833   inferredReturnTypes.assign({RankedTensorType::get(
834       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
835   return success();
836 }
837 
838 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
839                                                         TypeRange r) {
840   if (l.size() != 1 || r.size() != 1)
841     return false;
842 
843   Type lhs = l.front();
844   Type rhs = r.front();
845 
846   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
847     // Shape type is compatible with all other valid return types.
848     return true;
849   return lhs == rhs;
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // CstrBroadcastableOp
854 //===----------------------------------------------------------------------===//
855 
856 void CstrBroadcastableOp::getCanonicalizationPatterns(
857     RewritePatternSet &patterns, MLIRContext *context) {
858   // Canonicalization patterns have overlap with the considerations during
859   // folding in case additional shape information is inferred at some point that
860   // does not result in folding.
861   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
862                CstrBroadcastableEqOps,
863                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
864                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
865 }
866 
867 // Return true if there is exactly one attribute not representing a scalar
868 // broadcast.
869 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
870   bool nonScalarSeen = false;
871   for (Attribute a : attributes) {
872     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
873       if (nonScalarSeen)
874         return false;
875       nonScalarSeen = true;
876     }
877   }
878   return true;
879 }
880 
881 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
882   // No broadcasting is needed if all operands but one are scalar.
883   if (hasAtMostSingleNonScalar(operands))
884     return BoolAttr::get(getContext(), true);
885 
886   if ([&] {
887         SmallVector<SmallVector<int64_t, 6>, 6> extents;
888         for (const auto &operand : operands) {
889           if (!operand)
890             return false;
891           extents.push_back(llvm::to_vector<6>(
892               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
893         }
894         return OpTrait::util::staticallyKnownBroadcastable(extents);
895       }())
896     return BoolAttr::get(getContext(), true);
897 
898   // Lastly, see if folding can be completed based on what constraints are known
899   // on the input shapes.
900   if ([&] {
901         SmallVector<SmallVector<int64_t, 6>, 6> extents;
902         for (auto shapeValue : getShapes()) {
903           extents.emplace_back();
904           if (failed(getShapeVec(shapeValue, extents.back())))
905             return false;
906         }
907         return OpTrait::util::staticallyKnownBroadcastable(extents);
908       }())
909     return BoolAttr::get(getContext(), true);
910 
911   // Because a failing witness result here represents an eventual assertion
912   // failure, we do not replace it with a constant witness.
913   return nullptr;
914 }
915 
916 static LogicalResult verify(CstrBroadcastableOp op) {
917   // Ensure that AssumingAllOp contains at least one operand
918   if (op.getNumOperands() < 2)
919     return op.emitOpError("required at least 2 input shapes");
920   return success();
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // CstrEqOp
925 //===----------------------------------------------------------------------===//
926 
927 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
928                                            MLIRContext *context) {
929   // If inputs are equal, return passing witness
930   patterns.add<CstrEqEqOps>(context);
931 }
932 
933 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
934   if (llvm::all_of(operands,
935                    [&](Attribute a) { return a && a == operands[0]; }))
936     return BoolAttr::get(getContext(), true);
937 
938   // Because a failing witness result here represents an eventual assertion
939   // failure, we do not try to replace it with a constant witness. Similarly, we
940   // cannot if there are any non-const inputs.
941   return nullptr;
942 }
943 
944 //===----------------------------------------------------------------------===//
945 // ConstSizeOp
946 //===----------------------------------------------------------------------===//
947 
948 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
949                         int64_t value) {
950   build(builder, result, builder.getIndexAttr(value));
951 }
952 
953 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
954 
955 void ConstSizeOp::getAsmResultNames(
956     llvm::function_ref<void(Value, StringRef)> setNameFn) {
957   SmallString<4> buffer;
958   llvm::raw_svector_ostream os(buffer);
959   os << "c" << getValue();
960   setNameFn(getResult(), os.str());
961 }
962 
963 //===----------------------------------------------------------------------===//
964 // ConstWitnessOp
965 //===----------------------------------------------------------------------===//
966 
967 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
968   return getPassingAttr();
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // CstrRequireOp
973 //===----------------------------------------------------------------------===//
974 
975 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
976   return operands[0];
977 }
978 
979 //===----------------------------------------------------------------------===//
980 // DivOp
981 //===----------------------------------------------------------------------===//
982 
983 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
984   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
985   if (!lhs)
986     return nullptr;
987   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
988   if (!rhs)
989     return nullptr;
990 
991   // Division in APInt does not follow floor(lhs, rhs) when the result is
992   // negative. Rather, APInt rounds toward zero.
993   APInt quotient, remainder;
994   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
995   if (quotient.isNegative() && !remainder.isNullValue()) {
996     quotient -= 1;
997   }
998 
999   Type indexTy = IndexType::get(getContext());
1000   return IntegerAttr::get(indexTy, quotient);
1001 }
1002 
1003 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1004     MLIRContext *context, Optional<Location> location, ValueRange operands,
1005     DictionaryAttr attributes, RegionRange regions,
1006     SmallVectorImpl<Type> &inferredReturnTypes) {
1007   if (operands[0].getType().isa<SizeType>() ||
1008       operands[1].getType().isa<SizeType>())
1009     inferredReturnTypes.assign({SizeType::get(context)});
1010   else
1011     inferredReturnTypes.assign({IndexType::get(context)});
1012   return success();
1013 }
1014 
1015 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1016   // SizeType is compatible with IndexType.
1017   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // ShapeEqOp
1022 //===----------------------------------------------------------------------===//
1023 
1024 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1025   bool allSame = true;
1026   if (!operands.empty() && !operands[0])
1027     return {};
1028   for (Attribute operand : operands.drop_front(1)) {
1029     if (!operand)
1030       return {};
1031     allSame = allSame && operand == operands[0];
1032   }
1033   return BoolAttr::get(getContext(), allSame);
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // IndexToSizeOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1041   // Constant values of both types, `shape.size` and `index`, are represented as
1042   // `IntegerAttr`s which makes constant folding simple.
1043   if (Attribute arg = operands[0])
1044     return arg;
1045   return {};
1046 }
1047 
1048 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1049                                                 MLIRContext *context) {
1050   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // FromExtentsOp
1055 //===----------------------------------------------------------------------===//
1056 
1057 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1058   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1059     return nullptr;
1060   SmallVector<int64_t, 6> extents;
1061   for (auto attr : operands)
1062     extents.push_back(attr.cast<IntegerAttr>().getInt());
1063   Builder builder(getContext());
1064   return builder.getIndexTensorAttr(extents);
1065 }
1066 
1067 //===----------------------------------------------------------------------===//
1068 // FunctionLibraryOp
1069 //===----------------------------------------------------------------------===//
1070 
1071 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1072                               StringRef name) {
1073   result.attributes.push_back(builder.getNamedAttr(
1074       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1075 }
1076 
1077 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1078   auto attr = getMapping()
1079                   .get(op->getName().getIdentifier())
1080                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1081   if (!attr)
1082     return nullptr;
1083   return lookupSymbol<FuncOp>(attr);
1084 }
1085 
1086 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1087                                    OperationState &result) {
1088   // Parse the op name.
1089   StringAttr nameAttr;
1090   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1091                              result.attributes))
1092     return failure();
1093 
1094   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1095     return failure();
1096 
1097   auto *bodyRegion = result.addRegion();
1098   if (parser.parseRegion(*bodyRegion))
1099     return failure();
1100 
1101   if (parser.parseKeyword("mapping"))
1102     return failure();
1103 
1104   DictionaryAttr mappingAttr;
1105   if (parser.parseAttribute(mappingAttr,
1106                             parser.getBuilder().getType<NoneType>(), "mapping",
1107                             result.attributes))
1108     return failure();
1109   return success();
1110 }
1111 
1112 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1113   p << ' ';
1114   p.printSymbolName(op.getName());
1115   p.printOptionalAttrDictWithKeyword(
1116       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1117   p << ' ';
1118   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1119                 /*printBlockTerminators=*/false);
1120   p << " mapping ";
1121   p.printAttributeWithoutType(op.getMappingAttr());
1122 }
1123 
1124 //===----------------------------------------------------------------------===//
1125 // GetExtentOp
1126 //===----------------------------------------------------------------------===//
1127 
1128 Optional<int64_t> GetExtentOp::getConstantDim() {
1129   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1130     return constSizeOp.getValue().getLimitedValue();
1131   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1132     return constantOp.getValue().cast<IntegerAttr>().getInt();
1133   return llvm::None;
1134 }
1135 
1136 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1137   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1138   if (!elements)
1139     return nullptr;
1140   Optional<int64_t> dim = getConstantDim();
1141   if (!dim.hasValue())
1142     return nullptr;
1143   if (dim.getValue() >= elements.getNumElements())
1144     return nullptr;
1145   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1146 }
1147 
1148 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1149                         int64_t dim) {
1150   auto loc = result.location;
1151   auto dimAttr = builder.getIndexAttr(dim);
1152   if (shape.getType().isa<ShapeType>()) {
1153     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1154     build(builder, result, builder.getType<SizeType>(), shape, dim);
1155   } else {
1156     Value dim =
1157         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1158     build(builder, result, builder.getIndexType(), shape, dim);
1159   }
1160 }
1161 
1162 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1163     MLIRContext *context, Optional<Location> location, ValueRange operands,
1164     DictionaryAttr attributes, RegionRange regions,
1165     SmallVectorImpl<Type> &inferredReturnTypes) {
1166   inferredReturnTypes.assign({IndexType::get(context)});
1167   return success();
1168 }
1169 
1170 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1171                                                        TypeRange r) {
1172   // SizeType is compatible with IndexType.
1173   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // IsBroadcastableOp
1178 //===----------------------------------------------------------------------===//
1179 
1180 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1181                                                     MLIRContext *context) {
1182   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1183 }
1184 
1185 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1186   // Can always broadcast fewer than two shapes.
1187   if (operands.size() < 2) {
1188     return BoolAttr::get(getContext(), true);
1189   }
1190 
1191   return nullptr;
1192 }
1193 
1194 //===----------------------------------------------------------------------===//
1195 // MeetOp
1196 //===----------------------------------------------------------------------===//
1197 
1198 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1199     MLIRContext *context, Optional<Location> location, ValueRange operands,
1200     DictionaryAttr attributes, RegionRange regions,
1201     SmallVectorImpl<Type> &inferredReturnTypes) {
1202   inferredReturnTypes.assign({operands[0].getType()});
1203   return success();
1204 }
1205 
1206 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1207   if (l.size() != 1 || r.size() != 1)
1208     return false;
1209   if (l == r)
1210     return true;
1211 
1212   Type lhs = l.front();
1213   Type rhs = r.front();
1214 
1215   if (lhs != rhs)
1216     return false;
1217 
1218   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1219     return true;
1220 
1221   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1222     return true;
1223   return false;
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // RankOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1231   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1232   if (!shape)
1233     return {};
1234   int64_t rank = shape.getNumElements();
1235   Builder builder(getContext());
1236   return builder.getIndexAttr(rank);
1237 }
1238 
1239 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1240 /// Constant folding fails in cases where only the rank is constant, not the
1241 /// shape itself.
1242 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1243 ///
1244 /// Example:
1245 ///
1246 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1247 /// %rank = shape.rank %shape
1248 ///
1249 /// becomes
1250 ///
1251 /// %rank = shape.const_size 3
1252 
1253 namespace {
1254 struct RankShapeOfCanonicalizationPattern
1255     : public OpRewritePattern<shape::RankOp> {
1256   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1257 
1258   LogicalResult matchAndRewrite(shape::RankOp op,
1259                                 PatternRewriter &rewriter) const override {
1260     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1261     if (!shapeOfOp)
1262       return failure();
1263     auto rankedTensorType =
1264         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1265     if (!rankedTensorType)
1266       return failure();
1267     int64_t rank = rankedTensorType.getRank();
1268     if (op.getType().isa<IndexType>()) {
1269       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1270                                                           rank);
1271     } else if (op.getType().isa<shape::SizeType>()) {
1272       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1273     } else {
1274       return failure();
1275     }
1276     return success();
1277   }
1278 };
1279 } // namespace
1280 
1281 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1282                                                 MLIRContext *context) {
1283   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1284 }
1285 
1286 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1287     MLIRContext *context, Optional<Location> location, ValueRange operands,
1288     DictionaryAttr attributes, RegionRange regions,
1289     SmallVectorImpl<Type> &inferredReturnTypes) {
1290   if (operands[0].getType().isa<ShapeType>())
1291     inferredReturnTypes.assign({SizeType::get(context)});
1292   else
1293     inferredReturnTypes.assign({IndexType::get(context)});
1294   return success();
1295 }
1296 
1297 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1298   // SizeType is compatible with IndexType.
1299   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // NumElementsOp
1304 //===----------------------------------------------------------------------===//
1305 
1306 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1307 
1308   // Fold only when argument constant.
1309   Attribute shape = operands[0];
1310   if (!shape)
1311     return {};
1312 
1313   APInt product(64, 1);
1314   for (auto value : shape.cast<DenseIntElementsAttr>())
1315     product *= value;
1316   Builder builder(getContext());
1317   return builder.getIndexAttr(product.getLimitedValue());
1318 }
1319 
1320 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1321     MLIRContext *context, Optional<Location> location, ValueRange operands,
1322     DictionaryAttr attributes, RegionRange regions,
1323     SmallVectorImpl<Type> &inferredReturnTypes) {
1324   if (operands[0].getType().isa<ShapeType>())
1325     inferredReturnTypes.assign({SizeType::get(context)});
1326   else
1327     inferredReturnTypes.assign({IndexType::get(context)});
1328   return success();
1329 }
1330 
1331 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1332                                                          TypeRange r) {
1333   // SizeType is compatible with IndexType.
1334   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // MaxOp
1339 //===----------------------------------------------------------------------===//
1340 
1341 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1342   // If operands are equal, just propagate one.
1343   if (getLhs() == getRhs())
1344     return getLhs();
1345   return nullptr;
1346 }
1347 
1348 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1349     MLIRContext *context, Optional<Location> location, ValueRange operands,
1350     DictionaryAttr attributes, RegionRange regions,
1351     SmallVectorImpl<Type> &inferredReturnTypes) {
1352   if (operands[0].getType() == operands[1].getType())
1353     inferredReturnTypes.assign({operands[0].getType()});
1354   else
1355     inferredReturnTypes.assign({SizeType::get(context)});
1356   return success();
1357 }
1358 
1359 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1360   if (l.size() != 1 || r.size() != 1)
1361     return false;
1362   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1363     return true;
1364   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1365     return true;
1366   return false;
1367 }
1368 
1369 //===----------------------------------------------------------------------===//
1370 // MinOp
1371 //===----------------------------------------------------------------------===//
1372 
1373 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1374   // If operands are equal, just propagate one.
1375   if (getLhs() == getRhs())
1376     return getLhs();
1377   return nullptr;
1378 }
1379 
1380 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1381     MLIRContext *context, Optional<Location> location, ValueRange operands,
1382     DictionaryAttr attributes, RegionRange regions,
1383     SmallVectorImpl<Type> &inferredReturnTypes) {
1384   if (operands[0].getType() == operands[1].getType())
1385     inferredReturnTypes.assign({operands[0].getType()});
1386   else
1387     inferredReturnTypes.assign({SizeType::get(context)});
1388   return success();
1389 }
1390 
1391 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1392   if (l.size() != 1 || r.size() != 1)
1393     return false;
1394   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1395     return true;
1396   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1397     return true;
1398   return false;
1399 }
1400 
1401 //===----------------------------------------------------------------------===//
1402 // MulOp
1403 //===----------------------------------------------------------------------===//
1404 
1405 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1406   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1407   if (!lhs)
1408     return nullptr;
1409   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1410   if (!rhs)
1411     return nullptr;
1412   APInt folded = lhs.getValue() * rhs.getValue();
1413   Type indexTy = IndexType::get(getContext());
1414   return IntegerAttr::get(indexTy, folded);
1415 }
1416 
1417 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1418     MLIRContext *context, Optional<Location> location, ValueRange operands,
1419     DictionaryAttr attributes, RegionRange regions,
1420     SmallVectorImpl<Type> &inferredReturnTypes) {
1421   if (operands[0].getType().isa<SizeType>() ||
1422       operands[1].getType().isa<SizeType>())
1423     inferredReturnTypes.assign({SizeType::get(context)});
1424   else
1425     inferredReturnTypes.assign({IndexType::get(context)});
1426   return success();
1427 }
1428 
1429 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1430   // SizeType is compatible with IndexType.
1431   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1432 }
1433 //===----------------------------------------------------------------------===//
1434 // ShapeOfOp
1435 //===----------------------------------------------------------------------===//
1436 
1437 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1438   auto type = getOperand().getType().dyn_cast<ShapedType>();
1439   if (!type || !type.hasStaticShape())
1440     return nullptr;
1441   Builder builder(getContext());
1442   return builder.getIndexTensorAttr(type.getShape());
1443 }
1444 
1445 namespace {
1446 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1447   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1448 
1449   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1450                                 PatternRewriter &rewriter) const override {
1451     if (!op.getArg().getType().isa<ShapedType>())
1452       return failure();
1453     if (op.getType().isa<ShapedType>())
1454       return failure();
1455 
1456     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1457                                                   op.getArg());
1458     return success();
1459   }
1460 };
1461 
1462 // Canonicalize
1463 // ```
1464 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1465 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1466 // ```
1467 // to
1468 // ```
1469 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1470 // ```
1471 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1472   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1473 
1474   LogicalResult matchAndRewrite(tensor::CastOp op,
1475                                 PatternRewriter &rewriter) const override {
1476     auto ty = op.getType().dyn_cast<RankedTensorType>();
1477     if (!ty || ty.getRank() != 1)
1478       return failure();
1479 
1480     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1481     if (!shapeOfOp)
1482       return failure();
1483 
1484     // Argument type must be ranked and must not conflict.
1485     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1486     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1487       return failure();
1488 
1489     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1490     return success();
1491   }
1492 };
1493 } // namespace
1494 
1495 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1496                                             MLIRContext *context) {
1497   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1498                ExtractFromShapeOfExtentTensor>(context);
1499 }
1500 
1501 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1502     MLIRContext *context, Optional<Location> location, ValueRange operands,
1503     DictionaryAttr attributes, RegionRange regions,
1504     SmallVectorImpl<Type> &inferredReturnTypes) {
1505   if (operands[0].getType().isa<ValueShapeType>())
1506     inferredReturnTypes.assign({ShapeType::get(context)});
1507   else {
1508     auto shapedTy = operands[0].getType().cast<ShapedType>();
1509     int64_t rank =
1510         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1511     Type indexTy = IndexType::get(context);
1512     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1513     inferredReturnTypes.assign({extentTensorTy});
1514   }
1515   return success();
1516 }
1517 
1518 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1519   if (l.size() != 1 || r.size() != 1)
1520     return false;
1521   if (l == r)
1522     return true;
1523 
1524   Type lhs = l.front();
1525   Type rhs = r.front();
1526 
1527   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1528     return false;
1529 
1530   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1531     // Shape type is compatible with all other valid return types.
1532     return true;
1533 
1534   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1535     return true;
1536   return false;
1537 }
1538 
1539 //===----------------------------------------------------------------------===//
1540 // SizeToIndexOp
1541 //===----------------------------------------------------------------------===//
1542 
1543 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1544   // Constant values of both types, `shape.size` and `index`, are represented as
1545   // `IntegerAttr`s which makes constant folding simple.
1546   if (Attribute arg = operands[0])
1547     return arg;
1548   return impl::foldCastOp(*this);
1549 }
1550 
1551 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1552                                                 MLIRContext *context) {
1553   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1554 }
1555 
1556 //===----------------------------------------------------------------------===//
1557 // YieldOp
1558 //===----------------------------------------------------------------------===//
1559 
1560 static LogicalResult verify(shape::YieldOp op) {
1561   auto *parentOp = op->getParentOp();
1562   auto results = parentOp->getResults();
1563   auto operands = op.getOperands();
1564 
1565   if (parentOp->getNumResults() != op.getNumOperands())
1566     return op.emitOpError() << "number of operands does not match number of "
1567                                "results of its parent";
1568   for (auto e : llvm::zip(results, operands))
1569     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1570       return op.emitOpError()
1571              << "types mismatch between yield op and its parent";
1572 
1573   return success();
1574 }
1575 
1576 //===----------------------------------------------------------------------===//
1577 // SplitAtOp
1578 //===----------------------------------------------------------------------===//
1579 
1580 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1581                               SmallVectorImpl<OpFoldResult> &results) {
1582   if (!operands[0] || !operands[1])
1583     return failure();
1584   auto shapeVec = llvm::to_vector<6>(
1585       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1586   auto shape = llvm::makeArrayRef(shapeVec);
1587   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1588   // Verify that the split point is in the correct range.
1589   // TODO: Constant fold to an "error".
1590   int64_t rank = shape.size();
1591   if (!(-rank <= splitPoint && splitPoint <= rank))
1592     return failure();
1593   if (splitPoint < 0)
1594     splitPoint += shape.size();
1595   Builder builder(operands[0].getContext());
1596   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1597   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1598   return success();
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // ToExtentTensorOp
1603 //===----------------------------------------------------------------------===//
1604 
1605 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1606   if (!operands[0])
1607     return impl::foldCastOp(*this);
1608   Builder builder(getContext());
1609   auto shape = llvm::to_vector<6>(
1610       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1611   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1612                                     builder.getIndexType());
1613   return DenseIntElementsAttr::get(type, shape);
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // ReduceOp
1618 //===----------------------------------------------------------------------===//
1619 
1620 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1621                      ValueRange initVals) {
1622   result.addOperands(shape);
1623   result.addOperands(initVals);
1624 
1625   Region *bodyRegion = result.addRegion();
1626   bodyRegion->push_back(new Block);
1627   Block &bodyBlock = bodyRegion->front();
1628   bodyBlock.addArgument(builder.getIndexType(), result.location);
1629 
1630   Type elementType;
1631   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1632     elementType = tensorType.getElementType();
1633   else
1634     elementType = SizeType::get(builder.getContext());
1635   bodyBlock.addArgument(elementType, shape.getLoc());
1636 
1637   for (Value initVal : initVals) {
1638     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1639     result.addTypes(initVal.getType());
1640   }
1641 }
1642 
1643 static LogicalResult verify(ReduceOp op) {
1644   // Verify block arg types.
1645   Block &block = op.getRegion().front();
1646 
1647   // The block takes index, extent, and aggregated values as arguments.
1648   auto blockArgsCount = op.getInitVals().size() + 2;
1649   if (block.getNumArguments() != blockArgsCount)
1650     return op.emitOpError() << "ReduceOp body is expected to have "
1651                             << blockArgsCount << " arguments";
1652 
1653   // The first block argument is the index and must always be of type `index`.
1654   if (!block.getArgument(0).getType().isa<IndexType>())
1655     return op.emitOpError(
1656         "argument 0 of ReduceOp body is expected to be of IndexType");
1657 
1658   // The second block argument is the extent and must be of type `size` or
1659   // `index`, depending on whether the reduce operation is applied to a shape or
1660   // to an extent tensor.
1661   Type extentTy = block.getArgument(1).getType();
1662   if (op.getShape().getType().isa<ShapeType>()) {
1663     if (!extentTy.isa<SizeType>())
1664       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1665                             "SizeType if the ReduceOp operates on a ShapeType");
1666   } else {
1667     if (!extentTy.isa<IndexType>())
1668       return op.emitOpError(
1669           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1670           "ReduceOp operates on an extent tensor");
1671   }
1672 
1673   for (const auto &type : llvm::enumerate(op.getInitVals()))
1674     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1675       return op.emitOpError()
1676              << "type mismatch between argument " << type.index() + 2
1677              << " of ReduceOp body and initial value " << type.index();
1678   return success();
1679 }
1680 
1681 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1682   // Parse operands.
1683   SmallVector<OpAsmParser::OperandType, 3> operands;
1684   Type shapeOrExtentTensorType;
1685   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1686                               OpAsmParser::Delimiter::Paren) ||
1687       parser.parseColonType(shapeOrExtentTensorType) ||
1688       parser.parseOptionalArrowTypeList(result.types))
1689     return failure();
1690 
1691   // Resolve operands.
1692   auto initVals = llvm::makeArrayRef(operands).drop_front();
1693   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1694                             result.operands) ||
1695       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1696                              result.operands))
1697     return failure();
1698 
1699   // Parse the body.
1700   Region *body = result.addRegion();
1701   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1702     return failure();
1703 
1704   // Parse attributes.
1705   if (parser.parseOptionalAttrDict(result.attributes))
1706     return failure();
1707 
1708   return success();
1709 }
1710 
1711 static void print(OpAsmPrinter &p, ReduceOp op) {
1712   p << '(' << op.getShape() << ", " << op.getInitVals()
1713     << ") : " << op.getShape().getType();
1714   p.printOptionalArrowTypeList(op.getResultTypes());
1715   p << ' ';
1716   p.printRegion(op.getRegion());
1717   p.printOptionalAttrDict(op->getAttrs());
1718 }
1719 
1720 #define GET_OP_CLASSES
1721 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1722