1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Traits.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 using namespace mlir;
28 using namespace mlir::shape;
29 
30 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
31 
32 namespace {
33 #include "ShapeCanonicalization.inc"
34 }
35 
36 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
37   return RankedTensorType::get({rank}, IndexType::get(ctx));
38 }
39 
40 bool shape::isExtentTensorType(Type type) {
41   auto ranked = type.dyn_cast<RankedTensorType>();
42   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
43 }
44 
45 LogicalResult shape::getShapeVec(Value input,
46                                  SmallVectorImpl<int64_t> &shapeValues) {
47   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
48     auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
49     if (!type.hasRank())
50       return failure();
51     shapeValues = llvm::to_vector<6>(type.getShape());
52     return success();
53   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
54     shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
55     return success();
56   } else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
57     shapeValues = llvm::to_vector<6>(
58         inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
59     return success();
60   } else {
61     return failure();
62   }
63 }
64 
65 static bool isErrorPropagationPossible(TypeRange operandTypes) {
66   return llvm::any_of(operandTypes, [](Type ty) {
67     return ty.isa<SizeType, ShapeType, ValueShapeType>();
68   });
69 }
70 
71 static LogicalResult verifySizeOrIndexOp(Operation *op) {
72   assert(op != nullptr && op->getNumResults() == 1);
73   Type resultTy = op->getResultTypes().front();
74   if (isErrorPropagationPossible(op->getOperandTypes())) {
75     if (!resultTy.isa<SizeType>())
76       return op->emitOpError()
77              << "if at least one of the operands can hold error values then "
78                 "the result must be of type `size` to propagate them";
79   }
80   return success();
81 }
82 
83 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84   assert(op != nullptr && op->getNumResults() == 1);
85   Type resultTy = op->getResultTypes().front();
86   if (isErrorPropagationPossible(op->getOperandTypes())) {
87     if (!resultTy.isa<ShapeType>())
88       return op->emitOpError()
89              << "if at least one of the operands can hold error values then "
90                 "the result must be of type `shape` to propagate them";
91   }
92   return success();
93 }
94 
95 template <typename... Ty>
96 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97   return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
98 }
99 
100 template <typename... Ty, typename... ranges>
101 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102   return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // InlinerInterface
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 /// This class defines the interface for inlining shape dialect ops.
111 struct ShapeInlinerInterface : public DialectInlinerInterface {
112   using DialectInlinerInterface::DialectInlinerInterface;
113 
114   // Returns true if the given region 'src' can be inlined into the region
115   // 'dest' that is attached to an operation registered to the current dialect.
116   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   // Returns true if the given operation 'op', that is registered to this
122   // dialect, can be inlined into the region 'dest' that is attached to an
123   // operation registered to the current dialect.
124   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125                        BlockAndValueMapping &) const final {
126     return true;
127   }
128 };
129 } // namespace
130 
131 void ShapeDialect::initialize() {
132   addOperations<
133 #define GET_OP_LIST
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135       >();
136   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
137   addInterfaces<ShapeInlinerInterface>();
138   // Allow unknown operations during prototyping and testing. As the dialect is
139   // still evolving it makes it simple to start with an unregistered ops and
140   // try different variants before actually defining the op.
141   allowUnknownOperations();
142 }
143 
144 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
145                                              Attribute value, Type type,
146                                              Location loc) {
147   if (type.isa<ShapeType>() || isExtentTensorType(type))
148     return builder.create<ConstShapeOp>(loc, type,
149                                         value.cast<DenseIntElementsAttr>());
150   if (type.isa<SizeType>())
151     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
152   if (type.isa<WitnessType>())
153     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
154   if (arith::ConstantOp::isBuildableWith(value, type))
155     return builder.create<arith::ConstantOp>(loc, type, value);
156   return nullptr;
157 }
158 
159 /// Parse a type registered to this dialect.
160 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
161   StringRef keyword;
162   if (parser.parseKeyword(&keyword))
163     return Type();
164 
165   if (keyword == "shape")
166     return ShapeType::get(getContext());
167   if (keyword == "size")
168     return SizeType::get(getContext());
169   if (keyword == "value_shape")
170     return ValueShapeType::get(getContext());
171   if (keyword == "witness")
172     return WitnessType::get(getContext());
173 
174   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
175   return Type();
176 }
177 
178 /// Print a type registered to this dialect.
179 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
180   TypeSwitch<Type>(type)
181       .Case<ShapeType>([&](Type) { os << "shape"; })
182       .Case<SizeType>([&](Type) { os << "size"; })
183       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
184       .Case<WitnessType>([&](Type) { os << "witness"; })
185       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
186 }
187 
188 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
189                                                      NamedAttribute attribute) {
190   // Verify shape.lib attribute.
191   if (attribute.getName() == "shape.lib") {
192     if (!op->hasTrait<OpTrait::SymbolTable>())
193       return op->emitError(
194           "shape.lib attribute may only be on op implementing SymbolTable");
195 
196     if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
197       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
198       if (!symbol)
199         return op->emitError("shape function library ")
200                << symbolRef << " not found";
201       return isa<shape::FunctionLibraryOp>(symbol)
202                  ? success()
203                  : op->emitError()
204                        << symbolRef << " required to be shape function library";
205     }
206 
207     if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
208       // Verify all entries are function libraries and mappings in libraries
209       // refer to unique ops.
210       DenseSet<StringAttr> key;
211       for (auto it : arr) {
212         if (!it.isa<SymbolRefAttr>())
213           return op->emitError(
214               "only SymbolRefAttr allowed in shape.lib attribute array");
215 
216         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
217             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
218         if (!shapeFnLib)
219           return op->emitError()
220                  << it << " does not refer to FunctionLibraryOp";
221         for (auto mapping : shapeFnLib.getMapping()) {
222           if (!key.insert(mapping.getName()).second) {
223             return op->emitError("only one op to shape mapping allowed, found "
224                                  "multiple for `")
225                    << mapping.getName() << "`";
226           }
227         }
228       }
229       return success();
230     }
231 
232     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
233                          "allowed as shape.lib attribute");
234   }
235   return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // AnyOp
240 //===----------------------------------------------------------------------===//
241 
242 // TODO: Canonicalization should be implemented for shapes that can be
243 // determined through mixtures of the known dimensions of the inputs.
244 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
245   // Only the last operand is checked because AnyOp is commutative.
246   if (operands.back())
247     return operands.back();
248 
249   return nullptr;
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // AssumingOp
254 //===----------------------------------------------------------------------===//
255 
256 static ParseResult parseAssumingOp(OpAsmParser &parser,
257                                    OperationState &result) {
258   result.regions.reserve(1);
259   Region *doRegion = result.addRegion();
260 
261   auto &builder = parser.getBuilder();
262   OpAsmParser::OperandType cond;
263   if (parser.parseOperand(cond) ||
264       parser.resolveOperand(cond, builder.getType<WitnessType>(),
265                             result.operands))
266     return failure();
267 
268   // Parse optional results type list.
269   if (parser.parseOptionalArrowTypeList(result.types))
270     return failure();
271 
272   // Parse the region and add a terminator if elided.
273   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
274     return failure();
275   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
276 
277   // Parse the optional attribute list.
278   if (parser.parseOptionalAttrDict(result.attributes))
279     return failure();
280   return success();
281 }
282 
283 static void print(OpAsmPrinter &p, AssumingOp op) {
284   bool yieldsResults = !op.getResults().empty();
285 
286   p << " " << op.getWitness();
287   if (yieldsResults) {
288     p << " -> (" << op.getResultTypes() << ")";
289   }
290   p.printRegion(op.getDoRegion(),
291                 /*printEntryBlockArgs=*/false,
292                 /*printBlockTerminators=*/yieldsResults);
293   p.printOptionalAttrDict(op->getAttrs());
294 }
295 
296 namespace {
297 // Removes AssumingOp with a passing witness and inlines the region.
298 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
299   using OpRewritePattern<AssumingOp>::OpRewritePattern;
300 
301   LogicalResult matchAndRewrite(AssumingOp op,
302                                 PatternRewriter &rewriter) const override {
303     auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
304     if (!witness || !witness.getPassingAttr())
305       return failure();
306 
307     AssumingOp::inlineRegionIntoParent(op, rewriter);
308     return success();
309   }
310 };
311 
312 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
313   using OpRewritePattern<AssumingOp>::OpRewritePattern;
314 
315   LogicalResult matchAndRewrite(AssumingOp op,
316                                 PatternRewriter &rewriter) const override {
317     Block *body = op.getBody();
318     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
319 
320     // Find used values.
321     SmallVector<Value, 4> newYieldOperands;
322     Value opResult, yieldOperand;
323     for (auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
324       std::tie(opResult, yieldOperand) = it;
325       if (!opResult.getUses().empty()) {
326         newYieldOperands.push_back(yieldOperand);
327       }
328     }
329 
330     // Rewrite only if redundant results exist.
331     if (newYieldOperands.size() == yieldOp->getNumOperands())
332       return failure();
333 
334     // Replace yield op in the old assuming op's body and move the entire region
335     // to the new assuming op.
336     rewriter.setInsertionPointToEnd(body);
337     auto newYieldOp =
338         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
339     rewriter.setInsertionPoint(op);
340     auto newOp = rewriter.create<AssumingOp>(
341         op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
342     newOp.getDoRegion().takeBody(op.getDoRegion());
343 
344     // Use the new results to replace the previously used ones.
345     SmallVector<Value, 4> replacementValues;
346     auto src = newOp.getResults().begin();
347     for (auto it : op.getResults()) {
348       if (it.getUses().empty())
349         replacementValues.push_back(nullptr);
350       else
351         replacementValues.push_back(*src++);
352     }
353     rewriter.replaceOp(op, replacementValues);
354     return success();
355   }
356 };
357 } // namespace
358 
359 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
360                                              MLIRContext *context) {
361   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
362 }
363 
364 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
365 void AssumingOp::getSuccessorRegions(
366     Optional<unsigned> index, ArrayRef<Attribute> operands,
367     SmallVectorImpl<RegionSuccessor> &regions) {
368   // AssumingOp has unconditional control flow into the region and back to the
369   // parent, so return the correct RegionSuccessor purely based on the index
370   // being None or 0.
371   if (index.hasValue()) {
372     regions.push_back(RegionSuccessor(getResults()));
373     return;
374   }
375 
376   regions.push_back(RegionSuccessor(&getDoRegion()));
377 }
378 
379 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
380                                         PatternRewriter &rewriter) {
381   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
382   auto *assumingBlock = op.getBody();
383   auto initPosition = rewriter.getInsertionPoint();
384   auto *blockAfterAssuming =
385       rewriter.splitBlock(blockBeforeAssuming, initPosition);
386 
387   // Remove the AssumingOp and AssumingYieldOp.
388   auto &yieldOp = assumingBlock->back();
389   rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
390   rewriter.replaceOp(op, yieldOp.getOperands());
391   rewriter.eraseOp(&yieldOp);
392 
393   // Merge blocks together as there was no branching behavior from the
394   // AssumingOp.
395   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
396   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
397 }
398 
399 void AssumingOp::build(
400     OpBuilder &builder, OperationState &result, Value witness,
401     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
402 
403   result.addOperands(witness);
404   Region *bodyRegion = result.addRegion();
405   bodyRegion->push_back(new Block);
406   Block &bodyBlock = bodyRegion->front();
407 
408   // Build body.
409   OpBuilder::InsertionGuard guard(builder);
410   builder.setInsertionPointToStart(&bodyBlock);
411   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
412   builder.create<AssumingYieldOp>(result.location, yieldValues);
413 
414   SmallVector<Type, 2> assumingTypes;
415   for (Value v : yieldValues)
416     assumingTypes.push_back(v.getType());
417   result.addTypes(assumingTypes);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // AddOp
422 //===----------------------------------------------------------------------===//
423 
424 LogicalResult mlir::shape::AddOp::inferReturnTypes(
425     MLIRContext *context, Optional<Location> location, ValueRange operands,
426     DictionaryAttr attributes, RegionRange regions,
427     SmallVectorImpl<Type> &inferredReturnTypes) {
428   if (operands[0].getType().isa<SizeType>() ||
429       operands[1].getType().isa<SizeType>())
430     inferredReturnTypes.assign({SizeType::get(context)});
431   else
432     inferredReturnTypes.assign({IndexType::get(context)});
433   return success();
434 }
435 
436 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
437   // SizeType is compatible with IndexType.
438   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
439 }
440 
441 OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
442   // add(x, 0) -> x
443   if (matchPattern(getRhs(), m_Zero()))
444     return getLhs();
445 
446   return constFoldBinaryOp<IntegerAttr>(operands,
447                                         [](APInt a, APInt b) { return a + b; });
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // AssumingAllOp
452 //===----------------------------------------------------------------------===//
453 
454 namespace {
455 struct AssumingAllToCstrEqCanonicalization
456     : public OpRewritePattern<AssumingAllOp> {
457   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
458 
459   LogicalResult matchAndRewrite(AssumingAllOp op,
460                                 PatternRewriter &rewriter) const override {
461     SmallVector<Value, 8> shapes;
462     for (Value w : op.getInputs()) {
463       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
464       if (!cstrEqOp)
465         return failure();
466       bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
467         return llvm::is_contained(shapes, s);
468       });
469       if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
470         return failure();
471       shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
472     }
473     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
474     return success();
475   }
476 };
477 
478 template <typename OpTy>
479 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
480   using OpRewritePattern<OpTy>::OpRewritePattern;
481 
482   LogicalResult matchAndRewrite(OpTy op,
483                                 PatternRewriter &rewriter) const override {
484     // Find unique operands.
485     SmallVector<Value, 2> unique;
486     for (Value v : op.getOperands()) {
487       if (!llvm::is_contained(unique, v))
488         unique.push_back(v);
489     }
490 
491     // Reduce op to equivalent with unique operands.
492     if (unique.size() < op.getNumOperands()) {
493       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
494                                         op->getAttrs());
495       return success();
496     }
497 
498     return failure();
499   }
500 };
501 } // namespace
502 
503 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
504                                                 MLIRContext *context) {
505   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
506                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
507 }
508 
509 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
510   // Iterate in reverse to first handle all constant operands. They are
511   // guaranteed to be the tail of the inputs because this is commutative.
512   for (int idx = operands.size() - 1; idx >= 0; idx--) {
513     Attribute a = operands[idx];
514     // Cannot fold if any inputs are not constant;
515     if (!a)
516       return nullptr;
517 
518     // We do not need to keep statically known values after handling them in
519     // this method.
520     getOperation()->eraseOperand(idx);
521 
522     // Always false if any input is statically known false
523     if (!a.cast<BoolAttr>().getValue())
524       return a;
525   }
526   // If this is reached, all inputs were statically known passing.
527   return BoolAttr::get(getContext(), true);
528 }
529 
530 static LogicalResult verify(AssumingAllOp op) {
531   // Ensure that AssumingAllOp contains at least one operand
532   if (op.getNumOperands() == 0)
533     return op.emitOpError("no operands specified");
534 
535   return success();
536 }
537 
538 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
539                           ValueRange inputs) {
540   build(b, state, b.getType<WitnessType>(), inputs);
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // BroadcastOp
545 //===----------------------------------------------------------------------===//
546 
547 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
548   if (getShapes().size() == 1) {
549     // Otherwise, we need a cast which would be a canonicalization, not folding.
550     if (getShapes().front().getType() != getType())
551       return nullptr;
552     return getShapes().front();
553   }
554 
555   // TODO: Support folding with more than 2 input shapes
556   if (getShapes().size() > 2)
557     return nullptr;
558 
559   if (!operands[0] || !operands[1])
560     return nullptr;
561   auto lhsShape = llvm::to_vector<6>(
562       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
563   auto rhsShape = llvm::to_vector<6>(
564       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
565   SmallVector<int64_t, 6> resultShape;
566 
567   // If the shapes are not compatible, we can't fold it.
568   // TODO: Fold to an "error".
569   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
570     return nullptr;
571 
572   Builder builder(getContext());
573   return builder.getIndexTensorAttr(resultShape);
574 }
575 
576 static LogicalResult verify(BroadcastOp op) {
577   return verifyShapeOrExtentTensorOp(op);
578 }
579 
580 namespace {
581 template <typename OpTy>
582 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
583   using OpRewritePattern<OpTy>::OpRewritePattern;
584 
585   LogicalResult matchAndRewrite(OpTy op,
586                                 PatternRewriter &rewriter) const override {
587     auto isPotentiallyNonEmptyShape = [](Value shape) {
588       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
589         if (extentTensorTy.getDimSize(0) == 0)
590           return false;
591       }
592       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
593         if (constShape.getShape().empty())
594           return false;
595       }
596       return true;
597     };
598     auto newOperands = llvm::to_vector<8>(
599         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
600 
601     // Reduce op to equivalent without empty shape operands.
602     if (newOperands.size() < op.getNumOperands()) {
603       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
604                                         op->getAttrs());
605       return success();
606     }
607 
608     return failure();
609   }
610 };
611 
612 struct BroadcastForwardSingleOperandPattern
613     : public OpRewritePattern<BroadcastOp> {
614   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
615 
616   LogicalResult matchAndRewrite(BroadcastOp op,
617                                 PatternRewriter &rewriter) const override {
618     if (op.getNumOperands() != 1)
619       return failure();
620     Value replacement = op.getShapes().front();
621 
622     // Insert cast if needed.
623     if (replacement.getType() != op.getType()) {
624       auto loc = op.getLoc();
625       if (op.getType().isa<ShapeType>()) {
626         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
627       } else {
628         assert(!op.getType().isa<ShapeType>() &&
629                !replacement.getType().isa<ShapeType>() &&
630                "expect extent tensor cast");
631         replacement =
632             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
633       }
634     }
635 
636     rewriter.replaceOp(op, replacement);
637     return success();
638   }
639 };
640 
641 struct BroadcastFoldConstantOperandsPattern
642     : public OpRewritePattern<BroadcastOp> {
643   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
644 
645   LogicalResult matchAndRewrite(BroadcastOp op,
646                                 PatternRewriter &rewriter) const override {
647     SmallVector<int64_t, 8> foldedConstantShape;
648     SmallVector<Value, 8> newShapeOperands;
649     for (Value shape : op.getShapes()) {
650       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
651         SmallVector<int64_t, 8> newFoldedConstantShape;
652         if (OpTrait::util::getBroadcastedShape(
653                 foldedConstantShape,
654                 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
655                 newFoldedConstantShape)) {
656           foldedConstantShape = newFoldedConstantShape;
657           continue;
658         }
659       }
660       newShapeOperands.push_back(shape);
661     }
662 
663     // Need at least two constant operands to fold anything.
664     if (op.getNumOperands() - newShapeOperands.size() < 2)
665       return failure();
666 
667     auto foldedConstantOperandsTy = RankedTensorType::get(
668         {static_cast<int64_t>(foldedConstantShape.size())},
669         rewriter.getIndexType());
670     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
671         op.getLoc(), foldedConstantOperandsTy,
672         rewriter.getIndexTensorAttr(foldedConstantShape)));
673     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
674                                              newShapeOperands);
675     return success();
676   }
677 };
678 
679 template <typename OpTy>
680 struct CanonicalizeCastExtentTensorOperandsPattern
681     : public OpRewritePattern<OpTy> {
682   using OpRewritePattern<OpTy>::OpRewritePattern;
683 
684   LogicalResult matchAndRewrite(OpTy op,
685                                 PatternRewriter &rewriter) const override {
686     // Canonicalize operands.
687     bool anyChange = false;
688     auto canonicalizeOperand = [&](Value operand) {
689       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
690         // Only eliminate the cast if it holds no shape information.
691         bool isInformationLoosingCast =
692             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
693         if (isInformationLoosingCast) {
694           anyChange = true;
695           return castOp.source();
696         }
697       }
698       return operand;
699     };
700     auto newOperands = llvm::to_vector<8>(
701         llvm::map_range(op.getOperands(), canonicalizeOperand));
702 
703     // Rewrite op if any change required.
704     if (!anyChange)
705       return failure();
706     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
707     return success();
708   }
709 };
710 
711 struct BroadcastConcretizeResultTypePattern
712     : public OpRewritePattern<BroadcastOp> {
713   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
714 
715   LogicalResult matchAndRewrite(BroadcastOp op,
716                                 PatternRewriter &rewriter) const override {
717     // Only concretize dynamic extent tensor result types.
718     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
719     if (!resultTy || !resultTy.isDynamicDim(0))
720       return failure();
721 
722     // Infer resulting shape rank if possible.
723     int64_t maxRank = 0;
724     for (Value shape : op.getShapes()) {
725       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
726         // Cannot infer resulting shape rank if any operand is dynamically
727         // ranked.
728         if (extentTensorTy.isDynamicDim(0))
729           return failure();
730         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
731       }
732     }
733 
734     auto newOp = rewriter.create<BroadcastOp>(
735         op.getLoc(), getExtentTensorType(getContext(), maxRank),
736         op.getShapes());
737     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
738     return success();
739   }
740 };
741 } // namespace
742 
743 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
744                                               MLIRContext *context) {
745   patterns.add<BroadcastConcretizeResultTypePattern,
746                BroadcastFoldConstantOperandsPattern,
747                BroadcastForwardSingleOperandPattern,
748                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
749                RemoveDuplicateOperandsPattern<BroadcastOp>,
750                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // ConcatOp
755 //===----------------------------------------------------------------------===//
756 
757 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
758   if (!operands[0] || !operands[1])
759     return nullptr;
760   auto lhsShape = llvm::to_vector<6>(
761       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
762   auto rhsShape = llvm::to_vector<6>(
763       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
764   SmallVector<int64_t, 6> resultShape;
765   resultShape.append(lhsShape.begin(), lhsShape.end());
766   resultShape.append(rhsShape.begin(), rhsShape.end());
767   Builder builder(getContext());
768   return builder.getIndexTensorAttr(resultShape);
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // ConstShapeOp
773 //===----------------------------------------------------------------------===//
774 
775 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
776   p << " ";
777   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
778   p << "[";
779   interleaveComma(op.getShape().getValues<int64_t>(), p,
780                   [&](int64_t i) { p << i; });
781   p << "] : ";
782   p.printType(op.getType());
783 }
784 
785 static ParseResult parseConstShapeOp(OpAsmParser &parser,
786                                      OperationState &result) {
787   if (parser.parseOptionalAttrDict(result.attributes))
788     return failure();
789   // We piggy-back on ArrayAttr parsing, though we don't internally store the
790   // shape as an ArrayAttr.
791   // TODO: Implement custom parser and maybe make syntax a bit more concise.
792   Attribute extentsRaw;
793   NamedAttrList dummy;
794   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
795     return failure();
796   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
797   if (!extentsArray)
798     return failure();
799   SmallVector<int64_t, 6> ints;
800   for (Attribute extent : extentsArray) {
801     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
802     if (!attr)
803       return failure();
804     ints.push_back(attr.getInt());
805   }
806   Builder &builder = parser.getBuilder();
807   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
808   Type resultTy;
809   if (parser.parseColonType(resultTy))
810     return failure();
811   result.types.push_back(resultTy);
812   return success();
813 }
814 
815 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
816 
817 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
818                                                MLIRContext *context) {
819   patterns.add<TensorCastConstShape>(context);
820 }
821 
822 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
823     MLIRContext *context, Optional<Location> location, ValueRange operands,
824     DictionaryAttr attributes, RegionRange regions,
825     SmallVectorImpl<Type> &inferredReturnTypes) {
826   Builder b(context);
827   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
828   if (!shape)
829     return emitOptionalError(location, "missing shape attribute");
830   inferredReturnTypes.assign({RankedTensorType::get(
831       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
832   return success();
833 }
834 
835 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
836                                                         TypeRange r) {
837   if (l.size() != 1 || r.size() != 1)
838     return false;
839 
840   Type lhs = l.front();
841   Type rhs = r.front();
842 
843   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
844     // Shape type is compatible with all other valid return types.
845     return true;
846   return lhs == rhs;
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // CstrBroadcastableOp
851 //===----------------------------------------------------------------------===//
852 
853 void CstrBroadcastableOp::getCanonicalizationPatterns(
854     RewritePatternSet &patterns, MLIRContext *context) {
855   // Canonicalization patterns have overlap with the considerations during
856   // folding in case additional shape information is inferred at some point that
857   // does not result in folding.
858   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
859                CstrBroadcastableEqOps,
860                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
861                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
862 }
863 
864 // Return true if there is exactly one attribute not representing a scalar
865 // broadcast.
866 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
867   bool nonScalarSeen = false;
868   for (Attribute a : attributes) {
869     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
870       if (nonScalarSeen)
871         return false;
872       nonScalarSeen = true;
873     }
874   }
875   return true;
876 }
877 
878 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
879   // No broadcasting is needed if all operands but one are scalar.
880   if (hasAtMostSingleNonScalar(operands))
881     return BoolAttr::get(getContext(), true);
882 
883   if ([&] {
884         SmallVector<SmallVector<int64_t, 6>, 6> extents;
885         for (const auto &operand : operands) {
886           if (!operand)
887             return false;
888           extents.push_back(llvm::to_vector<6>(
889               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
890         }
891         return OpTrait::util::staticallyKnownBroadcastable(extents);
892       }())
893     return BoolAttr::get(getContext(), true);
894 
895   // Lastly, see if folding can be completed based on what constraints are known
896   // on the input shapes.
897   if ([&] {
898         SmallVector<SmallVector<int64_t, 6>, 6> extents;
899         for (auto shapeValue : getShapes()) {
900           extents.emplace_back();
901           if (failed(getShapeVec(shapeValue, extents.back())))
902             return false;
903         }
904         return OpTrait::util::staticallyKnownBroadcastable(extents);
905       }())
906     return BoolAttr::get(getContext(), true);
907 
908   // Because a failing witness result here represents an eventual assertion
909   // failure, we do not replace it with a constant witness.
910   return nullptr;
911 }
912 
913 static LogicalResult verify(CstrBroadcastableOp op) {
914   // Ensure that AssumingAllOp contains at least one operand
915   if (op.getNumOperands() < 2)
916     return op.emitOpError("required at least 2 input shapes");
917   return success();
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // CstrEqOp
922 //===----------------------------------------------------------------------===//
923 
924 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925                                            MLIRContext *context) {
926   // If inputs are equal, return passing witness
927   patterns.add<CstrEqEqOps>(context);
928 }
929 
930 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
931   if (llvm::all_of(operands,
932                    [&](Attribute a) { return a && a == operands[0]; }))
933     return BoolAttr::get(getContext(), true);
934 
935   // Because a failing witness result here represents an eventual assertion
936   // failure, we do not try to replace it with a constant witness. Similarly, we
937   // cannot if there are any non-const inputs.
938   return nullptr;
939 }
940 
941 //===----------------------------------------------------------------------===//
942 // ConstSizeOp
943 //===----------------------------------------------------------------------===//
944 
945 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
946                         int64_t value) {
947   build(builder, result, builder.getIndexAttr(value));
948 }
949 
950 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
951 
952 void ConstSizeOp::getAsmResultNames(
953     llvm::function_ref<void(Value, StringRef)> setNameFn) {
954   SmallString<4> buffer;
955   llvm::raw_svector_ostream os(buffer);
956   os << "c" << getValue();
957   setNameFn(getResult(), os.str());
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // ConstWitnessOp
962 //===----------------------------------------------------------------------===//
963 
964 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
965   return getPassingAttr();
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // CstrRequireOp
970 //===----------------------------------------------------------------------===//
971 
972 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
973   return operands[0];
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // DivOp
978 //===----------------------------------------------------------------------===//
979 
980 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
981   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
982   if (!lhs)
983     return nullptr;
984   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
985   if (!rhs)
986     return nullptr;
987 
988   // Division in APInt does not follow floor(lhs, rhs) when the result is
989   // negative. Rather, APInt rounds toward zero.
990   APInt quotient, remainder;
991   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
992   if (quotient.isNegative() && !remainder.isNullValue()) {
993     quotient -= 1;
994   }
995 
996   Type indexTy = IndexType::get(getContext());
997   return IntegerAttr::get(indexTy, quotient);
998 }
999 
1000 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1001     MLIRContext *context, Optional<Location> location, ValueRange operands,
1002     DictionaryAttr attributes, RegionRange regions,
1003     SmallVectorImpl<Type> &inferredReturnTypes) {
1004   if (operands[0].getType().isa<SizeType>() ||
1005       operands[1].getType().isa<SizeType>())
1006     inferredReturnTypes.assign({SizeType::get(context)});
1007   else
1008     inferredReturnTypes.assign({IndexType::get(context)});
1009   return success();
1010 }
1011 
1012 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1013   // SizeType is compatible with IndexType.
1014   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // ShapeEqOp
1019 //===----------------------------------------------------------------------===//
1020 
1021 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
1022   bool allSame = true;
1023   if (!operands.empty() && !operands[0])
1024     return {};
1025   for (Attribute operand : operands.drop_front(1)) {
1026     if (!operand)
1027       return {};
1028     allSame = allSame && operand == operands[0];
1029   }
1030   return BoolAttr::get(getContext(), allSame);
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // IndexToSizeOp
1035 //===----------------------------------------------------------------------===//
1036 
1037 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
1038   // Constant values of both types, `shape.size` and `index`, are represented as
1039   // `IntegerAttr`s which makes constant folding simple.
1040   if (Attribute arg = operands[0])
1041     return arg;
1042   return {};
1043 }
1044 
1045 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1046                                                 MLIRContext *context) {
1047   patterns.add<SizeToIndexToSizeCanonicalization>(context);
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // FromExtentsOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
1055   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
1056     return nullptr;
1057   SmallVector<int64_t, 6> extents;
1058   for (auto attr : operands)
1059     extents.push_back(attr.cast<IntegerAttr>().getInt());
1060   Builder builder(getContext());
1061   return builder.getIndexTensorAttr(extents);
1062 }
1063 
1064 //===----------------------------------------------------------------------===//
1065 // FunctionLibraryOp
1066 //===----------------------------------------------------------------------===//
1067 
1068 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1069                               StringRef name) {
1070   result.attributes.push_back(builder.getNamedAttr(
1071       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1072 }
1073 
1074 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1075   auto attr = getMapping()
1076                   .get(op->getName().getIdentifier())
1077                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1078   if (!attr)
1079     return nullptr;
1080   return lookupSymbol<FuncOp>(attr);
1081 }
1082 
1083 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1084                                    OperationState &result) {
1085   // Parse the op name.
1086   StringAttr nameAttr;
1087   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1088                              result.attributes))
1089     return failure();
1090 
1091   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1092     return failure();
1093 
1094   auto *bodyRegion = result.addRegion();
1095   if (parser.parseRegion(*bodyRegion))
1096     return failure();
1097 
1098   if (parser.parseKeyword("mapping"))
1099     return failure();
1100 
1101   DictionaryAttr mappingAttr;
1102   if (parser.parseAttribute(mappingAttr,
1103                             parser.getBuilder().getType<NoneType>(), "mapping",
1104                             result.attributes))
1105     return failure();
1106   return success();
1107 }
1108 
1109 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1110   p << ' ';
1111   p.printSymbolName(op.getName());
1112   p.printOptionalAttrDictWithKeyword(
1113       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1114   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1115                 /*printBlockTerminators=*/false);
1116   p << " mapping ";
1117   p.printAttributeWithoutType(op.getMappingAttr());
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // GetExtentOp
1122 //===----------------------------------------------------------------------===//
1123 
1124 Optional<int64_t> GetExtentOp::getConstantDim() {
1125   if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1126     return constSizeOp.getValue().getLimitedValue();
1127   if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1128     return constantOp.getValue().cast<IntegerAttr>().getInt();
1129   return llvm::None;
1130 }
1131 
1132 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1133   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1134   if (!elements)
1135     return nullptr;
1136   Optional<int64_t> dim = getConstantDim();
1137   if (!dim.hasValue())
1138     return nullptr;
1139   if (dim.getValue() >= elements.getNumElements())
1140     return nullptr;
1141   return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
1142 }
1143 
1144 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1145                         int64_t dim) {
1146   auto loc = result.location;
1147   auto dimAttr = builder.getIndexAttr(dim);
1148   if (shape.getType().isa<ShapeType>()) {
1149     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1150     build(builder, result, builder.getType<SizeType>(), shape, dim);
1151   } else {
1152     Value dim =
1153         builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1154     build(builder, result, builder.getIndexType(), shape, dim);
1155   }
1156 }
1157 
1158 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1159     MLIRContext *context, Optional<Location> location, ValueRange operands,
1160     DictionaryAttr attributes, RegionRange regions,
1161     SmallVectorImpl<Type> &inferredReturnTypes) {
1162   inferredReturnTypes.assign({IndexType::get(context)});
1163   return success();
1164 }
1165 
1166 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1167                                                        TypeRange r) {
1168   // SizeType is compatible with IndexType.
1169   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // IsBroadcastableOp
1174 //===----------------------------------------------------------------------===//
1175 
1176 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1177                                                     MLIRContext *context) {
1178   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1179 }
1180 
1181 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1182   // Can always broadcast fewer than two shapes.
1183   if (operands.size() < 2) {
1184     return BoolAttr::get(getContext(), true);
1185   }
1186 
1187   return nullptr;
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // MeetOp
1192 //===----------------------------------------------------------------------===//
1193 
1194 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1195     MLIRContext *context, Optional<Location> location, ValueRange operands,
1196     DictionaryAttr attributes, RegionRange regions,
1197     SmallVectorImpl<Type> &inferredReturnTypes) {
1198   inferredReturnTypes.assign({operands[0].getType()});
1199   return success();
1200 }
1201 
1202 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1203   if (l.size() != 1 || r.size() != 1)
1204     return false;
1205   if (l == r)
1206     return true;
1207 
1208   Type lhs = l.front();
1209   Type rhs = r.front();
1210 
1211   if (lhs != rhs)
1212     return false;
1213 
1214   if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
1215     return true;
1216 
1217   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1218     return true;
1219   return false;
1220 }
1221 
1222 //===----------------------------------------------------------------------===//
1223 // RankOp
1224 //===----------------------------------------------------------------------===//
1225 
1226 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1227   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1228   if (!shape)
1229     return {};
1230   int64_t rank = shape.getNumElements();
1231   Builder builder(getContext());
1232   return builder.getIndexAttr(rank);
1233 }
1234 
1235 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1236 /// Constant folding fails in cases where only the rank is constant, not the
1237 /// shape itself.
1238 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1239 ///
1240 /// Example:
1241 ///
1242 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1243 /// %rank = shape.rank %shape
1244 ///
1245 /// becomes
1246 ///
1247 /// %rank = shape.const_size 3
1248 
1249 namespace {
1250 struct RankShapeOfCanonicalizationPattern
1251     : public OpRewritePattern<shape::RankOp> {
1252   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1253 
1254   LogicalResult matchAndRewrite(shape::RankOp op,
1255                                 PatternRewriter &rewriter) const override {
1256     auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1257     if (!shapeOfOp)
1258       return failure();
1259     auto rankedTensorType =
1260         shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1261     if (!rankedTensorType)
1262       return failure();
1263     int64_t rank = rankedTensorType.getRank();
1264     if (op.getType().isa<IndexType>()) {
1265       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1266                                                           rank);
1267     } else if (op.getType().isa<shape::SizeType>()) {
1268       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1269     } else {
1270       return failure();
1271     }
1272     return success();
1273   }
1274 };
1275 } // namespace
1276 
1277 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1278                                                 MLIRContext *context) {
1279   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1280 }
1281 
1282 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1283     MLIRContext *context, Optional<Location> location, ValueRange operands,
1284     DictionaryAttr attributes, RegionRange regions,
1285     SmallVectorImpl<Type> &inferredReturnTypes) {
1286   if (operands[0].getType().isa<ShapeType>())
1287     inferredReturnTypes.assign({SizeType::get(context)});
1288   else
1289     inferredReturnTypes.assign({IndexType::get(context)});
1290   return success();
1291 }
1292 
1293 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1294   // SizeType is compatible with IndexType.
1295   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1296 }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // NumElementsOp
1300 //===----------------------------------------------------------------------===//
1301 
1302 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1303 
1304   // Fold only when argument constant.
1305   Attribute shape = operands[0];
1306   if (!shape)
1307     return {};
1308 
1309   APInt product(64, 1);
1310   for (auto value : shape.cast<DenseIntElementsAttr>())
1311     product *= value;
1312   Builder builder(getContext());
1313   return builder.getIndexAttr(product.getLimitedValue());
1314 }
1315 
1316 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1317     MLIRContext *context, Optional<Location> location, ValueRange operands,
1318     DictionaryAttr attributes, RegionRange regions,
1319     SmallVectorImpl<Type> &inferredReturnTypes) {
1320   if (operands[0].getType().isa<ShapeType>())
1321     inferredReturnTypes.assign({SizeType::get(context)});
1322   else
1323     inferredReturnTypes.assign({IndexType::get(context)});
1324   return success();
1325 }
1326 
1327 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1328                                                          TypeRange r) {
1329   // SizeType is compatible with IndexType.
1330   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // MaxOp
1335 //===----------------------------------------------------------------------===//
1336 
1337 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1338   // If operands are equal, just propagate one.
1339   if (getLhs() == getRhs())
1340     return getLhs();
1341   return nullptr;
1342 }
1343 
1344 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1345     MLIRContext *context, Optional<Location> location, ValueRange operands,
1346     DictionaryAttr attributes, RegionRange regions,
1347     SmallVectorImpl<Type> &inferredReturnTypes) {
1348   if (operands[0].getType() == operands[1].getType())
1349     inferredReturnTypes.assign({operands[0].getType()});
1350   else
1351     inferredReturnTypes.assign({SizeType::get(context)});
1352   return success();
1353 }
1354 
1355 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1356   if (l.size() != 1 || r.size() != 1)
1357     return false;
1358   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1359     return true;
1360   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1361     return true;
1362   return false;
1363 }
1364 
1365 //===----------------------------------------------------------------------===//
1366 // MinOp
1367 //===----------------------------------------------------------------------===//
1368 
1369 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1370   // If operands are equal, just propagate one.
1371   if (getLhs() == getRhs())
1372     return getLhs();
1373   return nullptr;
1374 }
1375 
1376 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1377     MLIRContext *context, Optional<Location> location, ValueRange operands,
1378     DictionaryAttr attributes, RegionRange regions,
1379     SmallVectorImpl<Type> &inferredReturnTypes) {
1380   if (operands[0].getType() == operands[1].getType())
1381     inferredReturnTypes.assign({operands[0].getType()});
1382   else
1383     inferredReturnTypes.assign({SizeType::get(context)});
1384   return success();
1385 }
1386 
1387 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1388   if (l.size() != 1 || r.size() != 1)
1389     return false;
1390   if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1391     return true;
1392   if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1393     return true;
1394   return false;
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // MulOp
1399 //===----------------------------------------------------------------------===//
1400 
1401 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1402   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1403   if (!lhs)
1404     return nullptr;
1405   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1406   if (!rhs)
1407     return nullptr;
1408   APInt folded = lhs.getValue() * rhs.getValue();
1409   Type indexTy = IndexType::get(getContext());
1410   return IntegerAttr::get(indexTy, folded);
1411 }
1412 
1413 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1414     MLIRContext *context, Optional<Location> location, ValueRange operands,
1415     DictionaryAttr attributes, RegionRange regions,
1416     SmallVectorImpl<Type> &inferredReturnTypes) {
1417   if (operands[0].getType().isa<SizeType>() ||
1418       operands[1].getType().isa<SizeType>())
1419     inferredReturnTypes.assign({SizeType::get(context)});
1420   else
1421     inferredReturnTypes.assign({IndexType::get(context)});
1422   return success();
1423 }
1424 
1425 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1426   // SizeType is compatible with IndexType.
1427   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1428 }
1429 //===----------------------------------------------------------------------===//
1430 // ShapeOfOp
1431 //===----------------------------------------------------------------------===//
1432 
1433 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1434   auto type = getOperand().getType().dyn_cast<ShapedType>();
1435   if (!type || !type.hasStaticShape())
1436     return nullptr;
1437   Builder builder(getContext());
1438   return builder.getIndexTensorAttr(type.getShape());
1439 }
1440 
1441 namespace {
1442 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1443   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1444 
1445   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1446                                 PatternRewriter &rewriter) const override {
1447     if (!op.getArg().getType().isa<ShapedType>())
1448       return failure();
1449     if (op.getType().isa<ShapedType>())
1450       return failure();
1451 
1452     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
1453                                                   op.getArg());
1454     return success();
1455   }
1456 };
1457 
1458 // Canonicalize
1459 // ```
1460 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1461 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1462 // ```
1463 // to
1464 // ```
1465 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1466 // ```
1467 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1468   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1469 
1470   LogicalResult matchAndRewrite(tensor::CastOp op,
1471                                 PatternRewriter &rewriter) const override {
1472     auto ty = op.getType().dyn_cast<RankedTensorType>();
1473     if (!ty || ty.getRank() != 1)
1474       return failure();
1475 
1476     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1477     if (!shapeOfOp)
1478       return failure();
1479 
1480     // Argument type must be ranked and must not conflict.
1481     auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1482     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1483       return failure();
1484 
1485     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1486     return success();
1487   }
1488 };
1489 } // namespace
1490 
1491 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1492                                             MLIRContext *context) {
1493   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1494                ExtractFromShapeOfExtentTensor>(context);
1495 }
1496 
1497 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1498     MLIRContext *context, Optional<Location> location, ValueRange operands,
1499     DictionaryAttr attributes, RegionRange regions,
1500     SmallVectorImpl<Type> &inferredReturnTypes) {
1501   if (operands[0].getType().isa<ValueShapeType>())
1502     inferredReturnTypes.assign({ShapeType::get(context)});
1503   else {
1504     auto shapedTy = operands[0].getType().cast<ShapedType>();
1505     int64_t rank =
1506         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1507     Type indexTy = IndexType::get(context);
1508     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1509     inferredReturnTypes.assign({extentTensorTy});
1510   }
1511   return success();
1512 }
1513 
1514 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1515   if (l.size() != 1 || r.size() != 1)
1516     return false;
1517   if (l == r)
1518     return true;
1519 
1520   Type lhs = l.front();
1521   Type rhs = r.front();
1522 
1523   if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
1524     return false;
1525 
1526   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
1527     // Shape type is compatible with all other valid return types.
1528     return true;
1529 
1530   if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1531     return true;
1532   return false;
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // SizeToIndexOp
1537 //===----------------------------------------------------------------------===//
1538 
1539 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1540   // Constant values of both types, `shape.size` and `index`, are represented as
1541   // `IntegerAttr`s which makes constant folding simple.
1542   if (Attribute arg = operands[0])
1543     return arg;
1544   return impl::foldCastOp(*this);
1545 }
1546 
1547 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1548                                                 MLIRContext *context) {
1549   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // YieldOp
1554 //===----------------------------------------------------------------------===//
1555 
1556 static LogicalResult verify(shape::YieldOp op) {
1557   auto *parentOp = op->getParentOp();
1558   auto results = parentOp->getResults();
1559   auto operands = op.getOperands();
1560 
1561   if (parentOp->getNumResults() != op.getNumOperands())
1562     return op.emitOpError() << "number of operands does not match number of "
1563                                "results of its parent";
1564   for (auto e : llvm::zip(results, operands))
1565     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1566       return op.emitOpError()
1567              << "types mismatch between yield op and its parent";
1568 
1569   return success();
1570 }
1571 
1572 //===----------------------------------------------------------------------===//
1573 // SplitAtOp
1574 //===----------------------------------------------------------------------===//
1575 
1576 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1577                               SmallVectorImpl<OpFoldResult> &results) {
1578   if (!operands[0] || !operands[1])
1579     return failure();
1580   auto shapeVec = llvm::to_vector<6>(
1581       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1582   auto shape = llvm::makeArrayRef(shapeVec);
1583   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1584   // Verify that the split point is in the correct range.
1585   // TODO: Constant fold to an "error".
1586   int64_t rank = shape.size();
1587   if (!(-rank <= splitPoint && splitPoint <= rank))
1588     return failure();
1589   if (splitPoint < 0)
1590     splitPoint += shape.size();
1591   Builder builder(operands[0].getContext());
1592   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1593   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1594   return success();
1595 }
1596 
1597 //===----------------------------------------------------------------------===//
1598 // ToExtentTensorOp
1599 //===----------------------------------------------------------------------===//
1600 
1601 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1602   if (!operands[0])
1603     return impl::foldCastOp(*this);
1604   Builder builder(getContext());
1605   auto shape = llvm::to_vector<6>(
1606       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1607   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1608                                     builder.getIndexType());
1609   return DenseIntElementsAttr::get(type, shape);
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // ReduceOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1617                      ValueRange initVals) {
1618   result.addOperands(shape);
1619   result.addOperands(initVals);
1620 
1621   Region *bodyRegion = result.addRegion();
1622   bodyRegion->push_back(new Block);
1623   Block &bodyBlock = bodyRegion->front();
1624   bodyBlock.addArgument(builder.getIndexType());
1625 
1626   Type elementType;
1627   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1628     elementType = tensorType.getElementType();
1629   else
1630     elementType = SizeType::get(builder.getContext());
1631   bodyBlock.addArgument(elementType);
1632 
1633   for (Type initValType : initVals.getTypes()) {
1634     bodyBlock.addArgument(initValType);
1635     result.addTypes(initValType);
1636   }
1637 }
1638 
1639 static LogicalResult verify(ReduceOp op) {
1640   // Verify block arg types.
1641   Block &block = op.getRegion().front();
1642 
1643   // The block takes index, extent, and aggregated values as arguments.
1644   auto blockArgsCount = op.getInitVals().size() + 2;
1645   if (block.getNumArguments() != blockArgsCount)
1646     return op.emitOpError() << "ReduceOp body is expected to have "
1647                             << blockArgsCount << " arguments";
1648 
1649   // The first block argument is the index and must always be of type `index`.
1650   if (!block.getArgument(0).getType().isa<IndexType>())
1651     return op.emitOpError(
1652         "argument 0 of ReduceOp body is expected to be of IndexType");
1653 
1654   // The second block argument is the extent and must be of type `size` or
1655   // `index`, depending on whether the reduce operation is applied to a shape or
1656   // to an extent tensor.
1657   Type extentTy = block.getArgument(1).getType();
1658   if (op.getShape().getType().isa<ShapeType>()) {
1659     if (!extentTy.isa<SizeType>())
1660       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1661                             "SizeType if the ReduceOp operates on a ShapeType");
1662   } else {
1663     if (!extentTy.isa<IndexType>())
1664       return op.emitOpError(
1665           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1666           "ReduceOp operates on an extent tensor");
1667   }
1668 
1669   for (auto type : llvm::enumerate(op.getInitVals()))
1670     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1671       return op.emitOpError()
1672              << "type mismatch between argument " << type.index() + 2
1673              << " of ReduceOp body and initial value " << type.index();
1674   return success();
1675 }
1676 
1677 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1678   // Parse operands.
1679   SmallVector<OpAsmParser::OperandType, 3> operands;
1680   Type shapeOrExtentTensorType;
1681   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1682                               OpAsmParser::Delimiter::Paren) ||
1683       parser.parseColonType(shapeOrExtentTensorType) ||
1684       parser.parseOptionalArrowTypeList(result.types))
1685     return failure();
1686 
1687   // Resolve operands.
1688   auto initVals = llvm::makeArrayRef(operands).drop_front();
1689   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1690                             result.operands) ||
1691       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1692                              result.operands))
1693     return failure();
1694 
1695   // Parse the body.
1696   Region *body = result.addRegion();
1697   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1698     return failure();
1699 
1700   // Parse attributes.
1701   if (parser.parseOptionalAttrDict(result.attributes))
1702     return failure();
1703 
1704   return success();
1705 }
1706 
1707 static void print(OpAsmPrinter &p, ReduceOp op) {
1708   p << '(' << op.getShape() << ", " << op.getInitVals()
1709     << ") : " << op.getShape().getType();
1710   p.printOptionalArrowTypeList(op.getResultTypes());
1711   p.printRegion(op.getRegion());
1712   p.printOptionalAttrDict(op->getAttrs());
1713 }
1714 
1715 #define GET_OP_CLASSES
1716 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1717