1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Traits.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 using namespace mlir;
28 using namespace mlir::shape;
29 
30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
31 
32 namespace {
33 #include "ShapeCanonicalization.inc"
34 } // namespace
35 
36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
37   return RankedTensorType::get({rank}, IndexType::get(ctx));
38 }
39 
40 bool shape::isExtentTensorType(Type type) {
41   auto ranked = type.dyn_cast<RankedTensorType>();
42   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
43 }
44 
45 LogicalResult shape::getShapeVec(Value input,
46                                  SmallVectorImpl<int64_t> &shapeValues) {
47   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
48     auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
49     if (!type.hasRank())
50       return failure();
51     shapeValues = llvm::to_vector<6>(type.getShape());
52     return success();
53   }
54   if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
55     shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
56     return success();
57   }
58   if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
59     shapeValues = llvm::to_vector<6>(
60         inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
61     return success();
62   }
63   return failure();
64 }
65 
66 static bool isErrorPropagationPossible(TypeRange operandTypes) {
67   return llvm::any_of(operandTypes, [](Type ty) {
68     return ty.isa<SizeType, ShapeType, ValueShapeType>();
69   });
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73   assert(op != nullptr && op->getNumResults() == 1);
74   Type resultTy = op->getResultTypes().front();
75   if (isErrorPropagationPossible(op->getOperandTypes())) {
76     if (!resultTy.isa<SizeType>())
77       return op->emitOpError()
78              << "if at least one of the operands can hold error values then "
79                 "the result must be of type `size` to propagate them";
80   }
81   return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85   assert(op != nullptr && op->getNumResults() == 1);
86   Type resultTy = op->getResultTypes().front();
87   if (isErrorPropagationPossible(op->getOperandTypes())) {
88     if (!resultTy.isa<ShapeType>())
89       return op->emitOpError()
90              << "if at least one of the operands can hold error values then "
91                 "the result must be of type `shape` to propagate them";
92   }
93   return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
113   using DialectInlinerInterface::DialectInlinerInterface;
114 
115   // Returns true if the given region 'src' can be inlined into the region
116   // 'dest' that is attached to an operation registered to the current dialect.
117   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118                        BlockAndValueMapping &) const final {
119     return true;
120   }
121 
122   // Returns true if the given operation 'op', that is registered to this
123   // dialect, can be inlined into the region 'dest' that is attached to an
124   // operation registered to the current dialect.
125   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126                        BlockAndValueMapping &) const final {
127     return true;
128   }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133   addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136       >();
137   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
138   addInterfaces<ShapeInlinerInterface>();
139   // Allow unknown operations during prototyping and testing. As the dialect is
140   // still evolving it makes it simple to start with an unregistered ops and
141   // try different variants before actually defining the op.
142   allowUnknownOperations();
143 }
144 
145 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
146                                              Attribute value, Type type,
147                                              Location loc) {
148   if (type.isa<ShapeType>() || isExtentTensorType(type))
149     return builder.create<ConstShapeOp>(loc, type,
150                                         value.cast<DenseIntElementsAttr>());
151   if (type.isa<SizeType>())
152     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
153   if (type.isa<WitnessType>())
154     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
155   if (arith::ConstantOp::isBuildableWith(value, type))
156     return builder.create<arith::ConstantOp>(loc, type, value);
157   return nullptr;
158 }
159 
160 /// Parse a type registered to this dialect.
161 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
162   StringRef keyword;
163   if (parser.parseKeyword(&keyword))
164     return Type();
165 
166   if (keyword == "shape")
167     return ShapeType::get(getContext());
168   if (keyword == "size")
169     return SizeType::get(getContext());
170   if (keyword == "value_shape")
171     return ValueShapeType::get(getContext());
172   if (keyword == "witness")
173     return WitnessType::get(getContext());
174 
175   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
176   return Type();
177 }
178 
179 /// Print a type registered to this dialect.
180 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
181   TypeSwitch<Type>(type)
182       .Case<ShapeType>([&](Type) { os << "shape"; })
183       .Case<SizeType>([&](Type) { os << "size"; })
184       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
185       .Case<WitnessType>([&](Type) { os << "witness"; })
186       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
187 }
188 
189 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
190                                                      NamedAttribute attribute) {
191   // Verify shape.lib attribute.
192   if (attribute.getName() == "shape.lib") {
193     if (!op->hasTrait<OpTrait::SymbolTable>())
194       return op->emitError(
195           "shape.lib attribute may only be on op implementing SymbolTable");
196 
197     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
198       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
199       if (!symbol)
200         return op->emitError("shape function library ")
201                << symbolRef << " not found";
202       return isa<shape::FunctionLibraryOp>(symbol)
203                  ? success()
204                  : op->emitError()
205                        << symbolRef << " required to be shape function library";
206     }
207 
208     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
209       // Verify all entries are function libraries and mappings in libraries
210       // refer to unique ops.
211       DenseSet<StringAttr> key;
212       for (auto it : arr) {
213         if (!it.isa<SymbolRefAttr>())
214           return op->emitError(
215               "only SymbolRefAttr allowed in shape.lib attribute array");
216 
217         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
218             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
219         if (!shapeFnLib)
220           return op->emitError()
221                  << it << " does not refer to FunctionLibraryOp";
222         for (auto mapping : shapeFnLib.getMapping()) {
223           if (!key.insert(mapping.getName()).second) {
224             return op->emitError("only one op to shape mapping allowed, found "
225                                  "multiple for `")
226                    << mapping.getName() << "`";
227           }
228         }
229       }
230       return success();
231     }
232 
233     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
234                          "allowed as shape.lib attribute");
235   }
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // AnyOp
241 //===----------------------------------------------------------------------===//
242 
243 // TODO: Canonicalization should be implemented for shapes that can be
244 // determined through mixtures of the known dimensions of the inputs.
245 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
246   // Only the last operand is checked because AnyOp is commutative.
247   if (operands.back())
248     return operands.back();
249 
250   return nullptr;
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // AssumingOp
255 //===----------------------------------------------------------------------===//
256 
257 static ParseResult parseAssumingOp(OpAsmParser &parser,
258                                    OperationState &result) {
259   result.regions.reserve(1);
260   Region *doRegion = result.addRegion();
261 
262   auto &builder = parser.getBuilder();
263   OpAsmParser::OperandType cond;
264   if (parser.parseOperand(cond) ||
265       parser.resolveOperand(cond, builder.getType<WitnessType>(),
266                             result.operands))
267     return failure();
268 
269   // Parse optional results type list.
270   if (parser.parseOptionalArrowTypeList(result.types))
271     return failure();
272 
273   // Parse the region and add a terminator if elided.
274   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
275     return failure();
276   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
277 
278   // Parse the optional attribute list.
279   if (parser.parseOptionalAttrDict(result.attributes))
280     return failure();
281   return success();
282 }
283 
284 static void print(OpAsmPrinter &p, AssumingOp op) {
285   bool yieldsResults = !op.getResults().empty();
286 
287   p << " " << op.getWitness();
288   if (yieldsResults) {
289     p << " -> (" << op.getResultTypes() << ")";
290   }
291   p.printRegion(op.getDoRegion(),
292                 /*printEntryBlockArgs=*/false,
293                 /*printBlockTerminators=*/yieldsResults);
294   p.printOptionalAttrDict(op->getAttrs());
295 }
296 
297 namespace {
298 // Removes AssumingOp with a passing witness and inlines the region.
299 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
300   using OpRewritePattern<AssumingOp>::OpRewritePattern;
301 
302   LogicalResult matchAndRewrite(AssumingOp op,
303                                 PatternRewriter &rewriter) const override {
304     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
305     if (!witness || !witness.getPassingAttr())
306       return failure();
307 
308     AssumingOp::inlineRegionIntoParent(op, rewriter);
309     return success();
310   }
311 };
312 
313 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
314   using OpRewritePattern<AssumingOp>::OpRewritePattern;
315 
316   LogicalResult matchAndRewrite(AssumingOp op,
317                                 PatternRewriter &rewriter) const override {
318     Block *body = op.getBody();
319     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
320 
321     // Find used values.
322     SmallVector<Value, 4> newYieldOperands;
323     Value opResult, yieldOperand;
324     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
325       std::tie(opResult, yieldOperand) = it;
326       if (!opResult.getUses().empty()) {
327         newYieldOperands.push_back(yieldOperand);
328       }
329     }
330 
331     // Rewrite only if redundant results exist.
332     if (newYieldOperands.size() == yieldOp->getNumOperands())
333       return failure();
334 
335     // Replace yield op in the old assuming op's body and move the entire region
336     // to the new assuming op.
337     rewriter.setInsertionPointToEnd(body);
338     auto newYieldOp =
339         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
340     rewriter.setInsertionPoint(op);
341     auto newOp = rewriter.create<AssumingOp>(
342         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
343     newOp.getDoRegion().takeBody(op.getDoRegion());
344 
345     // Use the new results to replace the previously used ones.
346     SmallVector<Value, 4> replacementValues;
347     auto src = newOp.getResults().begin();
348     for (auto it : op.getResults()) {
349       if (it.getUses().empty())
350         replacementValues.push_back(nullptr);
351       else
352         replacementValues.push_back(*src++);
353     }
354     rewriter.replaceOp(op, replacementValues);
355     return success();
356   }
357 };
358 } // namespace
359 
360 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
361                                              MLIRContext *context) {
362   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
363 }
364 
365 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
366 void AssumingOp::getSuccessorRegions(
367     Optional<unsigned> index, ArrayRef<Attribute> operands,
368     SmallVectorImpl<RegionSuccessor> &regions) {
369   // AssumingOp has unconditional control flow into the region and back to the
370   // parent, so return the correct RegionSuccessor purely based on the index
371   // being None or 0.
372   if (index.hasValue()) {
373     regions.push_back(RegionSuccessor(getResults()));
374     return;
375   }
376 
377   regions.push_back(RegionSuccessor(&getDoRegion()));
378 }
379 
380 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
381                                         PatternRewriter &rewriter) {
382   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
383   auto *assumingBlock = op.getBody();
384   auto initPosition = rewriter.getInsertionPoint();
385   auto *blockAfterAssuming =
386       rewriter.splitBlock(blockBeforeAssuming, initPosition);
387 
388   // Remove the AssumingOp and AssumingYieldOp.
389   auto &yieldOp = assumingBlock->back();
390   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
391   rewriter.replaceOp(op, yieldOp.getOperands());
392   rewriter.eraseOp(&yieldOp);
393 
394   // Merge blocks together as there was no branching behavior from the
395   // AssumingOp.
396   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
397   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
398 }
399 
400 void AssumingOp::build(
401     OpBuilder &builder, OperationState &result, Value witness,
402     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
403 
404   result.addOperands(witness);
405   Region *bodyRegion = result.addRegion();
406   bodyRegion->push_back(new Block);
407   Block &bodyBlock = bodyRegion->front();
408 
409   // Build body.
410   OpBuilder::InsertionGuard guard(builder);
411   builder.setInsertionPointToStart(&bodyBlock);
412   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
413   builder.create<AssumingYieldOp>(result.location, yieldValues);
414 
415   SmallVector<Type, 2> assumingTypes;
416   for (Value v : yieldValues)
417     assumingTypes.push_back(v.getType());
418   result.addTypes(assumingTypes);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // AddOp
423 //===----------------------------------------------------------------------===//
424 
425 LogicalResult mlir::shape::AddOp::inferReturnTypes(
426     MLIRContext *context, Optional<Location> location, ValueRange operands,
427     DictionaryAttr attributes, RegionRange regions,
428     SmallVectorImpl<Type> &inferredReturnTypes) {
429   if (operands[0].getType().isa<SizeType>() ||
430       operands[1].getType().isa<SizeType>())
431     inferredReturnTypes.assign({SizeType::get(context)});
432   else
433     inferredReturnTypes.assign({IndexType::get(context)});
434   return success();
435 }
436 
437 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
438   // SizeType is compatible with IndexType.
439   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
440 }
441 
442 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
443   // add(x, 0) -> x
444   if (matchPattern(getRhs(), m_Zero()))
445     return getLhs();
446 
447   return constFoldBinaryOp<IntegerAttr>(operands,
448                                         [](APInt a, APInt b) { return a + b; });
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // AssumingAllOp
453 //===----------------------------------------------------------------------===//
454 
455 namespace {
456 struct AssumingAllToCstrEqCanonicalization
457     : public OpRewritePattern<AssumingAllOp> {
458   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
459 
460   LogicalResult matchAndRewrite(AssumingAllOp op,
461                                 PatternRewriter &rewriter) const override {
462     SmallVector<Value, 8> shapes;
463     for (Value w : op.getInputs()) {
464       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
465       if (!cstrEqOp)
466         return failure();
467       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
468         return llvm::is_contained(shapes, s);
469       });
470       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
471         return failure();
472       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
473     }
474     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
475     return success();
476   }
477 };
478 
479 template <typename OpTy>
480 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
481   using OpRewritePattern<OpTy>::OpRewritePattern;
482 
483   LogicalResult matchAndRewrite(OpTy op,
484                                 PatternRewriter &rewriter) const override {
485     // Find unique operands.
486     SmallVector<Value, 2> unique;
487     for (Value v : op.getOperands()) {
488       if (!llvm::is_contained(unique, v))
489         unique.push_back(v);
490     }
491 
492     // Reduce op to equivalent with unique operands.
493     if (unique.size() < op.getNumOperands()) {
494       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
495                                         op->getAttrs());
496       return success();
497     }
498 
499     return failure();
500   }
501 };
502 } // namespace
503 
504 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
505                                                 MLIRContext *context) {
506   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
507                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
508 }
509 
510 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
511   // Iterate in reverse to first handle all constant operands. They are
512   // guaranteed to be the tail of the inputs because this is commutative.
513   for (int idx = operands.size() - 1; idx >= 0; idx--) {
514     Attribute a = operands[idx];
515     // Cannot fold if any inputs are not constant;
516     if (!a)
517       return nullptr;
518 
519     // We do not need to keep statically known values after handling them in
520     // this method.
521     getOperation()->eraseOperand(idx);
522 
523     // Always false if any input is statically known false
524     if (!a.cast<BoolAttr>().getValue())
525       return a;
526   }
527   // If this is reached, all inputs were statically known passing.
528   return BoolAttr::get(getContext(), true);
529 }
530 
531 static LogicalResult verify(AssumingAllOp op) {
532   // Ensure that AssumingAllOp contains at least one operand
533   if (op.getNumOperands() == 0)
534     return op.emitOpError("no operands specified");
535 
536   return success();
537 }
538 
539 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
540                           ValueRange inputs) {
541   build(b, state, b.getType<WitnessType>(), inputs);
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // BroadcastOp
546 //===----------------------------------------------------------------------===//
547 
548 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
549   if (getShapes().size() == 1) {
550     // Otherwise, we need a cast which would be a canonicalization, not folding.
551     if (getShapes().front().getType() != getType())
552       return nullptr;
553     return getShapes().front();
554   }
555 
556   // TODO: Support folding with more than 2 input shapes
557   if (getShapes().size() > 2)
558     return nullptr;
559 
560   if (!operands[0] || !operands[1])
561     return nullptr;
562   auto lhsShape = llvm::to_vector<6>(
563       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
564   auto rhsShape = llvm::to_vector<6>(
565       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
566   SmallVector<int64_t, 6> resultShape;
567 
568   // If the shapes are not compatible, we can't fold it.
569   // TODO: Fold to an "error".
570   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
571     return nullptr;
572 
573   Builder builder(getContext());
574   return builder.getIndexTensorAttr(resultShape);
575 }
576 
577 static LogicalResult verify(BroadcastOp op) {
578   return verifyShapeOrExtentTensorOp(op);
579 }
580 
581 namespace {
582 template <typename OpTy>
583 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
584   using OpRewritePattern<OpTy>::OpRewritePattern;
585 
586   LogicalResult matchAndRewrite(OpTy op,
587                                 PatternRewriter &rewriter) const override {
588     auto isPotentiallyNonEmptyShape = [](Value shape) {
589       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
590         if (extentTensorTy.getDimSize(0) == 0)
591           return false;
592       }
593       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
594         if (constShape.getShape().empty())
595           return false;
596       }
597       return true;
598     };
599     auto newOperands = llvm::to_vector<8>(
600         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
601 
602     // Reduce op to equivalent without empty shape operands.
603     if (newOperands.size() < op.getNumOperands()) {
604       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
605                                         op->getAttrs());
606       return success();
607     }
608 
609     return failure();
610   }
611 };
612 
613 struct BroadcastForwardSingleOperandPattern
614     : public OpRewritePattern<BroadcastOp> {
615   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
616 
617   LogicalResult matchAndRewrite(BroadcastOp op,
618                                 PatternRewriter &rewriter) const override {
619     if (op.getNumOperands() != 1)
620       return failure();
621     Value replacement = op.getShapes().front();
622 
623     // Insert cast if needed.
624     if (replacement.getType() != op.getType()) {
625       auto loc = op.getLoc();
626       if (op.getType().isa<ShapeType>()) {
627         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
628       } else {
629         assert(!op.getType().isa<ShapeType>() &&
630                !replacement.getType().isa<ShapeType>() &&
631                "expect extent tensor cast");
632         replacement =
633             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
634       }
635     }
636 
637     rewriter.replaceOp(op, replacement);
638     return success();
639   }
640 };
641 
642 struct BroadcastFoldConstantOperandsPattern
643     : public OpRewritePattern<BroadcastOp> {
644   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
645 
646   LogicalResult matchAndRewrite(BroadcastOp op,
647                                 PatternRewriter &rewriter) const override {
648     SmallVector<int64_t, 8> foldedConstantShape;
649     SmallVector<Value, 8> newShapeOperands;
650     for (Value shape : op.getShapes()) {
651       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
652         SmallVector<int64_t, 8> newFoldedConstantShape;
653         if (OpTrait::util::getBroadcastedShape(
654                 foldedConstantShape,
655                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
656                 newFoldedConstantShape)) {
657           foldedConstantShape = newFoldedConstantShape;
658           continue;
659         }
660       }
661       newShapeOperands.push_back(shape);
662     }
663 
664     // Need at least two constant operands to fold anything.
665     if (op.getNumOperands() - newShapeOperands.size() < 2)
666       return failure();
667 
668     auto foldedConstantOperandsTy = RankedTensorType::get(
669         {static_cast<int64_t>(foldedConstantShape.size())},
670         rewriter.getIndexType());
671     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
672         op.getLoc(), foldedConstantOperandsTy,
673         rewriter.getIndexTensorAttr(foldedConstantShape)));
674     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
675                                              newShapeOperands);
676     return success();
677   }
678 };
679 
680 template <typename OpTy>
681 struct CanonicalizeCastExtentTensorOperandsPattern
682     : public OpRewritePattern<OpTy> {
683   using OpRewritePattern<OpTy>::OpRewritePattern;
684 
685   LogicalResult matchAndRewrite(OpTy op,
686                                 PatternRewriter &rewriter) const override {
687     // Canonicalize operands.
688     bool anyChange = false;
689     auto canonicalizeOperand = [&](Value operand) {
690       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
691         // Only eliminate the cast if it holds no shape information.
692         bool isInformationLoosingCast =
693             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
694         if (isInformationLoosingCast) {
695           anyChange = true;
696           return castOp.source();
697         }
698       }
699       return operand;
700     };
701     auto newOperands = llvm::to_vector<8>(
702         llvm::map_range(op.getOperands(), canonicalizeOperand));
703 
704     // Rewrite op if any change required.
705     if (!anyChange)
706       return failure();
707     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
708     return success();
709   }
710 };
711 
712 struct BroadcastConcretizeResultTypePattern
713     : public OpRewritePattern<BroadcastOp> {
714   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
715 
716   LogicalResult matchAndRewrite(BroadcastOp op,
717                                 PatternRewriter &rewriter) const override {
718     // Only concretize dynamic extent tensor result types.
719     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
720     if (!resultTy || !resultTy.isDynamicDim(0))
721       return failure();
722 
723     // Infer resulting shape rank if possible.
724     int64_t maxRank = 0;
725     for (Value shape : op.getShapes()) {
726       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
727         // Cannot infer resulting shape rank if any operand is dynamically
728         // ranked.
729         if (extentTensorTy.isDynamicDim(0))
730           return failure();
731         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
732       }
733     }
734 
735     auto newOp = rewriter.create<BroadcastOp>(
736         op.getLoc(), getExtentTensorType(getContext(), maxRank),
737         op.getShapes());
738     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
739     return success();
740   }
741 };
742 } // namespace
743 
744 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
745                                               MLIRContext *context) {
746   patterns.add<BroadcastConcretizeResultTypePattern,
747                BroadcastFoldConstantOperandsPattern,
748                BroadcastForwardSingleOperandPattern,
749                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
750                RemoveDuplicateOperandsPattern<BroadcastOp>,
751                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
752 }
753 
754 //===----------------------------------------------------------------------===//
755 // ConcatOp
756 //===----------------------------------------------------------------------===//
757 
758 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
759   if (!operands[0] || !operands[1])
760     return nullptr;
761   auto lhsShape = llvm::to_vector<6>(
762       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
763   auto rhsShape = llvm::to_vector<6>(
764       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
765   SmallVector<int64_t, 6> resultShape;
766   resultShape.append(lhsShape.begin(), lhsShape.end());
767   resultShape.append(rhsShape.begin(), rhsShape.end());
768   Builder builder(getContext());
769   return builder.getIndexTensorAttr(resultShape);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // ConstShapeOp
774 //===----------------------------------------------------------------------===//
775 
776 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
777   p << " ";
778   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
779   p << "[";
780   interleaveComma(op.getShape().getValues<int64_t>(), p,
781                   [&](int64_t i) { p << i; });
782   p << "] : ";
783   p.printType(op.getType());
784 }
785 
786 static ParseResult parseConstShapeOp(OpAsmParser &parser,
787                                      OperationState &result) {
788   if (parser.parseOptionalAttrDict(result.attributes))
789     return failure();
790   // We piggy-back on ArrayAttr parsing, though we don't internally store the
791   // shape as an ArrayAttr.
792   // TODO: Implement custom parser and maybe make syntax a bit more concise.
793   Attribute extentsRaw;
794   NamedAttrList dummy;
795   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
796     return failure();
797   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
798   if (!extentsArray)
799     return failure();
800   SmallVector<int64_t, 6> ints;
801   for (Attribute extent : extentsArray) {
802     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
803     if (!attr)
804       return failure();
805     ints.push_back(attr.getInt());
806   }
807   Builder &builder = parser.getBuilder();
808   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
809   Type resultTy;
810   if (parser.parseColonType(resultTy))
811     return failure();
812   result.types.push_back(resultTy);
813   return success();
814 }
815 
816 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
817 
818 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
819                                                MLIRContext *context) {
820   patterns.add<TensorCastConstShape>(context);
821 }
822 
823 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
824     MLIRContext *context, Optional<Location> location, ValueRange operands,
825     DictionaryAttr attributes, RegionRange regions,
826     SmallVectorImpl<Type> &inferredReturnTypes) {
827   Builder b(context);
828   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
829   if (!shape)
830     return emitOptionalError(location, "missing shape attribute");
831   inferredReturnTypes.assign({RankedTensorType::get(
832       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
833   return success();
834 }
835 
836 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
837                                                         TypeRange r) {
838   if (l.size() != 1 || r.size() != 1)
839     return false;
840 
841   Type lhs = l.front();
842   Type rhs = r.front();
843 
844   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
845     // Shape type is compatible with all other valid return types.
846     return true;
847   return lhs == rhs;
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // CstrBroadcastableOp
852 //===----------------------------------------------------------------------===//
853 
854 void CstrBroadcastableOp::getCanonicalizationPatterns(
855     RewritePatternSet &patterns, MLIRContext *context) {
856   // Canonicalization patterns have overlap with the considerations during
857   // folding in case additional shape information is inferred at some point that
858   // does not result in folding.
859   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
860                CstrBroadcastableEqOps,
861                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
862                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
863 }
864 
865 // Return true if there is exactly one attribute not representing a scalar
866 // broadcast.
867 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
868   bool nonScalarSeen = false;
869   for (Attribute a : attributes) {
870     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
871       if (nonScalarSeen)
872         return false;
873       nonScalarSeen = true;
874     }
875   }
876   return true;
877 }
878 
879 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
880   // No broadcasting is needed if all operands but one are scalar.
881   if (hasAtMostSingleNonScalar(operands))
882     return BoolAttr::get(getContext(), true);
883 
884   if ([&] {
885         SmallVector<SmallVector<int64_t, 6>, 6> extents;
886         for (const auto &operand : operands) {
887           if (!operand)
888             return false;
889           extents.push_back(llvm::to_vector<6>(
890               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
891         }
892         return OpTrait::util::staticallyKnownBroadcastable(extents);
893       }())
894     return BoolAttr::get(getContext(), true);
895 
896   // Lastly, see if folding can be completed based on what constraints are known
897   // on the input shapes.
898   if ([&] {
899         SmallVector<SmallVector<int64_t, 6>, 6> extents;
900         for (auto shapeValue : getShapes()) {
901           extents.emplace_back();
902           if (failed(getShapeVec(shapeValue, extents.back())))
903             return false;
904         }
905         return OpTrait::util::staticallyKnownBroadcastable(extents);
906       }())
907     return BoolAttr::get(getContext(), true);
908 
909   // Because a failing witness result here represents an eventual assertion
910   // failure, we do not replace it with a constant witness.
911   return nullptr;
912 }
913 
914 static LogicalResult verify(CstrBroadcastableOp op) {
915   // Ensure that AssumingAllOp contains at least one operand
916   if (op.getNumOperands() < 2)
917     return op.emitOpError("required at least 2 input shapes");
918   return success();
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // CstrEqOp
923 //===----------------------------------------------------------------------===//
924 
925 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
926                                            MLIRContext *context) {
927   // If inputs are equal, return passing witness
928   patterns.add<CstrEqEqOps>(context);
929 }
930 
931 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
932   if (llvm::all_of(operands,
933                    [&](Attribute a) { return a && a == operands[0]; }))
934     return BoolAttr::get(getContext(), true);
935 
936   // Because a failing witness result here represents an eventual assertion
937   // failure, we do not try to replace it with a constant witness. Similarly, we
938   // cannot if there are any non-const inputs.
939   return nullptr;
940 }
941 
942 //===----------------------------------------------------------------------===//
943 // ConstSizeOp
944 //===----------------------------------------------------------------------===//
945 
946 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
947                         int64_t value) {
948   build(builder, result, builder.getIndexAttr(value));
949 }
950 
951 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
952 
953 void ConstSizeOp::getAsmResultNames(
954     llvm::function_ref<void(Value, StringRef)> setNameFn) {
955   SmallString<4> buffer;
956   llvm::raw_svector_ostream os(buffer);
957   os << "c" << getValue();
958   setNameFn(getResult(), os.str());
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // ConstWitnessOp
963 //===----------------------------------------------------------------------===//
964 
965 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
966   return getPassingAttr();
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // CstrRequireOp
971 //===----------------------------------------------------------------------===//
972 
973 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
974   return operands[0];
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // DivOp
979 //===----------------------------------------------------------------------===//
980 
981 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
982   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
983   if (!lhs)
984     return nullptr;
985   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
986   if (!rhs)
987     return nullptr;
988 
989   // Division in APInt does not follow floor(lhs, rhs) when the result is
990   // negative. Rather, APInt rounds toward zero.
991   APInt quotient, remainder;
992   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
993   if (quotient.isNegative() && !remainder.isNullValue()) {
994     quotient -= 1;
995   }
996 
997   Type indexTy = IndexType::get(getContext());
998   return IntegerAttr::get(indexTy, quotient);
999 }
1000 
1001 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1002     MLIRContext *context, Optional<Location> location, ValueRange operands,
1003     DictionaryAttr attributes, RegionRange regions,
1004     SmallVectorImpl<Type> &inferredReturnTypes) {
1005   if (operands[0].getType().isa<SizeType>() ||
1006       operands[1].getType().isa<SizeType>())
1007     inferredReturnTypes.assign({SizeType::get(context)});
1008   else
1009     inferredReturnTypes.assign({IndexType::get(context)});
1010   return success();
1011 }
1012 
1013 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1014   // SizeType is compatible with IndexType.
1015   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // ShapeEqOp
1020 //===----------------------------------------------------------------------===//
1021 
1022 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1023   bool allSame = true;
1024   if (!operands.empty() && !operands[0])
1025     return {};
1026   for (Attribute operand : operands.drop_front(1)) {
1027     if (!operand)
1028       return {};
1029     allSame = allSame && operand == operands[0];
1030   }
1031   return BoolAttr::get(getContext(), allSame);
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // IndexToSizeOp
1036 //===----------------------------------------------------------------------===//
1037 
1038 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1039   // Constant values of both types, `shape.size` and `index`, are represented as
1040   // `IntegerAttr`s which makes constant folding simple.
1041   if (Attribute arg = operands[0])
1042     return arg;
1043   return {};
1044 }
1045 
1046 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1047                                                 MLIRContext *context) {
1048   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // FromExtentsOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1056   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1057     return nullptr;
1058   SmallVector<int64_t, 6> extents;
1059   for (auto attr : operands)
1060     extents.push_back(attr.cast<IntegerAttr>().getInt());
1061   Builder builder(getContext());
1062   return builder.getIndexTensorAttr(extents);
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // FunctionLibraryOp
1067 //===----------------------------------------------------------------------===//
1068 
1069 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1070                               StringRef name) {
1071   result.attributes.push_back(builder.getNamedAttr(
1072       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1073 }
1074 
1075 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1076   auto attr = getMapping()
1077                   .get(op->getName().getIdentifier())
1078                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1079   if (!attr)
1080     return nullptr;
1081   return lookupSymbol<FuncOp>(attr);
1082 }
1083 
1084 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1085                                    OperationState &result) {
1086   // Parse the op name.
1087   StringAttr nameAttr;
1088   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1089                              result.attributes))
1090     return failure();
1091 
1092   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1093     return failure();
1094 
1095   auto *bodyRegion = result.addRegion();
1096   if (parser.parseRegion(*bodyRegion))
1097     return failure();
1098 
1099   if (parser.parseKeyword("mapping"))
1100     return failure();
1101 
1102   DictionaryAttr mappingAttr;
1103   if (parser.parseAttribute(mappingAttr,
1104                             parser.getBuilder().getType<NoneType>(), "mapping",
1105                             result.attributes))
1106     return failure();
1107   return success();
1108 }
1109 
1110 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1111   p << ' ';
1112   p.printSymbolName(op.getName());
1113   p.printOptionalAttrDictWithKeyword(
1114       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1115   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1116                 /*printBlockTerminators=*/false);
1117   p << " mapping ";
1118   p.printAttributeWithoutType(op.getMappingAttr());
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // GetExtentOp
1123 //===----------------------------------------------------------------------===//
1124 
1125 Optional<int64_t> GetExtentOp::getConstantDim() {
1126   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1127     return constSizeOp.getValue().getLimitedValue();
1128   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1129     return constantOp.getValue().cast<IntegerAttr>().getInt();
1130   return llvm::None;
1131 }
1132 
1133 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1134   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1135   if (!elements)
1136     return nullptr;
1137   Optional<int64_t> dim = getConstantDim();
1138   if (!dim.hasValue())
1139     return nullptr;
1140   if (dim.getValue() >= elements.getNumElements())
1141     return nullptr;
1142   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1143 }
1144 
1145 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1146                         int64_t dim) {
1147   auto loc = result.location;
1148   auto dimAttr = builder.getIndexAttr(dim);
1149   if (shape.getType().isa<ShapeType>()) {
1150     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1151     build(builder, result, builder.getType<SizeType>(), shape, dim);
1152   } else {
1153     Value dim =
1154         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1155     build(builder, result, builder.getIndexType(), shape, dim);
1156   }
1157 }
1158 
1159 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1160     MLIRContext *context, Optional<Location> location, ValueRange operands,
1161     DictionaryAttr attributes, RegionRange regions,
1162     SmallVectorImpl<Type> &inferredReturnTypes) {
1163   inferredReturnTypes.assign({IndexType::get(context)});
1164   return success();
1165 }
1166 
1167 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1168                                                        TypeRange r) {
1169   // SizeType is compatible with IndexType.
1170   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // IsBroadcastableOp
1175 //===----------------------------------------------------------------------===//
1176 
1177 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1178                                                     MLIRContext *context) {
1179   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1180 }
1181 
1182 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1183   // Can always broadcast fewer than two shapes.
1184   if (operands.size() < 2) {
1185     return BoolAttr::get(getContext(), true);
1186   }
1187 
1188   return nullptr;
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // MeetOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1196     MLIRContext *context, Optional<Location> location, ValueRange operands,
1197     DictionaryAttr attributes, RegionRange regions,
1198     SmallVectorImpl<Type> &inferredReturnTypes) {
1199   inferredReturnTypes.assign({operands[0].getType()});
1200   return success();
1201 }
1202 
1203 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1204   if (l.size() != 1 || r.size() != 1)
1205     return false;
1206   if (l == r)
1207     return true;
1208 
1209   Type lhs = l.front();
1210   Type rhs = r.front();
1211 
1212   if (lhs != rhs)
1213     return false;
1214 
1215   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1216     return true;
1217 
1218   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1219     return true;
1220   return false;
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // RankOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1228   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1229   if (!shape)
1230     return {};
1231   int64_t rank = shape.getNumElements();
1232   Builder builder(getContext());
1233   return builder.getIndexAttr(rank);
1234 }
1235 
1236 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1237 /// Constant folding fails in cases where only the rank is constant, not the
1238 /// shape itself.
1239 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1240 ///
1241 /// Example:
1242 ///
1243 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1244 /// %rank = shape.rank %shape
1245 ///
1246 /// becomes
1247 ///
1248 /// %rank = shape.const_size 3
1249 
1250 namespace {
1251 struct RankShapeOfCanonicalizationPattern
1252     : public OpRewritePattern<shape::RankOp> {
1253   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1254 
1255   LogicalResult matchAndRewrite(shape::RankOp op,
1256                                 PatternRewriter &rewriter) const override {
1257     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1258     if (!shapeOfOp)
1259       return failure();
1260     auto rankedTensorType =
1261         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1262     if (!rankedTensorType)
1263       return failure();
1264     int64_t rank = rankedTensorType.getRank();
1265     if (op.getType().isa<IndexType>()) {
1266       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1267                                                           rank);
1268     } else if (op.getType().isa<shape::SizeType>()) {
1269       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1270     } else {
1271       return failure();
1272     }
1273     return success();
1274   }
1275 };
1276 } // namespace
1277 
1278 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1279                                                 MLIRContext *context) {
1280   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1281 }
1282 
1283 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1284     MLIRContext *context, Optional<Location> location, ValueRange operands,
1285     DictionaryAttr attributes, RegionRange regions,
1286     SmallVectorImpl<Type> &inferredReturnTypes) {
1287   if (operands[0].getType().isa<ShapeType>())
1288     inferredReturnTypes.assign({SizeType::get(context)});
1289   else
1290     inferredReturnTypes.assign({IndexType::get(context)});
1291   return success();
1292 }
1293 
1294 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1295   // SizeType is compatible with IndexType.
1296   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // NumElementsOp
1301 //===----------------------------------------------------------------------===//
1302 
1303 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1304 
1305   // Fold only when argument constant.
1306   Attribute shape = operands[0];
1307   if (!shape)
1308     return {};
1309 
1310   APInt product(64, 1);
1311   for (auto value : shape.cast<DenseIntElementsAttr>())
1312     product *= value;
1313   Builder builder(getContext());
1314   return builder.getIndexAttr(product.getLimitedValue());
1315 }
1316 
1317 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1318     MLIRContext *context, Optional<Location> location, ValueRange operands,
1319     DictionaryAttr attributes, RegionRange regions,
1320     SmallVectorImpl<Type> &inferredReturnTypes) {
1321   if (operands[0].getType().isa<ShapeType>())
1322     inferredReturnTypes.assign({SizeType::get(context)});
1323   else
1324     inferredReturnTypes.assign({IndexType::get(context)});
1325   return success();
1326 }
1327 
1328 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1329                                                          TypeRange r) {
1330   // SizeType is compatible with IndexType.
1331   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // MaxOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1339   // If operands are equal, just propagate one.
1340   if (getLhs() == getRhs())
1341     return getLhs();
1342   return nullptr;
1343 }
1344 
1345 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1346     MLIRContext *context, Optional<Location> location, ValueRange operands,
1347     DictionaryAttr attributes, RegionRange regions,
1348     SmallVectorImpl<Type> &inferredReturnTypes) {
1349   if (operands[0].getType() == operands[1].getType())
1350     inferredReturnTypes.assign({operands[0].getType()});
1351   else
1352     inferredReturnTypes.assign({SizeType::get(context)});
1353   return success();
1354 }
1355 
1356 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1357   if (l.size() != 1 || r.size() != 1)
1358     return false;
1359   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1360     return true;
1361   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1362     return true;
1363   return false;
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // MinOp
1368 //===----------------------------------------------------------------------===//
1369 
1370 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1371   // If operands are equal, just propagate one.
1372   if (getLhs() == getRhs())
1373     return getLhs();
1374   return nullptr;
1375 }
1376 
1377 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1378     MLIRContext *context, Optional<Location> location, ValueRange operands,
1379     DictionaryAttr attributes, RegionRange regions,
1380     SmallVectorImpl<Type> &inferredReturnTypes) {
1381   if (operands[0].getType() == operands[1].getType())
1382     inferredReturnTypes.assign({operands[0].getType()});
1383   else
1384     inferredReturnTypes.assign({SizeType::get(context)});
1385   return success();
1386 }
1387 
1388 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1389   if (l.size() != 1 || r.size() != 1)
1390     return false;
1391   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1392     return true;
1393   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1394     return true;
1395   return false;
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // MulOp
1400 //===----------------------------------------------------------------------===//
1401 
1402 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1403   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1404   if (!lhs)
1405     return nullptr;
1406   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1407   if (!rhs)
1408     return nullptr;
1409   APInt folded = lhs.getValue() * rhs.getValue();
1410   Type indexTy = IndexType::get(getContext());
1411   return IntegerAttr::get(indexTy, folded);
1412 }
1413 
1414 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1415     MLIRContext *context, Optional<Location> location, ValueRange operands,
1416     DictionaryAttr attributes, RegionRange regions,
1417     SmallVectorImpl<Type> &inferredReturnTypes) {
1418   if (operands[0].getType().isa<SizeType>() ||
1419       operands[1].getType().isa<SizeType>())
1420     inferredReturnTypes.assign({SizeType::get(context)});
1421   else
1422     inferredReturnTypes.assign({IndexType::get(context)});
1423   return success();
1424 }
1425 
1426 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1427   // SizeType is compatible with IndexType.
1428   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1429 }
1430 //===----------------------------------------------------------------------===//
1431 // ShapeOfOp
1432 //===----------------------------------------------------------------------===//
1433 
1434 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1435   auto type = getOperand().getType().dyn_cast<ShapedType>();
1436   if (!type || !type.hasStaticShape())
1437     return nullptr;
1438   Builder builder(getContext());
1439   return builder.getIndexTensorAttr(type.getShape());
1440 }
1441 
1442 namespace {
1443 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1444   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1445 
1446   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1447                                 PatternRewriter &rewriter) const override {
1448     if (!op.getArg().getType().isa<ShapedType>())
1449       return failure();
1450     if (op.getType().isa<ShapedType>())
1451       return failure();
1452 
1453     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1454                                                   op.getArg());
1455     return success();
1456   }
1457 };
1458 
1459 // Canonicalize
1460 // ```
1461 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1462 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1463 // ```
1464 // to
1465 // ```
1466 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1467 // ```
1468 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1469   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1470 
1471   LogicalResult matchAndRewrite(tensor::CastOp op,
1472                                 PatternRewriter &rewriter) const override {
1473     auto ty = op.getType().dyn_cast<RankedTensorType>();
1474     if (!ty || ty.getRank() != 1)
1475       return failure();
1476 
1477     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1478     if (!shapeOfOp)
1479       return failure();
1480 
1481     // Argument type must be ranked and must not conflict.
1482     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1483     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1484       return failure();
1485 
1486     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1487     return success();
1488   }
1489 };
1490 } // namespace
1491 
1492 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1493                                             MLIRContext *context) {
1494   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1495                ExtractFromShapeOfExtentTensor>(context);
1496 }
1497 
1498 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1499     MLIRContext *context, Optional<Location> location, ValueRange operands,
1500     DictionaryAttr attributes, RegionRange regions,
1501     SmallVectorImpl<Type> &inferredReturnTypes) {
1502   if (operands[0].getType().isa<ValueShapeType>())
1503     inferredReturnTypes.assign({ShapeType::get(context)});
1504   else {
1505     auto shapedTy = operands[0].getType().cast<ShapedType>();
1506     int64_t rank =
1507         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1508     Type indexTy = IndexType::get(context);
1509     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1510     inferredReturnTypes.assign({extentTensorTy});
1511   }
1512   return success();
1513 }
1514 
1515 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1516   if (l.size() != 1 || r.size() != 1)
1517     return false;
1518   if (l == r)
1519     return true;
1520 
1521   Type lhs = l.front();
1522   Type rhs = r.front();
1523 
1524   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1525     return false;
1526 
1527   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1528     // Shape type is compatible with all other valid return types.
1529     return true;
1530 
1531   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1532     return true;
1533   return false;
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // SizeToIndexOp
1538 //===----------------------------------------------------------------------===//
1539 
1540 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1541   // Constant values of both types, `shape.size` and `index`, are represented as
1542   // `IntegerAttr`s which makes constant folding simple.
1543   if (Attribute arg = operands[0])
1544     return arg;
1545   return impl::foldCastOp(*this);
1546 }
1547 
1548 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1549                                                 MLIRContext *context) {
1550   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // YieldOp
1555 //===----------------------------------------------------------------------===//
1556 
1557 static LogicalResult verify(shape::YieldOp op) {
1558   auto *parentOp = op->getParentOp();
1559   auto results = parentOp->getResults();
1560   auto operands = op.getOperands();
1561 
1562   if (parentOp->getNumResults() != op.getNumOperands())
1563     return op.emitOpError() << "number of operands does not match number of "
1564                                "results of its parent";
1565   for (auto e : llvm::zip(results, operands))
1566     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1567       return op.emitOpError()
1568              << "types mismatch between yield op and its parent";
1569 
1570   return success();
1571 }
1572 
1573 //===----------------------------------------------------------------------===//
1574 // SplitAtOp
1575 //===----------------------------------------------------------------------===//
1576 
1577 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1578                               SmallVectorImpl<OpFoldResult> &results) {
1579   if (!operands[0] || !operands[1])
1580     return failure();
1581   auto shapeVec = llvm::to_vector<6>(
1582       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1583   auto shape = llvm::makeArrayRef(shapeVec);
1584   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1585   // Verify that the split point is in the correct range.
1586   // TODO: Constant fold to an "error".
1587   int64_t rank = shape.size();
1588   if (!(-rank <= splitPoint && splitPoint <= rank))
1589     return failure();
1590   if (splitPoint < 0)
1591     splitPoint += shape.size();
1592   Builder builder(operands[0].getContext());
1593   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1594   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1595   return success();
1596 }
1597 
1598 //===----------------------------------------------------------------------===//
1599 // ToExtentTensorOp
1600 //===----------------------------------------------------------------------===//
1601 
1602 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1603   if (!operands[0])
1604     return impl::foldCastOp(*this);
1605   Builder builder(getContext());
1606   auto shape = llvm::to_vector<6>(
1607       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1608   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1609                                     builder.getIndexType());
1610   return DenseIntElementsAttr::get(type, shape);
1611 }
1612 
1613 //===----------------------------------------------------------------------===//
1614 // ReduceOp
1615 //===----------------------------------------------------------------------===//
1616 
1617 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1618                      ValueRange initVals) {
1619   result.addOperands(shape);
1620   result.addOperands(initVals);
1621 
1622   Region *bodyRegion = result.addRegion();
1623   bodyRegion->push_back(new Block);
1624   Block &bodyBlock = bodyRegion->front();
1625   bodyBlock.addArgument(builder.getIndexType());
1626 
1627   Type elementType;
1628   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1629     elementType = tensorType.getElementType();
1630   else
1631     elementType = SizeType::get(builder.getContext());
1632   bodyBlock.addArgument(elementType);
1633 
1634   for (Type initValType : initVals.getTypes()) {
1635     bodyBlock.addArgument(initValType);
1636     result.addTypes(initValType);
1637   }
1638 }
1639 
1640 static LogicalResult verify(ReduceOp op) {
1641   // Verify block arg types.
1642   Block &block = op.getRegion().front();
1643 
1644   // The block takes index, extent, and aggregated values as arguments.
1645   auto blockArgsCount = op.getInitVals().size() + 2;
1646   if (block.getNumArguments() != blockArgsCount)
1647     return op.emitOpError() << "ReduceOp body is expected to have "
1648                             << blockArgsCount << " arguments";
1649 
1650   // The first block argument is the index and must always be of type `index`.
1651   if (!block.getArgument(0).getType().isa<IndexType>())
1652     return op.emitOpError(
1653         "argument 0 of ReduceOp body is expected to be of IndexType");
1654 
1655   // The second block argument is the extent and must be of type `size` or
1656   // `index`, depending on whether the reduce operation is applied to a shape or
1657   // to an extent tensor.
1658   Type extentTy = block.getArgument(1).getType();
1659   if (op.getShape().getType().isa<ShapeType>()) {
1660     if (!extentTy.isa<SizeType>())
1661       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1662                             "SizeType if the ReduceOp operates on a ShapeType");
1663   } else {
1664     if (!extentTy.isa<IndexType>())
1665       return op.emitOpError(
1666           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1667           "ReduceOp operates on an extent tensor");
1668   }
1669 
1670   for (auto type : llvm::enumerate(op.getInitVals()))
1671     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1672       return op.emitOpError()
1673              << "type mismatch between argument " << type.index() + 2
1674              << " of ReduceOp body and initial value " << type.index();
1675   return success();
1676 }
1677 
1678 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1679   // Parse operands.
1680   SmallVector<OpAsmParser::OperandType, 3> operands;
1681   Type shapeOrExtentTensorType;
1682   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1683                               OpAsmParser::Delimiter::Paren) ||
1684       parser.parseColonType(shapeOrExtentTensorType) ||
1685       parser.parseOptionalArrowTypeList(result.types))
1686     return failure();
1687 
1688   // Resolve operands.
1689   auto initVals = llvm::makeArrayRef(operands).drop_front();
1690   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1691                             result.operands) ||
1692       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1693                              result.operands))
1694     return failure();
1695 
1696   // Parse the body.
1697   Region *body = result.addRegion();
1698   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1699     return failure();
1700 
1701   // Parse attributes.
1702   if (parser.parseOptionalAttrDict(result.attributes))
1703     return failure();
1704 
1705   return success();
1706 }
1707 
1708 static void print(OpAsmPrinter &p, ReduceOp op) {
1709   p << '(' << op.getShape() << ", " << op.getInitVals()
1710     << ") : " << op.getShape().getType();
1711   p.printOptionalArrowTypeList(op.getResultTypes());
1712   p.printRegion(op.getRegion());
1713   p.printOptionalAttrDict(op->getAttrs());
1714 }
1715 
1716 #define GET_OP_CLASSES
1717 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1718