1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/CommonFolders.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Traits.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 using namespace mlir;
29 using namespace mlir::shape;
30 
31 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
32 
33 namespace {
34 #include "ShapeCanonicalization.inc"
35 } // namespace
36 
37 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
38   return RankedTensorType::get({rank}, IndexType::get(ctx));
39 }
40 
41 bool shape::isExtentTensorType(Type type) {
42   auto ranked = type.dyn_cast<RankedTensorType>();
43   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
44 }
45 
46 LogicalResult shape::getShapeVec(Value input,
47                                  SmallVectorImpl<int64_t> &shapeValues) {
48   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
49     auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
50     if (!type.hasRank())
51       return failure();
52     shapeValues = llvm::to_vector<6>(type.getShape());
53     return success();
54   }
55   if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
56     shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
57     return success();
58   }
59   if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
60     shapeValues = llvm::to_vector<6>(
61         inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
62     return success();
63   }
64   return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68   return llvm::any_of(operandTypes, [](Type ty) {
69     return ty.isa<SizeType, ShapeType, ValueShapeType>();
70   });
71 }
72 
73 static LogicalResult verifySizeOrIndexOp(Operation *op) {
74   assert(op != nullptr && op->getNumResults() == 1);
75   Type resultTy = op->getResultTypes().front();
76   if (isErrorPropagationPossible(op->getOperandTypes())) {
77     if (!resultTy.isa<SizeType>())
78       return op->emitOpError()
79              << "if at least one of the operands can hold error values then "
80                 "the result must be of type `size` to propagate them";
81   }
82   return success();
83 }
84 
85 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
86   assert(op != nullptr && op->getNumResults() == 1);
87   Type resultTy = op->getResultTypes().front();
88   if (isErrorPropagationPossible(op->getOperandTypes())) {
89     if (!resultTy.isa<ShapeType>())
90       return op->emitOpError()
91              << "if at least one of the operands can hold error values then "
92                 "the result must be of type `shape` to propagate them";
93   }
94   return success();
95 }
96 
97 template <typename... Ty>
98 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
99   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
100 }
101 
102 template <typename... Ty, typename... ranges>
103 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
104   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // InlinerInterface
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 /// This class defines the interface for inlining shape dialect ops.
113 struct ShapeInlinerInterface : public DialectInlinerInterface {
114   using DialectInlinerInterface::DialectInlinerInterface;
115 
116   // Returns true if the given region 'src' can be inlined into the region
117   // 'dest' that is attached to an operation registered to the current dialect.
118   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
119                        BlockAndValueMapping &) const final {
120     return true;
121   }
122 
123   // Returns true if the given operation 'op', that is registered to this
124   // dialect, can be inlined into the region 'dest' that is attached to an
125   // operation registered to the current dialect.
126   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
127                        BlockAndValueMapping &) const final {
128     return true;
129   }
130 };
131 } // namespace
132 
133 void ShapeDialect::initialize() {
134   addOperations<
135 #define GET_OP_LIST
136 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
137       >();
138   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
139   addInterfaces<ShapeInlinerInterface>();
140   // Allow unknown operations during prototyping and testing. As the dialect is
141   // still evolving it makes it simple to start with an unregistered ops and
142   // try different variants before actually defining the op.
143   allowUnknownOperations();
144 }
145 
146 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
147                                              Attribute value, Type type,
148                                              Location loc) {
149   if (type.isa<ShapeType>() || isExtentTensorType(type))
150     return builder.create<ConstShapeOp>(loc, type,
151                                         value.cast<DenseIntElementsAttr>());
152   if (type.isa<SizeType>())
153     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
154   if (type.isa<WitnessType>())
155     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
156   if (arith::ConstantOp::isBuildableWith(value, type))
157     return builder.create<arith::ConstantOp>(loc, type, value);
158   return nullptr;
159 }
160 
161 /// Parse a type registered to this dialect.
162 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
163   StringRef keyword;
164   if (parser.parseKeyword(&keyword))
165     return Type();
166 
167   if (keyword == "shape")
168     return ShapeType::get(getContext());
169   if (keyword == "size")
170     return SizeType::get(getContext());
171   if (keyword == "value_shape")
172     return ValueShapeType::get(getContext());
173   if (keyword == "witness")
174     return WitnessType::get(getContext());
175 
176   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
177   return Type();
178 }
179 
180 /// Print a type registered to this dialect.
181 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
182   TypeSwitch<Type>(type)
183       .Case<ShapeType>([&](Type) { os << "shape"; })
184       .Case<SizeType>([&](Type) { os << "size"; })
185       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
186       .Case<WitnessType>([&](Type) { os << "witness"; })
187       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
188 }
189 
190 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
191                                                      NamedAttribute attribute) {
192   // Verify shape.lib attribute.
193   if (attribute.getName() == "shape.lib") {
194     if (!op->hasTrait<OpTrait::SymbolTable>())
195       return op->emitError(
196           "shape.lib attribute may only be on op implementing SymbolTable");
197 
198     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
199       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
200       if (!symbol)
201         return op->emitError("shape function library ")
202                << symbolRef << " not found";
203       return isa<shape::FunctionLibraryOp>(symbol)
204                  ? success()
205                  : op->emitError()
206                        << symbolRef << " required to be shape function library";
207     }
208 
209     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
210       // Verify all entries are function libraries and mappings in libraries
211       // refer to unique ops.
212       DenseSet<StringAttr> key;
213       for (auto it : arr) {
214         if (!it.isa<SymbolRefAttr>())
215           return op->emitError(
216               "only SymbolRefAttr allowed in shape.lib attribute array");
217 
218         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
219             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
220         if (!shapeFnLib)
221           return op->emitError()
222                  << it << " does not refer to FunctionLibraryOp";
223         for (auto mapping : shapeFnLib.getMapping()) {
224           if (!key.insert(mapping.getName()).second) {
225             return op->emitError("only one op to shape mapping allowed, found "
226                                  "multiple for `")
227                    << mapping.getName() << "`";
228           }
229         }
230       }
231       return success();
232     }
233 
234     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
235                          "allowed as shape.lib attribute");
236   }
237   return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AnyOp
242 //===----------------------------------------------------------------------===//
243 
244 // TODO: Canonicalization should be implemented for shapes that can be
245 // determined through mixtures of the known dimensions of the inputs.
246 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
247   // Only the last operand is checked because AnyOp is commutative.
248   if (operands.back())
249     return operands.back();
250 
251   return nullptr;
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // AssumingOp
256 //===----------------------------------------------------------------------===//
257 
258 static ParseResult parseAssumingOp(OpAsmParser &parser,
259                                    OperationState &result) {
260   result.regions.reserve(1);
261   Region *doRegion = result.addRegion();
262 
263   auto &builder = parser.getBuilder();
264   OpAsmParser::OperandType cond;
265   if (parser.parseOperand(cond) ||
266       parser.resolveOperand(cond, builder.getType<WitnessType>(),
267                             result.operands))
268     return failure();
269 
270   // Parse optional results type list.
271   if (parser.parseOptionalArrowTypeList(result.types))
272     return failure();
273 
274   // Parse the region and add a terminator if elided.
275   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
276     return failure();
277   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
278 
279   // Parse the optional attribute list.
280   if (parser.parseOptionalAttrDict(result.attributes))
281     return failure();
282   return success();
283 }
284 
285 static void print(OpAsmPrinter &p, AssumingOp op) {
286   bool yieldsResults = !op.getResults().empty();
287 
288   p << " " << op.getWitness();
289   if (yieldsResults)
290     p << " -> (" << op.getResultTypes() << ")";
291   p << ' ';
292   p.printRegion(op.getDoRegion(),
293                 /*printEntryBlockArgs=*/false,
294                 /*printBlockTerminators=*/yieldsResults);
295   p.printOptionalAttrDict(op->getAttrs());
296 }
297 
298 namespace {
299 // Removes AssumingOp with a passing witness and inlines the region.
300 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
301   using OpRewritePattern<AssumingOp>::OpRewritePattern;
302 
303   LogicalResult matchAndRewrite(AssumingOp op,
304                                 PatternRewriter &rewriter) const override {
305     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
306     if (!witness || !witness.getPassingAttr())
307       return failure();
308 
309     AssumingOp::inlineRegionIntoParent(op, rewriter);
310     return success();
311   }
312 };
313 
314 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
315   using OpRewritePattern<AssumingOp>::OpRewritePattern;
316 
317   LogicalResult matchAndRewrite(AssumingOp op,
318                                 PatternRewriter &rewriter) const override {
319     Block *body = op.getBody();
320     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
321 
322     // Find used values.
323     SmallVector<Value, 4> newYieldOperands;
324     Value opResult, yieldOperand;
325     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
326       std::tie(opResult, yieldOperand) = it;
327       if (!opResult.getUses().empty()) {
328         newYieldOperands.push_back(yieldOperand);
329       }
330     }
331 
332     // Rewrite only if redundant results exist.
333     if (newYieldOperands.size() == yieldOp->getNumOperands())
334       return failure();
335 
336     // Replace yield op in the old assuming op's body and move the entire region
337     // to the new assuming op.
338     rewriter.setInsertionPointToEnd(body);
339     auto newYieldOp =
340         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
341     rewriter.setInsertionPoint(op);
342     auto newOp = rewriter.create<AssumingOp>(
343         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
344     newOp.getDoRegion().takeBody(op.getDoRegion());
345 
346     // Use the new results to replace the previously used ones.
347     SmallVector<Value, 4> replacementValues;
348     auto src = newOp.getResults().begin();
349     for (auto it : op.getResults()) {
350       if (it.getUses().empty())
351         replacementValues.push_back(nullptr);
352       else
353         replacementValues.push_back(*src++);
354     }
355     rewriter.replaceOp(op, replacementValues);
356     return success();
357   }
358 };
359 } // namespace
360 
361 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
362                                              MLIRContext *context) {
363   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
364 }
365 
366 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
367 void AssumingOp::getSuccessorRegions(
368     Optional<unsigned> index, ArrayRef<Attribute> operands,
369     SmallVectorImpl<RegionSuccessor> &regions) {
370   // AssumingOp has unconditional control flow into the region and back to the
371   // parent, so return the correct RegionSuccessor purely based on the index
372   // being None or 0.
373   if (index.hasValue()) {
374     regions.push_back(RegionSuccessor(getResults()));
375     return;
376   }
377 
378   regions.push_back(RegionSuccessor(&getDoRegion()));
379 }
380 
381 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
382                                         PatternRewriter &rewriter) {
383   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
384   auto *assumingBlock = op.getBody();
385   auto initPosition = rewriter.getInsertionPoint();
386   auto *blockAfterAssuming =
387       rewriter.splitBlock(blockBeforeAssuming, initPosition);
388 
389   // Remove the AssumingOp and AssumingYieldOp.
390   auto &yieldOp = assumingBlock->back();
391   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
392   rewriter.replaceOp(op, yieldOp.getOperands());
393   rewriter.eraseOp(&yieldOp);
394 
395   // Merge blocks together as there was no branching behavior from the
396   // AssumingOp.
397   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
398   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
399 }
400 
401 void AssumingOp::build(
402     OpBuilder &builder, OperationState &result, Value witness,
403     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
404 
405   result.addOperands(witness);
406   Region *bodyRegion = result.addRegion();
407   bodyRegion->push_back(new Block);
408   Block &bodyBlock = bodyRegion->front();
409 
410   // Build body.
411   OpBuilder::InsertionGuard guard(builder);
412   builder.setInsertionPointToStart(&bodyBlock);
413   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
414   builder.create<AssumingYieldOp>(result.location, yieldValues);
415 
416   SmallVector<Type, 2> assumingTypes;
417   for (Value v : yieldValues)
418     assumingTypes.push_back(v.getType());
419   result.addTypes(assumingTypes);
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // AddOp
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult mlir::shape::AddOp::inferReturnTypes(
427     MLIRContext *context, Optional<Location> location, ValueRange operands,
428     DictionaryAttr attributes, RegionRange regions,
429     SmallVectorImpl<Type> &inferredReturnTypes) {
430   if (operands[0].getType().isa<SizeType>() ||
431       operands[1].getType().isa<SizeType>())
432     inferredReturnTypes.assign({SizeType::get(context)});
433   else
434     inferredReturnTypes.assign({IndexType::get(context)});
435   return success();
436 }
437 
438 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
439   // SizeType is compatible with IndexType.
440   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
441 }
442 
443 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
444   // add(x, 0) -> x
445   if (matchPattern(getRhs(), m_Zero()))
446     return getLhs();
447 
448   return constFoldBinaryOp<IntegerAttr>(
449       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // AssumingAllOp
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 struct AssumingAllToCstrEqCanonicalization
458     : public OpRewritePattern<AssumingAllOp> {
459   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
460 
461   LogicalResult matchAndRewrite(AssumingAllOp op,
462                                 PatternRewriter &rewriter) const override {
463     SmallVector<Value, 8> shapes;
464     for (Value w : op.getInputs()) {
465       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
466       if (!cstrEqOp)
467         return failure();
468       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
469         return llvm::is_contained(shapes, s);
470       });
471       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
472         return failure();
473       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
474     }
475     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
476     return success();
477   }
478 };
479 
480 template <typename OpTy>
481 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
482   using OpRewritePattern<OpTy>::OpRewritePattern;
483 
484   LogicalResult matchAndRewrite(OpTy op,
485                                 PatternRewriter &rewriter) const override {
486     // Find unique operands.
487     SmallVector<Value, 2> unique;
488     for (Value v : op.getOperands()) {
489       if (!llvm::is_contained(unique, v))
490         unique.push_back(v);
491     }
492 
493     // Reduce op to equivalent with unique operands.
494     if (unique.size() < op.getNumOperands()) {
495       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
496                                         op->getAttrs());
497       return success();
498     }
499 
500     return failure();
501   }
502 };
503 } // namespace
504 
505 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
506                                                 MLIRContext *context) {
507   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
508                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
509 }
510 
511 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
512   // Iterate in reverse to first handle all constant operands. They are
513   // guaranteed to be the tail of the inputs because this is commutative.
514   for (int idx = operands.size() - 1; idx >= 0; idx--) {
515     Attribute a = operands[idx];
516     // Cannot fold if any inputs are not constant;
517     if (!a)
518       return nullptr;
519 
520     // We do not need to keep statically known values after handling them in
521     // this method.
522     getOperation()->eraseOperand(idx);
523 
524     // Always false if any input is statically known false
525     if (!a.cast<BoolAttr>().getValue())
526       return a;
527   }
528   // If this is reached, all inputs were statically known passing.
529   return BoolAttr::get(getContext(), true);
530 }
531 
532 static LogicalResult verify(AssumingAllOp op) {
533   // Ensure that AssumingAllOp contains at least one operand
534   if (op.getNumOperands() == 0)
535     return op.emitOpError("no operands specified");
536 
537   return success();
538 }
539 
540 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
541                           ValueRange inputs) {
542   build(b, state, b.getType<WitnessType>(), inputs);
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // BroadcastOp
547 //===----------------------------------------------------------------------===//
548 
549 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
550   if (getShapes().size() == 1) {
551     // Otherwise, we need a cast which would be a canonicalization, not folding.
552     if (getShapes().front().getType() != getType())
553       return nullptr;
554     return getShapes().front();
555   }
556 
557   // TODO: Support folding with more than 2 input shapes
558   if (getShapes().size() > 2)
559     return nullptr;
560 
561   if (!operands[0] || !operands[1])
562     return nullptr;
563   auto lhsShape = llvm::to_vector<6>(
564       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
565   auto rhsShape = llvm::to_vector<6>(
566       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
567   SmallVector<int64_t, 6> resultShape;
568 
569   // If the shapes are not compatible, we can't fold it.
570   // TODO: Fold to an "error".
571   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
572     return nullptr;
573 
574   Builder builder(getContext());
575   return builder.getIndexTensorAttr(resultShape);
576 }
577 
578 static LogicalResult verify(BroadcastOp op) {
579   return verifyShapeOrExtentTensorOp(op);
580 }
581 
582 namespace {
583 template <typename OpTy>
584 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
585   using OpRewritePattern<OpTy>::OpRewritePattern;
586 
587   LogicalResult matchAndRewrite(OpTy op,
588                                 PatternRewriter &rewriter) const override {
589     auto isPotentiallyNonEmptyShape = [](Value shape) {
590       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
591         if (extentTensorTy.getDimSize(0) == 0)
592           return false;
593       }
594       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
595         if (constShape.getShape().empty())
596           return false;
597       }
598       return true;
599     };
600     auto newOperands = llvm::to_vector<8>(
601         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
602 
603     // Reduce op to equivalent without empty shape operands.
604     if (newOperands.size() < op.getNumOperands()) {
605       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
606                                         op->getAttrs());
607       return success();
608     }
609 
610     return failure();
611   }
612 };
613 
614 struct BroadcastForwardSingleOperandPattern
615     : public OpRewritePattern<BroadcastOp> {
616   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
617 
618   LogicalResult matchAndRewrite(BroadcastOp op,
619                                 PatternRewriter &rewriter) const override {
620     if (op.getNumOperands() != 1)
621       return failure();
622     Value replacement = op.getShapes().front();
623 
624     // Insert cast if needed.
625     if (replacement.getType() != op.getType()) {
626       auto loc = op.getLoc();
627       if (op.getType().isa<ShapeType>()) {
628         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
629       } else {
630         assert(!op.getType().isa<ShapeType>() &&
631                !replacement.getType().isa<ShapeType>() &&
632                "expect extent tensor cast");
633         replacement =
634             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
635       }
636     }
637 
638     rewriter.replaceOp(op, replacement);
639     return success();
640   }
641 };
642 
643 struct BroadcastFoldConstantOperandsPattern
644     : public OpRewritePattern<BroadcastOp> {
645   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
646 
647   LogicalResult matchAndRewrite(BroadcastOp op,
648                                 PatternRewriter &rewriter) const override {
649     SmallVector<int64_t, 8> foldedConstantShape;
650     SmallVector<Value, 8> newShapeOperands;
651     for (Value shape : op.getShapes()) {
652       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
653         SmallVector<int64_t, 8> newFoldedConstantShape;
654         if (OpTrait::util::getBroadcastedShape(
655                 foldedConstantShape,
656                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
657                 newFoldedConstantShape)) {
658           foldedConstantShape = newFoldedConstantShape;
659           continue;
660         }
661       }
662       newShapeOperands.push_back(shape);
663     }
664 
665     // Need at least two constant operands to fold anything.
666     if (op.getNumOperands() - newShapeOperands.size() < 2)
667       return failure();
668 
669     auto foldedConstantOperandsTy = RankedTensorType::get(
670         {static_cast<int64_t>(foldedConstantShape.size())},
671         rewriter.getIndexType());
672     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
673         op.getLoc(), foldedConstantOperandsTy,
674         rewriter.getIndexTensorAttr(foldedConstantShape)));
675     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
676                                              newShapeOperands);
677     return success();
678   }
679 };
680 
681 template <typename OpTy>
682 struct CanonicalizeCastExtentTensorOperandsPattern
683     : public OpRewritePattern<OpTy> {
684   using OpRewritePattern<OpTy>::OpRewritePattern;
685 
686   LogicalResult matchAndRewrite(OpTy op,
687                                 PatternRewriter &rewriter) const override {
688     // Canonicalize operands.
689     bool anyChange = false;
690     auto canonicalizeOperand = [&](Value operand) {
691       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
692         // Only eliminate the cast if it holds no shape information.
693         bool isInformationLoosingCast =
694             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
695         if (isInformationLoosingCast) {
696           anyChange = true;
697           return castOp.source();
698         }
699       }
700       return operand;
701     };
702     auto newOperands = llvm::to_vector<8>(
703         llvm::map_range(op.getOperands(), canonicalizeOperand));
704 
705     // Rewrite op if any change required.
706     if (!anyChange)
707       return failure();
708     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
709     return success();
710   }
711 };
712 
713 struct BroadcastConcretizeResultTypePattern
714     : public OpRewritePattern<BroadcastOp> {
715   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
716 
717   LogicalResult matchAndRewrite(BroadcastOp op,
718                                 PatternRewriter &rewriter) const override {
719     // Only concretize dynamic extent tensor result types.
720     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
721     if (!resultTy || !resultTy.isDynamicDim(0))
722       return failure();
723 
724     // Infer resulting shape rank if possible.
725     int64_t maxRank = 0;
726     for (Value shape : op.getShapes()) {
727       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
728         // Cannot infer resulting shape rank if any operand is dynamically
729         // ranked.
730         if (extentTensorTy.isDynamicDim(0))
731           return failure();
732         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
733       }
734     }
735 
736     auto newOp = rewriter.create<BroadcastOp>(
737         op.getLoc(), getExtentTensorType(getContext(), maxRank),
738         op.getShapes());
739     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
740     return success();
741   }
742 };
743 } // namespace
744 
745 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
746                                               MLIRContext *context) {
747   patterns.add<BroadcastConcretizeResultTypePattern,
748                BroadcastFoldConstantOperandsPattern,
749                BroadcastForwardSingleOperandPattern,
750                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
751                RemoveDuplicateOperandsPattern<BroadcastOp>,
752                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // ConcatOp
757 //===----------------------------------------------------------------------===//
758 
759 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
760   if (!operands[0] || !operands[1])
761     return nullptr;
762   auto lhsShape = llvm::to_vector<6>(
763       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
764   auto rhsShape = llvm::to_vector<6>(
765       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
766   SmallVector<int64_t, 6> resultShape;
767   resultShape.append(lhsShape.begin(), lhsShape.end());
768   resultShape.append(rhsShape.begin(), rhsShape.end());
769   Builder builder(getContext());
770   return builder.getIndexTensorAttr(resultShape);
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // ConstShapeOp
775 //===----------------------------------------------------------------------===//
776 
777 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
778   p << " ";
779   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
780   p << "[";
781   interleaveComma(op.getShape().getValues<int64_t>(), p,
782                   [&](int64_t i) { p << i; });
783   p << "] : ";
784   p.printType(op.getType());
785 }
786 
787 static ParseResult parseConstShapeOp(OpAsmParser &parser,
788                                      OperationState &result) {
789   if (parser.parseOptionalAttrDict(result.attributes))
790     return failure();
791   // We piggy-back on ArrayAttr parsing, though we don't internally store the
792   // shape as an ArrayAttr.
793   // TODO: Implement custom parser and maybe make syntax a bit more concise.
794   Attribute extentsRaw;
795   NamedAttrList dummy;
796   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
797     return failure();
798   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
799   if (!extentsArray)
800     return failure();
801   SmallVector<int64_t, 6> ints;
802   for (Attribute extent : extentsArray) {
803     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
804     if (!attr)
805       return failure();
806     ints.push_back(attr.getInt());
807   }
808   Builder &builder = parser.getBuilder();
809   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
810   Type resultTy;
811   if (parser.parseColonType(resultTy))
812     return failure();
813   result.types.push_back(resultTy);
814   return success();
815 }
816 
817 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
818 
819 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
820                                                MLIRContext *context) {
821   patterns.add<TensorCastConstShape>(context);
822 }
823 
824 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
825     MLIRContext *context, Optional<Location> location, ValueRange operands,
826     DictionaryAttr attributes, RegionRange regions,
827     SmallVectorImpl<Type> &inferredReturnTypes) {
828   Builder b(context);
829   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
830   if (!shape)
831     return emitOptionalError(location, "missing shape attribute");
832   inferredReturnTypes.assign({RankedTensorType::get(
833       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
834   return success();
835 }
836 
837 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
838                                                         TypeRange r) {
839   if (l.size() != 1 || r.size() != 1)
840     return false;
841 
842   Type lhs = l.front();
843   Type rhs = r.front();
844 
845   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
846     // Shape type is compatible with all other valid return types.
847     return true;
848   return lhs == rhs;
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // CstrBroadcastableOp
853 //===----------------------------------------------------------------------===//
854 
855 void CstrBroadcastableOp::getCanonicalizationPatterns(
856     RewritePatternSet &patterns, MLIRContext *context) {
857   // Canonicalization patterns have overlap with the considerations during
858   // folding in case additional shape information is inferred at some point that
859   // does not result in folding.
860   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
861                CstrBroadcastableEqOps,
862                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
863                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
864 }
865 
866 // Return true if there is exactly one attribute not representing a scalar
867 // broadcast.
868 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
869   bool nonScalarSeen = false;
870   for (Attribute a : attributes) {
871     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
872       if (nonScalarSeen)
873         return false;
874       nonScalarSeen = true;
875     }
876   }
877   return true;
878 }
879 
880 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
881   // No broadcasting is needed if all operands but one are scalar.
882   if (hasAtMostSingleNonScalar(operands))
883     return BoolAttr::get(getContext(), true);
884 
885   if ([&] {
886         SmallVector<SmallVector<int64_t, 6>, 6> extents;
887         for (const auto &operand : operands) {
888           if (!operand)
889             return false;
890           extents.push_back(llvm::to_vector<6>(
891               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
892         }
893         return OpTrait::util::staticallyKnownBroadcastable(extents);
894       }())
895     return BoolAttr::get(getContext(), true);
896 
897   // Lastly, see if folding can be completed based on what constraints are known
898   // on the input shapes.
899   if ([&] {
900         SmallVector<SmallVector<int64_t, 6>, 6> extents;
901         for (auto shapeValue : getShapes()) {
902           extents.emplace_back();
903           if (failed(getShapeVec(shapeValue, extents.back())))
904             return false;
905         }
906         return OpTrait::util::staticallyKnownBroadcastable(extents);
907       }())
908     return BoolAttr::get(getContext(), true);
909 
910   // Because a failing witness result here represents an eventual assertion
911   // failure, we do not replace it with a constant witness.
912   return nullptr;
913 }
914 
915 static LogicalResult verify(CstrBroadcastableOp op) {
916   // Ensure that AssumingAllOp contains at least one operand
917   if (op.getNumOperands() < 2)
918     return op.emitOpError("required at least 2 input shapes");
919   return success();
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // CstrEqOp
924 //===----------------------------------------------------------------------===//
925 
926 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
927                                            MLIRContext *context) {
928   // If inputs are equal, return passing witness
929   patterns.add<CstrEqEqOps>(context);
930 }
931 
932 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
933   if (llvm::all_of(operands,
934                    [&](Attribute a) { return a && a == operands[0]; }))
935     return BoolAttr::get(getContext(), true);
936 
937   // Because a failing witness result here represents an eventual assertion
938   // failure, we do not try to replace it with a constant witness. Similarly, we
939   // cannot if there are any non-const inputs.
940   return nullptr;
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // ConstSizeOp
945 //===----------------------------------------------------------------------===//
946 
947 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
948                         int64_t value) {
949   build(builder, result, builder.getIndexAttr(value));
950 }
951 
952 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
953 
954 void ConstSizeOp::getAsmResultNames(
955     llvm::function_ref<void(Value, StringRef)> setNameFn) {
956   SmallString<4> buffer;
957   llvm::raw_svector_ostream os(buffer);
958   os << "c" << getValue();
959   setNameFn(getResult(), os.str());
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // ConstWitnessOp
964 //===----------------------------------------------------------------------===//
965 
966 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
967   return getPassingAttr();
968 }
969 
970 //===----------------------------------------------------------------------===//
971 // CstrRequireOp
972 //===----------------------------------------------------------------------===//
973 
974 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
975   return operands[0];
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // DivOp
980 //===----------------------------------------------------------------------===//
981 
982 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
983   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
984   if (!lhs)
985     return nullptr;
986   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
987   if (!rhs)
988     return nullptr;
989 
990   // Division in APInt does not follow floor(lhs, rhs) when the result is
991   // negative. Rather, APInt rounds toward zero.
992   APInt quotient, remainder;
993   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
994   if (quotient.isNegative() && !remainder.isNullValue()) {
995     quotient -= 1;
996   }
997 
998   Type indexTy = IndexType::get(getContext());
999   return IntegerAttr::get(indexTy, quotient);
1000 }
1001 
1002 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1003     MLIRContext *context, Optional<Location> location, ValueRange operands,
1004     DictionaryAttr attributes, RegionRange regions,
1005     SmallVectorImpl<Type> &inferredReturnTypes) {
1006   if (operands[0].getType().isa<SizeType>() ||
1007       operands[1].getType().isa<SizeType>())
1008     inferredReturnTypes.assign({SizeType::get(context)});
1009   else
1010     inferredReturnTypes.assign({IndexType::get(context)});
1011   return success();
1012 }
1013 
1014 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1015   // SizeType is compatible with IndexType.
1016   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1017 }
1018 
1019 //===----------------------------------------------------------------------===//
1020 // ShapeEqOp
1021 //===----------------------------------------------------------------------===//
1022 
1023 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1024   bool allSame = true;
1025   if (!operands.empty() && !operands[0])
1026     return {};
1027   for (Attribute operand : operands.drop_front(1)) {
1028     if (!operand)
1029       return {};
1030     allSame = allSame && operand == operands[0];
1031   }
1032   return BoolAttr::get(getContext(), allSame);
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // IndexToSizeOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1040   // Constant values of both types, `shape.size` and `index`, are represented as
1041   // `IntegerAttr`s which makes constant folding simple.
1042   if (Attribute arg = operands[0])
1043     return arg;
1044   return {};
1045 }
1046 
1047 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1048                                                 MLIRContext *context) {
1049   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // FromExtentsOp
1054 //===----------------------------------------------------------------------===//
1055 
1056 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1057   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1058     return nullptr;
1059   SmallVector<int64_t, 6> extents;
1060   for (auto attr : operands)
1061     extents.push_back(attr.cast<IntegerAttr>().getInt());
1062   Builder builder(getContext());
1063   return builder.getIndexTensorAttr(extents);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // FunctionLibraryOp
1068 //===----------------------------------------------------------------------===//
1069 
1070 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1071                               StringRef name) {
1072   result.attributes.push_back(builder.getNamedAttr(
1073       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1074 }
1075 
1076 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1077   auto attr = getMapping()
1078                   .get(op->getName().getIdentifier())
1079                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1080   if (!attr)
1081     return nullptr;
1082   return lookupSymbol<FuncOp>(attr);
1083 }
1084 
1085 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1086                                    OperationState &result) {
1087   // Parse the op name.
1088   StringAttr nameAttr;
1089   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1090                              result.attributes))
1091     return failure();
1092 
1093   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1094     return failure();
1095 
1096   auto *bodyRegion = result.addRegion();
1097   if (parser.parseRegion(*bodyRegion))
1098     return failure();
1099 
1100   if (parser.parseKeyword("mapping"))
1101     return failure();
1102 
1103   DictionaryAttr mappingAttr;
1104   if (parser.parseAttribute(mappingAttr,
1105                             parser.getBuilder().getType<NoneType>(), "mapping",
1106                             result.attributes))
1107     return failure();
1108   return success();
1109 }
1110 
1111 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1112   p << ' ';
1113   p.printSymbolName(op.getName());
1114   p.printOptionalAttrDictWithKeyword(
1115       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1116   p << ' ';
1117   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1118                 /*printBlockTerminators=*/false);
1119   p << " mapping ";
1120   p.printAttributeWithoutType(op.getMappingAttr());
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // GetExtentOp
1125 //===----------------------------------------------------------------------===//
1126 
1127 Optional<int64_t> GetExtentOp::getConstantDim() {
1128   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1129     return constSizeOp.getValue().getLimitedValue();
1130   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1131     return constantOp.getValue().cast<IntegerAttr>().getInt();
1132   return llvm::None;
1133 }
1134 
1135 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1136   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1137   if (!elements)
1138     return nullptr;
1139   Optional<int64_t> dim = getConstantDim();
1140   if (!dim.hasValue())
1141     return nullptr;
1142   if (dim.getValue() >= elements.getNumElements())
1143     return nullptr;
1144   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1145 }
1146 
1147 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1148                         int64_t dim) {
1149   auto loc = result.location;
1150   auto dimAttr = builder.getIndexAttr(dim);
1151   if (shape.getType().isa<ShapeType>()) {
1152     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1153     build(builder, result, builder.getType<SizeType>(), shape, dim);
1154   } else {
1155     Value dim =
1156         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1157     build(builder, result, builder.getIndexType(), shape, dim);
1158   }
1159 }
1160 
1161 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1162     MLIRContext *context, Optional<Location> location, ValueRange operands,
1163     DictionaryAttr attributes, RegionRange regions,
1164     SmallVectorImpl<Type> &inferredReturnTypes) {
1165   inferredReturnTypes.assign({IndexType::get(context)});
1166   return success();
1167 }
1168 
1169 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1170                                                        TypeRange r) {
1171   // SizeType is compatible with IndexType.
1172   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1173 }
1174 
1175 //===----------------------------------------------------------------------===//
1176 // IsBroadcastableOp
1177 //===----------------------------------------------------------------------===//
1178 
1179 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1180                                                     MLIRContext *context) {
1181   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1182 }
1183 
1184 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1185   // Can always broadcast fewer than two shapes.
1186   if (operands.size() < 2) {
1187     return BoolAttr::get(getContext(), true);
1188   }
1189 
1190   return nullptr;
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // MeetOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1198     MLIRContext *context, Optional<Location> location, ValueRange operands,
1199     DictionaryAttr attributes, RegionRange regions,
1200     SmallVectorImpl<Type> &inferredReturnTypes) {
1201   inferredReturnTypes.assign({operands[0].getType()});
1202   return success();
1203 }
1204 
1205 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1206   if (l.size() != 1 || r.size() != 1)
1207     return false;
1208   if (l == r)
1209     return true;
1210 
1211   Type lhs = l.front();
1212   Type rhs = r.front();
1213 
1214   if (lhs != rhs)
1215     return false;
1216 
1217   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1218     return true;
1219 
1220   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1221     return true;
1222   return false;
1223 }
1224 
1225 //===----------------------------------------------------------------------===//
1226 // RankOp
1227 //===----------------------------------------------------------------------===//
1228 
1229 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1230   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1231   if (!shape)
1232     return {};
1233   int64_t rank = shape.getNumElements();
1234   Builder builder(getContext());
1235   return builder.getIndexAttr(rank);
1236 }
1237 
1238 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1239 /// Constant folding fails in cases where only the rank is constant, not the
1240 /// shape itself.
1241 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1242 ///
1243 /// Example:
1244 ///
1245 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1246 /// %rank = shape.rank %shape
1247 ///
1248 /// becomes
1249 ///
1250 /// %rank = shape.const_size 3
1251 
1252 namespace {
1253 struct RankShapeOfCanonicalizationPattern
1254     : public OpRewritePattern<shape::RankOp> {
1255   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1256 
1257   LogicalResult matchAndRewrite(shape::RankOp op,
1258                                 PatternRewriter &rewriter) const override {
1259     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1260     if (!shapeOfOp)
1261       return failure();
1262     auto rankedTensorType =
1263         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1264     if (!rankedTensorType)
1265       return failure();
1266     int64_t rank = rankedTensorType.getRank();
1267     if (op.getType().isa<IndexType>()) {
1268       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1269                                                           rank);
1270     } else if (op.getType().isa<shape::SizeType>()) {
1271       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1272     } else {
1273       return failure();
1274     }
1275     return success();
1276   }
1277 };
1278 } // namespace
1279 
1280 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1281                                                 MLIRContext *context) {
1282   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1283 }
1284 
1285 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1286     MLIRContext *context, Optional<Location> location, ValueRange operands,
1287     DictionaryAttr attributes, RegionRange regions,
1288     SmallVectorImpl<Type> &inferredReturnTypes) {
1289   if (operands[0].getType().isa<ShapeType>())
1290     inferredReturnTypes.assign({SizeType::get(context)});
1291   else
1292     inferredReturnTypes.assign({IndexType::get(context)});
1293   return success();
1294 }
1295 
1296 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1297   // SizeType is compatible with IndexType.
1298   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // NumElementsOp
1303 //===----------------------------------------------------------------------===//
1304 
1305 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1306 
1307   // Fold only when argument constant.
1308   Attribute shape = operands[0];
1309   if (!shape)
1310     return {};
1311 
1312   APInt product(64, 1);
1313   for (auto value : shape.cast<DenseIntElementsAttr>())
1314     product *= value;
1315   Builder builder(getContext());
1316   return builder.getIndexAttr(product.getLimitedValue());
1317 }
1318 
1319 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1320     MLIRContext *context, Optional<Location> location, ValueRange operands,
1321     DictionaryAttr attributes, RegionRange regions,
1322     SmallVectorImpl<Type> &inferredReturnTypes) {
1323   if (operands[0].getType().isa<ShapeType>())
1324     inferredReturnTypes.assign({SizeType::get(context)});
1325   else
1326     inferredReturnTypes.assign({IndexType::get(context)});
1327   return success();
1328 }
1329 
1330 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1331                                                          TypeRange r) {
1332   // SizeType is compatible with IndexType.
1333   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // MaxOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1341   // If operands are equal, just propagate one.
1342   if (getLhs() == getRhs())
1343     return getLhs();
1344   return nullptr;
1345 }
1346 
1347 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1348     MLIRContext *context, Optional<Location> location, ValueRange operands,
1349     DictionaryAttr attributes, RegionRange regions,
1350     SmallVectorImpl<Type> &inferredReturnTypes) {
1351   if (operands[0].getType() == operands[1].getType())
1352     inferredReturnTypes.assign({operands[0].getType()});
1353   else
1354     inferredReturnTypes.assign({SizeType::get(context)});
1355   return success();
1356 }
1357 
1358 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1359   if (l.size() != 1 || r.size() != 1)
1360     return false;
1361   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1362     return true;
1363   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1364     return true;
1365   return false;
1366 }
1367 
1368 //===----------------------------------------------------------------------===//
1369 // MinOp
1370 //===----------------------------------------------------------------------===//
1371 
1372 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1373   // If operands are equal, just propagate one.
1374   if (getLhs() == getRhs())
1375     return getLhs();
1376   return nullptr;
1377 }
1378 
1379 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1380     MLIRContext *context, Optional<Location> location, ValueRange operands,
1381     DictionaryAttr attributes, RegionRange regions,
1382     SmallVectorImpl<Type> &inferredReturnTypes) {
1383   if (operands[0].getType() == operands[1].getType())
1384     inferredReturnTypes.assign({operands[0].getType()});
1385   else
1386     inferredReturnTypes.assign({SizeType::get(context)});
1387   return success();
1388 }
1389 
1390 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1391   if (l.size() != 1 || r.size() != 1)
1392     return false;
1393   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1394     return true;
1395   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1396     return true;
1397   return false;
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // MulOp
1402 //===----------------------------------------------------------------------===//
1403 
1404 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1405   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1406   if (!lhs)
1407     return nullptr;
1408   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1409   if (!rhs)
1410     return nullptr;
1411   APInt folded = lhs.getValue() * rhs.getValue();
1412   Type indexTy = IndexType::get(getContext());
1413   return IntegerAttr::get(indexTy, folded);
1414 }
1415 
1416 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1417     MLIRContext *context, Optional<Location> location, ValueRange operands,
1418     DictionaryAttr attributes, RegionRange regions,
1419     SmallVectorImpl<Type> &inferredReturnTypes) {
1420   if (operands[0].getType().isa<SizeType>() ||
1421       operands[1].getType().isa<SizeType>())
1422     inferredReturnTypes.assign({SizeType::get(context)});
1423   else
1424     inferredReturnTypes.assign({IndexType::get(context)});
1425   return success();
1426 }
1427 
1428 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1429   // SizeType is compatible with IndexType.
1430   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1431 }
1432 //===----------------------------------------------------------------------===//
1433 // ShapeOfOp
1434 //===----------------------------------------------------------------------===//
1435 
1436 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1437   auto type = getOperand().getType().dyn_cast<ShapedType>();
1438   if (!type || !type.hasStaticShape())
1439     return nullptr;
1440   Builder builder(getContext());
1441   return builder.getIndexTensorAttr(type.getShape());
1442 }
1443 
1444 namespace {
1445 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1446   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1447 
1448   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1449                                 PatternRewriter &rewriter) const override {
1450     if (!op.getArg().getType().isa<ShapedType>())
1451       return failure();
1452     if (op.getType().isa<ShapedType>())
1453       return failure();
1454 
1455     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1456                                                   op.getArg());
1457     return success();
1458   }
1459 };
1460 
1461 // Canonicalize
1462 // ```
1463 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1464 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1465 // ```
1466 // to
1467 // ```
1468 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1469 // ```
1470 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1471   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1472 
1473   LogicalResult matchAndRewrite(tensor::CastOp op,
1474                                 PatternRewriter &rewriter) const override {
1475     auto ty = op.getType().dyn_cast<RankedTensorType>();
1476     if (!ty || ty.getRank() != 1)
1477       return failure();
1478 
1479     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1480     if (!shapeOfOp)
1481       return failure();
1482 
1483     // Argument type must be ranked and must not conflict.
1484     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1485     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1486       return failure();
1487 
1488     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1489     return success();
1490   }
1491 };
1492 } // namespace
1493 
1494 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1495                                             MLIRContext *context) {
1496   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1497                ExtractFromShapeOfExtentTensor>(context);
1498 }
1499 
1500 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1501     MLIRContext *context, Optional<Location> location, ValueRange operands,
1502     DictionaryAttr attributes, RegionRange regions,
1503     SmallVectorImpl<Type> &inferredReturnTypes) {
1504   if (operands[0].getType().isa<ValueShapeType>())
1505     inferredReturnTypes.assign({ShapeType::get(context)});
1506   else {
1507     auto shapedTy = operands[0].getType().cast<ShapedType>();
1508     int64_t rank =
1509         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1510     Type indexTy = IndexType::get(context);
1511     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1512     inferredReturnTypes.assign({extentTensorTy});
1513   }
1514   return success();
1515 }
1516 
1517 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1518   if (l.size() != 1 || r.size() != 1)
1519     return false;
1520   if (l == r)
1521     return true;
1522 
1523   Type lhs = l.front();
1524   Type rhs = r.front();
1525 
1526   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1527     return false;
1528 
1529   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1530     // Shape type is compatible with all other valid return types.
1531     return true;
1532 
1533   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1534     return true;
1535   return false;
1536 }
1537 
1538 //===----------------------------------------------------------------------===//
1539 // SizeToIndexOp
1540 //===----------------------------------------------------------------------===//
1541 
1542 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1543   // Constant values of both types, `shape.size` and `index`, are represented as
1544   // `IntegerAttr`s which makes constant folding simple.
1545   if (Attribute arg = operands[0])
1546     return arg;
1547   return impl::foldCastOp(*this);
1548 }
1549 
1550 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1551                                                 MLIRContext *context) {
1552   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // YieldOp
1557 //===----------------------------------------------------------------------===//
1558 
1559 static LogicalResult verify(shape::YieldOp op) {
1560   auto *parentOp = op->getParentOp();
1561   auto results = parentOp->getResults();
1562   auto operands = op.getOperands();
1563 
1564   if (parentOp->getNumResults() != op.getNumOperands())
1565     return op.emitOpError() << "number of operands does not match number of "
1566                                "results of its parent";
1567   for (auto e : llvm::zip(results, operands))
1568     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1569       return op.emitOpError()
1570              << "types mismatch between yield op and its parent";
1571 
1572   return success();
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // SplitAtOp
1577 //===----------------------------------------------------------------------===//
1578 
1579 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1580                               SmallVectorImpl<OpFoldResult> &results) {
1581   if (!operands[0] || !operands[1])
1582     return failure();
1583   auto shapeVec = llvm::to_vector<6>(
1584       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1585   auto shape = llvm::makeArrayRef(shapeVec);
1586   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1587   // Verify that the split point is in the correct range.
1588   // TODO: Constant fold to an "error".
1589   int64_t rank = shape.size();
1590   if (!(-rank <= splitPoint && splitPoint <= rank))
1591     return failure();
1592   if (splitPoint < 0)
1593     splitPoint += shape.size();
1594   Builder builder(operands[0].getContext());
1595   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1596   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1597   return success();
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // ToExtentTensorOp
1602 //===----------------------------------------------------------------------===//
1603 
1604 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1605   if (!operands[0])
1606     return impl::foldCastOp(*this);
1607   Builder builder(getContext());
1608   auto shape = llvm::to_vector<6>(
1609       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1610   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1611                                     builder.getIndexType());
1612   return DenseIntElementsAttr::get(type, shape);
1613 }
1614 
1615 //===----------------------------------------------------------------------===//
1616 // ReduceOp
1617 //===----------------------------------------------------------------------===//
1618 
1619 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1620                      ValueRange initVals) {
1621   result.addOperands(shape);
1622   result.addOperands(initVals);
1623 
1624   Region *bodyRegion = result.addRegion();
1625   bodyRegion->push_back(new Block);
1626   Block &bodyBlock = bodyRegion->front();
1627   bodyBlock.addArgument(builder.getIndexType(), result.location);
1628 
1629   Type elementType;
1630   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1631     elementType = tensorType.getElementType();
1632   else
1633     elementType = SizeType::get(builder.getContext());
1634   bodyBlock.addArgument(elementType, shape.getLoc());
1635 
1636   for (Value initVal : initVals) {
1637     bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
1638     result.addTypes(initVal.getType());
1639   }
1640 }
1641 
1642 static LogicalResult verify(ReduceOp op) {
1643   // Verify block arg types.
1644   Block &block = op.getRegion().front();
1645 
1646   // The block takes index, extent, and aggregated values as arguments.
1647   auto blockArgsCount = op.getInitVals().size() + 2;
1648   if (block.getNumArguments() != blockArgsCount)
1649     return op.emitOpError() << "ReduceOp body is expected to have "
1650                             << blockArgsCount << " arguments";
1651 
1652   // The first block argument is the index and must always be of type `index`.
1653   if (!block.getArgument(0).getType().isa<IndexType>())
1654     return op.emitOpError(
1655         "argument 0 of ReduceOp body is expected to be of IndexType");
1656 
1657   // The second block argument is the extent and must be of type `size` or
1658   // `index`, depending on whether the reduce operation is applied to a shape or
1659   // to an extent tensor.
1660   Type extentTy = block.getArgument(1).getType();
1661   if (op.getShape().getType().isa<ShapeType>()) {
1662     if (!extentTy.isa<SizeType>())
1663       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1664                             "SizeType if the ReduceOp operates on a ShapeType");
1665   } else {
1666     if (!extentTy.isa<IndexType>())
1667       return op.emitOpError(
1668           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1669           "ReduceOp operates on an extent tensor");
1670   }
1671 
1672   for (const auto &type : llvm::enumerate(op.getInitVals()))
1673     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1674       return op.emitOpError()
1675              << "type mismatch between argument " << type.index() + 2
1676              << " of ReduceOp body and initial value " << type.index();
1677   return success();
1678 }
1679 
1680 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1681   // Parse operands.
1682   SmallVector<OpAsmParser::OperandType, 3> operands;
1683   Type shapeOrExtentTensorType;
1684   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1685                               OpAsmParser::Delimiter::Paren) ||
1686       parser.parseColonType(shapeOrExtentTensorType) ||
1687       parser.parseOptionalArrowTypeList(result.types))
1688     return failure();
1689 
1690   // Resolve operands.
1691   auto initVals = llvm::makeArrayRef(operands).drop_front();
1692   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1693                             result.operands) ||
1694       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1695                              result.operands))
1696     return failure();
1697 
1698   // Parse the body.
1699   Region *body = result.addRegion();
1700   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1701     return failure();
1702 
1703   // Parse attributes.
1704   if (parser.parseOptionalAttrDict(result.attributes))
1705     return failure();
1706 
1707   return success();
1708 }
1709 
1710 static void print(OpAsmPrinter &p, ReduceOp op) {
1711   p << '(' << op.getShape() << ", " << op.getInitVals()
1712     << ") : " << op.getShape().getType();
1713   p.printOptionalArrowTypeList(op.getResultTypes());
1714   p << ' ';
1715   p.printRegion(op.getRegion());
1716   p.printOptionalAttrDict(op->getAttrs());
1717 }
1718 
1719 #define GET_OP_CLASSES
1720 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1721