1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/IR/Shape.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/Dialect/Traits.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 using namespace mlir;
25 using namespace mlir::shape;
26 
27 namespace {
28 #include "ShapeCanonicalization.inc"
29 }
30 
31 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
32   return RankedTensorType::get({rank}, IndexType::get(ctx));
33 }
34 
35 bool shape::isExtentTensorType(Type type) {
36   auto ranked = type.dyn_cast<RankedTensorType>();
37   return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
38 }
39 
40 LogicalResult shape::getShapeVec(Value input,
41                                  SmallVectorImpl<int64_t> &shapeValues) {
42   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
43     auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
44     if (!type.hasRank())
45       return failure();
46     shapeValues = llvm::to_vector<6>(type.getShape());
47     return success();
48   } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
49     shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
50     return success();
51   } else if (auto inputOp = input.getDefiningOp<ConstantOp>()) {
52     shapeValues = llvm::to_vector<6>(
53         inputOp.value().cast<DenseIntElementsAttr>().getValues<int64_t>());
54     return success();
55   } else {
56     return failure();
57   }
58 }
59 
60 static bool isErrorPropagationPossible(TypeRange operandTypes) {
61   return llvm::any_of(operandTypes, [](Type ty) {
62     return ty.isa<SizeType, ShapeType, ValueShapeType>();
63   });
64 }
65 
66 static LogicalResult verifySizeOrIndexOp(Operation *op) {
67   assert(op != nullptr && op->getNumResults() == 1);
68   Type resultTy = op->getResultTypes().front();
69   if (isErrorPropagationPossible(op->getOperandTypes())) {
70     if (!resultTy.isa<SizeType>())
71       return op->emitOpError()
72              << "if at least one of the operands can hold error values then "
73                 "the result must be of type `size` to propagate them";
74   }
75   return success();
76 }
77 
78 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
79   assert(op != nullptr && op->getNumResults() == 1);
80   Type resultTy = op->getResultTypes().front();
81   if (isErrorPropagationPossible(op->getOperandTypes())) {
82     if (!resultTy.isa<ShapeType>())
83       return op->emitOpError()
84              << "if at least one of the operands can hold error values then "
85                 "the result must be of type `shape` to propagate them";
86   }
87   return success();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // InlinerInterface
92 //===----------------------------------------------------------------------===//
93 
94 namespace {
95 /// This class defines the interface for inlining shape dialect ops.
96 struct ShapeInlinerInterface : public DialectInlinerInterface {
97   using DialectInlinerInterface::DialectInlinerInterface;
98 
99   // Returns true if the given region 'src' can be inlined into the region
100   // 'dest' that is attached to an operation registered to the current dialect.
101   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
102                        BlockAndValueMapping &) const final {
103     return true;
104   }
105 
106   // Returns true if the given operation 'op', that is registered to this
107   // dialect, can be inlined into the region 'dest' that is attached to an
108   // operation registered to the current dialect.
109   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
110                        BlockAndValueMapping &) const final {
111     return true;
112   }
113 };
114 } // namespace
115 
116 void ShapeDialect::initialize() {
117   addOperations<
118 #define GET_OP_LIST
119 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
120       >();
121   addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
122   addInterfaces<ShapeInlinerInterface>();
123   // Allow unknown operations during prototyping and testing. As the dialect is
124   // still evolving it makes it simple to start with an unregistered ops and
125   // try different variants before actually defining the op.
126   allowUnknownOperations();
127 }
128 
129 Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
130                                              Attribute value, Type type,
131                                              Location loc) {
132   if (type.isa<ShapeType>() || isExtentTensorType(type))
133     return builder.create<ConstShapeOp>(loc, type,
134                                         value.cast<DenseIntElementsAttr>());
135   if (type.isa<SizeType>())
136     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
137   if (type.isa<WitnessType>())
138     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
139   if (ConstantOp::isBuildableWith(value, type))
140     return builder.create<ConstantOp>(loc, type, value);
141   return nullptr;
142 }
143 
144 /// Parse a type registered to this dialect.
145 Type ShapeDialect::parseType(DialectAsmParser &parser) const {
146   StringRef keyword;
147   if (parser.parseKeyword(&keyword))
148     return Type();
149 
150   if (keyword == "shape")
151     return ShapeType::get(getContext());
152   if (keyword == "size")
153     return SizeType::get(getContext());
154   if (keyword == "value_shape")
155     return ValueShapeType::get(getContext());
156   if (keyword == "witness")
157     return WitnessType::get(getContext());
158 
159   parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
160   return Type();
161 }
162 
163 /// Print a type registered to this dialect.
164 void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
165   TypeSwitch<Type>(type)
166       .Case<ShapeType>([&](Type) { os << "shape"; })
167       .Case<SizeType>([&](Type) { os << "size"; })
168       .Case<ValueShapeType>([&](Type) { os << "value_shape"; })
169       .Case<WitnessType>([&](Type) { os << "witness"; })
170       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
171 }
172 
173 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
174                                                      NamedAttribute attribute) {
175   // Verify shape.lib attribute.
176   if (attribute.first == "shape.lib") {
177     if (!op->hasTrait<OpTrait::SymbolTable>())
178       return op->emitError(
179           "shape.lib attribute may only be on op implementing SymbolTable");
180 
181     if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
182       auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
183       if (!symbol)
184         return op->emitError("shape function library ")
185                << symbolRef << " not found";
186       return isa<shape::FunctionLibraryOp>(symbol)
187                  ? success()
188                  : op->emitError()
189                        << symbolRef << " required to be shape function library";
190     }
191 
192     if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
193       // Verify all entries are function libraries and mappings in libraries
194       // refer to unique ops.
195       DenseSet<Identifier> key;
196       for (auto it : arr) {
197         if (!it.isa<SymbolRefAttr>())
198           return op->emitError(
199               "only SymbolRefAttr allowed in shape.lib attribute array");
200 
201         auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
202             SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
203         if (!shapeFnLib)
204           return op->emitError()
205                  << it << " does not refer to FunctionLibraryOp";
206         for (auto mapping : shapeFnLib.mapping()) {
207           if (!key.insert(mapping.first).second) {
208             return op->emitError("only one op to shape mapping allowed, found "
209                                  "multiple for `")
210                    << mapping.first << "`";
211           }
212         }
213       }
214       return success();
215     }
216 
217     return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
218                          "allowed as shape.lib attribute");
219   }
220   return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // AnyOp
225 //===----------------------------------------------------------------------===//
226 
227 // TODO: Canonicalization should be implemented for shapes that can be
228 // determined through mixtures of the known dimensions of the inputs.
229 OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
230   // Only the last operand is checked because AnyOp is commutative.
231   if (operands.back())
232     return operands.back();
233 
234   return nullptr;
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // AssumingOp
239 //===----------------------------------------------------------------------===//
240 
241 static ParseResult parseAssumingOp(OpAsmParser &parser,
242                                    OperationState &result) {
243   result.regions.reserve(1);
244   Region *doRegion = result.addRegion();
245 
246   auto &builder = parser.getBuilder();
247   OpAsmParser::OperandType cond;
248   if (parser.parseOperand(cond) ||
249       parser.resolveOperand(cond, builder.getType<WitnessType>(),
250                             result.operands))
251     return failure();
252 
253   // Parse optional results type list.
254   if (parser.parseOptionalArrowTypeList(result.types))
255     return failure();
256 
257   // Parse the region and add a terminator if elided.
258   if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
259     return failure();
260   AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
261 
262   // Parse the optional attribute list.
263   if (parser.parseOptionalAttrDict(result.attributes))
264     return failure();
265   return success();
266 }
267 
268 static void print(OpAsmPrinter &p, AssumingOp op) {
269   bool yieldsResults = !op.results().empty();
270 
271   p << AssumingOp::getOperationName() << " " << op.witness();
272   if (yieldsResults) {
273     p << " -> (" << op.getResultTypes() << ")";
274   }
275   p.printRegion(op.doRegion(),
276                 /*printEntryBlockArgs=*/false,
277                 /*printBlockTerminators=*/yieldsResults);
278   p.printOptionalAttrDict(op->getAttrs());
279 }
280 
281 namespace {
282 // Removes AssumingOp with a passing witness and inlines the region.
283 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
284   using OpRewritePattern<AssumingOp>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(AssumingOp op,
287                                 PatternRewriter &rewriter) const override {
288     auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
289     if (!witness || !witness.passingAttr())
290       return failure();
291 
292     AssumingOp::inlineRegionIntoParent(op, rewriter);
293     return success();
294   }
295 };
296 
297 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
298   using OpRewritePattern<AssumingOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(AssumingOp op,
301                                 PatternRewriter &rewriter) const override {
302     Block *body = op.getBody();
303     auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
304 
305     // Find used values.
306     SmallVector<Value, 4> newYieldOperands;
307     Value opResult, yieldOperand;
308     for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) {
309       std::tie(opResult, yieldOperand) = it;
310       if (!opResult.getUses().empty()) {
311         newYieldOperands.push_back(yieldOperand);
312       }
313     }
314 
315     // Rewrite only if redundant results exist.
316     if (newYieldOperands.size() == yieldOp->getNumOperands())
317       return failure();
318 
319     // Replace yield op in the old assuming op's body and move the entire region
320     // to the new assuming op.
321     rewriter.setInsertionPointToEnd(body);
322     auto newYieldOp =
323         rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
324     rewriter.setInsertionPoint(op);
325     auto newOp = rewriter.create<AssumingOp>(
326         op.getLoc(), newYieldOp->getOperandTypes(), op.witness());
327     newOp.doRegion().takeBody(op.doRegion());
328 
329     // Use the new results to replace the previously used ones.
330     SmallVector<Value, 4> replacementValues;
331     auto src = newOp.getResults().begin();
332     for (auto it : op.getResults()) {
333       if (it.getUses().empty())
334         replacementValues.push_back(nullptr);
335       else
336         replacementValues.push_back(*src++);
337     }
338     rewriter.replaceOp(op, replacementValues);
339     return success();
340   }
341 };
342 } // namespace
343 
344 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
345                                              MLIRContext *context) {
346   patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
347 }
348 
349 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
350 void AssumingOp::getSuccessorRegions(
351     Optional<unsigned> index, ArrayRef<Attribute> operands,
352     SmallVectorImpl<RegionSuccessor> &regions) {
353   // AssumingOp has unconditional control flow into the region and back to the
354   // parent, so return the correct RegionSuccessor purely based on the index
355   // being None or 0.
356   if (index.hasValue()) {
357     regions.push_back(RegionSuccessor(getResults()));
358     return;
359   }
360 
361   regions.push_back(RegionSuccessor(&doRegion()));
362 }
363 
364 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
365                                         PatternRewriter &rewriter) {
366   auto *blockBeforeAssuming = rewriter.getInsertionBlock();
367   auto *assumingBlock = op.getBody();
368   auto initPosition = rewriter.getInsertionPoint();
369   auto *blockAfterAssuming =
370       rewriter.splitBlock(blockBeforeAssuming, initPosition);
371 
372   // Remove the AssumingOp and AssumingYieldOp.
373   auto &yieldOp = assumingBlock->back();
374   rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
375   rewriter.replaceOp(op, yieldOp.getOperands());
376   rewriter.eraseOp(&yieldOp);
377 
378   // Merge blocks together as there was no branching behavior from the
379   // AssumingOp.
380   rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
381   rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
382 }
383 
384 void AssumingOp::build(
385     OpBuilder &builder, OperationState &result, Value witness,
386     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
387 
388   result.addOperands(witness);
389   Region *bodyRegion = result.addRegion();
390   bodyRegion->push_back(new Block);
391   Block &bodyBlock = bodyRegion->front();
392 
393   // Build body.
394   OpBuilder::InsertionGuard guard(builder);
395   builder.setInsertionPointToStart(&bodyBlock);
396   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
397   builder.create<AssumingYieldOp>(result.location, yieldValues);
398 
399   SmallVector<Type, 2> assumingTypes;
400   for (Value v : yieldValues)
401     assumingTypes.push_back(v.getType());
402   result.addTypes(assumingTypes);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // AssumingAllOp
407 //===----------------------------------------------------------------------===//
408 
409 namespace {
410 struct AssumingAllToCstrEqCanonicalization
411     : public OpRewritePattern<AssumingAllOp> {
412   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
413 
414   LogicalResult matchAndRewrite(AssumingAllOp op,
415                                 PatternRewriter &rewriter) const override {
416     SmallVector<Value, 8> shapes;
417     for (Value w : op.inputs()) {
418       auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
419       if (!cstrEqOp)
420         return failure();
421       bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) {
422         return llvm::is_contained(shapes, s);
423       });
424       if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes)
425         return failure();
426       shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end());
427     }
428     rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
429     return success();
430   }
431 };
432 
433 template <typename OpTy>
434 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
435   using OpRewritePattern<OpTy>::OpRewritePattern;
436 
437   LogicalResult matchAndRewrite(OpTy op,
438                                 PatternRewriter &rewriter) const override {
439     // Find unique operands.
440     SmallVector<Value, 2> unique;
441     for (Value v : op.getOperands()) {
442       if (!llvm::is_contained(unique, v))
443         unique.push_back(v);
444     }
445 
446     // Reduce op to equivalent with unique operands.
447     if (unique.size() < op.getNumOperands()) {
448       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
449                                         op->getAttrs());
450       return success();
451     }
452 
453     return failure();
454   }
455 };
456 } // namespace
457 
458 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
459                                                 MLIRContext *context) {
460   patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
461                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
462 }
463 
464 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
465   // Iterate in reverse to first handle all constant operands. They are
466   // guaranteed to be the tail of the inputs because this is commutative.
467   for (int idx = operands.size() - 1; idx >= 0; idx--) {
468     Attribute a = operands[idx];
469     // Cannot fold if any inputs are not constant;
470     if (!a)
471       return nullptr;
472 
473     // We do not need to keep statically known values after handling them in
474     // this method.
475     getOperation()->eraseOperand(idx);
476 
477     // Always false if any input is statically known false
478     if (!a.cast<BoolAttr>().getValue())
479       return a;
480   }
481   // If this is reached, all inputs were statically known passing.
482   return BoolAttr::get(getContext(), true);
483 }
484 
485 static LogicalResult verify(AssumingAllOp op) {
486   // Ensure that AssumingAllOp contains at least one operand
487   if (op.getNumOperands() == 0)
488     return op.emitOpError("no operands specified");
489 
490   return success();
491 }
492 
493 void AssumingAllOp::build(OpBuilder &b, OperationState &state,
494                           ValueRange inputs) {
495   build(b, state, b.getType<WitnessType>(), inputs);
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // BroadcastOp
500 //===----------------------------------------------------------------------===//
501 
502 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
503   if (shapes().size() == 1) {
504     // Otherwise, we need a cast which would be a canonicalization, not folding.
505     if (shapes().front().getType() != getType())
506       return nullptr;
507     return shapes().front();
508   }
509 
510   // TODO: Support folding with more than 2 input shapes
511   if (shapes().size() > 2)
512     return nullptr;
513 
514   if (!operands[0] || !operands[1])
515     return nullptr;
516   auto lhsShape = llvm::to_vector<6>(
517       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
518   auto rhsShape = llvm::to_vector<6>(
519       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
520   SmallVector<int64_t, 6> resultShape;
521 
522   // If the shapes are not compatible, we can't fold it.
523   // TODO: Fold to an "error".
524   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
525     return nullptr;
526 
527   Builder builder(getContext());
528   return builder.getIndexTensorAttr(resultShape);
529 }
530 
531 static LogicalResult verify(BroadcastOp op) {
532   return verifyShapeOrExtentTensorOp(op);
533 }
534 
535 namespace {
536 template <typename OpTy>
537 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
538   using OpRewritePattern<OpTy>::OpRewritePattern;
539 
540   LogicalResult matchAndRewrite(OpTy op,
541                                 PatternRewriter &rewriter) const override {
542     auto isPotentiallyNonEmptyShape = [](Value shape) {
543       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
544         if (extentTensorTy.getDimSize(0) == 0)
545           return false;
546       }
547       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
548         if (constShape.shape().empty())
549           return false;
550       }
551       return true;
552     };
553     auto newOperands = llvm::to_vector<8>(
554         llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
555 
556     // Reduce op to equivalent without empty shape operands.
557     if (newOperands.size() < op.getNumOperands()) {
558       rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
559                                         op->getAttrs());
560       return success();
561     }
562 
563     return failure();
564   }
565 };
566 
567 struct BroadcastForwardSingleOperandPattern
568     : public OpRewritePattern<BroadcastOp> {
569   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
570 
571   LogicalResult matchAndRewrite(BroadcastOp op,
572                                 PatternRewriter &rewriter) const override {
573     if (op.getNumOperands() != 1)
574       return failure();
575     Value replacement = op.shapes().front();
576 
577     // Insert cast if needed.
578     if (replacement.getType() != op.getType()) {
579       auto loc = op.getLoc();
580       if (op.getType().isa<ShapeType>()) {
581         replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
582       } else {
583         assert(!op.getType().isa<ShapeType>() &&
584                !replacement.getType().isa<ShapeType>() &&
585                "expect extent tensor cast");
586         replacement =
587             rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
588       }
589     }
590 
591     rewriter.replaceOp(op, replacement);
592     return success();
593   }
594 };
595 
596 struct BroadcastFoldConstantOperandsPattern
597     : public OpRewritePattern<BroadcastOp> {
598   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
599 
600   LogicalResult matchAndRewrite(BroadcastOp op,
601                                 PatternRewriter &rewriter) const override {
602     SmallVector<int64_t, 8> foldedConstantShape;
603     SmallVector<Value, 8> newShapeOperands;
604     for (Value shape : op.shapes()) {
605       if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
606         SmallVector<int64_t, 8> newFoldedConstantShape;
607         if (OpTrait::util::getBroadcastedShape(
608                 foldedConstantShape,
609                 llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
610                 newFoldedConstantShape)) {
611           foldedConstantShape = newFoldedConstantShape;
612           continue;
613         }
614       }
615       newShapeOperands.push_back(shape);
616     }
617 
618     // Need at least two constant operands to fold anything.
619     if (op.getNumOperands() - newShapeOperands.size() < 2)
620       return failure();
621 
622     auto foldedConstantOperandsTy = RankedTensorType::get(
623         {static_cast<int64_t>(foldedConstantShape.size())},
624         rewriter.getIndexType());
625     newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
626         op.getLoc(), foldedConstantOperandsTy,
627         rewriter.getIndexTensorAttr(foldedConstantShape)));
628     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
629                                              newShapeOperands);
630     return success();
631   }
632 };
633 
634 template <typename OpTy>
635 struct CanonicalizeCastExtentTensorOperandsPattern
636     : public OpRewritePattern<OpTy> {
637   using OpRewritePattern<OpTy>::OpRewritePattern;
638 
639   LogicalResult matchAndRewrite(OpTy op,
640                                 PatternRewriter &rewriter) const override {
641     // Canonicalize operands.
642     bool anyChange = false;
643     auto canonicalizeOperand = [&](Value operand) {
644       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
645         // Only eliminate the cast if it holds no shape information.
646         bool isInformationLoosingCast =
647             castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
648         if (isInformationLoosingCast) {
649           anyChange = true;
650           return castOp.source();
651         }
652       }
653       return operand;
654     };
655     auto newOperands = llvm::to_vector<8>(
656         llvm::map_range(op.getOperands(), canonicalizeOperand));
657 
658     // Rewrite op if any change required.
659     if (!anyChange)
660       return failure();
661     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
662     return success();
663   }
664 };
665 
666 struct BroadcastConcretizeResultTypePattern
667     : public OpRewritePattern<BroadcastOp> {
668   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
669 
670   LogicalResult matchAndRewrite(BroadcastOp op,
671                                 PatternRewriter &rewriter) const override {
672     // Only concretize dynamic extent tensor result types.
673     auto resultTy = op.getType().dyn_cast<RankedTensorType>();
674     if (!resultTy || !resultTy.isDynamicDim(0))
675       return failure();
676 
677     // Infer resulting shape rank if possible.
678     int64_t maxRank = 0;
679     for (Value shape : op.shapes()) {
680       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
681         // Cannot infer resulting shape rank if any operand is dynamically
682         // ranked.
683         if (extentTensorTy.isDynamicDim(0))
684           return failure();
685         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
686       }
687     }
688 
689     auto newOp = rewriter.create<BroadcastOp>(
690         op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
691     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
692     return success();
693   }
694 };
695 } // namespace
696 
697 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
698                                               MLIRContext *context) {
699   patterns.add<BroadcastConcretizeResultTypePattern,
700                BroadcastFoldConstantOperandsPattern,
701                BroadcastForwardSingleOperandPattern,
702                CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
703                RemoveDuplicateOperandsPattern<BroadcastOp>,
704                RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // ConcatOp
709 //===----------------------------------------------------------------------===//
710 
711 OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
712   if (!operands[0] || !operands[1])
713     return nullptr;
714   auto lhsShape = llvm::to_vector<6>(
715       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
716   auto rhsShape = llvm::to_vector<6>(
717       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
718   SmallVector<int64_t, 6> resultShape;
719   resultShape.append(lhsShape.begin(), lhsShape.end());
720   resultShape.append(rhsShape.begin(), rhsShape.end());
721   Builder builder(getContext());
722   return builder.getIndexTensorAttr(resultShape);
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // ConstShapeOp
727 //===----------------------------------------------------------------------===//
728 
729 static void print(OpAsmPrinter &p, ConstShapeOp &op) {
730   p << "shape.const_shape ";
731   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
732   p << "[";
733   interleaveComma(op.shape().getValues<int64_t>(), p,
734                   [&](int64_t i) { p << i; });
735   p << "] : ";
736   p.printType(op.getType());
737 }
738 
739 static ParseResult parseConstShapeOp(OpAsmParser &parser,
740                                      OperationState &result) {
741   if (parser.parseOptionalAttrDict(result.attributes))
742     return failure();
743   // We piggy-back on ArrayAttr parsing, though we don't internally store the
744   // shape as an ArrayAttr.
745   // TODO: Implement custom parser and maybe make syntax a bit more concise.
746   Attribute extentsRaw;
747   NamedAttrList dummy;
748   if (parser.parseAttribute(extentsRaw, "dummy", dummy))
749     return failure();
750   auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
751   if (!extentsArray)
752     return failure();
753   SmallVector<int64_t, 6> ints;
754   for (Attribute extent : extentsArray) {
755     IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
756     if (!attr)
757       return failure();
758     ints.push_back(attr.getInt());
759   }
760   Builder &builder = parser.getBuilder();
761   result.addAttribute("shape", builder.getIndexTensorAttr(ints));
762   Type resultTy;
763   if (parser.parseColonType(resultTy))
764     return failure();
765   result.types.push_back(resultTy);
766   return success();
767 }
768 
769 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
770 
771 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
772                                                MLIRContext *context) {
773   patterns.add<TensorCastConstShape>(context);
774 }
775 
776 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
777     MLIRContext *context, Optional<Location> location, ValueRange operands,
778     DictionaryAttr attributes, RegionRange regions,
779     SmallVectorImpl<Type> &inferredReturnTypes) {
780   Builder b(context);
781   auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
782   if (!shape)
783     return emitOptionalError(location, "missing shape attribute");
784   inferredReturnTypes.assign({RankedTensorType::get(
785       {static_cast<int64_t>(shape.size())}, b.getIndexType())});
786   return success();
787 }
788 
789 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
790                                                         TypeRange r) {
791   if (l.size() != 1 || r.size() != 1)
792     return false;
793 
794   Type lhs = l.front();
795   Type rhs = r.front();
796 
797   if (lhs == rhs)
798     return true;
799 
800   if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
801     // Shape type is compatible with all other valid return types.
802     return true;
803 
804   return succeeded(verifyCompatibleShapes(lhs, rhs));
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // CstrBroadcastableOp
809 //===----------------------------------------------------------------------===//
810 
811 void CstrBroadcastableOp::getCanonicalizationPatterns(
812     RewritePatternSet &patterns, MLIRContext *context) {
813   // Canonicalization patterns have overlap with the considerations during
814   // folding in case additional shape information is inferred at some point that
815   // does not result in folding.
816   patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
817                CstrBroadcastableEqOps,
818                RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
819                RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
820 }
821 
822 // Return true if there is exactly one attribute not representing a scalar
823 // broadcast.
824 static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
825   bool nonScalarSeen = false;
826   for (Attribute a : attributes) {
827     if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
828       if (nonScalarSeen)
829         return false;
830       nonScalarSeen = true;
831     }
832   }
833   return true;
834 }
835 
836 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
837   // No broadcasting is needed if all operands but one are scalar.
838   if (hasAtMostSingleNonScalar(operands))
839     return BoolAttr::get(getContext(), true);
840 
841   if ([&] {
842         SmallVector<SmallVector<int64_t, 6>, 6> extents;
843         for (const auto &operand : operands) {
844           if (!operand)
845             return false;
846           extents.push_back(llvm::to_vector<6>(
847               operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
848         }
849         return OpTrait::util::staticallyKnownBroadcastable(extents);
850       }())
851     return BoolAttr::get(getContext(), true);
852 
853   // Lastly, see if folding can be completed based on what constraints are known
854   // on the input shapes.
855   if ([&] {
856         SmallVector<SmallVector<int64_t, 6>, 6> extents;
857         for (auto shapeValue : shapes()) {
858           extents.emplace_back();
859           if (failed(getShapeVec(shapeValue, extents.back())))
860             return false;
861         }
862         return OpTrait::util::staticallyKnownBroadcastable(extents);
863       }())
864     return BoolAttr::get(getContext(), true);
865 
866   // Because a failing witness result here represents an eventual assertion
867   // failure, we do not replace it with a constant witness.
868   return nullptr;
869 }
870 
871 static LogicalResult verify(CstrBroadcastableOp op) {
872   // Ensure that AssumingAllOp contains at least one operand
873   if (op.getNumOperands() < 2)
874     return op.emitOpError("required at least 2 input shapes");
875   return success();
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // CstrEqOp
880 //===----------------------------------------------------------------------===//
881 
882 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
883                                            MLIRContext *context) {
884   // If inputs are equal, return passing witness
885   patterns.add<CstrEqEqOps>(context);
886 }
887 
888 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
889   if (llvm::all_of(operands,
890                    [&](Attribute a) { return a && a == operands[0]; }))
891     return BoolAttr::get(getContext(), true);
892 
893   // Because a failing witness result here represents an eventual assertion
894   // failure, we do not try to replace it with a constant witness. Similarly, we
895   // cannot if there are any non-const inputs.
896   return nullptr;
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // ConstSizeOp
901 //===----------------------------------------------------------------------===//
902 
903 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
904                         int64_t value) {
905   build(builder, result, builder.getIndexAttr(value));
906 }
907 
908 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
909 
910 void ConstSizeOp::getAsmResultNames(
911     llvm::function_ref<void(Value, StringRef)> setNameFn) {
912   SmallString<4> buffer;
913   llvm::raw_svector_ostream os(buffer);
914   os << "c" << value();
915   setNameFn(getResult(), os.str());
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // ConstWitnessOp
920 //===----------------------------------------------------------------------===//
921 
922 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
923 
924 //===----------------------------------------------------------------------===//
925 // CstrRequireOp
926 //===----------------------------------------------------------------------===//
927 
928 OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
929   return operands[0];
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // DivOp
934 //===----------------------------------------------------------------------===//
935 
936 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
937   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
938   if (!lhs)
939     return nullptr;
940   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
941   if (!rhs)
942     return nullptr;
943 
944   // Division in APInt does not follow floor(lhs, rhs) when the result is
945   // negative. Rather, APInt rounds toward zero.
946   APInt quotient, remainder;
947   APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
948   if (quotient.isNegative() && !remainder.isNullValue()) {
949     quotient -= 1;
950   }
951 
952   Type indexTy = IndexType::get(getContext());
953   return IntegerAttr::get(indexTy, quotient);
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // ShapeEqOp
958 //===----------------------------------------------------------------------===//
959 
960 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
961   bool allSame = true;
962   if (!operands.empty() && !operands[0])
963     return {};
964   for (Attribute operand : operands.drop_front(1)) {
965     if (!operand)
966       return {};
967     allSame = allSame && operand == operands[0];
968   }
969   return BoolAttr::get(getContext(), allSame);
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // IndexToSizeOp
974 //===----------------------------------------------------------------------===//
975 
976 OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
977   // Constant values of both types, `shape.size` and `index`, are represented as
978   // `IntegerAttr`s which makes constant folding simple.
979   if (Attribute arg = operands[0])
980     return arg;
981   return {};
982 }
983 
984 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
985                                                 MLIRContext *context) {
986   patterns.add<SizeToIndexToSizeCanonicalization>(context);
987 }
988 
989 //===----------------------------------------------------------------------===//
990 // FromExtentsOp
991 //===----------------------------------------------------------------------===//
992 
993 OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
994   if (llvm::any_of(operands, [](Attribute a) { return !a; }))
995     return nullptr;
996   SmallVector<int64_t, 6> extents;
997   for (auto attr : operands)
998     extents.push_back(attr.cast<IntegerAttr>().getInt());
999   Builder builder(getContext());
1000   return builder.getIndexTensorAttr(extents);
1001 }
1002 
1003 //===----------------------------------------------------------------------===//
1004 // FunctionLibraryOp
1005 //===----------------------------------------------------------------------===//
1006 
1007 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1008                               StringRef name) {
1009   result.attributes.push_back(builder.getNamedAttr(
1010       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1011 }
1012 
1013 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1014   auto attr = mapping()
1015                   .get(op->getName().getIdentifier())
1016                   .dyn_cast_or_null<FlatSymbolRefAttr>();
1017   if (!attr)
1018     return nullptr;
1019   return lookupSymbol<FuncOp>(attr);
1020 }
1021 
1022 ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
1023                                    OperationState &result) {
1024   // Parse the op name.
1025   StringAttr nameAttr;
1026   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1027                              result.attributes))
1028     return failure();
1029 
1030   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1031     return failure();
1032 
1033   auto *bodyRegion = result.addRegion();
1034   if (parser.parseRegion(*bodyRegion))
1035     return failure();
1036 
1037   if (parser.parseKeyword("mapping"))
1038     return failure();
1039 
1040   DictionaryAttr mappingAttr;
1041   if (parser.parseAttribute(mappingAttr,
1042                             parser.getBuilder().getType<NoneType>(), "mapping",
1043                             result.attributes))
1044     return failure();
1045   return success();
1046 }
1047 
1048 void print(OpAsmPrinter &p, FunctionLibraryOp op) {
1049   p << op.getOperationName() << ' ';
1050   p.printSymbolName(op.getName());
1051   p.printOptionalAttrDictWithKeyword(
1052       op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
1053   p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
1054                 /*printBlockTerminators=*/false);
1055   p << " mapping ";
1056   p.printAttributeWithoutType(op.mappingAttr());
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // GetExtentOp
1061 //===----------------------------------------------------------------------===//
1062 
1063 Optional<int64_t> GetExtentOp::getConstantDim() {
1064   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
1065     return constSizeOp.value().getLimitedValue();
1066   if (auto constantOp = dim().getDefiningOp<ConstantOp>())
1067     return constantOp.value().cast<IntegerAttr>().getInt();
1068   return llvm::None;
1069 }
1070 
1071 OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
1072   auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1073   if (!elements)
1074     return nullptr;
1075   Optional<int64_t> dim = getConstantDim();
1076   if (!dim.hasValue())
1077     return nullptr;
1078   if (dim.getValue() >= elements.getNumElements())
1079     return nullptr;
1080   return elements.getValue({(uint64_t)dim.getValue()});
1081 }
1082 
1083 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1084                         int64_t dim) {
1085   auto loc = result.location;
1086   auto dimAttr = builder.getIndexAttr(dim);
1087   if (shape.getType().isa<ShapeType>()) {
1088     Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1089     build(builder, result, builder.getType<SizeType>(), shape, dim);
1090   } else {
1091     Value dim =
1092         builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
1093     build(builder, result, builder.getIndexType(), shape, dim);
1094   }
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // IsBroadcastableOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1102                                                     MLIRContext *context) {
1103   patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1104 }
1105 
1106 OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
1107   // Can always broadcast fewer than two shapes.
1108   if (operands.size() < 2) {
1109     return BoolAttr::get(getContext(), true);
1110   }
1111 
1112   return nullptr;
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // RankOp
1117 //===----------------------------------------------------------------------===//
1118 
1119 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
1120   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1121   if (!shape)
1122     return {};
1123   int64_t rank = shape.getNumElements();
1124   Builder builder(getContext());
1125   return builder.getIndexAttr(rank);
1126 }
1127 
1128 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1129 /// Constant folding fails in cases where only the rank is constant, not the
1130 /// shape itself.
1131 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1132 ///
1133 /// Example:
1134 ///
1135 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1136 /// %rank = shape.rank %shape
1137 ///
1138 /// becomes
1139 ///
1140 /// %rank = shape.const_size 3
1141 
1142 namespace {
1143 struct RankShapeOfCanonicalizationPattern
1144     : public OpRewritePattern<shape::RankOp> {
1145   using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1146 
1147   LogicalResult matchAndRewrite(shape::RankOp op,
1148                                 PatternRewriter &rewriter) const override {
1149     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
1150     if (!shapeOfOp)
1151       return failure();
1152     auto rankedTensorType =
1153         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1154     if (!rankedTensorType)
1155       return failure();
1156     int64_t rank = rankedTensorType.getRank();
1157     if (op.getType().isa<IndexType>()) {
1158       rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
1159     } else if (op.getType().isa<shape::SizeType>()) {
1160       rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1161     } else {
1162       return failure();
1163     }
1164     return success();
1165   }
1166 };
1167 } // namespace
1168 
1169 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1170                                                 MLIRContext *context) {
1171   patterns.add<RankShapeOfCanonicalizationPattern>(context);
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // NumElementsOp
1176 //===----------------------------------------------------------------------===//
1177 
1178 OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
1179 
1180   // Fold only when argument constant.
1181   Attribute shape = operands[0];
1182   if (!shape)
1183     return {};
1184 
1185   APInt product(64, 1);
1186   for (auto value : shape.cast<DenseIntElementsAttr>())
1187     product *= value;
1188   Builder builder(getContext());
1189   return builder.getIndexAttr(product.getLimitedValue());
1190 }
1191 
1192 void NumElementsOp::build(OpBuilder &builder, OperationState &result,
1193                           Value shape) {
1194   if (shape.getType().isa<ShapedType>()) {
1195     auto type = builder.getIndexType();
1196     return build(builder, result, type, shape);
1197   }
1198   auto type = SizeType::get(builder.getContext());
1199   return build(builder, result, type, shape);
1200 }
1201 
1202 //===----------------------------------------------------------------------===//
1203 // MaxOp
1204 //===----------------------------------------------------------------------===//
1205 
1206 OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1207   // If operands are equal, just propagate one.
1208   if (lhs() == rhs())
1209     return lhs();
1210   return nullptr;
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // MinOp
1215 //===----------------------------------------------------------------------===//
1216 
1217 OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
1218   // If operands are equal, just propagate one.
1219   if (lhs() == rhs())
1220     return lhs();
1221   return nullptr;
1222 }
1223 
1224 //===----------------------------------------------------------------------===//
1225 // MulOp
1226 //===----------------------------------------------------------------------===//
1227 
1228 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
1229   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1230   if (!lhs)
1231     return nullptr;
1232   auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1233   if (!rhs)
1234     return nullptr;
1235   APInt folded = lhs.getValue() * rhs.getValue();
1236   Type indexTy = IndexType::get(getContext());
1237   return IntegerAttr::get(indexTy, folded);
1238 }
1239 
1240 //===----------------------------------------------------------------------===//
1241 // ShapeOfOp
1242 //===----------------------------------------------------------------------===//
1243 
1244 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
1245   auto type = getOperand().getType().dyn_cast<ShapedType>();
1246   if (!type || !type.hasStaticShape())
1247     return nullptr;
1248   Builder builder(getContext());
1249   return builder.getIndexTensorAttr(type.getShape());
1250 }
1251 
1252 void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
1253   if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
1254     int64_t rank =
1255         shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1256     Type indexTy = builder.getIndexType();
1257     Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1258     return ShapeOfOp::build(builder, result, extentTensorTy, arg);
1259   }
1260   Type shapeTy = builder.getType<ShapeType>();
1261   return ShapeOfOp::build(builder, result, shapeTy, arg);
1262 }
1263 
1264 namespace {
1265 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
1266   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1267 
1268   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1269                                 PatternRewriter &rewriter) const override {
1270     if (!op.arg().getType().isa<ShapedType>())
1271       return failure();
1272     if (op.getType().isa<ShapedType>())
1273       return failure();
1274 
1275     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
1276     return success();
1277   }
1278 };
1279 
1280 // Canonicalize
1281 // ```
1282 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1283 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1284 // ```
1285 // to
1286 // ```
1287 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1288 // ```
1289 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1290   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1291 
1292   LogicalResult matchAndRewrite(tensor::CastOp op,
1293                                 PatternRewriter &rewriter) const override {
1294     auto ty = op.getType().dyn_cast<RankedTensorType>();
1295     if (!ty || ty.getRank() != 1)
1296       return failure();
1297 
1298     auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
1299     if (!shapeOfOp)
1300       return failure();
1301 
1302     // Argument type must be ranked and must not conflict.
1303     auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
1304     if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1305       return failure();
1306 
1307     rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
1308     return success();
1309   }
1310 };
1311 } // namespace
1312 
1313 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1314                                             MLIRContext *context) {
1315   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
1316 }
1317 
1318 //===----------------------------------------------------------------------===//
1319 // SizeToIndexOp
1320 //===----------------------------------------------------------------------===//
1321 
1322 OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
1323   // Constant values of both types, `shape.size` and `index`, are represented as
1324   // `IntegerAttr`s which makes constant folding simple.
1325   if (Attribute arg = operands[0])
1326     return arg;
1327   return impl::foldCastOp(*this);
1328 }
1329 
1330 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1331                                                 MLIRContext *context) {
1332   patterns.add<IndexToSizeToIndexCanonicalization>(context);
1333 }
1334 
1335 //===----------------------------------------------------------------------===//
1336 // YieldOp
1337 //===----------------------------------------------------------------------===//
1338 
1339 static LogicalResult verify(shape::YieldOp op) {
1340   auto *parentOp = op->getParentOp();
1341   auto results = parentOp->getResults();
1342   auto operands = op.getOperands();
1343 
1344   if (parentOp->getNumResults() != op.getNumOperands())
1345     return op.emitOpError() << "number of operands does not match number of "
1346                                "results of its parent";
1347   for (auto e : llvm::zip(results, operands))
1348     if (std::get<0>(e).getType() != std::get<1>(e).getType())
1349       return op.emitOpError()
1350              << "types mismatch between yield op and its parent";
1351 
1352   return success();
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // SplitAtOp
1357 //===----------------------------------------------------------------------===//
1358 
1359 LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
1360                               SmallVectorImpl<OpFoldResult> &results) {
1361   if (!operands[0] || !operands[1])
1362     return failure();
1363   auto shapeVec = llvm::to_vector<6>(
1364       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1365   auto shape = llvm::makeArrayRef(shapeVec);
1366   auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1367   // Verify that the split point is in the correct range.
1368   // TODO: Constant fold to an "error".
1369   int64_t rank = shape.size();
1370   if (!(-rank <= splitPoint && splitPoint <= rank))
1371     return failure();
1372   if (splitPoint < 0)
1373     splitPoint += shape.size();
1374   Builder builder(operands[0].getContext());
1375   results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1376   results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1377   return success();
1378 }
1379 
1380 //===----------------------------------------------------------------------===//
1381 // ToExtentTensorOp
1382 //===----------------------------------------------------------------------===//
1383 
1384 OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
1385   if (!operands[0])
1386     return impl::foldCastOp(*this);
1387   Builder builder(getContext());
1388   auto shape = llvm::to_vector<6>(
1389       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
1390   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1391                                     builder.getIndexType());
1392   return DenseIntElementsAttr::get(type, shape);
1393 }
1394 
1395 //===----------------------------------------------------------------------===//
1396 // ReduceOp
1397 //===----------------------------------------------------------------------===//
1398 
1399 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1400                      ValueRange initVals) {
1401   result.addOperands(shape);
1402   result.addOperands(initVals);
1403 
1404   Region *bodyRegion = result.addRegion();
1405   bodyRegion->push_back(new Block);
1406   Block &bodyBlock = bodyRegion->front();
1407   bodyBlock.addArgument(builder.getIndexType());
1408 
1409   Type elementType;
1410   if (auto tensorType = shape.getType().dyn_cast<TensorType>())
1411     elementType = tensorType.getElementType();
1412   else
1413     elementType = SizeType::get(builder.getContext());
1414   bodyBlock.addArgument(elementType);
1415 
1416   for (Type initValType : initVals.getTypes()) {
1417     bodyBlock.addArgument(initValType);
1418     result.addTypes(initValType);
1419   }
1420 }
1421 
1422 static LogicalResult verify(ReduceOp op) {
1423   // Verify block arg types.
1424   Block &block = op.region().front();
1425 
1426   // The block takes index, extent, and aggregated values as arguments.
1427   auto blockArgsCount = op.initVals().size() + 2;
1428   if (block.getNumArguments() != blockArgsCount)
1429     return op.emitOpError() << "ReduceOp body is expected to have "
1430                             << blockArgsCount << " arguments";
1431 
1432   // The first block argument is the index and must always be of type `index`.
1433   if (!block.getArgument(0).getType().isa<IndexType>())
1434     return op.emitOpError(
1435         "argument 0 of ReduceOp body is expected to be of IndexType");
1436 
1437   // The second block argument is the extent and must be of type `size` or
1438   // `index`, depending on whether the reduce operation is applied to a shape or
1439   // to an extent tensor.
1440   Type extentTy = block.getArgument(1).getType();
1441   if (op.shape().getType().isa<ShapeType>()) {
1442     if (!extentTy.isa<SizeType>())
1443       return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
1444                             "SizeType if the ReduceOp operates on a ShapeType");
1445   } else {
1446     if (!extentTy.isa<IndexType>())
1447       return op.emitOpError(
1448           "argument 1 of ReduceOp body is expected to be of IndexType if the "
1449           "ReduceOp operates on an extent tensor");
1450   }
1451 
1452   for (auto type : llvm::enumerate(op.initVals()))
1453     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
1454       return op.emitOpError()
1455              << "type mismatch between argument " << type.index() + 2
1456              << " of ReduceOp body and initial value " << type.index();
1457   return success();
1458 }
1459 
1460 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
1461   // Parse operands.
1462   SmallVector<OpAsmParser::OperandType, 3> operands;
1463   Type shapeOrExtentTensorType;
1464   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
1465                               OpAsmParser::Delimiter::Paren) ||
1466       parser.parseColonType(shapeOrExtentTensorType) ||
1467       parser.parseOptionalArrowTypeList(result.types))
1468     return failure();
1469 
1470   // Resolve operands.
1471   auto initVals = llvm::makeArrayRef(operands).drop_front();
1472   if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
1473                             result.operands) ||
1474       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1475                              result.operands))
1476     return failure();
1477 
1478   // Parse the body.
1479   Region *body = result.addRegion();
1480   if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
1481     return failure();
1482 
1483   // Parse attributes.
1484   if (parser.parseOptionalAttrDict(result.attributes))
1485     return failure();
1486 
1487   return success();
1488 }
1489 
1490 static void print(OpAsmPrinter &p, ReduceOp op) {
1491   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
1492     << ") : " << op.shape().getType();
1493   p.printOptionalArrowTypeList(op.getResultTypes());
1494   p.printRegion(op.region());
1495   p.printOptionalAttrDict(op->getAttrs());
1496 }
1497 
1498 #define GET_OP_CLASSES
1499 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1500