1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Traits.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 using namespace mlir;
28 using namespace mlir::shape;
29 
30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
31 
32 namespace {
33 #include "ShapeCanonicalization.inc"
34 }
35 
36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
37   return RankedTensorType::get({rank}, IndexType::get(ctx));
38 }
39 
40 bool shape::isExtentTensorType(Type type) {
41   auto ranked = type.dyn_cast<RankedTensorType>();
42   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
43 }
44 
45 LogicalResult shape::getShapeVec(Value input,
46                                  SmallVectorImpl<int64_t> &shapeValues) {
47   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
48     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
49     if (!type.hasRank())
50       return failure();
51     shapeValues = llvm::to_vector<6>(type.getShape());
52     return success();
53   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
54     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
55     return success();
56   } else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
57     shapeValues = llvm::to_vector<6>(
58         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
59     return success();
60   } else {
61     return failure();
62   }
63 }
64 
65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66   return llvm::any_of(operandTypes, [](Type ty) {
67     return ty.isa<SizeType, ShapeType, ValueShapeType>();
68   });
69 }
70 
71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72   assert(op != nullptr && op->getNumResults() == 1);
73   Type resultTy = op->getResultTypes().front();
74   if (isErrorPropagationPossible(op->getOperandTypes())) {
75     if (!resultTy.isa<SizeType>())
76       return op->emitOpError()
77              << "if at least one of the operands can hold error values then "
78                 "the result must be of type `size` to propagate them";
79   }
80   return success();
81 }
82 
83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84   assert(op != nullptr && op->getNumResults() == 1);
85   Type resultTy = op->getResultTypes().front();
86   if (isErrorPropagationPossible(op->getOperandTypes())) {
87     if (!resultTy.isa<ShapeType>())
88       return op->emitOpError()
89              << "if at least one of the operands can hold error values then "
90                 "the result must be of type `shape` to propagate them";
91   }
92   return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   // Returns true if the given region 'src' can be inlined into the region
115   // 'dest' that is attached to an operation registered to the current dialect.
116   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   // Returns true if the given operation 'op', that is registered to this
122   // dialect, can be inlined into the region 'dest' that is attached to an
123   // operation registered to the current dialect.
124   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125                        BlockAndValueMapping &) const final {
126     return true;
127   }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132   addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135       >();
136   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
137   addInterfaces<ShapeInlinerInterface>();
138   // Allow unknown operations during prototyping and testing. As the dialect is
139   // still evolving it makes it simple to start with an unregistered ops and
140   // try different variants before actually defining the op.
141   allowUnknownOperations();
142 }
143 
144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
145                                              Attribute value, Type type,
146                                              Location loc) {
147   if (type.isa<ShapeType>() || isExtentTensorType(type))
148     return builder.create<ConstShapeOp>(loc, type,
149                                         value.cast<DenseIntElementsAttr>());
150   if (type.isa<SizeType>())
151     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
152   if (type.isa<WitnessType>())
153     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
154   if (arith::ConstantOp::isBuildableWith(value, type))
155     return builder.create<arith::ConstantOp>(loc, type, value);
156   return nullptr;
157 }
158 
159 /// Parse a type registered to this dialect.
160 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
161   StringRef keyword;
162   if (parser.parseKeyword(&keyword))
163     return Type();
164 
165   if (keyword == "shape")
166     return ShapeType::get(getContext());
167   if (keyword == "size")
168     return SizeType::get(getContext());
169   if (keyword == "value_shape")
170     return ValueShapeType::get(getContext());
171   if (keyword == "witness")
172     return WitnessType::get(getContext());
173 
174   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
175   return Type();
176 }
177 
178 /// Print a type registered to this dialect.
179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
180   TypeSwitch<Type>(type)
181       .Case<ShapeType>([&](Type) { os << "shape"; })
182       .Case<SizeType>([&](Type) { os << "size"; })
183       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
184       .Case<WitnessType>([&](Type) { os << "witness"; })
185       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
186 }
187 
188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
189                                                      NamedAttribute attribute) {
190   // Verify shape.lib attribute.
191   if (attribute.first == "shape.lib") {
192     if (!op->hasTrait<OpTrait::SymbolTable>())
193       return op->emitError(
194           "shape.lib attribute may only be on op implementing SymbolTable");
195 
196     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
197       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
198       if (!symbol)
199         return op->emitError("shape function library ")
200                << symbolRef << " not found";
201       return isa<shape::FunctionLibraryOp>(symbol)
202                  ? success()
203                  : op->emitError()
204                        << symbolRef << " required to be shape function library";
205     }
206 
207     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
208       // Verify all entries are function libraries and mappings in libraries
209       // refer to unique ops.
210       DenseSet<Identifier> key;
211       for (auto it : arr) {
212         if (!it.isa<SymbolRefAttr>())
213           return op->emitError(
214               "only SymbolRefAttr allowed in shape.lib attribute array");
215 
216         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
217             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
218         if (!shapeFnLib)
219           return op->emitError()
220                  << it << " does not refer to FunctionLibraryOp";
221         for (auto mapping : shapeFnLib.mapping()) {
222           if (!key.insert(mapping.first).second) {
223             return op->emitError("only one op to shape mapping allowed, found "
224                                  "multiple for `")
225                    << mapping.first << "`";
226           }
227         }
228       }
229       return success();
230     }
231 
232     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
233                          "allowed as shape.lib attribute");
234   }
235   return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // AnyOp
240 //===----------------------------------------------------------------------===//
241 
242 // TODO: Canonicalization should be implemented for shapes that can be
243 // determined through mixtures of the known dimensions of the inputs.
244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
245   // Only the last operand is checked because AnyOp is commutative.
246   if (operands.back())
247     return operands.back();
248 
249   return nullptr;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // AssumingOp
254 //===----------------------------------------------------------------------===//
255 
256 static ParseResult parseAssumingOp(OpAsmParser &parser,
257                                    OperationState &result) {
258   result.regions.reserve(1);
259   Region *doRegion = result.addRegion();
260 
261   auto &builder = parser.getBuilder();
262   OpAsmParser::OperandType cond;
263   if (parser.parseOperand(cond) ||
264       parser.resolveOperand(cond, builder.getType<WitnessType>(),
265                             result.operands))
266     return failure();
267 
268   // Parse optional results type list.
269   if (parser.parseOptionalArrowTypeList(result.types))
270     return failure();
271 
272   // Parse the region and add a terminator if elided.
273   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
274     return failure();
275   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
276 
277   // Parse the optional attribute list.
278   if (parser.parseOptionalAttrDict(result.attributes))
279     return failure();
280   return success();
281 }
282 
283 static void print(OpAsmPrinter &p, AssumingOp op) {
284   bool yieldsResults = !op.results().empty();
285 
286   p << " " << op.witness();
287   if (yieldsResults) {
288     p << " -> (" << op.getResultTypes() << ")";
289   }
290   p.printRegion(op.doRegion(),
291                 /*printEntryBlockArgs=*/false,
292                 /*printBlockTerminators=*/yieldsResults);
293   p.printOptionalAttrDict(op->getAttrs());
294 }
295 
296 namespace {
297 // Removes AssumingOp with a passing witness and inlines the region.
298 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
299   using OpRewritePattern<AssumingOp>::OpRewritePattern;
300 
301   LogicalResult matchAndRewrite(AssumingOp op,
302                                 PatternRewriter &rewriter) const override {
303     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
304     if (!witness || !witness.passingAttr())
305       return failure();
306 
307     AssumingOp::inlineRegionIntoParent(op, rewriter);
308     return success();
309   }
310 };
311 
312 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
313   using OpRewritePattern<AssumingOp>::OpRewritePattern;
314 
315   LogicalResult matchAndRewrite(AssumingOp op,
316                                 PatternRewriter &rewriter) const override {
317     Block *body = op.getBody();
318     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
319 
320     // Find used values.
321     SmallVector<Value, 4> newYieldOperands;
322     Value opResult, yieldOperand;
323     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
324       std::tie(opResult, yieldOperand) = it;
325       if (!opResult.getUses().empty()) {
326         newYieldOperands.push_back(yieldOperand);
327       }
328     }
329 
330     // Rewrite only if redundant results exist.
331     if (newYieldOperands.size() == yieldOp->getNumOperands())
332       return failure();
333 
334     // Replace yield op in the old assuming op's body and move the entire region
335     // to the new assuming op.
336     rewriter.setInsertionPointToEnd(body);
337     auto newYieldOp =
338         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
339     rewriter.setInsertionPoint(op);
340     auto newOp = rewriter.create<AssumingOp>(
341         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
342     newOp.doRegion().takeBody(op.doRegion());
343 
344     // Use the new results to replace the previously used ones.
345     SmallVector<Value, 4> replacementValues;
346     auto src = newOp.getResults().begin();
347     for (auto it : op.getResults()) {
348       if (it.getUses().empty())
349         replacementValues.push_back(nullptr);
350       else
351         replacementValues.push_back(*src++);
352     }
353     rewriter.replaceOp(op, replacementValues);
354     return success();
355   }
356 };
357 } // namespace
358 
359 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
360                                              MLIRContext *context) {
361   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
362 }
363 
364 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
365 void AssumingOp::getSuccessorRegions(
366     Optional<unsigned> index, ArrayRef<Attribute> operands,
367     SmallVectorImpl<RegionSuccessor> &regions) {
368   // AssumingOp has unconditional control flow into the region and back to the
369   // parent, so return the correct RegionSuccessor purely based on the index
370   // being None or 0.
371   if (index.hasValue()) {
372     regions.push_back(RegionSuccessor(getResults()));
373     return;
374   }
375 
376   regions.push_back(RegionSuccessor(&doRegion()));
377 }
378 
379 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
380                                         PatternRewriter &rewriter) {
381   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
382   auto *assumingBlock = op.getBody();
383   auto initPosition = rewriter.getInsertionPoint();
384   auto *blockAfterAssuming =
385       rewriter.splitBlock(blockBeforeAssuming, initPosition);
386 
387   // Remove the AssumingOp and AssumingYieldOp.
388   auto &yieldOp = assumingBlock->back();
389   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
390   rewriter.replaceOp(op, yieldOp.getOperands());
391   rewriter.eraseOp(&yieldOp);
392 
393   // Merge blocks together as there was no branching behavior from the
394   // AssumingOp.
395   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
396   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
397 }
398 
399 void AssumingOp::build(
400     OpBuilder &builder, OperationState &result, Value witness,
401     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
402 
403   result.addOperands(witness);
404   Region *bodyRegion = result.addRegion();
405   bodyRegion->push_back(new Block);
406   Block &bodyBlock = bodyRegion->front();
407 
408   // Build body.
409   OpBuilder::InsertionGuard guard(builder);
410   builder.setInsertionPointToStart(&bodyBlock);
411   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
412   builder.create<AssumingYieldOp>(result.location, yieldValues);
413 
414   SmallVector<Type, 2> assumingTypes;
415   for (Value v : yieldValues)
416     assumingTypes.push_back(v.getType());
417   result.addTypes(assumingTypes);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // AddOp
422 //===----------------------------------------------------------------------===//
423 
424 LogicalResult mlir::shape::AddOp::inferReturnTypes(
425     MLIRContext *context, Optional<Location> location, ValueRange operands,
426     DictionaryAttr attributes, RegionRange regions,
427     SmallVectorImpl<Type> &inferredReturnTypes) {
428   if (operands[0].getType().isa<SizeType>() ||
429       operands[1].getType().isa<SizeType>())
430     inferredReturnTypes.assign({SizeType::get(context)});
431   else
432     inferredReturnTypes.assign({IndexType::get(context)});
433   return success();
434 }
435 
436 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
437   // SizeType is compatible with IndexType.
438   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
439 }
440 
441 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
442   // add(x, 0) -> x
443   if (matchPattern(rhs(), m_Zero()))
444     return lhs();
445 
446   return constFoldBinaryOp<IntegerAttr>(operands,
447                                         [](APInt a, APInt b) { return a + b; });
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // AssumingAllOp
452 //===----------------------------------------------------------------------===//
453 
454 namespace {
455 struct AssumingAllToCstrEqCanonicalization
456     : public OpRewritePattern<AssumingAllOp> {
457   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
458 
459   LogicalResult matchAndRewrite(AssumingAllOp op,
460                                 PatternRewriter &rewriter) const override {
461     SmallVector<Value, 8> shapes;
462     for (Value w : op.inputs()) {
463       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
464       if (!cstrEqOp)
465         return failure();
466       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
467         return llvm::is_contained(shapes, s);
468       });
469       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
470         return failure();
471       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
472     }
473     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
474     return success();
475   }
476 };
477 
478 template <typename OpTy>
479 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
480   using OpRewritePattern<OpTy>::OpRewritePattern;
481 
482   LogicalResult matchAndRewrite(OpTy op,
483                                 PatternRewriter &rewriter) const override {
484     // Find unique operands.
485     SmallVector<Value, 2> unique;
486     for (Value v : op.getOperands()) {
487       if (!llvm::is_contained(unique, v))
488         unique.push_back(v);
489     }
490 
491     // Reduce op to equivalent with unique operands.
492     if (unique.size() < op.getNumOperands()) {
493       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
494                                         op->getAttrs());
495       return success();
496     }
497 
498     return failure();
499   }
500 };
501 } // namespace
502 
503 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
504                                                 MLIRContext *context) {
505   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
506                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
507 }
508 
509 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
510   // Iterate in reverse to first handle all constant operands. They are
511   // guaranteed to be the tail of the inputs because this is commutative.
512   for (int idx = operands.size() - 1; idx >= 0; idx--) {
513     Attribute a = operands[idx];
514     // Cannot fold if any inputs are not constant;
515     if (!a)
516       return nullptr;
517 
518     // We do not need to keep statically known values after handling them in
519     // this method.
520     getOperation()->eraseOperand(idx);
521 
522     // Always false if any input is statically known false
523     if (!a.cast<BoolAttr>().getValue())
524       return a;
525   }
526   // If this is reached, all inputs were statically known passing.
527   return BoolAttr::get(getContext(), true);
528 }
529 
530 static LogicalResult verify(AssumingAllOp op) {
531   // Ensure that AssumingAllOp contains at least one operand
532   if (op.getNumOperands() == 0)
533     return op.emitOpError("no operands specified");
534 
535   return success();
536 }
537 
538 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
539                           ValueRange inputs) {
540   build(b, state, b.getType<WitnessType>(), inputs);
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // BroadcastOp
545 //===----------------------------------------------------------------------===//
546 
547 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
548   if (shapes().size() == 1) {
549     // Otherwise, we need a cast which would be a canonicalization, not folding.
550     if (shapes().front().getType() != getType())
551       return nullptr;
552     return shapes().front();
553   }
554 
555   // TODO: Support folding with more than 2 input shapes
556   if (shapes().size() > 2)
557     return nullptr;
558 
559   if (!operands[0] || !operands[1])
560     return nullptr;
561   auto lhsShape = llvm::to_vector<6>(
562       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
563   auto rhsShape = llvm::to_vector<6>(
564       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
565   SmallVector<int64_t, 6> resultShape;
566 
567   // If the shapes are not compatible, we can't fold it.
568   // TODO: Fold to an "error".
569   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
570     return nullptr;
571 
572   Builder builder(getContext());
573   return builder.getIndexTensorAttr(resultShape);
574 }
575 
576 static LogicalResult verify(BroadcastOp op) {
577   return verifyShapeOrExtentTensorOp(op);
578 }
579 
580 namespace {
581 template <typename OpTy>
582 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
583   using OpRewritePattern<OpTy>::OpRewritePattern;
584 
585   LogicalResult matchAndRewrite(OpTy op,
586                                 PatternRewriter &rewriter) const override {
587     auto isPotentiallyNonEmptyShape = [](Value shape) {
588       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
589         if (extentTensorTy.getDimSize(0) == 0)
590           return false;
591       }
592       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
593         if (constShape.shape().empty())
594           return false;
595       }
596       return true;
597     };
598     auto newOperands = llvm::to_vector<8>(
599         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
600 
601     // Reduce op to equivalent without empty shape operands.
602     if (newOperands.size() < op.getNumOperands()) {
603       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
604                                         op->getAttrs());
605       return success();
606     }
607 
608     return failure();
609   }
610 };
611 
612 struct BroadcastForwardSingleOperandPattern
613     : public OpRewritePattern<BroadcastOp> {
614   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
615 
616   LogicalResult matchAndRewrite(BroadcastOp op,
617                                 PatternRewriter &rewriter) const override {
618     if (op.getNumOperands() != 1)
619       return failure();
620     Value replacement = op.shapes().front();
621 
622     // Insert cast if needed.
623     if (replacement.getType() != op.getType()) {
624       auto loc = op.getLoc();
625       if (op.getType().isa<ShapeType>()) {
626         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
627       } else {
628         assert(!op.getType().isa<ShapeType>() &&
629                !replacement.getType().isa<ShapeType>() &&
630                "expect extent tensor cast");
631         replacement =
632             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
633       }
634     }
635 
636     rewriter.replaceOp(op, replacement);
637     return success();
638   }
639 };
640 
641 struct BroadcastFoldConstantOperandsPattern
642     : public OpRewritePattern<BroadcastOp> {
643   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
644 
645   LogicalResult matchAndRewrite(BroadcastOp op,
646                                 PatternRewriter &rewriter) const override {
647     SmallVector<int64_t, 8> foldedConstantShape;
648     SmallVector<Value, 8> newShapeOperands;
649     for (Value shape : op.shapes()) {
650       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
651         SmallVector<int64_t, 8> newFoldedConstantShape;
652         if (OpTrait::util::getBroadcastedShape(
653                 foldedConstantShape,
654                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
655                 newFoldedConstantShape)) {
656           foldedConstantShape = newFoldedConstantShape;
657           continue;
658         }
659       }
660       newShapeOperands.push_back(shape);
661     }
662 
663     // Need at least two constant operands to fold anything.
664     if (op.getNumOperands() - newShapeOperands.size() < 2)
665       return failure();
666 
667     auto foldedConstantOperandsTy = RankedTensorType::get(
668         {static_cast<int64_t>(foldedConstantShape.size())},
669         rewriter.getIndexType());
670     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
671         op.getLoc(), foldedConstantOperandsTy,
672         rewriter.getIndexTensorAttr(foldedConstantShape)));
673     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
674                                              newShapeOperands);
675     return success();
676   }
677 };
678 
679 template <typename OpTy>
680 struct CanonicalizeCastExtentTensorOperandsPattern
681     : public OpRewritePattern<OpTy> {
682   using OpRewritePattern<OpTy>::OpRewritePattern;
683 
684   LogicalResult matchAndRewrite(OpTy op,
685                                 PatternRewriter &rewriter) const override {
686     // Canonicalize operands.
687     bool anyChange = false;
688     auto canonicalizeOperand = [&](Value operand) {
689       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
690         // Only eliminate the cast if it holds no shape information.
691         bool isInformationLoosingCast =
692             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
693         if (isInformationLoosingCast) {
694           anyChange = true;
695           return castOp.source();
696         }
697       }
698       return operand;
699     };
700     auto newOperands = llvm::to_vector<8>(
701         llvm::map_range(op.getOperands(), canonicalizeOperand));
702 
703     // Rewrite op if any change required.
704     if (!anyChange)
705       return failure();
706     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
707     return success();
708   }
709 };
710 
711 struct BroadcastConcretizeResultTypePattern
712     : public OpRewritePattern<BroadcastOp> {
713   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
714 
715   LogicalResult matchAndRewrite(BroadcastOp op,
716                                 PatternRewriter &rewriter) const override {
717     // Only concretize dynamic extent tensor result types.
718     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
719     if (!resultTy || !resultTy.isDynamicDim(0))
720       return failure();
721 
722     // Infer resulting shape rank if possible.
723     int64_t maxRank = 0;
724     for (Value shape : op.shapes()) {
725       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
726         // Cannot infer resulting shape rank if any operand is dynamically
727         // ranked.
728         if (extentTensorTy.isDynamicDim(0))
729           return failure();
730         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
731       }
732     }
733 
734     auto newOp = rewriter.create<BroadcastOp>(
735         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
736     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
737     return success();
738   }
739 };
740 } // namespace
741 
742 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
743                                               MLIRContext *context) {
744   patterns.add<BroadcastConcretizeResultTypePattern,
745                BroadcastFoldConstantOperandsPattern,
746                BroadcastForwardSingleOperandPattern,
747                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
748                RemoveDuplicateOperandsPattern<BroadcastOp>,
749                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
750 }
751 
752 //===----------------------------------------------------------------------===//
753 // ConcatOp
754 //===----------------------------------------------------------------------===//
755 
756 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
757   if (!operands[0] || !operands[1])
758     return nullptr;
759   auto lhsShape = llvm::to_vector<6>(
760       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
761   auto rhsShape = llvm::to_vector<6>(
762       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
763   SmallVector<int64_t, 6> resultShape;
764   resultShape.append(lhsShape.begin(), lhsShape.end());
765   resultShape.append(rhsShape.begin(), rhsShape.end());
766   Builder builder(getContext());
767   return builder.getIndexTensorAttr(resultShape);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // ConstShapeOp
772 //===----------------------------------------------------------------------===//
773 
774 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
775   p << " ";
776   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
777   p << "[";
778   interleaveComma(op.shape().getValues<int64_t>(), p,
779                   [&](int64_t i) { p << i; });
780   p << "] : ";
781   p.printType(op.getType());
782 }
783 
784 static ParseResult parseConstShapeOp(OpAsmParser &parser,
785                                      OperationState &result) {
786   if (parser.parseOptionalAttrDict(result.attributes))
787     return failure();
788   // We piggy-back on ArrayAttr parsing, though we don't internally store the
789   // shape as an ArrayAttr.
790   // TODO: Implement custom parser and maybe make syntax a bit more concise.
791   Attribute extentsRaw;
792   NamedAttrList dummy;
793   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
794     return failure();
795   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
796   if (!extentsArray)
797     return failure();
798   SmallVector<int64_t, 6> ints;
799   for (Attribute extent : extentsArray) {
800     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
801     if (!attr)
802       return failure();
803     ints.push_back(attr.getInt());
804   }
805   Builder &builder = parser.getBuilder();
806   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
807   Type resultTy;
808   if (parser.parseColonType(resultTy))
809     return failure();
810   result.types.push_back(resultTy);
811   return success();
812 }
813 
814 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
815 
816 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
817                                                MLIRContext *context) {
818   patterns.add<TensorCastConstShape>(context);
819 }
820 
821 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
822     MLIRContext *context, Optional<Location> location, ValueRange operands,
823     DictionaryAttr attributes, RegionRange regions,
824     SmallVectorImpl<Type> &inferredReturnTypes) {
825   Builder b(context);
826   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
827   if (!shape)
828     return emitOptionalError(location, "missing shape attribute");
829   inferredReturnTypes.assign({RankedTensorType::get(
830       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
831   return success();
832 }
833 
834 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
835                                                         TypeRange r) {
836   if (l.size() != 1 || r.size() != 1)
837     return false;
838 
839   Type lhs = l.front();
840   Type rhs = r.front();
841 
842   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
843     // Shape type is compatible with all other valid return types.
844     return true;
845   return lhs == rhs;
846 }
847 
848 //===----------------------------------------------------------------------===//
849 // CstrBroadcastableOp
850 //===----------------------------------------------------------------------===//
851 
852 void CstrBroadcastableOp::getCanonicalizationPatterns(
853     RewritePatternSet &patterns, MLIRContext *context) {
854   // Canonicalization patterns have overlap with the considerations during
855   // folding in case additional shape information is inferred at some point that
856   // does not result in folding.
857   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
858                CstrBroadcastableEqOps,
859                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
860                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
861 }
862 
863 // Return true if there is exactly one attribute not representing a scalar
864 // broadcast.
865 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
866   bool nonScalarSeen = false;
867   for (Attribute a : attributes) {
868     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
869       if (nonScalarSeen)
870         return false;
871       nonScalarSeen = true;
872     }
873   }
874   return true;
875 }
876 
877 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
878   // No broadcasting is needed if all operands but one are scalar.
879   if (hasAtMostSingleNonScalar(operands))
880     return BoolAttr::get(getContext(), true);
881 
882   if ([&] {
883         SmallVector<SmallVector<int64_t, 6>, 6> extents;
884         for (const auto &operand : operands) {
885           if (!operand)
886             return false;
887           extents.push_back(llvm::to_vector<6>(
888               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
889         }
890         return OpTrait::util::staticallyKnownBroadcastable(extents);
891       }())
892     return BoolAttr::get(getContext(), true);
893 
894   // Lastly, see if folding can be completed based on what constraints are known
895   // on the input shapes.
896   if ([&] {
897         SmallVector<SmallVector<int64_t, 6>, 6> extents;
898         for (auto shapeValue : shapes()) {
899           extents.emplace_back();
900           if (failed(getShapeVec(shapeValue, extents.back())))
901             return false;
902         }
903         return OpTrait::util::staticallyKnownBroadcastable(extents);
904       }())
905     return BoolAttr::get(getContext(), true);
906 
907   // Because a failing witness result here represents an eventual assertion
908   // failure, we do not replace it with a constant witness.
909   return nullptr;
910 }
911 
912 static LogicalResult verify(CstrBroadcastableOp op) {
913   // Ensure that AssumingAllOp contains at least one operand
914   if (op.getNumOperands() < 2)
915     return op.emitOpError("required at least 2 input shapes");
916   return success();
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // CstrEqOp
921 //===----------------------------------------------------------------------===//
922 
923 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
924                                            MLIRContext *context) {
925   // If inputs are equal, return passing witness
926   patterns.add<CstrEqEqOps>(context);
927 }
928 
929 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
930   if (llvm::all_of(operands,
931                    [&](Attribute a) { return a && a == operands[0]; }))
932     return BoolAttr::get(getContext(), true);
933 
934   // Because a failing witness result here represents an eventual assertion
935   // failure, we do not try to replace it with a constant witness. Similarly, we
936   // cannot if there are any non-const inputs.
937   return nullptr;
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // ConstSizeOp
942 //===----------------------------------------------------------------------===//
943 
944 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
945                         int64_t value) {
946   build(builder, result, builder.getIndexAttr(value));
947 }
948 
949 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
950 
951 void ConstSizeOp::getAsmResultNames(
952     llvm::function_ref<void(Value, StringRef)> setNameFn) {
953   SmallString<4> buffer;
954   llvm::raw_svector_ostream os(buffer);
955   os << "c" << value();
956   setNameFn(getResult(), os.str());
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // ConstWitnessOp
961 //===----------------------------------------------------------------------===//
962 
963 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
964 
965 //===----------------------------------------------------------------------===//
966 // CstrRequireOp
967 //===----------------------------------------------------------------------===//
968 
969 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
970   return operands[0];
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // DivOp
975 //===----------------------------------------------------------------------===//
976 
977 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
978   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
979   if (!lhs)
980     return nullptr;
981   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
982   if (!rhs)
983     return nullptr;
984 
985   // Division in APInt does not follow floor(lhs, rhs) when the result is
986   // negative. Rather, APInt rounds toward zero.
987   APInt quotient, remainder;
988   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
989   if (quotient.isNegative() && !remainder.isNullValue()) {
990     quotient -= 1;
991   }
992 
993   Type indexTy = IndexType::get(getContext());
994   return IntegerAttr::get(indexTy, quotient);
995 }
996 
997 LogicalResult mlir::shape::DivOp::inferReturnTypes(
998     MLIRContext *context, Optional<Location> location, ValueRange operands,
999     DictionaryAttr attributes, RegionRange regions,
1000     SmallVectorImpl<Type> &inferredReturnTypes) {
1001   if (operands[0].getType().isa<SizeType>() ||
1002       operands[1].getType().isa<SizeType>())
1003     inferredReturnTypes.assign({SizeType::get(context)});
1004   else
1005     inferredReturnTypes.assign({IndexType::get(context)});
1006   return success();
1007 }
1008 
1009 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1010   // SizeType is compatible with IndexType.
1011   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // ShapeEqOp
1016 //===----------------------------------------------------------------------===//
1017 
1018 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1019   bool allSame = true;
1020   if (!operands.empty() && !operands[0])
1021     return {};
1022   for (Attribute operand : operands.drop_front(1)) {
1023     if (!operand)
1024       return {};
1025     allSame = allSame && operand == operands[0];
1026   }
1027   return BoolAttr::get(getContext(), allSame);
1028 }
1029 
1030 //===----------------------------------------------------------------------===//
1031 // IndexToSizeOp
1032 //===----------------------------------------------------------------------===//
1033 
1034 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1035   // Constant values of both types, `shape.size` and `index`, are represented as
1036   // `IntegerAttr`s which makes constant folding simple.
1037   if (Attribute arg = operands[0])
1038     return arg;
1039   return {};
1040 }
1041 
1042 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1043                                                 MLIRContext *context) {
1044   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1045 }
1046 
1047 //===----------------------------------------------------------------------===//
1048 // FromExtentsOp
1049 //===----------------------------------------------------------------------===//
1050 
1051 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1052   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1053     return nullptr;
1054   SmallVector<int64_t, 6> extents;
1055   for (auto attr : operands)
1056     extents.push_back(attr.cast<IntegerAttr>().getInt());
1057   Builder builder(getContext());
1058   return builder.getIndexTensorAttr(extents);
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // FunctionLibraryOp
1063 //===----------------------------------------------------------------------===//
1064 
1065 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1066                               StringRef name) {
1067   result.attributes.push_back(builder.getNamedAttr(
1068       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1069 }
1070 
1071 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1072   auto attr = mapping()
1073                   .get(op->getName().getIdentifier())
1074                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1075   if (!attr)
1076     return nullptr;
1077   return lookupSymbol<FuncOp>(attr);
1078 }
1079 
1080 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1081                                    OperationState &result) {
1082   // Parse the op name.
1083   StringAttr nameAttr;
1084   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1085                              result.attributes))
1086     return failure();
1087 
1088   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1089     return failure();
1090 
1091   auto *bodyRegion = result.addRegion();
1092   if (parser.parseRegion(*bodyRegion))
1093     return failure();
1094 
1095   if (parser.parseKeyword("mapping"))
1096     return failure();
1097 
1098   DictionaryAttr mappingAttr;
1099   if (parser.parseAttribute(mappingAttr,
1100                             parser.getBuilder().getType<NoneType>(), "mapping",
1101                             result.attributes))
1102     return failure();
1103   return success();
1104 }
1105 
1106 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1107   p << ' ';
1108   p.printSymbolName(op.getName());
1109   p.printOptionalAttrDictWithKeyword(
1110       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1111   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1112                 /*printBlockTerminators=*/false);
1113   p << " mapping ";
1114   p.printAttributeWithoutType(op.mappingAttr());
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // GetExtentOp
1119 //===----------------------------------------------------------------------===//
1120 
1121 Optional<int64_t> GetExtentOp::getConstantDim() {
1122   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1123     return constSizeOp.value().getLimitedValue();
1124   if (auto constantOp = dim().getDefiningOp<arith::ConstantOp>())
1125     return constantOp.value().cast<IntegerAttr>().getInt();
1126   return llvm::None;
1127 }
1128 
1129 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1130   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1131   if (!elements)
1132     return nullptr;
1133   Optional<int64_t> dim = getConstantDim();
1134   if (!dim.hasValue())
1135     return nullptr;
1136   if (dim.getValue() >= elements.getNumElements())
1137     return nullptr;
1138   return elements.getValue({(uint64_t)dim.getValue()});
1139 }
1140 
1141 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1142                         int64_t dim) {
1143   auto loc = result.location;
1144   auto dimAttr = builder.getIndexAttr(dim);
1145   if (shape.getType().isa<ShapeType>()) {
1146     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1147     build(builder, result, builder.getType<SizeType>(), shape, dim);
1148   } else {
1149     Value dim =
1150         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1151     build(builder, result, builder.getIndexType(), shape, dim);
1152   }
1153 }
1154 
1155 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1156     MLIRContext *context, Optional<Location> location, ValueRange operands,
1157     DictionaryAttr attributes, RegionRange regions,
1158     SmallVectorImpl<Type> &inferredReturnTypes) {
1159   inferredReturnTypes.assign({IndexType::get(context)});
1160   return success();
1161 }
1162 
1163 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1164                                                        TypeRange r) {
1165   // SizeType is compatible with IndexType.
1166   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1167 }
1168 
1169 //===----------------------------------------------------------------------===//
1170 // IsBroadcastableOp
1171 //===----------------------------------------------------------------------===//
1172 
1173 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1174                                                     MLIRContext *context) {
1175   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1176 }
1177 
1178 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1179   // Can always broadcast fewer than two shapes.
1180   if (operands.size() < 2) {
1181     return BoolAttr::get(getContext(), true);
1182   }
1183 
1184   return nullptr;
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // MeetOp
1189 //===----------------------------------------------------------------------===//
1190 
1191 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1192     MLIRContext *context, Optional<Location> location, ValueRange operands,
1193     DictionaryAttr attributes, RegionRange regions,
1194     SmallVectorImpl<Type> &inferredReturnTypes) {
1195   inferredReturnTypes.assign({operands[0].getType()});
1196   return success();
1197 }
1198 
1199 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1200   if (l.size() != 1 || r.size() != 1)
1201     return false;
1202   if (l == r)
1203     return true;
1204 
1205   Type lhs = l.front();
1206   Type rhs = r.front();
1207 
1208   if (lhs != rhs)
1209     return false;
1210 
1211   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1212     return true;
1213 
1214   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1215     return true;
1216   return false;
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // RankOp
1221 //===----------------------------------------------------------------------===//
1222 
1223 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1224   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1225   if (!shape)
1226     return {};
1227   int64_t rank = shape.getNumElements();
1228   Builder builder(getContext());
1229   return builder.getIndexAttr(rank);
1230 }
1231 
1232 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1233 /// Constant folding fails in cases where only the rank is constant, not the
1234 /// shape itself.
1235 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1236 ///
1237 /// Example:
1238 ///
1239 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1240 /// %rank = shape.rank %shape
1241 ///
1242 /// becomes
1243 ///
1244 /// %rank = shape.const_size 3
1245 
1246 namespace {
1247 struct RankShapeOfCanonicalizationPattern
1248     : public OpRewritePattern<shape::RankOp> {
1249   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1250 
1251   LogicalResult matchAndRewrite(shape::RankOp op,
1252                                 PatternRewriter &rewriter) const override {
1253     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1254     if (!shapeOfOp)
1255       return failure();
1256     auto rankedTensorType =
1257         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1258     if (!rankedTensorType)
1259       return failure();
1260     int64_t rank = rankedTensorType.getRank();
1261     if (op.getType().isa<IndexType>()) {
1262       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1263                                                           rank);
1264     } else if (op.getType().isa<shape::SizeType>()) {
1265       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1266     } else {
1267       return failure();
1268     }
1269     return success();
1270   }
1271 };
1272 } // namespace
1273 
1274 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1275                                                 MLIRContext *context) {
1276   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1277 }
1278 
1279 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1280     MLIRContext *context, Optional<Location> location, ValueRange operands,
1281     DictionaryAttr attributes, RegionRange regions,
1282     SmallVectorImpl<Type> &inferredReturnTypes) {
1283   if (operands[0].getType().isa<ShapeType>())
1284     inferredReturnTypes.assign({SizeType::get(context)});
1285   else
1286     inferredReturnTypes.assign({IndexType::get(context)});
1287   return success();
1288 }
1289 
1290 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1291   // SizeType is compatible with IndexType.
1292   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1293 }
1294 
1295 //===----------------------------------------------------------------------===//
1296 // NumElementsOp
1297 //===----------------------------------------------------------------------===//
1298 
1299 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1300 
1301   // Fold only when argument constant.
1302   Attribute shape = operands[0];
1303   if (!shape)
1304     return {};
1305 
1306   APInt product(64, 1);
1307   for (auto value : shape.cast<DenseIntElementsAttr>())
1308     product *= value;
1309   Builder builder(getContext());
1310   return builder.getIndexAttr(product.getLimitedValue());
1311 }
1312 
1313 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1314     MLIRContext *context, Optional<Location> location, ValueRange operands,
1315     DictionaryAttr attributes, RegionRange regions,
1316     SmallVectorImpl<Type> &inferredReturnTypes) {
1317   if (operands[0].getType().isa<ShapeType>())
1318     inferredReturnTypes.assign({SizeType::get(context)});
1319   else
1320     inferredReturnTypes.assign({IndexType::get(context)});
1321   return success();
1322 }
1323 
1324 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1325                                                          TypeRange r) {
1326   // SizeType is compatible with IndexType.
1327   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // MaxOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1335   // If operands are equal, just propagate one.
1336   if (lhs() == rhs())
1337     return lhs();
1338   return nullptr;
1339 }
1340 
1341 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1342     MLIRContext *context, Optional<Location> location, ValueRange operands,
1343     DictionaryAttr attributes, RegionRange regions,
1344     SmallVectorImpl<Type> &inferredReturnTypes) {
1345   if (operands[0].getType() == operands[1].getType())
1346     inferredReturnTypes.assign({operands[0].getType()});
1347   else
1348     inferredReturnTypes.assign({SizeType::get(context)});
1349   return success();
1350 }
1351 
1352 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1353   if (l.size() != 1 || r.size() != 1)
1354     return false;
1355   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1356     return true;
1357   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1358     return true;
1359   return false;
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // MinOp
1364 //===----------------------------------------------------------------------===//
1365 
1366 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1367   // If operands are equal, just propagate one.
1368   if (lhs() == rhs())
1369     return lhs();
1370   return nullptr;
1371 }
1372 
1373 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1374     MLIRContext *context, Optional<Location> location, ValueRange operands,
1375     DictionaryAttr attributes, RegionRange regions,
1376     SmallVectorImpl<Type> &inferredReturnTypes) {
1377   if (operands[0].getType() == operands[1].getType())
1378     inferredReturnTypes.assign({operands[0].getType()});
1379   else
1380     inferredReturnTypes.assign({SizeType::get(context)});
1381   return success();
1382 }
1383 
1384 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1385   if (l.size() != 1 || r.size() != 1)
1386     return false;
1387   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1388     return true;
1389   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1390     return true;
1391   return false;
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // MulOp
1396 //===----------------------------------------------------------------------===//
1397 
1398 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1399   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1400   if (!lhs)
1401     return nullptr;
1402   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1403   if (!rhs)
1404     return nullptr;
1405   APInt folded = lhs.getValue() * rhs.getValue();
1406   Type indexTy = IndexType::get(getContext());
1407   return IntegerAttr::get(indexTy, folded);
1408 }
1409 
1410 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1411     MLIRContext *context, Optional<Location> location, ValueRange operands,
1412     DictionaryAttr attributes, RegionRange regions,
1413     SmallVectorImpl<Type> &inferredReturnTypes) {
1414   if (operands[0].getType().isa<SizeType>() ||
1415       operands[1].getType().isa<SizeType>())
1416     inferredReturnTypes.assign({SizeType::get(context)});
1417   else
1418     inferredReturnTypes.assign({IndexType::get(context)});
1419   return success();
1420 }
1421 
1422 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1423   // SizeType is compatible with IndexType.
1424   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1425 }
1426 //===----------------------------------------------------------------------===//
1427 // ShapeOfOp
1428 //===----------------------------------------------------------------------===//
1429 
1430 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1431   auto type = getOperand().getType().dyn_cast<ShapedType>();
1432   if (!type || !type.hasStaticShape())
1433     return nullptr;
1434   Builder builder(getContext());
1435   return builder.getIndexTensorAttr(type.getShape());
1436 }
1437 
1438 namespace {
1439 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1440   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1441 
1442   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1443                                 PatternRewriter &rewriter) const override {
1444     if (!op.arg().getType().isa<ShapedType>())
1445       return failure();
1446     if (op.getType().isa<ShapedType>())
1447       return failure();
1448 
1449     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1450     return success();
1451   }
1452 };
1453 
1454 // Canonicalize
1455 // ```
1456 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1457 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1458 // ```
1459 // to
1460 // ```
1461 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1462 // ```
1463 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1464   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1465 
1466   LogicalResult matchAndRewrite(tensor::CastOp op,
1467                                 PatternRewriter &rewriter) const override {
1468     auto ty = op.getType().dyn_cast<RankedTensorType>();
1469     if (!ty || ty.getRank() != 1)
1470       return failure();
1471 
1472     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1473     if (!shapeOfOp)
1474       return failure();
1475 
1476     // Argument type must be ranked and must not conflict.
1477     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1478     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1479       return failure();
1480 
1481     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1482     return success();
1483   }
1484 };
1485 } // namespace
1486 
1487 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1488                                             MLIRContext *context) {
1489   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1490                ExtractFromShapeOfExtentTensor>(context);
1491 }
1492 
1493 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1494     MLIRContext *context, Optional<Location> location, ValueRange operands,
1495     DictionaryAttr attributes, RegionRange regions,
1496     SmallVectorImpl<Type> &inferredReturnTypes) {
1497   if (operands[0].getType().isa<ValueShapeType>())
1498     inferredReturnTypes.assign({ShapeType::get(context)});
1499   else {
1500     auto shapedTy = operands[0].getType().cast<ShapedType>();
1501     int64_t rank =
1502         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1503     Type indexTy = IndexType::get(context);
1504     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1505     inferredReturnTypes.assign({extentTensorTy});
1506   }
1507   return success();
1508 }
1509 
1510 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1511   if (l.size() != 1 || r.size() != 1)
1512     return false;
1513   if (l == r)
1514     return true;
1515 
1516   Type lhs = l.front();
1517   Type rhs = r.front();
1518 
1519   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1520     return false;
1521 
1522   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1523     // Shape type is compatible with all other valid return types.
1524     return true;
1525 
1526   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1527     return true;
1528   return false;
1529 }
1530 
1531 //===----------------------------------------------------------------------===//
1532 // SizeToIndexOp
1533 //===----------------------------------------------------------------------===//
1534 
1535 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1536   // Constant values of both types, `shape.size` and `index`, are represented as
1537   // `IntegerAttr`s which makes constant folding simple.
1538   if (Attribute arg = operands[0])
1539     return arg;
1540   return impl::foldCastOp(*this);
1541 }
1542 
1543 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1544                                                 MLIRContext *context) {
1545   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1546 }
1547 
1548 //===----------------------------------------------------------------------===//
1549 // YieldOp
1550 //===----------------------------------------------------------------------===//
1551 
1552 static LogicalResult verify(shape::YieldOp op) {
1553   auto *parentOp = op->getParentOp();
1554   auto results = parentOp->getResults();
1555   auto operands = op.getOperands();
1556 
1557   if (parentOp->getNumResults() != op.getNumOperands())
1558     return op.emitOpError() << "number of operands does not match number of "
1559                                "results of its parent";
1560   for (auto e : llvm::zip(results, operands))
1561     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1562       return op.emitOpError()
1563              << "types mismatch between yield op and its parent";
1564 
1565   return success();
1566 }
1567 
1568 //===----------------------------------------------------------------------===//
1569 // SplitAtOp
1570 //===----------------------------------------------------------------------===//
1571 
1572 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1573                               SmallVectorImpl<OpFoldResult> &results) {
1574   if (!operands[0] || !operands[1])
1575     return failure();
1576   auto shapeVec = llvm::to_vector<6>(
1577       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1578   auto shape = llvm::makeArrayRef(shapeVec);
1579   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1580   // Verify that the split point is in the correct range.
1581   // TODO: Constant fold to an "error".
1582   int64_t rank = shape.size();
1583   if (!(-rank <= splitPoint && splitPoint <= rank))
1584     return failure();
1585   if (splitPoint < 0)
1586     splitPoint += shape.size();
1587   Builder builder(operands[0].getContext());
1588   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1589   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1590   return success();
1591 }
1592 
1593 //===----------------------------------------------------------------------===//
1594 // ToExtentTensorOp
1595 //===----------------------------------------------------------------------===//
1596 
1597 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1598   if (!operands[0])
1599     return impl::foldCastOp(*this);
1600   Builder builder(getContext());
1601   auto shape = llvm::to_vector<6>(
1602       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1603   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1604                                     builder.getIndexType());
1605   return DenseIntElementsAttr::get(type, shape);
1606 }
1607 
1608 //===----------------------------------------------------------------------===//
1609 // ReduceOp
1610 //===----------------------------------------------------------------------===//
1611 
1612 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1613                      ValueRange initVals) {
1614   result.addOperands(shape);
1615   result.addOperands(initVals);
1616 
1617   Region *bodyRegion = result.addRegion();
1618   bodyRegion->push_back(new Block);
1619   Block &bodyBlock = bodyRegion->front();
1620   bodyBlock.addArgument(builder.getIndexType());
1621 
1622   Type elementType;
1623   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1624     elementType = tensorType.getElementType();
1625   else
1626     elementType = SizeType::get(builder.getContext());
1627   bodyBlock.addArgument(elementType);
1628 
1629   for (Type initValType : initVals.getTypes()) {
1630     bodyBlock.addArgument(initValType);
1631     result.addTypes(initValType);
1632   }
1633 }
1634 
1635 static LogicalResult verify(ReduceOp op) {
1636   // Verify block arg types.
1637   Block &block = op.region().front();
1638 
1639   // The block takes index, extent, and aggregated values as arguments.
1640   auto blockArgsCount = op.initVals().size() + 2;
1641   if (block.getNumArguments() != blockArgsCount)
1642     return op.emitOpError() << "ReduceOp body is expected to have "
1643                             << blockArgsCount << " arguments";
1644 
1645   // The first block argument is the index and must always be of type `index`.
1646   if (!block.getArgument(0).getType().isa<IndexType>())
1647     return op.emitOpError(
1648         "argument 0 of ReduceOp body is expected to be of IndexType");
1649 
1650   // The second block argument is the extent and must be of type `size` or
1651   // `index`, depending on whether the reduce operation is applied to a shape or
1652   // to an extent tensor.
1653   Type extentTy = block.getArgument(1).getType();
1654   if (op.shape().getType().isa<ShapeType>()) {
1655     if (!extentTy.isa<SizeType>())
1656       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1657                             "SizeType if the ReduceOp operates on a ShapeType");
1658   } else {
1659     if (!extentTy.isa<IndexType>())
1660       return op.emitOpError(
1661           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1662           "ReduceOp operates on an extent tensor");
1663   }
1664 
1665   for (auto type : llvm::enumerate(op.initVals()))
1666     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1667       return op.emitOpError()
1668              << "type mismatch between argument " << type.index() + 2
1669              << " of ReduceOp body and initial value " << type.index();
1670   return success();
1671 }
1672 
1673 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1674   // Parse operands.
1675   SmallVector<OpAsmParser::OperandType, 3> operands;
1676   Type shapeOrExtentTensorType;
1677   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1678                               OpAsmParser::Delimiter::Paren) ||
1679       parser.parseColonType(shapeOrExtentTensorType) ||
1680       parser.parseOptionalArrowTypeList(result.types))
1681     return failure();
1682 
1683   // Resolve operands.
1684   auto initVals = llvm::makeArrayRef(operands).drop_front();
1685   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1686                             result.operands) ||
1687       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1688                              result.operands))
1689     return failure();
1690 
1691   // Parse the body.
1692   Region *body = result.addRegion();
1693   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1694     return failure();
1695 
1696   // Parse attributes.
1697   if (parser.parseOptionalAttrDict(result.attributes))
1698     return failure();
1699 
1700   return success();
1701 }
1702 
1703 static void print(OpAsmPrinter &p, ReduceOp op) {
1704   p << '(' << op.shape() << ", " << op.initVals()
1705     << ") : " << op.shape().getType();
1706   p.printOptionalArrowTypeList(op.getResultTypes());
1707   p.printRegion(op.region());
1708   p.printOptionalAttrDict(op->getAttrs());
1709 }
1710 
1711 #define GET_OP_CLASSES
1712 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1713